fix EM support
This commit is contained in:
parent
43365003dc
commit
c81bc96fd0
@ -1,5 +1,7 @@
|
||||
|
||||
:- module(jt, [jt/3]).
|
||||
:- module(jt, [jt/3,
|
||||
init_jt_solver/4,
|
||||
run_jt_solver/3]).
|
||||
|
||||
|
||||
:- use_module(library(dgraphs),
|
||||
@ -51,6 +53,9 @@
|
||||
:- use_module(library(lists),
|
||||
[reverse/2]).
|
||||
|
||||
:- use_module(library('clpbn/aggregates'),
|
||||
[check_for_agg_vars/2]).
|
||||
|
||||
:- use_module(library('clpbn/dists'),
|
||||
[get_dist_domain_size/2,
|
||||
get_dist_domain/2,
|
||||
@ -72,19 +77,43 @@
|
||||
:- use_module(library('clpbn/display'), [
|
||||
clpbn_bind_vals/3]).
|
||||
|
||||
:- use_module(library('clpbn/connected'),
|
||||
[
|
||||
init_influences/3,
|
||||
influences/5
|
||||
]).
|
||||
|
||||
|
||||
jt([[]],_,_) :- !.
|
||||
jt([LVs],Vs0,AllDiffs) :-
|
||||
get_graph(Vs0, BayesNet, CPTs, Evidence),
|
||||
jt(LLVs,Vs0,AllDiffs) :-
|
||||
init_jt_solver(LLVs, Vs0, AllDiffs, State),
|
||||
run_jt_solver(LLVs, LLPs, State),
|
||||
clpbn_bind_vals(LLVs,LLPs,AllDiffs).
|
||||
|
||||
|
||||
init_jt_solver(LLVs, Vs0, _, State) :-
|
||||
check_for_agg_vars(Vs0, Vs1),
|
||||
init_influences(Vs1, G, RG),
|
||||
init_jt_solver_for_questions(LLVs, G, RG, State).
|
||||
|
||||
init_jt_solver_for_questions([], _, _, []).
|
||||
init_jt_solver_for_questions([LLVs|MoreLLVs], G, RG, [state(JTree, Evidence)|State]) :-
|
||||
influences(LLVs, _, NVs0, G, RG),
|
||||
sort(NVs0, NVs),
|
||||
get_graph(NVs, BayesNet, CPTs, Evidence),
|
||||
build_jt(BayesNet, CPTs, JTree),
|
||||
init_jt_solver_for_questions(MoreLLVs, G, RG, State).
|
||||
|
||||
run_jt_solver([], [], []).
|
||||
run_jt_solver([LVs|MoreLVs], [LPs|MorePs], [state(JTree, Evidence)|MoreState]) :-
|
||||
% JTree is a dgraph
|
||||
% now our tree has cpts
|
||||
fill_with_cpts(JTree, NewTree),
|
||||
% write_tree(NewTree,0),
|
||||
propagate_evidence(Evidence, NewTree, EvTree),
|
||||
message_passing(EvTree, MTree),
|
||||
get_margins(MTree, LVs, LPs),
|
||||
clpbn_bind_vals([LVs],[LPs],AllDiffs).
|
||||
|
||||
get_margin(MTree, LVs, LPs),
|
||||
run_jt_solver(MoreLVs, MorePs, MoreState).
|
||||
|
||||
get_graph(LVs, BayesNet, CPTs, Evidence) :-
|
||||
run_vars(LVs, Edges, Vertices, CPTs, Evidence),
|
||||
@ -462,12 +491,7 @@ downward([tree(Clique1-(Dist1,Msg1),DistKids)|Kids], Clique, Tab, [tree(Clique1-
|
||||
downward(Kids, Clique, Tab, NKids).
|
||||
|
||||
|
||||
get_margins(_, [], []).
|
||||
get_margins(NewTree, [Vs|LVs], [LP|LPs]) :-
|
||||
get_margin(NewTree, Vs, LP),
|
||||
get_margins(NewTree, LVs, LPs).
|
||||
|
||||
get_margin(NewTree, LVs, LPs) :-
|
||||
get_margin(NewTree, LVs0, LPs) :-
|
||||
sort(LVs0, LVs),
|
||||
find_clique(NewTree, LVs, Clique, Dist),
|
||||
sum_out_from_CPT(LVs, Dist, Clique, tab(TAB,_,_)),
|
||||
|
Reference in New Issue
Block a user