% % each variable is represented by a node in a binary tree. % each node contains: % key, % current_value % Markov Blanket % :- module(gibbs, [gibbs/3, check_if_gibbs_done/1]). :- use_module(library(rbtrees), [rb_new/1, rb_insert/4, rb_lookup/3]). :- use_module(library(lists), [member/2, append/3, delete/3, max_list/2]). :- use_module(library(ordsets), [ord_subtract/3]). :- use_module(library('clpbn/discrete_utils'), [ project_from_CPT/3, reorder_CPT/5]). :- use_module(library('clpbn/utils'), [ check_for_hidden_vars/3]). :- use_module(library('clpbn/topsort'), [ topsort/2]). :- dynamic gibbs_params/3. :- dynamic implicit/1. gibbs([],_,_) :- !. gibbs(LVs,Vs0,_) :- clean_up, check_for_hidden_vars(Vs0, Vs0, Vs1), sort(Vs1,Vs), (clpbn:output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,Vs) ; true), (clpbn:output(gviz(XBifStream)) -> clpbn2gviz(XBifStream,vel,Vs,LVs) ; true), initialise(Vs, Graph, LVs, OutputVars, VarOrder), % write(Graph),nl, process(VarOrder, Graph, OutputVars, Estimates), write(Estimates),nl, clean_up. initialise(LVs, Graph, GVs, OutputVars, VarOrder) :- init_keys(Keys0), gen_keys(LVs, 0, VLen, Keys0, Keys), functor(Graph,graph,VLen), graph_representation(LVs, Graph, 0, Keys, TGraph), compile_graph(Graph), topsort(TGraph, VarOrder), % show_sorted(VarOrder, Graph), add_output_vars(GVs, Keys, OutputVars). init_keys(Keys0) :- rb_new(Keys0). gen_keys([], I, I, Keys, Keys). gen_keys([V|Vs], I0, If, Keys0, Keys) :- clpbn:get_atts(V,[evidence(_)]), !, gen_keys(Vs, I0, If, Keys0, Keys). gen_keys([V|Vs], I0, If, Keys0, Keys) :- I is I0+1, rb_insert(Keys0,V,I,KeysI), gen_keys(Vs, I, If, KeysI, Keys). graph_representation([],_,_,_,[]). graph_representation([V|Vs], Graph, I0, Keys, TGraph) :- clpbn:get_atts(V,[evidence(_)]), !, clpbn:get_atts(V, [dist(Vals,Table,Parents)]), get_sizes(Parents, Szs), length(Vals,Sz), project_evidence_out([V|Parents],[V|Parents],Table,[Sz|Szs],Variables,NewTable), % all variables are parents propagate2parents(Variables, NewTable, Variables, Graph, Keys), graph_representation(Vs, Graph, I0, Keys, TGraph). graph_representation([V|Vs], Graph, I0, Keys, [I-IParents|TGraph]) :- I is I0+1, clpbn:get_atts(V, [dist(Vals,Table,Parents)]), get_sizes(Parents, Szs), length(Vals,Sz), project_evidence_out([V|Parents],[V|Parents],Table,[Sz|Szs],Variables,NewTable), Variables = [V|NewParents], sort_according_to_indices(NewParents,Keys,SortedNVs,SortedIndices), reorder_CPT(Variables,NewTable,[V|SortedNVs],NewTable2,_), add2graph(V, Vals, NewTable2, SortedIndices, Graph, Keys), propagate2parents(NewParents, NewTable, Variables, Graph,Keys), parent_indices(NewParents, Keys, IVariables0), sort(IVariables0, IParents), arg(I, Graph, var(_,_,_,_,_,_,_,NewTable2,SortedIndices)), graph_representation(Vs, Graph, I, Keys, TGraph). write_pars([]). write_pars([V|Parents]) :- clpbn:get_atts(V, [key(K)]),write(K),nl, write_pars(Parents). get_sizes([], []). get_sizes([V|Parents], [Sz|Szs]) :- clpbn:get_atts(V, [dist(Vals,_,_)]), length(Vals,Sz), get_sizes(Parents, Szs). parent_indices([], _, []). parent_indices([V|Parents], Keys, [I|IParents]) :- rb_lookup(V, I, Keys), parent_indices(Parents, Keys, IParents). % % first, remove nodes that have evidence from tables. % project_evidence_out([],Deps,Table,_,Deps,Table). project_evidence_out([V|Parents],Deps,Table,Szs,NewDeps,NewTable) :- clpbn:get_atts(V,[evidence(_)]), !, NTab =.. [t|Table], project_from_CPT(V,tab(NTab,Deps,Szs),tab(ITable,IDeps,ISzs)), ITable =.. [_|LITable], project_evidence_out(Parents,IDeps,LITable,ISzs,NewDeps,NewTable). project_evidence_out([_Par|Parents],Deps,Table,Szs,NewDeps,NewTable) :- project_evidence_out(Parents,Deps,Table,Szs,NewDeps,NewTable). propagate2parents([], _, _, _, _). propagate2parents([V|NewParents], Table, Variables, Graph, Keys) :- delete(Variables,V,NVs), sort_according_to_indices(NVs,Keys,SortedNVs,SortedIndices), reorder_CPT(Variables,Table,[V|SortedNVs],NewTable,_), add2graph(V, _, NewTable, SortedIndices, Graph, Keys), propagate2parents(NewParents,Table, Variables, Graph, Keys). add2graph(V, Vals, Table, IParents, Graph, Keys) :- rb_lookup(V, Index, Keys), (var(Vals) -> true ; length(Vals,Sz)), arg(Index, Graph, var(V,Index,_,Vals,Sz,VarSlot,_,_,_)), member(tabular(Table,Index,IParents), VarSlot), !. sort_according_to_indices(NVs,Keys,SortedNVs,SortedIndices) :- vars2indices(NVs,Keys,ToSort), keysort(ToSort, Sorted), split_parents(Sorted, SortedNVs,SortedIndices). split_parents([], [], []). split_parents([I-V|Sorted], [V|SortedNVs],[I|SortedIndices]) :- split_parents(Sorted, SortedNVs, SortedIndices). vars2indices([],_,[]). vars2indices([V|Parents],Keys,[I-V|IParents]) :- rb_lookup(V, I, Keys), vars2indices(Parents,Keys,IParents). compact_table(NewTable, RepTable) :- NewTable = [_|_], !, RepTable =.. [t|NewTable]. % % This is the really cool bit. % compile_graph(Graph) :- Graph =.. [_|VarsInfo], compile_vars(VarsInfo,Graph). compile_vars([],_). compile_vars([var(_,I,_,Vals,Sz,VarSlot,Parents,_,_)|VarsInfo],Graph) :- compile_var(I,Vals,Sz,VarSlot,Parents,Graph), compile_vars(VarsInfo,Graph). compile_var(I,Vals,Sz,VarSlot,Parents,Graph) :- fetch_all_parents(VarSlot,Graph,[],Parents,[],Sizes), mult_list(Sizes,1,TotSize), compile_var(TotSize,I,Vals,Sz,VarSlot,Parents,Sizes,Graph). fetch_all_parents([],_,Parents,Parents,Sizes,Sizes). fetch_all_parents([tabular(_,_,Ps)|CPTs],Graph,Parents0,ParentsF,Sizes0,SizesF) :- merge_these_parents(Ps,Graph,Parents0,ParentsI,Sizes0,SizesI), fetch_all_parents(CPTs,Graph,ParentsI,ParentsF,SizesI,SizesF). merge_these_parents([],_,Parents,Parents,Sizes,Sizes). merge_these_parents([I|Ps],Graph,Parents0,ParentsF,Sizes0,SizesF) :- member(I,Parents0), !, merge_these_parents(Ps,Graph,Parents0,ParentsF,Sizes0,SizesF). merge_these_parents([I|Ps],Graph,Parents0,ParentsF,Sizes0,SizesF) :- arg(I,Graph,var(_,I,_,Vals,_,_,_,_,_)), length(Vals, Sz), add_parent(Parents0,I,ParentsI,Sizes0,Sz,SizesI), merge_these_parents(Ps,Graph,ParentsI,ParentsF,SizesI,SizesF). add_parent([],I,[I],[],Sz,[Sz]). add_parent([P|Parents0],I,[I,P|Parents0],Sizes0,Sz,[Sz|Sizes0]) :- P > I, !. add_parent([P|Parents0],I,[P|ParentsI],[S|Sizes0],Sz,[S|SizesI]) :- add_parent(Parents0,I,ParentsI,Sizes0,Sz,SizesI). mult_list([],Mult,Mult). mult_list([Sz|Sizes],Mult0,Mult) :- MultI is Sz*Mult0, mult_list(Sizes,MultI,Mult). % compile node as set of facts, faster execution compile_var(TotSize,I,_Vals,Sz,CPTs,Parents,_Sizes,Graph) :- TotSize < 1024*64, TotSize > 0, !, multiply_all(I,Parents,CPTs,Sz,Graph). compile_var(_,I,_,_,_,_,_,_) :- assert(implicit(I)). multiply_all(I,Parents,CPTs,Sz,Graph) :- markov_blanket_instance(Parents,Graph,Values), multiply_all(CPTs,Sz,Graph,Probs), store_mblanket(I,Values,Probs), fail. multiply_all(_,_,_,_,_). % note: what matters is how this predicate instantiates the temp % slot in the graph! markov_blanket_instance([],_,[]). markov_blanket_instance([I|Parents],Graph,[Pos|Values]) :- arg(I,Graph,var(_,I,Pos,Vals,_,_,_,_,_)), fetch_val(Vals,0,Pos), markov_blanket_instance(Parents,Graph,Values). % backtrack through every value in domain % fetch_val([_|_],Pos,Pos). fetch_val([_|Vals],I0,Pos) :- I is I0+1, fetch_val(Vals,I,Pos). :- dynamic a/0. multiply_all(CPTs,Size,Graph,Probs) :- init_factors(Size,Factors0), mult_factors(CPTs,Size,Graph,Factors0,Factors), normalise_factors(Factors,Probs). init_factors(0,[]) :- !. init_factors(I0,[0.0|Factors]) :- I is I0-1, init_factors(I,Factors). mult_factors([],_,_,Factors,Factors). mult_factors([tabular(Table,_,Parents)|CPTs],Size,Graph,Factors0,Factors) :- functor(Table,_,CPTSize), Off is CPTSize//Size, factor(Parents,Table,Graph,0,Off,Indx0), Indx is Indx0+1, mult_with_probs(Factors0,Indx,Off,Table,FactorsI), mult_factors(CPTs,Size,Graph,FactorsI,Factors). factor([],_,_,Arg,_,Arg). factor([I|Parents],Table,Graph,Pos0,Weight0,Pos) :- arg(I,Graph,var(_,I,CurPos,_,Sz,_,_,_,_)), NWeight is Weight0 // Sz, PosI is Pos0+(NWeight*CurPos), factor(Parents,Table,Graph,PosI,NWeight,Pos). mult_with_probs([],_,_,_,[]). mult_with_probs([F0|Factors0],Indx,Off,Table,[F|Factors]) :- arg(Indx,Table,P1), F is F0+log(P1), Indx1 is Indx+Off, mult_with_probs(Factors0,Indx1,Off,Table,Factors). normalise_factors(Factors,Probs) :- max_list(Factors,Max), logs2list(Factors,Max,NFactors), normalise_factors(NFactors,0,_,Probs,_). logs2list([],_,[]). logs2list([Log|Factors],Max,[P|NFactors]) :- P is exp(Log+Max), logs2list(Factors,Max,NFactors). normalise_factors([],Sum,Sum,[],1.0) :- Sum > 0.0. normalise_factors([F|Factors],S0,S,[P0|Probs],PF) :- Si is S0+F, normalise_factors(Factors,Si,S,Probs,P0), PF is P0-F/S. store_mblanket(I,Values,Probs) :- append(Values,Probs,Args), Rule =.. [mblanket,I|Args], assert(Rule). add_output_vars([], _, []). add_output_vars([V|LVs], Keys, [I|OutputVars]) :- rb_lookup(V, I, Keys), add_output_vars(LVs, Keys, OutputVars). process(VarOrder, Graph, OutputVars, Estimates) :- gibbs_params(NChains,BurnIn,NSamples), functor(Graph,_,Len), init_chains(NChains,VarOrder,Len,Graph,Chains0), init_estimates(NChains,OutputVars,Graph,Est0), process_chains(BurnIn,VarOrder,BurnedIn,Chains0,Graph,Len,Est0,_), process_chains(NSamples,VarOrder,_,BurnedIn,Graph,Len,Est0,Estimates). % % I use an uniform distribution to generate the initial sample. % init_chains(0,_,_,_,[]) :- !. init_chains(I,VarOrder,Len,Graph,[Chain|Chains]) :- init_chain(VarOrder,Len,Graph,Chain), I1 is I-1, init_chains(I1,VarOrder,Len,Graph,Chains). init_chain(VarOrder,Len,Graph,Chain) :- functor(Chain,sample,Len), gen_sample(VarOrder,Graph,Chain). gen_sample([],_,_) :- !. gen_sample([I|Vs],Graph,Chain) :- arg(I,Graph,var(_,I,_,_,Sz,_,_,Table,IPars)), functor(Table,_,CPTSize), Off is CPTSize//Sz, iparents_pos_sz(IPars, Chain, IPos, Graph, ISz), R is random, project(IPos, ISz, Table,0,Off,Indx0), Indx is Indx0+1, fetch_from_dist(Table,R,Indx,Off,0,Pos), arg(I,Chain,Pos), gen_sample(Vs,Graph,Chain). project([],[],_,Arg,_,Arg). project([CurPos|Parents],[Sz|Sizes],Table,Pos0,Weight0,Pos) :- NWeight is Weight0 // Sz, PosI is Pos0+(NWeight*CurPos), project(Parents,Sizes,Table,PosI,NWeight,Pos). fetch_from_dist(Table,R,Indx,Off,IPos,Pos) :- arg(Indx,Table,P), ( P >= R -> Pos = IPos ; NR is R-P, NIndx is Indx+Off, NPos is IPos+1, fetch_from_dist(Table,NR,NIndx,Off,NPos,Pos) ). iparents_pos_sz([], _, [], _, []). iparents_pos_sz([I|IPars], Chain, [P|IPos], Graph, [Sz|Sizes]) :- arg(I,Chain,P), arg(I,Graph, var(_,I,_,_,Sz,_,_,_,_)), iparents_pos_sz(IPars, Chain, IPos, Graph, Sizes). init_estimates(0,_,_,[]) :- !. init_estimates(NChains,OutputVars,Graph,[Est|Est0]) :- NChainsI is NChains-1, init_estimate(OutputVars,Graph,Est), init_estimates(NChainsI,OutputVars,Graph,Est0). init_estimate([],_,[]). init_estimate([V|OutputVars],Graph,[[I|E0L]|Est]) :- arg(V,Graph,var(_,I,_,_,Sz,_,_,_,_)), gen_e0(Sz,E0L), init_estimate(OutputVars,Graph,Est). gen_e0(0,[]) :- !. gen_e0(Sz,[0|E0L]) :- Sz1 is Sz-1, gen_e0(Sz1,E0L). process_chains(0,_,F,F,_,_,Est,Est) :- !. process_chains(ToDo,VarOrder,End,Start,Graph,Len,Est0,Estf) :- process_chains(Start,VarOrder,Int,Graph,Len,Est0,Esti), (ToDo mod 100 =:= 0 -> statistics,cvt2problist(Esti, Probs), Int =[S|_], format('did ~d: ~w~n ~w~n',[ToDo,Probs,S]) ; true), ToDo1 is ToDo-1, process_chains(ToDo1,VarOrder,End,Int,Graph,Len,Esti,Estf). process_chains([], _, [], _, _,[],[]). process_chains([Sample0|Samples0], VarOrder, [Sample|Samples], Graph, SampLen,[E0|E0s],[Ef|Efs]) :- functor(Sample,sample,SampLen), do_sample(VarOrder,Sample,Sample0,Graph), %format('Sample = ~w~n',[Sample]), update_estimate(E0,Sample,Ef), process_chains(Samples0, VarOrder, Samples, Graph, SampLen,E0s,Efs). do_sample([],_,_,_). do_sample([I|VarOrder],Sample,Sample0,Graph) :- do_var(I,Sample,Sample0,Graph), do_sample(VarOrder,Sample,Sample0,Graph). do_var(I,Sample,Sample0,Graph) :- arg(I,Graph,var(_,I,_,_,Sz,CPTs,Parents,_,_)), ( implicit(I) -> fetch_parents(Parents,I,Sample,Sample0,Bindings,[]), multiply_all_in_context(Parents,Bindings,CPTs,Sz,Graph,Vals) ; length(Vals,Sz), fetch_parents(Parents,I,Sample,Sample0,Args,Vals), Goal =.. [mblanket,I|Args], call(Goal) ), X is random, pick_new_value(Vals,X,0,Val), arg(I,Sample,Val). multiply_all_in_context(Parents,Args,CPTs,Sz,Graph,Vals) :- set_pos(Parents,Args,Graph), multiply_all(CPTs,Sz,Graph,Vals), assert(mall(Vals)), fail. multiply_all_in_context(_,_,_,_,_,Vals) :- retract(mall(Vals)). set_pos([],[],_). set_pos([I|Is],[Pos|Args],Graph) :- arg(I,Graph,var(_,I,Pos,_,_,_,_,_,_)), set_pos(Is,Args,Graph). fetch_parents([],_,_,_,Args,Args). fetch_parents([P|Parents],I,Sample,Sample0,[VP|Args],Vals) :- arg(P,Sample,VP), nonvar(VP), !, fetch_parents(Parents,I,Sample,Sample0,Args,Vals). fetch_parents([P|Parents],I,Sample,Sample0,[VP|Args],Vals) :- arg(P,Sample0,VP), fetch_parents(Parents,I,Sample,Sample0,Args,Vals). pick_new_value([V|_],X,Val,Val) :- X < V, !. pick_new_value([_|Vals],X,I0,Val) :- I is I0+1, pick_new_value(Vals,X,I,Val). update_estimate([],_,[]). update_estimate([[I|E]|E0],Sample,[[I|NE]|Ef]) :- arg(I,Sample,V), update_estimate_for_var(V,E,NE), update_estimate(E0,Sample,Ef). update_estimate_for_var(0,[X|T],[X1|T]) :- !, X1 is X+1. update_estimate_for_var(V,[E|Es],[E|NEs]) :- V1 is V-1, update_estimate_for_var(V1,Es,NEs). check_if_gibbs_done(Var) :- get_atts(Var, [dist(_)]), !. clean_up :- current_predicate(mblanket,P), retractall(P), fail. clean_up :- retractall(implicit(_)), fail. clean_up. gibbs_params(5,10000,100000). cvt2problist([], []). cvt2problist([[[_|E]]|Est0], [Ps|Probs]) :- sum_all(E,0,Sum), do_probs(E,Sum,Ps), cvt2problist(Est0, Probs) . sum_all([],Sum,Sum). sum_all([E|Es],S0,Sum) :- SI is S0+E, sum_all(Es,SI,Sum). do_probs([],_,[]). do_probs([E|Es],Sum,[P|Ps]) :- P is E/Sum, do_probs(Es,Sum,Ps). show_sorted([], _) :- nl. show_sorted([I|VarOrder], Graph) :- arg(I,Graph,var(V,I,_,_,_,_,_,_,_)), clpbn:get_atts(V,[key(K)]), format('~w ',[K]), show_sorted(VarOrder, Graph).