implementation of bayes ball.

This commit is contained in:
Vitor Santos Costa 2011-05-27 21:34:55 +01:00
parent a57cd039d8
commit e9171547b9
7 changed files with 136 additions and 171 deletions

View File

@ -210,27 +210,21 @@ delete_remaining_edges(SortedVs,Vs0,Vsf) :-
dgraph_transpose(Graph, TGraph) :- dgraph_transpose(Graph, TGraph) :-
rb_visit(Graph, Edges), rb_visit(Graph, Edges),
rb_clone(Graph, TGraph, NewNodes), transpose(Edges, Nodes, TEdges, []),
tedges(Edges,UnsortedTEdges), dgraph_new(G0),
sort(UnsortedTEdges,TEdges), % make sure we have all vertices, even if they are unconnected.
fill_nodes(NewNodes,TEdges). dgraph_add_vertices(G0, Nodes, G1),
dgraph_add_edges(G1, TEdges, TGraph).
tedges([],[]). transpose([], []) --> [].
tedges([V-Vs|Edges],TEdges) :- transpose([V-Edges|MoreVs], [V|Vs]) -->
fill_tedges(Vs, V, TEdges, TEdges0), transpose_edges(Edges, V),
tedges(Edges,TEdges0). transpose(MoreVs, Vs).
fill_tedges([], _, TEdges, TEdges). transpose_edges([], _V) --> [].
fill_tedges([V1|Vs], V, [V1-V|TEdges], TEdges0) :- transpose_edges(E.Edges, V) -->
fill_tedges(Vs, V, TEdges, TEdges0). [E-V],
transpose_edges(Edges, V).
fill_nodes([],[]).
fill_nodes([V-[Child|MoreChildren]|Nodes],[V-Child|Edges]) :- !,
get_extra_children(Edges,V,MoreChildren,REdges),
fill_nodes(Nodes,REdges).
fill_nodes([_-[]|Edges],TEdges) :-
fill_nodes(Edges,TEdges).
dgraph_compose(T1,T2,CT) :- dgraph_compose(T1,T2,CT) :-
rb_visit(T1,Nodes), rb_visit(T1,Nodes),

View File

