fix EM support

This commit is contained in:
Vitor Santos Costa 2008-11-04 03:30:12 +00:00
parent 43365003dc
commit c81bc96fd0

View File

@ -1,5 +1,7 @@
:- module(jt, [jt/3]).
:- module(jt, [jt/3,
init_jt_solver/4,
run_jt_solver/3]).
:- use_module(library(dgraphs),
@ -51,6 +53,9 @@
:- use_module(library(lists),
[reverse/2]).
:- use_module(library('clpbn/aggregates'),
[check_for_agg_vars/2]).
:- use_module(library('clpbn/dists'),
[get_dist_domain_size/2,
get_dist_domain/2,
@ -72,19 +77,43 @@
:- use_module(library('clpbn/display'), [
clpbn_bind_vals/3]).
:- use_module(library('clpbn/connected'),
[
init_influences/3,
influences/5
]).
jt([[]],_,_) :- !.
jt([LVs],Vs0,AllDiffs) :-
get_graph(Vs0, BayesNet, CPTs, Evidence),
jt(LLVs,Vs0,AllDiffs) :-
init_jt_solver(LLVs, Vs0, AllDiffs, State),
run_jt_solver(LLVs, LLPs, State),
clpbn_bind_vals(LLVs,LLPs,AllDiffs).
init_jt_solver(LLVs, Vs0, _, State) :-
check_for_agg_vars(Vs0, Vs1),
init_influences(Vs1, G, RG),
init_jt_solver_for_questions(LLVs, G, RG, State).
init_jt_solver_for_questions([], _, _, []).
init_jt_solver_for_questions([LLVs|MoreLLVs], G, RG, [state(JTree, Evidence)|State]) :-
influences(LLVs, _, NVs0, G, RG),
sort(NVs0, NVs),
get_graph(NVs, BayesNet, CPTs, Evidence),
build_jt(BayesNet, CPTs, JTree),
init_jt_solver_for_questions(MoreLLVs, G, RG, State).
run_jt_solver([], [], []).
run_jt_solver([LVs|MoreLVs], [LPs|MorePs], [state(JTree, Evidence)|MoreState]) :-
% JTree is a dgraph
% now our tree has cpts
fill_with_cpts(JTree, NewTree),
% write_tree(NewTree,0),
propagate_evidence(Evidence, NewTree, EvTree),
message_passing(EvTree, MTree),
get_margins(MTree, LVs, LPs),
clpbn_bind_vals([LVs],[LPs],AllDiffs).
get_margin(MTree, LVs, LPs),
run_jt_solver(MoreLVs, MorePs, MoreState).
get_graph(LVs, BayesNet, CPTs, Evidence) :-
run_vars(LVs, Edges, Vertices, CPTs, Evidence),
@ -462,12 +491,7 @@ downward([tree(Clique1-(Dist1,Msg1),DistKids)|Kids], Clique, Tab, [tree(Clique1-
downward(Kids, Clique, Tab, NKids).
get_margins(_, [], []).
get_margins(NewTree, [Vs|LVs], [LP|LPs]) :-
get_margin(NewTree, Vs, LP),
get_margins(NewTree, LVs, LPs).
get_margin(NewTree, LVs, LPs) :-
get_margin(NewTree, LVs0, LPs) :-
sort(LVs0, LVs),
find_clique(NewTree, LVs, Clique, Dist),
sum_out_from_CPT(LVs, Dist, Clique, tab(TAB,_,_)),