fix EM support
This commit is contained in:
		@@ -1,5 +1,7 @@
 | 
			
		||||
 | 
			
		||||
:- module(jt, [jt/3]).
 | 
			
		||||
:- module(jt, [jt/3,
 | 
			
		||||
	       init_jt_solver/4,
 | 
			
		||||
	       run_jt_solver/3]).
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
:- use_module(library(dgraphs),
 | 
			
		||||
@@ -51,6 +53,9 @@
 | 
			
		||||
:- use_module(library(lists),
 | 
			
		||||
	      [reverse/2]).
 | 
			
		||||
 | 
			
		||||
:- use_module(library('clpbn/aggregates'),
 | 
			
		||||
	      [check_for_agg_vars/2]).
 | 
			
		||||
 | 
			
		||||
:- use_module(library('clpbn/dists'),
 | 
			
		||||
	      [get_dist_domain_size/2,
 | 
			
		||||
	       get_dist_domain/2,
 | 
			
		||||
@@ -72,19 +77,43 @@
 | 
			
		||||
:- use_module(library('clpbn/display'), [
 | 
			
		||||
	clpbn_bind_vals/3]).
 | 
			
		||||
 | 
			
		||||
:- use_module(library('clpbn/connected'),
 | 
			
		||||
	      [
 | 
			
		||||
	       init_influences/3,
 | 
			
		||||
	       influences/5
 | 
			
		||||
	      ]).
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
jt([[]],_,_) :- !.
 | 
			
		||||
jt([LVs],Vs0,AllDiffs) :-
 | 
			
		||||
	get_graph(Vs0, BayesNet, CPTs, Evidence),
 | 
			
		||||
jt(LLVs,Vs0,AllDiffs) :-
 | 
			
		||||
	init_jt_solver(LLVs, Vs0, AllDiffs, State),
 | 
			
		||||
	run_jt_solver(LLVs, LLPs, State),
 | 
			
		||||
	clpbn_bind_vals(LLVs,LLPs,AllDiffs).
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
init_jt_solver(LLVs, Vs0, _, State) :-
 | 
			
		||||
	check_for_agg_vars(Vs0, Vs1), 
 | 
			
		||||
	init_influences(Vs1, G, RG),
 | 
			
		||||
	init_jt_solver_for_questions(LLVs, G, RG, State).
 | 
			
		||||
 | 
			
		||||
init_jt_solver_for_questions([], _, _, []).
 | 
			
		||||
init_jt_solver_for_questions([LLVs|MoreLLVs], G, RG, [state(JTree, Evidence)|State]) :-
 | 
			
		||||
	influences(LLVs, _, NVs0, G, RG),
 | 
			
		||||
	sort(NVs0, NVs),
 | 
			
		||||
	get_graph(NVs, BayesNet, CPTs, Evidence),
 | 
			
		||||
	build_jt(BayesNet, CPTs, JTree),
 | 
			
		||||
	init_jt_solver_for_questions(MoreLLVs, G, RG, State).
 | 
			
		||||
 | 
			
		||||
run_jt_solver([], [], []).
 | 
			
		||||
run_jt_solver([LVs|MoreLVs], [LPs|MorePs], [state(JTree, Evidence)|MoreState]) :-
 | 
			
		||||
	% JTree is a dgraph
 | 
			
		||||
	% now our tree has cpts
 | 
			
		||||
	fill_with_cpts(JTree, NewTree),
 | 
			
		||||
%	write_tree(NewTree,0),
 | 
			
		||||
	propagate_evidence(Evidence, NewTree, EvTree),
 | 
			
		||||
	message_passing(EvTree, MTree),
 | 
			
		||||
	get_margins(MTree, LVs, LPs),
 | 
			
		||||
	clpbn_bind_vals([LVs],[LPs],AllDiffs).
 | 
			
		||||
 | 
			
		||||
	get_margin(MTree, LVs, LPs),
 | 
			
		||||
	run_jt_solver(MoreLVs, MorePs, MoreState).
 | 
			
		||||
 | 
			
		||||
get_graph(LVs, BayesNet, CPTs, Evidence) :-
 | 
			
		||||
	run_vars(LVs, Edges, Vertices, CPTs, Evidence),
 | 
			
		||||
@@ -462,12 +491,7 @@ downward([tree(Clique1-(Dist1,Msg1),DistKids)|Kids], Clique, Tab, [tree(Clique1-
 | 
			
		||||
	downward(Kids, Clique, Tab, NKids).
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
get_margins(_, [], []).
 | 
			
		||||
get_margins(NewTree, [Vs|LVs], [LP|LPs]) :-
 | 
			
		||||
	get_margin(NewTree, Vs, LP),
 | 
			
		||||
	get_margins(NewTree, LVs, LPs).
 | 
			
		||||
 | 
			
		||||
get_margin(NewTree, LVs, LPs) :-
 | 
			
		||||
get_margin(NewTree, LVs0, LPs) :-
 | 
			
		||||
	sort(LVs0, LVs),
 | 
			
		||||
	find_clique(NewTree, LVs, Clique, Dist),
 | 
			
		||||
	sum_out_from_CPT(LVs, Dist, Clique, tab(TAB,_,_)),
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user