@ -1,172 +1,144 @@
:- module(clpbn_connected, :- module(clpbn_connected,
[clpbn_subgraphs/2, [influences/3,
influences/4,
init_influences/3, init_influences/3,
influences/5]). influences/4]).
:- use_module(library(dgraphs), :- use_module(library(dgraphs),
[dgraph_new/1, [dgraph_new/1,
dgraph_add_edges/3, dgraph_add_edges/3,
dgraph_add_vertex/3, dgraph_add_vertex/3,
dgraph_neighbors/3, dgraph_neighbors/3,
dgraph_edge/3]). dgraph_edge/3,
dgraph_transpose/2]).
:- use_module(library(rbtrees), :- use_module(library(rbtrees),
[rb_new/1, [rb_new/1,
rb_lookup/3,
rb_insert/4, rb_insert/4,
rb_lookup/3]). rb_visit/2]).
:- attribute component/1. influences(Vs, QVars, LV) :-
% search for connected components, that is, where we know that A influences B or B influences A.
clpbn_subgraphs(Vs, Gs) :-
mark_components(Vs, Components),
keysort(Components, Ordered),
same_key(Ordered, Gs).
% ignore variables with evidence,
% the others mark the MB.
mark_components([], []).
mark_components([V|Vs], Components) :-
clpbn:get_atts(V, [evidence(_),dist(_,Parents)]), !,
merge_parents(Parents, _),
mark_components(Vs, Components).
mark_components([V|Vs], [Mark-V|Components]) :-
mark_var(V, Mark),
mark_components(Vs, Components).
mark_var(V, Mark) :-
get_atts(V, [component(Mark)]), !,
clpbn:get_atts(V, [dist(_,Parents)]), !,
merge_parents(Parents, Mark).
mark_var(V, Mark) :-
clpbn:get_atts(V, [dist(_,Parents)]), !,
put_atts(V,[component(Mark)]),
merge_parents(Parents, Mark).
merge_parents([], _).
merge_parents([V|Parents], Mark) :-
clpbn:get_atts(V,[evidence(_)]), !,
merge_parents(Parents, Mark).
merge_parents([V|Parents], Mark) :-
get_atts(V,[component(Mark)]), !,
merge_parents(Parents, Mark).
merge_parents([V|Parents], Mark) :-
put_atts(V,[component(Mark)]),
merge_parents(Parents, Mark).
same_key([],[]).
same_key([K-El|More],[[El|Els]|Gs]) :-
same_keys(More, K, Els, Rest),
same_key(Rest,Gs).
same_keys([], _, [], []).
same_keys([K1-El|More], K, [El|Els], Rest) :-
K == K1, !,
same_keys(More, K, Els, Rest).
same_keys(Rest, _, [], Rest).
influences_more([], _, _, Is, Is, Evs, Evs, V2, V2).
influences_more([V|LV], G, RG, Is0, Is, Evs0, Evs, GV0, GV2) :-
rb_lookup(V, _, GV0), !,
influences_more(LV, G, RG, Is0, Is, Evs0, Evs, GV0, GV2).
influences_more([V|LV], G, RG, Is0, Is, Evs0, Evs, GV0, GV3) :-
rb_insert(GV0, V, _, GV1),
follow_dgraph(V, G, RG, [V|Is0], Is1, [V|Evs0], Evs1, GV1, GV2),
influences_more(LV, G, RG, Is1, Is, Evs1, Evs, GV2, GV3).
% search for the set of variables that influence V
influences(Vs, LV, Is, Evs) :-
init_influences(Vs, G, RG), init_influences(Vs, G, RG),
influences(LV, Is, Evs, G, RG). influences(QVars, G, RG, LV).
init_influences(Vs, G, RG) :- init_influences(Vs, G, RG) :-
dgraph_new(G0), dgraph_new(G0),
dgraph_new(RG0), to_dgraph(Vs, G0, G),
to_dgraph(Vs, G0, G, RG0, RG). dgraph_transpose(G, RG).
influences([], [], [], _, _). to_dgraph([], G, G).
influences([V|LV], Is, Evs, G, RG) :- to_dgraph([V|Vs], G0, G) :-
rb_new(V0), clpbn:get_atts(V, [dist(_,Parents)]), !,
rb_insert(V0, V, _, V1), dgraph_add_vertex(G0, V, G00),
follow_dgraph(V, G, RG, [V], Is1, [V], Evs1, V1, V2), build_edges(Parents, V, Edges),
influences_more(LV, G, RG, Is1, Is, Evs1, Evs, V2, _). dgraph_add_edges(G00, Edges, G1),
to_dgraph(Vs, G1, G).
to_dgraph([], G, G, RG, RG). build_edges([], _, []).
to_dgraph([V|Vs], G0, G, RG0, RG) :- build_edges([P|Parents], V, [P-V|Edges]) :-
clpbn:get_atts(V, [evidence(_),dist(_,Parents)]), !, build_edges(Parents, V, Edges).
build_edges(Parents, V, Edges, REdges),
dgraph_add_edges(G0,[V-e|Edges],G1),
dgraph_add_edges(RG0,REdges,RG1),
to_dgraph(Vs, G1, G, RG1, RG).
to_dgraph([V|Vs], G0, G, RG0, RG) :-
clpbn:get_atts(V, [dist(_,Parents)]),
build_edges(Parents, V, Edges, REdges),
dgraph_add_vertex(G0,V,G1),
dgraph_add_edges(G1, Edges, G2),
dgraph_add_vertex(RG0,V,RG1),
dgraph_add_edges(RG1, REdges, RG2),
to_dgraph(Vs, G2, G, RG2, RG).
% search for the set of variables that influence V
influences(Vs, G, RG, Vars) :-
rb_new(Visited0),
influences(Vs, G, RG, Visited0, Visited),
all_top(Visited, Vars),
length(Vars,Leng), writeln(done:Leng).
build_edges([], _, [], []). influences([], _, _, Visited, Visited).
build_edges([P|Parents], V, [P-V|Edges], [V-P|REdges]) :- influences([V|LV], G, RG, Vs, NVs) :-
build_edges(Parents, V, Edges, REdges). rb_lookup(V, T.B, Vs), T == t, B == b, !,
influences(LV, G, RG, Vs, NVs).
influences([V|LV], G, RG, Vs0, Vs3) :-
rb_insert(Vs0, V, t.b, Vs1),
process_new_variable(V, G, RG, Vs1, Vs2),
influences(LV, G, RG, Vs2, Vs3).
follow_dgraph(V, G, RG, Is0, IsF, Evs0, EvsF, Visited0, Visited) :- process_new_variable(V, _G, _RG, _Vs0, _Vs1) :-
clpbn:get_atts(V,[evidence(Ev)]), !,
throw(error(bound_to_evidence(V/Ev))).
process_new_variable(V, G, RG, Vs0, Vs2) :-
dgraph_neighbors(V, G, Children),
throw_all_below(Children, G, RG, Vs0, Vs1),
dgraph_neighbors(V, RG, Parents), dgraph_neighbors(V, RG, Parents),
add_parents(Parents, G, RG, Is0, IsI, Evs0, EvsI, Visited0, Visited1), throw_all_above(Parents, G, RG, Vs1, Vs2).
dgraph_neighbors(V, G, Kids),
add_kids(Kids, G, RG, IsI, IsF, EvsI, EvsF, Visited1, Visited).
add_parents([], _, _, Is, Is, Evs, Evs, Visited, Visited). throw_all_below([], _, _, Vs, Vs).
% been here already, can safely ignore. throw_all_below(Child.Children, G, RG, Vs0, Vs2) :-
add_parents([V|Vs], G, RG, Is0, IsF, Evs0, EvsF, Visited0, VisitedF) :- % clpbn:get_atts(Child,[key(K)]), rb_visit(Vs0, Pairs), writeln(down:Child:K:Pairs),
rb_lookup(V, _, Visited0), !, throw_below(Child, G, RG, Vs0, Vs1),
add_parents(Vs, G, RG, Is0, IsF, Evs0, EvsF, Visited0, VisitedF). throw_all_below(Children, G, RG, Vs1, Vs2).
% evidence node,
% just say that we visited it
add_parents([V|Vs], G, RG, Is0, IsF, Evs0, EvsF, Visited0, VisitedF) :-
dgraph_edge(V,e,G), !, % has evidence
rb_insert(Visited0, V, _, VisitedI),
add_parents(Vs, G, RG, Is0, IsF, [V|Evs0], EvsF, VisitedI, VisitedF).
% non-evidence node,
% we will need to find its parents.
add_parents([V|Vs], G, RG, Is0, IsF, Evs0, EvsF, Visited0, VisitedF) :-
rb_insert(Visited0, V, _, VisitedI),
follow_dgraph(V, G, RG, [V|Is0], IsI, [V|Evs0], EvsI, VisitedI, VisitedII),
add_parents(Vs, G, RG, IsI, IsF, EvsI, EvsF, VisitedII, VisitedF).
add_kids([], _, _, Is, Is, Evs, Evs, Visited, Visited). % visited
add_kids([V|Vs], G, RG, Is0, IsF, Evs0, EvsF, Visited0, VisitedF) :- throw_below(Child, G, RG, Vs0, Vs1) :-
dgraph_edge(V,e,G), % has evidence rb_lookup(Child, _.B, Vs0), !,
% we will go there even if it was visited (
( rb_insert(Visited0, V, _, Visited1) -> B == b ->
true Vs0 = Vs1 % been there before
; ;
% we've been there, but were we there as a father or as a kid? B = b, % mark it
not_in(Evs0, V), handle_ball_from_above(Child, G, RG, Vs0, Vs1)
Visited1 = Visited0
),
!,
dgraph_neighbors(V, RG, Parents),
add_parents(Parents, G, RG, Is0, Is1, [V|Evs0], EvsI, Visited1, VisitedI),
(Is1 = Is0 ->
% ignore whatever we did with this node,
% it didn't lead anywhere (all parents have evidence).
add_kids(Vs, G, RG, Is0, IsF, [V|Evs0], EvsF, Visited1, VisitedF)
;
% insert parents
add_kids(Vs, G, RG, Is1, IsF, EvsI, EvsF, VisitedI, VisitedF)
). ).
add_kids([_|Vs], G, RG, Is0, IsF, Evs0, EvsF, Visited0, VisitedF) :- throw_below(Child, G, RG, Vs0, Vs2) :-
add_kids(Vs, G, RG, Is0, IsF, Evs0, EvsF, Visited0, VisitedF). rb_insert(Vs0, Child, _.b, Vs1),
handle_ball_from_above(Child, G, RG, Vs1, Vs2).
% share this with parents, if we have evidence
handle_ball_from_above(V, G, RG, Vs0, Vs1) :-
clpbn:get_atts(V,[evidence(_)]), !,
dgraph_neighbors(V, RG, Parents),
throw_all_above(Parents, G, RG, Vs0, Vs1).
% propagate to kids, if we do not
handle_ball_from_above(V, G, RG, Vs0, Vs1) :-
dgraph_neighbors(V, G, Children),
throw_all_below(Children, G, RG, Vs0, Vs1).
not_in([V1|_], V) :- V1 == V, !, fail. throw_all_above([], _, _, Vs, Vs).
not_in([_|Evs0], V) :- throw_all_above(Parent.Parentren, G, RG, Vs0, Vs2) :-
not_in(Evs0, V). % clpbn:get_atts(Parent,[key(K)]), rb_visit(Vs0, Pairs), writeln(up:Parent:K:Pairs),
throw_above(Parent, G, RG, Vs0, Vs1),
throw_all_above(Parentren, G, RG, Vs1, Vs2).
% visited
throw_above(Parent, G, RG, Vs0, Vs1) :-
rb_lookup(Parent, T._, Vs0), !,
(
T == t ->
Vs1 = Vs0 % been there before
;
T = t, % mark it
handle_ball_from_below(Parent, G, RG, Vs0, Vs1)
).
throw_above(Parent, G, RG, Vs0, Vs2) :-
rb_insert(Vs0, Parent, t._, Vs1),
handle_ball_from_below(Parent, G, RG, Vs1, Vs2).
% share this with parents, if we have evidence
handle_ball_from_below(V, _, _, Vs, Vs) :-
clpbn:get_atts(V,[evidence(_)]), !.
% propagate to kids, if we do not
handle_ball_from_below(V, G, RG, Vs0, Vs1) :-
dgraph_neighbors(V, RG, Parents),
propagate_ball_from_below(Parents, V, G, RG, Vs0, Vs1).
propagate_ball_from_below([], V, G, RG, Vs0, Vs1) :- !,
dgraph_neighbors(V, G, Children),
throw_all_below(Children, G, RG, Vs0, Vs1).
propagate_ball_from_below(Parents, _V, G, RG, Vs0, Vs1) :-
throw_all_above(Parents, G, RG, Vs0, Vs1).
all_top(T, Vs) :-
rb_visit(T, Pairs),
get_tops(Pairs, Vs).
get_tops([], []).
get_tops([V-(T._)|Pairs], V.Vs) :-
T == t, !,
get_tops(Pairs, Vs).
get_tops([V-_|Pairs], V.Vs) :-
clpbn:get_atts(V,[evidence(_)]), !,
get_tops(Pairs, Vs).
get_tops(_.Pairs, Vs) :-
get_tops(Pairs, Vs).

