/*********************************** Variable Elimination in Prolog How to do it Three steps: build the graph: - for all variables, find out all tables they connect to; multiply their size order by size *********************************/ :- module(clpbn_ve, [ve/3, check_if_ve_done/1, init_ve_solver/4, run_ve_solver/3]). :- attribute size/1, all_diffs/1. :- use_module(library(ordsets), [ord_union/3, ord_member/2]). :- use_module(library('clpbn/xbif'), [clpbn2xbif/3]). :- use_module(library('clpbn/graphviz'), [clpbn2gviz/4]). :- use_module(library('clpbn/dists'), [ dist/4, get_dist_domain_size/2, get_dist_params/2, get_dist_domain_size/2, get_dist_matrix/5]). :- use_module(library('clpbn/utils'), [ clpbn_not_var_member/2]). :- use_module(library('clpbn/display'), [ clpbn_bind_vals/3]). :- use_module(library('clpbn/connected'), [ init_influences/3, influences/4 ]). :- use_module(library(clpbn/matrix_cpt_utils)). :- use_module(library(lists), [ member/2, append/3, delete/3 ]). :- use_module(library(maplist)). :- use_module(library(rbtrees)). :- use_module(library(clpbn/vmap)). :- use_module(library('clpbn/aggregates'), [check_for_agg_vars/2]). % % uses a bipartite graph where bigraph(Vs, NFs, Fs) % Vs=map variables to lists of factors % NFs=number of factors % Fs=map factor id -> f(Id, Vars, Table) % check_if_ve_done(Var) :- get_atts(Var, [size(_)]), !. % % implementation of the well known variable elimination algorithm % ve([[]],_,_) :- !. ve(LLVs,Vs0,AllDiffs) :- init_ve_solver(LLVs, Vs0, AllDiffs, State), % variable elimination proper run_ve_solver(LLVs, LLPs, State), % bind Probs back to variables so that they can be output. clpbn_bind_vals(LLVs,LLPs,AllDiffs). % % Qs is a list of lists with all query vars (marginals) % IQs is the corresponding list of integers % LVis is a list of lists with all variables reachable from the query % ILVis is the corresponding list of integers % Vmap is the map V->I % init_ve_solver(Qs, Vs0, _, state(IQs, LVIs, VMap, Bigraph, Ev)) :- check_for_agg_vars(Vs0, Vs1), % LVi will have a list of CLPBN variables init_influences(Vs1, Graph, TGraph), maplist(init_ve_solver_for_question(Graph, TGraph), Qs, LVs), init_vmap(VMap0), lvars_to_numbers(LVs, LVIs, VMap0, VMap1), lvars_to_numbers(Qs, IQs, VMap1, VMap), vars_to_bigraph(VMap, Bigraph, Ev). init_ve_solver_for_question(G, RG, Vs, NVs) :- influences(Vs, G, RG, NVs0), sort(NVs0, NVs). % % construct a bipartite graph with vars and factors % the nodes of the var graph just contain pointer to the factors % the nodes of the factors contain alist of variables and a matrix % also provide a matrix with evidence % vars_to_bigraph(VMap, bigraph(VInfo, IF, Fs), Evs) :- rb_new(Fs0), vmap_to_list(VMap, VIds), foldl3(id_to_factor(VMap), VIds, 0, IF, Fs0, Fs, [], Evs), factors_to_vs(Fs, VInfo). id_to_factor(VMap, V-I, IF0, IF, Fs0, Fs, Evs0, Evs) :- % process evidence for variable clpbn:get_atts(V, [evidence(E), dist(_,Ps)]), checklist(noparent_of_interest(VMap), Ps), !, % I don't need to get a factor here Evs = [I=E|Evs0], IF = IF0, Fs = Fs0. id_to_factor(VMap, V-I, IF0, IF, Fs0, Fs, Evs0, Evs) :- % process distribution/factors ( clpbn:get_atts(V, [evidence(E)]) -> Evs = [I=E|Evs0] ; Evs = Evs0 ), clpbn:get_atts(V, [dist(D, Ps)]), get_dist_params(D, Pars0), get_dist_domain_size(D, DS), maplist(parent_to_id(VMap), Ps, Sizes, IPs), init_CPT(Pars0, [DS|Sizes], CPT0), reorder_CPT([I|IPs], CPT0, FIPs, CPT, _), rb_insert(Fs0, IF0, f(IF0, FIPs, CPT), Fs), IF is IF0+1. noparent_of_interest(VMap, P) :- \+ get_from_vmap(P, _, VMap). parent_to_id(VMap, V, DS, I) :- clpbn:get_atts(V, [dist(D, _Ps)]), get_dist_domain_size(D, DS), get_from_vmap(V, I, VMap). factors_to_vs(Fs, VInfo) :- rb_visit(Fs, L), foldl(fsvs, L, FVs, []), sort(FVs, SFVs), rb_new(VInfo0), add_vs(SFVs, Fs, VInfo0, VInfo). fsvs(F-f(_, IVs, _)) --> fvs(IVs, F). fvs([], _F) --> []. fvs([I|IVs], F) --> [I-F], fvs(IVs, F). % % construct variable nodes % add_vs([], _, VInfo, VInfo). add_vs([V-F|SFVs], Fs, VInfo0, VInfo) :- rb_lookup(F, FInfo, Fs), collect_factors(SFVs, Fs, V, Fs0, R), rb_insert(VInfo0, V, [FInfo|Fs0], VInfoI), add_vs(R, Fs, VInfoI, VInfo). collect_factors([], _Fs, _V, [], []). collect_factors([V-F|SFVs], Fs, V, [FInfo|FInfos], R):- !, rb_lookup(F, FInfo, Fs), collect_factors(SFVs, Fs, V, FInfos, R). collect_factors(SFVs, _Fs, _V, [], SFVs). % solve each query independently % use a findall to recover space without needing for GC run_ve_solver(_, LLPs, state(LQVs, LVs, _VMap, Bigraph, Ev)) :- findall(LPs, solve_ve(LQVs, LVs, Bigraph, Ev, LPs), LLPs). % % IQVs are the current marginal, % IVs are all variables related to that % IFVs are the factors % SVs are the variables % solve_ve([IQVs|_], [IVs|_], bigraph(OldVs, IF, _Fs), Ev, Ps) :- % get only what is relevant to query, project_to_query_related(IVs, OldVs, SVs, Fs1), % and also prune using evidence foldl2(clean_v_ev, Ev, Fs1, Fs2, SVs, EVs), % eliminate eliminate(IQVs, digraph(EVs, IF, Fs2), Dist), % writeln(m:Dist),matrix:matrix_to_list(Dist,LD),writeln(LD), %exps(LD,LDE),writeln(LDE), % move from potentials back to probabilities normalise_CPT(Dist,MPs), list_from_CPT(MPs, Ps). solve_ve([_|MoreLVs], [_|MoreLVis], Digraph, Ev, Ps) :- solve_ve(MoreLVs, MoreLVis, Digraph, Ev, Ps). % % given our input queries, sort them and obtain the subgraphs of vars and facs. % project_to_query_related(IVs0, OldVs, NVs, NFs) :- sort(IVs0, IVs), rb_new(Vs0), foldl(cp_to_vs, IVs, Vs0, AuxVs), rb_new(NFs0), foldl(simplify_graph_node(OldVs, AuxVs), IVs, VFs, NFs0, NFs), list_to_rbtree(VFs, NVs). % % auxiliary tree for fast access to vars. % cp_to_vs(V, Vs0, Vs) :- rb_insert(Vs0, V, _, Vs). % % construct a new, hopefully much smaller, graph % simplify_graph_node(OldVs, NVs, V, V-RemFs, NFs0, NFs) :- rb_lookup(V, Fs, OldVs), foldl2(check_factor(V, NVs), Fs, NFs0, NFs, [], RemFs). % % check if a factor belongs to the subgraph. % % % Two cases: first time factor comes up: all its vars must be in subgraph % second case: second time it comes up, it must be already in graph % % args: +Factor F, +current V (int), +rbtree with all Vs, % -Factors in new Graph, +factors in current graph, -rbtree of factors % % check_factor(V, NVs, F, NFs0, NFs, RemFs, NewRemFs) :- F = f(IF, [V|More], _), !, ( checklist(check_v(NVs), More) -> rb_insert(NFs0, IF, F, NFs), NewRemFs = [F|RemFs] ; NFs0 = NFs, NewRemFs = RemFs ). check_factor(_V, _NVs, F, NFs, NFs, RemFs, NewRemFs) :- F = f(Id, _, _), ( rb_lookup(Id, F, NFs) -> NewRemFs = [F|RemFs] ; NewRemFs = RemFs ). check_v(NVs, V) :- rb_lookup(V, _, NVs). % % simplify a variable with evidence % clean_v_ev(V=E, FVs0, FVs, Vs0, Vs) :- rb_delete(Vs0, V, Fs, Vs1), foldl2(simplify_f_ev(V, E), Fs, FVs0, FVs, Vs1, Vs). % % % tricky: clean a factor means also cleaning all back references. % simplify_f_ev(V, E, F, Fs0, Fs, Vs0, Vs) :- F = f(Id, FVs, CPT), NF = f(Id, NFVs, NCPT), project_from_CPT(V, E, CPT, FVs, NCPT, NFVs), % update factor rb_update(Fs0, Id, NF, Fs), foldl(update_factors(F,NF), NFVs, Vs0, Vs). % update all instances of F in var graph update_factors(F, NF, V, Vs0, Vs) :- rb_update(Vs0, V, Fs, NFs, Vs), maplist(replace_factor(F,NF), Fs, NFs). replace_factor(F, NF, F, NF) :- !. replace_factor(_F,_NF,OF, OF). eliminate(QVs, digraph(Vs0, I, Fs0), Dist) :- find_best(Vs0, QVs, BestV, VFs), !, %writeln(best:BestV:QVs), % delete all factors that touched the variable foldl2(del_fac, VFs, Fs0, Fs1, Vs0, Vs1), % delete current variable rb_delete(Vs1, BestV, Vs2), I1 is I+1, % construct new table multiply_and_delete(VFs, BestV, NewFVs, NewCPT), % insert new factor in graph insert_fac(I, NewFVs, NewCPT, Fs1, Fs, Vs2, Vs), eliminate(QVs, digraph(Vs, I1, Fs), Dist). eliminate(_QVs, digraph(_, _, Fs), Dist) :- combine_factors(Fs, Dist). find_best(Vs, QVs, BestV, VFs) :- rb_key_fold(best_var(QVs), Vs, i(+inf,-1,[]), i(_Cost,BestV,VFs)), BestV \= -1, !. % do not eliminate marginalised variables best_var(QVs, I, _Node, Info, Info) :- member(I, QVs), !. % pick the variable with less factors best_var(_Qs, I, Node, i(ValSoFar,_,_), i(NewVal,I,Node)) :- length(Node, NewVal), NewVal < ValSoFar, !. best_var(_, _I, _Node, Info, Info). % delete one factor, need to also touch all variables del_fac(f(I,FVs,_), Fs0, Fs, Vs0, Vs) :- rb_delete(Fs0, I, Fs), foldl(delete_fac_from_v(I), FVs, Vs0, Vs). delete_fac_from_v(I, FV, Vs0, Vs) :- rb_update(Vs0, FV, Fs, NFs, Vs), exclude(factor_name(I), Fs, NFs). factor_name(I, f(I,_,_)). % insert one factor, need to touch all corresponding variables insert_fac(I, FVs, CPT, Fs0, Fs, Vs0, Vs) :- F = f(I, FVs, CPT), rb_insert(Fs0, I, F, Fs), foldl(insert_fac_in_v(F), FVs, Vs0, Vs). insert_fac_in_v(F, FV, Vs0, Vs) :- rb_update(Vs0, FV, Fs, [F|Fs], Vs). combine_factors(Fs, Dist) :- rb_visit(Fs,Els), maplist(extract_factor,Els,Factors), multiply(Factors, _, Dist). extract_factor(_-Factor, Factor). multiply_and_delete([f(I,Vs0,T0)|Fs], V, Vs, T) :- foldl(multiply_factor, Fs, f(I,Vs0,T0), f(_,Vs1,T1)), sum_out_from_CPT(V, T1, Vs1, T, Vs). multiply([F0|Fs], Vs, T) :- foldl(multiply_factor, Fs, F0, f(_,Vs,T)). multiply_factor(f(_,Vs1,T1), f(_,Vs0,T0), f(_,Vs,T)) :- multiply_CPTs(T1, Vs1, T0, Vs0, T, Vs).