make EM work with PFL and BP.
This commit is contained in:
@@ -24,6 +24,15 @@
|
||||
randomise_all_dists/0,
|
||||
uniformise_all_dists/0]).
|
||||
|
||||
:- use_module(library(clpbn/ground_factors),
|
||||
[generate_network/5,
|
||||
f/3]).
|
||||
|
||||
:- use_module(library(bhash), [
|
||||
b_hash_new/1,
|
||||
b_hash_lookup/3,
|
||||
b_hash_insert/4]).
|
||||
|
||||
:- use_module(library('clpbn/learning/learn_utils'),
|
||||
[run_all/1,
|
||||
clpbn_vars/2,
|
||||
@@ -61,9 +70,11 @@ em(_, _, _, Tables, Likelihood) :-
|
||||
retract(em_found(Tables, Likelihood)).
|
||||
|
||||
|
||||
handle_em(error(repeated_parents)) :-
|
||||
handle_em(error(repeated_parents)) :- !,
|
||||
assert(em_found(_, -inf)),
|
||||
fail.
|
||||
handle_em(Error) :-
|
||||
throw(Error).
|
||||
|
||||
% This gets you an initial configuration. If there is a lot of evidence
|
||||
% tables may be filled in close to optimal, otherwise they may be
|
||||
@@ -74,7 +85,7 @@ handle_em(error(repeated_parents)) :-
|
||||
% it includes the list of variables without evidence,
|
||||
% the list of distributions for which we want to compute parameters,
|
||||
% and more detailed info on distributions, namely with a list of all instances for the distribution.
|
||||
init_em(Items, state( AllDists, AllDistInstances, MargVars, SolverVars)) :-
|
||||
init_em(Items, State) :-
|
||||
clpbn_flag(em_solver, Solver),
|
||||
% only used for PCGs
|
||||
clpbn_init_graph(Solver),
|
||||
@@ -83,12 +94,27 @@ init_em(Items, state( AllDists, AllDistInstances, MargVars, SolverVars)) :-
|
||||
% randomise_all_dists,
|
||||
% set initial values for distributions
|
||||
uniformise_all_dists,
|
||||
% get all variablews to marginalise
|
||||
setup_em_network(Solver, State).
|
||||
|
||||
setup_em_network(Solver, state( AllDists, AllDistInstances, MargVars, SolverState)) :-
|
||||
clpbn:use_parfactors(on), !,
|
||||
% get all variables to marginalise
|
||||
attributes:all_attvars(AllVars0),
|
||||
% and order them
|
||||
sort_vars_by_key(AllVars0,AllVars,[]),
|
||||
% no, we are in trouble because we don't know the network yet.
|
||||
% get the ground network
|
||||
generate_network([AllVars], _, Keys, Factors, EList),
|
||||
% get the EM CPT connections info from the factors
|
||||
generate_dists(Factors, EList, AllDists, AllDistInstances, MargVars),
|
||||
% setup solver, if necessary
|
||||
clpbn_init_solver(Solver, MargVars, _AllVars, ground(MargVars, Keys, Factors, EList), SolverState).
|
||||
setup_em_network(Solver, state( AllDists, AllDistInstances, MargVars, SolverVars)) :-
|
||||
% get all variables to marginalise
|
||||
attributes:all_attvars(AllVars0),
|
||||
% and order them
|
||||
sort_vars_by_key(AllVars0,AllVars,[]),
|
||||
% remove variables that do not have to do with this query.
|
||||
% check_for_hidden_vars(AllVars1, AllVars1, AllVars),
|
||||
different_dists(AllVars, AllDists, AllDistInstances, MargVars),
|
||||
% setup solver by doing parameter independent work.
|
||||
clpbn_init_solver(Solver, MargVars, AllVars, _, SolverVars).
|
||||
@@ -97,7 +123,8 @@ init_em(Items, state( AllDists, AllDistInstances, MargVars, SolverVars)) :-
|
||||
em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF, FTables) :-
|
||||
estimate(State, LPs),
|
||||
maximise(State, Tables, LPs, Likelihood),
|
||||
writeln(iteration:Its:Likelihood:Its:Likelihood0:Tables),
|
||||
ltables(Tables, F0Tables),
|
||||
writeln(iteration:Its:Likelihood:Its:Likelihood0:F0Tables),
|
||||
(
|
||||
(
|
||||
abs((Likelihood - Likelihood0)/Likelihood) < MaxError
|
||||
@@ -118,6 +145,41 @@ ltables([Id-T|Tables], [Key-LTable|FTables]) :-
|
||||
get_dist_key(Id, Key),
|
||||
ltables(Tables, FTables).
|
||||
|
||||
|
||||
generate_dists(Factors, EList, AllDists, AllInfo, MargVars) :-
|
||||
b_hash_new(Ev0),
|
||||
elist_to_hash(EList, Ev0, Ev),
|
||||
process_factors(Factors, Ev, Dists0),
|
||||
sort(Dists0, Dists1),
|
||||
group(Dists1, AllDists, AllInfo, MargVars0, []),
|
||||
sort(MargVars0, MargVars).
|
||||
|
||||
elist_to_hash([], Ev, Ev).
|
||||
elist_to_hash([K=V|EList], Ev0, Ev) :-
|
||||
b_hash_insert(Ev0, K, V, Evi),
|
||||
elist_to_hash(EList, Evi, Ev).
|
||||
|
||||
process_factors([], _Ev, []).
|
||||
process_factors([f(bayes,Id,Ks)|Factors], Ev, [i(Id, Ks, Cases, NonEvs)|AllDistInstances]) :-
|
||||
fetch_evidence(Ks, Ev, CompactCases, NonEvs),
|
||||
uncompact_cases(CompactCases, Cases),
|
||||
process_factors(Factors, Ev, AllDistInstances).
|
||||
|
||||
fetch_evidence([], _Ev, [], []).
|
||||
fetch_evidence([K|Ks], Ev, [E|CompactCases], NonEvs) :-
|
||||
b_hash_lookup(K, E, Ev), !,
|
||||
fetch_evidence(Ks, Ev, CompactCases, NonEvs).
|
||||
fetch_evidence([K|Ks], Ev, [Ns|CompactCases], [K|NonEvs]) :-
|
||||
pfl:skolem(K,D),
|
||||
domain_to_numbers(D,0,Ns),
|
||||
fetch_evidence(Ks, Ev, CompactCases, NonEvs).
|
||||
|
||||
domain_to_numbers([],_,[]).
|
||||
domain_to_numbers([_|D],I0,[I0|Ns]) :-
|
||||
I is I0+1,
|
||||
domain_to_numbers(D,I,Ns).
|
||||
|
||||
|
||||
% collect the different dists we are going to learn next.
|
||||
different_dists(AllVars, AllDists, AllInfo, MargVars) :-
|
||||
all_dists(AllVars, AllVars, Dists0),
|
||||
@@ -134,14 +196,6 @@ different_dists(AllVars, AllDists, AllInfo, MargVars) :-
|
||||
% Hiddens will be C
|
||||
%
|
||||
all_dists([], _, []).
|
||||
all_dists([V|AllVars], AllVars0, [i(Id, [V|Parents], Cases, Hiddens)|Dists]) :-
|
||||
clpbn:use_parfactors(on), !,
|
||||
clpbn:get_atts(V, [key(K)]),
|
||||
pfl:factor(bayes,Id,[K|PKeys],_,_,_),
|
||||
find_variables(PKeys, AllVars0, Parents),
|
||||
generate_hidden_cases([V|Parents], CompactCases, Hiddens),
|
||||
uncompact_cases(CompactCases, Cases),
|
||||
all_dists(AllVars, AllVars0, Dists).
|
||||
all_dists([V|AllVars], AllVars0, [i(Id, [V|Parents], Cases, Hiddens)|Dists]) :-
|
||||
% V is an instance of Id
|
||||
clpbn:get_atts(V, [dist(Id,Parents)]),
|
||||
@@ -181,18 +235,12 @@ generate_hidden_cases([], [], []).
|
||||
generate_hidden_cases([V|Parents], [P|Cases], Hiddens) :-
|
||||
clpbn:get_atts(V, [evidence(P)]), !,
|
||||
generate_hidden_cases(Parents, Cases, Hiddens).
|
||||
generate_hidden_cases([V|Parents], [Cases|MoreCases], [V|Hiddens]) :-
|
||||
clpbn:use_parfactors(on), !,
|
||||
clpbn:get_atts(V, [key(K)]),
|
||||
pfl:skolem(K,D), length(D,Sz),
|
||||
gen_cases(0, Sz, Cases),
|
||||
generate_hidden_cases(Parents, MoreCases, Hiddens).
|
||||
generate_hidden_cases([V|Parents], [Cases|MoreCases], [V|Hiddens]) :-
|
||||
clpbn:get_atts(V, [dist(Id,_)]),
|
||||
get_dist_domain_size(Id, Sz),
|
||||
gen_cases(0, Sz, Cases),
|
||||
generate_hidden_cases(Parents, MoreCases, Hiddens).
|
||||
|
||||
|
||||
gen_cases(Sz, Sz, []) :- !.
|
||||
gen_cases(I, Sz, [I|Cases]) :-
|
||||
I1 is I+1,
|
||||
|
Reference in New Issue
Block a user