View File

@ -51,7 +51,7 @@
:- use_module(library('clpbn/connected'), :- use_module(library('clpbn/connected'),
[ [
influences/4 influences/3
]). ]).
:- dynamic gibbs_params/3. :- dynamic gibbs_params/3.
@ -73,7 +73,7 @@ init_gibbs_solver(GoalVs, Vs0, _, Vs) :-
clean_up, clean_up,
term_variables(GoalVs, LVs), term_variables(GoalVs, LVs),
check_for_hidden_vars(Vs0, Vs0, Vs1), check_for_hidden_vars(Vs0, Vs0, Vs1),
influences(Vs1, LVs, _, Vs2), influences(Vs1, LVs, Vs2),
sort(Vs2,Vs). sort(Vs2,Vs).
run_gibbs_solver(LVs, LPs, Vs) :- run_gibbs_solver(LVs, LPs, Vs) :-

View File

@ -80,7 +80,7 @@
:- use_module(library('clpbn/connected'), :- use_module(library('clpbn/connected'),
[ [
init_influences/3, init_influences/3,
influences/5 influences/4
]). ]).
@ -98,7 +98,7 @@ init_jt_solver(LLVs, Vs0, _, State) :-
init_jt_solver_for_questions([], _, _, []). init_jt_solver_for_questions([], _, _, []).
init_jt_solver_for_questions([LLVs|MoreLLVs], G, RG, [state(JTree, Evidence)|State]) :- init_jt_solver_for_questions([LLVs|MoreLLVs], G, RG, [state(JTree, Evidence)|State]) :-
influences(LLVs, _, NVs0, G, RG), influences(LLVs, G, RG, NVs0),
sort(NVs0, NVs), sort(NVs0, NVs),
get_graph(NVs, BayesNet, CPTs, Evidence), get_graph(NVs, BayesNet, CPTs, Evidence),
build_jt(BayesNet, CPTs, JTree), build_jt(BayesNet, CPTs, JTree),

