fix EM support
This commit is contained in:
parent
43365003dc
commit
c81bc96fd0
@ -1,5 +1,7 @@
|
|||||||
|
|
||||||
:- module(jt, [jt/3]).
|
:- module(jt, [jt/3,
|
||||||
|
init_jt_solver/4,
|
||||||
|
run_jt_solver/3]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(library(dgraphs),
|
:- use_module(library(dgraphs),
|
||||||
@ -51,6 +53,9 @@
|
|||||||
:- use_module(library(lists),
|
:- use_module(library(lists),
|
||||||
[reverse/2]).
|
[reverse/2]).
|
||||||
|
|
||||||
|
:- use_module(library('clpbn/aggregates'),
|
||||||
|
[check_for_agg_vars/2]).
|
||||||
|
|
||||||
:- use_module(library('clpbn/dists'),
|
:- use_module(library('clpbn/dists'),
|
||||||
[get_dist_domain_size/2,
|
[get_dist_domain_size/2,
|
||||||
get_dist_domain/2,
|
get_dist_domain/2,
|
||||||
@ -72,19 +77,43 @@
|
|||||||
:- use_module(library('clpbn/display'), [
|
:- use_module(library('clpbn/display'), [
|
||||||
clpbn_bind_vals/3]).
|
clpbn_bind_vals/3]).
|
||||||
|
|
||||||
|
:- use_module(library('clpbn/connected'),
|
||||||
|
[
|
||||||
|
init_influences/3,
|
||||||
|
influences/5
|
||||||
|
]).
|
||||||
|
|
||||||
|
|
||||||
jt([[]],_,_) :- !.
|
jt([[]],_,_) :- !.
|
||||||
jt([LVs],Vs0,AllDiffs) :-
|
jt(LLVs,Vs0,AllDiffs) :-
|
||||||
get_graph(Vs0, BayesNet, CPTs, Evidence),
|
init_jt_solver(LLVs, Vs0, AllDiffs, State),
|
||||||
|
run_jt_solver(LLVs, LLPs, State),
|
||||||
|
clpbn_bind_vals(LLVs,LLPs,AllDiffs).
|
||||||
|
|
||||||
|
|
||||||
|
init_jt_solver(LLVs, Vs0, _, State) :-
|
||||||
|
check_for_agg_vars(Vs0, Vs1),
|
||||||
|
init_influences(Vs1, G, RG),
|
||||||
|
init_jt_solver_for_questions(LLVs, G, RG, State).
|
||||||
|
|
||||||
|
init_jt_solver_for_questions([], _, _, []).
|
||||||
|
init_jt_solver_for_questions([LLVs|MoreLLVs], G, RG, [state(JTree, Evidence)|State]) :-
|
||||||
|
influences(LLVs, _, NVs0, G, RG),
|
||||||
|
sort(NVs0, NVs),
|
||||||
|
get_graph(NVs, BayesNet, CPTs, Evidence),
|
||||||
build_jt(BayesNet, CPTs, JTree),
|
build_jt(BayesNet, CPTs, JTree),
|
||||||
|
init_jt_solver_for_questions(MoreLLVs, G, RG, State).
|
||||||
|
|
||||||
|
run_jt_solver([], [], []).
|
||||||
|
run_jt_solver([LVs|MoreLVs], [LPs|MorePs], [state(JTree, Evidence)|MoreState]) :-
|
||||||
% JTree is a dgraph
|
% JTree is a dgraph
|
||||||
% now our tree has cpts
|
% now our tree has cpts
|
||||||
fill_with_cpts(JTree, NewTree),
|
fill_with_cpts(JTree, NewTree),
|
||||||
% write_tree(NewTree,0),
|
% write_tree(NewTree,0),
|
||||||
propagate_evidence(Evidence, NewTree, EvTree),
|
propagate_evidence(Evidence, NewTree, EvTree),
|
||||||
message_passing(EvTree, MTree),
|
message_passing(EvTree, MTree),
|
||||||
get_margins(MTree, LVs, LPs),
|
get_margin(MTree, LVs, LPs),
|
||||||
clpbn_bind_vals([LVs],[LPs],AllDiffs).
|
run_jt_solver(MoreLVs, MorePs, MoreState).
|
||||||
|
|
||||||
|
|
||||||
get_graph(LVs, BayesNet, CPTs, Evidence) :-
|
get_graph(LVs, BayesNet, CPTs, Evidence) :-
|
||||||
run_vars(LVs, Edges, Vertices, CPTs, Evidence),
|
run_vars(LVs, Edges, Vertices, CPTs, Evidence),
|
||||||
@ -462,12 +491,7 @@ downward([tree(Clique1-(Dist1,Msg1),DistKids)|Kids], Clique, Tab, [tree(Clique1-
|
|||||||
downward(Kids, Clique, Tab, NKids).
|
downward(Kids, Clique, Tab, NKids).
|
||||||
|
|
||||||
|
|
||||||
get_margins(_, [], []).
|
get_margin(NewTree, LVs0, LPs) :-
|
||||||
get_margins(NewTree, [Vs|LVs], [LP|LPs]) :-
|
|
||||||
get_margin(NewTree, Vs, LP),
|
|
||||||
get_margins(NewTree, LVs, LPs).
|
|
||||||
|
|
||||||
get_margin(NewTree, LVs, LPs) :-
|
|
||||||
sort(LVs0, LVs),
|
sort(LVs0, LVs),
|
||||||
find_clique(NewTree, LVs, Clique, Dist),
|
find_clique(NewTree, LVs, Clique, Dist),
|
||||||
sum_out_from_CPT(LVs, Dist, Clique, tab(TAB,_,_)),
|
sum_out_from_CPT(LVs, Dist, Clique, tab(TAB,_,_)),
|
||||||
|
Reference in New Issue
Block a user