fix EM support
This commit is contained in:
		@@ -1,5 +1,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
:- module(jt, [jt/3]).
 | 
					:- module(jt, [jt/3,
 | 
				
			||||||
 | 
						       init_jt_solver/4,
 | 
				
			||||||
 | 
						       run_jt_solver/3]).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
:- use_module(library(dgraphs),
 | 
					:- use_module(library(dgraphs),
 | 
				
			||||||
@@ -51,6 +53,9 @@
 | 
				
			|||||||
:- use_module(library(lists),
 | 
					:- use_module(library(lists),
 | 
				
			||||||
	      [reverse/2]).
 | 
						      [reverse/2]).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:- use_module(library('clpbn/aggregates'),
 | 
				
			||||||
 | 
						      [check_for_agg_vars/2]).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
:- use_module(library('clpbn/dists'),
 | 
					:- use_module(library('clpbn/dists'),
 | 
				
			||||||
	      [get_dist_domain_size/2,
 | 
						      [get_dist_domain_size/2,
 | 
				
			||||||
	       get_dist_domain/2,
 | 
						       get_dist_domain/2,
 | 
				
			||||||
@@ -72,19 +77,43 @@
 | 
				
			|||||||
:- use_module(library('clpbn/display'), [
 | 
					:- use_module(library('clpbn/display'), [
 | 
				
			||||||
	clpbn_bind_vals/3]).
 | 
						clpbn_bind_vals/3]).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:- use_module(library('clpbn/connected'),
 | 
				
			||||||
 | 
						      [
 | 
				
			||||||
 | 
						       init_influences/3,
 | 
				
			||||||
 | 
						       influences/5
 | 
				
			||||||
 | 
						      ]).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
jt([[]],_,_) :- !.
 | 
					jt([[]],_,_) :- !.
 | 
				
			||||||
jt([LVs],Vs0,AllDiffs) :-
 | 
					jt(LLVs,Vs0,AllDiffs) :-
 | 
				
			||||||
	get_graph(Vs0, BayesNet, CPTs, Evidence),
 | 
						init_jt_solver(LLVs, Vs0, AllDiffs, State),
 | 
				
			||||||
 | 
						run_jt_solver(LLVs, LLPs, State),
 | 
				
			||||||
 | 
						clpbn_bind_vals(LLVs,LLPs,AllDiffs).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					init_jt_solver(LLVs, Vs0, _, State) :-
 | 
				
			||||||
 | 
						check_for_agg_vars(Vs0, Vs1), 
 | 
				
			||||||
 | 
						init_influences(Vs1, G, RG),
 | 
				
			||||||
 | 
						init_jt_solver_for_questions(LLVs, G, RG, State).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					init_jt_solver_for_questions([], _, _, []).
 | 
				
			||||||
 | 
					init_jt_solver_for_questions([LLVs|MoreLLVs], G, RG, [state(JTree, Evidence)|State]) :-
 | 
				
			||||||
 | 
						influences(LLVs, _, NVs0, G, RG),
 | 
				
			||||||
 | 
						sort(NVs0, NVs),
 | 
				
			||||||
 | 
						get_graph(NVs, BayesNet, CPTs, Evidence),
 | 
				
			||||||
	build_jt(BayesNet, CPTs, JTree),
 | 
						build_jt(BayesNet, CPTs, JTree),
 | 
				
			||||||
 | 
						init_jt_solver_for_questions(MoreLLVs, G, RG, State).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					run_jt_solver([], [], []).
 | 
				
			||||||
 | 
					run_jt_solver([LVs|MoreLVs], [LPs|MorePs], [state(JTree, Evidence)|MoreState]) :-
 | 
				
			||||||
	% JTree is a dgraph
 | 
						% JTree is a dgraph
 | 
				
			||||||
	% now our tree has cpts
 | 
						% now our tree has cpts
 | 
				
			||||||
	fill_with_cpts(JTree, NewTree),
 | 
						fill_with_cpts(JTree, NewTree),
 | 
				
			||||||
%	write_tree(NewTree,0),
 | 
					%	write_tree(NewTree,0),
 | 
				
			||||||
	propagate_evidence(Evidence, NewTree, EvTree),
 | 
						propagate_evidence(Evidence, NewTree, EvTree),
 | 
				
			||||||
	message_passing(EvTree, MTree),
 | 
						message_passing(EvTree, MTree),
 | 
				
			||||||
	get_margins(MTree, LVs, LPs),
 | 
						get_margin(MTree, LVs, LPs),
 | 
				
			||||||
	clpbn_bind_vals([LVs],[LPs],AllDiffs).
 | 
						run_jt_solver(MoreLVs, MorePs, MoreState).
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
get_graph(LVs, BayesNet, CPTs, Evidence) :-
 | 
					get_graph(LVs, BayesNet, CPTs, Evidence) :-
 | 
				
			||||||
	run_vars(LVs, Edges, Vertices, CPTs, Evidence),
 | 
						run_vars(LVs, Edges, Vertices, CPTs, Evidence),
 | 
				
			||||||
@@ -462,12 +491,7 @@ downward([tree(Clique1-(Dist1,Msg1),DistKids)|Kids], Clique, Tab, [tree(Clique1-
 | 
				
			|||||||
	downward(Kids, Clique, Tab, NKids).
 | 
						downward(Kids, Clique, Tab, NKids).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
get_margins(_, [], []).
 | 
					get_margin(NewTree, LVs0, LPs) :-
 | 
				
			||||||
get_margins(NewTree, [Vs|LVs], [LP|LPs]) :-
 | 
					 | 
				
			||||||
	get_margin(NewTree, Vs, LP),
 | 
					 | 
				
			||||||
	get_margins(NewTree, LVs, LPs).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
get_margin(NewTree, LVs, LPs) :-
 | 
					 | 
				
			||||||
	sort(LVs0, LVs),
 | 
						sort(LVs0, LVs),
 | 
				
			||||||
	find_clique(NewTree, LVs, Clique, Dist),
 | 
						find_clique(NewTree, LVs, Clique, Dist),
 | 
				
			||||||
	sum_out_from_CPT(LVs, Dist, Clique, tab(TAB,_,_)),
 | 
						sum_out_from_CPT(LVs, Dist, Clique, tab(TAB,_,_)),
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user