View File

@ -44,7 +44,7 @@
:- use_module(library('clpbn/connected'), :- use_module(library('clpbn/connected'),
[ [
init_influences/3, init_influences/3,
influences/5 influences/4
]). ]).
:- use_module(library('clpbn/matrix_cpt_utils'), :- use_module(library('clpbn/matrix_cpt_utils'),
@ -87,7 +87,7 @@ init_ve_solver(Qs, Vs0, _, LVis) :-
init_ve_solver_for_questions([], _, _, [], []). init_ve_solver_for_questions([], _, _, [], []).
init_ve_solver_for_questions([Vs|MVs], G, RG, [NVs|MNVs0], [NVs|LVis]) :- init_ve_solver_for_questions([Vs|MVs], G, RG, [NVs|MNVs0], [NVs|LVis]) :-
influences(Vs, _, NVs0, G, RG), influences(Vs, G, RG, NVs0),
sort(NVs0, NVs), sort(NVs0, NVs),
%clpbn_gviz:clpbn2gviz(user_error, test, NVs, Vs), %clpbn_gviz:clpbn2gviz(user_error, test, NVs, Vs),
init_ve_solver_for_questions(MVs, G, RG, MNVs0, LVis). init_ve_solver_for_questions(MVs, G, RG, MNVs0, LVis).

View File

@ -1,4 +1,6 @@
:- style_check(all).
:- ensure_loaded(library(clpbn)). :- ensure_loaded(library(clpbn)).
wet_grass(W) :- wet_grass(W) :-

View File

@ -22,9 +22,6 @@
randomise_all_dists/0, randomise_all_dists/0,
uniformise_all_dists/0]). uniformise_all_dists/0]).
:- use_module(library('clpbn/connected'),
[clpbn_subgraphs/2]).
:- use_module(library('clpbn/learning/learn_utils'), :- use_module(library('clpbn/learning/learn_utils'),
[run_all/1, [run_all/1,
clpbn_vars/2, clpbn_vars/2,