From 3b811d0d703724820ed04b3229010928b4616df6 Mon Sep 17 00:00:00 2001 From: Vitor Santos Costa Date: Tue, 30 Sep 2008 00:02:31 +0100 Subject: [PATCH] changes to support em: step 1 --- CLPBN/clpbn/dists.yap | 4 +- CLPBN/clpbn/gibbs.yap | 12 +++-- CLPBN/learning/em.yap | 114 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 6 deletions(-) create mode 100644 CLPBN/learning/em.yap diff --git a/CLPBN/clpbn/dists.yap b/CLPBN/clpbn/dists.yap index f441b96a3..a303f8e4f 100644 --- a/CLPBN/clpbn/dists.yap +++ b/CLPBN/clpbn/dists.yap @@ -18,7 +18,8 @@ get_evidence_from_position/3, dist_to_term/2, empty_dist/2, - dist_new_table/2 + dist_new_table/2, + all_dist_ids/1 ]). :- use_module(library(lists),[is_list/1,nth0/3]). @@ -229,4 +230,3 @@ dist_new_table(Id, NewMat) :- fail. dist_new_table(_, _). - diff --git a/CLPBN/clpbn/gibbs.yap b/CLPBN/clpbn/gibbs.yap index 22018a280..a08425ade 100644 --- a/CLPBN/clpbn/gibbs.yap +++ b/CLPBN/clpbn/gibbs.yap @@ -8,7 +8,8 @@ % :- module(gibbs, [gibbs/3, - check_if_gibbs_done/1]). + check_if_gibbs_done/1, + init_gibbs_solver/3]). :- use_module(library(rbtrees), [rb_new/1, @@ -49,9 +50,7 @@ gibbs([],_,_) :- !. gibbs(LVs,Vs0,AllDiffs) :- LVs = [_], !, - clean_up, - check_for_hidden_vars(Vs0, Vs0, Vs1), - sort(Vs1,Vs), + init_gibbs_solver(Vs0, LVs, Gibbs), (clpbn:output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,Vs) ; true), (clpbn:output(gviz(XBifStream)) -> clpbn2gviz(XBifStream,vel,Vs,LVs) ; true), initialise(Vs, Graph, LVs, OutputVars, VarOrder), @@ -64,6 +63,11 @@ gibbs(LVs,Vs0,AllDiffs) :- gibbs(LVs,_,_) :- throw(error(domain_error(solver,LVs),solver(gibbs))). +init_gibbs_solver(LVs, Vs0, Gibbs) :- + clean_up, + check_for_hidden_vars(Vs0, Vs0, Vs1), + sort(Vs1,Vs). + initialise(LVs, Graph, GVs, OutputVars, VarOrder) :- init_keys(Keys0), gen_keys(LVs, 0, VLen, Keys0, Keys), diff --git a/CLPBN/learning/em.yap b/CLPBN/learning/em.yap new file mode 100644 index 000000000..4539b4e82 --- /dev/null +++ b/CLPBN/learning/em.yap @@ -0,0 +1,114 @@ +% +% The world famous EM algorithm, in a nutshell +% + +:- module(clpbn_em, [em/6]). + +:- use_module(library(lists), + [append/3]). + +:- use_module(library('clpbn/learning/learn_utils'), + [run_all/1, + clpbn_vars/2, + normalise_counts/2]). + +em(Items, MaxError, MaxIts, Tables, Likelihood) :- + init_em(Items, State), + em_loop(0, 0.0, state(AllVars,AllDists), MaxError, MaxIts, Likelihood), + get_tables(State, Tables). + +% This gets you an initial configuration. If there is a lot of evidence +% tables may be filled in close to optimal, otherwise they may be +% close to uniform. +% it also gets you a run for random variables +init_em(Items, state(AllVars, AllDists, AllDistInstances)) :- + run_all(Items), + different_dists(AllVars, AllDists, AllDistInstances). + +% loop for as long as you want. +em_loop(MaxIts, Likelihood State, _, _ MaxIts, Likelihood) :- !. +em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF) :- + estimate(State), + maximise(State, Likelihood), + ( + ( + (Likelihood - Likelihood0)/Likelihood < MaxError + ; + Its == MaxIts + ) + -> + LikelihoodF = Likelihood + ; + Its1 is Its+1, + em_loop(Its1, Likelihood, State, MaxError, MaxIts, LikelihoodF) + ). + +% collect the different dists we are going to learn next. +different_dists(AllVars, AllDists, AllInfo) :- + all_dists(AllVars, Dists0, AllInfo), + sort(Dists0, Dists1), + group(Dists1, AllInfo). + +group([], []) :- +group([i(Id,V,Ps)|Dists1], [Id-[[V|Ps]|Extra]|AllInfo]) :- + same_id(Dists1, Id, Extra, Rest), + group(Rest, AllInfo). + +same_id([i(Id,V,Ps)|Dists1], Id, [[V|Ps]|Extra], Rest) :- !, + same_id(Dists1, Id, Extra, Rest). +same_id(Dists, _, [], Dists). + +all_dists([], [], []). +all_dists([V|AllVars], Dists, [i(Id, AllInfo, Parents)|AllInfo]) :- + clpbn:get_atts(V, [dist(Id,_)]), + with_evidence(V, Id, Dists, Dists0), !, + all_dists(AllVars, Dists0, AllInfo). + +with_evidence(V, Id) --> + {clpbn:get_atts(V, [evidence(Pos)]) }, !, + { dist_pos2bin(Pos, Id, Bin) }. +with_evidence(V, Id) --> + [d(V,Id)]. + +estimate(state(Vars,Info,_)) :- + clpbn_solve_graph(Vars, OVars), + marg_vars(Info, Vars). + +marg_vars([], _). +marg_vars([d(V,Id)|Vars], AllVs) :- + clpbn_marginalise_in_vars(V, AllVs), + marg_vars(Vars, AllVs). + +maximise(state(_,_,DistInstances), Tables, Likelihood) :- + compute_parameters(DistInstances, Tables, 0.0, Likelihood). + +compute_parameters([], [], Lik, Lik). +compute_parameters([Id-Samples|Dists], [Tab|Tables], Lik0, Lik) :- + empty_dist(Id, NewTable), + add_samples(Samples, NewTable). + normalise_table(Id, NewTable), + compute_parameters(Dists, Tables, Lik0, Lik). + +add_samples([], _). +add_samples([S|Samples], Table) :- + run_sample(S, 1.0, Pos, Tot), + matrix_add(Table, Pos, Tot), + fail. +add_samples([_|Samples], Table) :- + add_samples(Samples, Table) + +run_sample([], Tot, [], Tot). +run_sample([V|S], W0, [P|Pos], Tot) :- + {clpbn:get_atts(V, [evidence(P)]) }, !, + run_sample(S, W0, Pos, Tot). +run_sample([V|S], W0, [P|Pos], Tot) :- + {clpbn_display:get_atts(V, [posterior,(_,_,Ps,_)]) }, + count_cases(Ps, 0, D0, P), + W1 is D0*W0, + run_sample(S, W1, Pos, Tot). + +count_cases([D0|Ps], I0, D0, I0). +count_cases([_|Ps], I0, P, W1) :- + I is I0+1, + count_cases(Ps, I, P, W1). +