diff --git a/CLPBN/learning/em.yap b/CLPBN/learning/em.yap index b723cf182..9d1d4fa07 100644 --- a/CLPBN/learning/em.yap +++ b/CLPBN/learning/em.yap @@ -8,13 +8,15 @@ [append/3]). :- use_module(library(clpbn), - [clpbn_init_solver/3, - clpbn_run_solver/3]). + [clpbn_init_solver/5, + clpbn_run_solver/4, + clpbn_flag/2]). :- use_module(library('clpbn/dists'), [get_dist_domain_size/2, empty_dist/2, - dist_new_table/2]). + dist_new_table/2, + get_dist_key/2]). :- use_module(library('clpbn/connected'), [clpbn_subgraphs/2]). @@ -33,6 +35,11 @@ [matrix_add/3, matrix_to_list/2]). +:- use_module(library(rbtrees), + [rb_new/1, + rb_insert/4, + rb_lookup/3]). + :- use_module(library('clpbn/utils'), [ check_for_hidden_vars/3, @@ -53,14 +60,15 @@ em(Items, MaxError, MaxIts, Tables, Likelihood) :- % it includes the list of variables without evidence, % the list of distributions for which we want to compute parameters, % and more detailed info on distributions, namely with a list of all instances for the distribution. -init_em(Items, state(AllVars, AllDists, AllDistInstances, MargVars)) :- +init_em(Items, state( AllDists, AllDistInstances, MargVars, SolverVars)) :- run_all(Items), attributes:all_attvars(AllVars0), sort_vars_by_key(AllVars0,AllVars1,[]), % remove variables that do not have to do with this query. check_for_hidden_vars(AllVars1, AllVars1, AllVars), different_dists(AllVars, AllDists, AllDistInstances, MargVars), - clpbn_init_solver(MargVars, AllVars, _). + clpbn_flag(em_solver, Solver), + clpbn_init_solver(Solver, MargVars, AllVars, _, SolverVars). % loop for as long as you want. em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF, FTables) :- @@ -82,8 +90,9 @@ em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF, FTables) :- ). ltables([], []). -ltables([Id-T|Tables], [Id-LTable|FTables]) :- +ltables([Id-T|Tables], [Key-LTable|FTables]) :- matrix_to_list(T,LTable), + get_dist_key(Id, Key), ltables(Tables, FTables). @@ -92,7 +101,8 @@ ltables([Id-T|Tables], [Id-LTable|FTables]) :- different_dists(AllVars, AllDists, AllInfo, MargVars) :- all_dists(AllVars, Dists0), sort(Dists0, Dists1), - group(Dists1, AllDists, AllInfo, MargVars, []). + group(Dists1, AllDists, AllInfo, MargVars0, []), + sort(MargVars0, MargVars). all_dists([], []). all_dists([V|AllVars], [i(Id, [V|Parents], Cases, Hiddens)|Dists]) :- @@ -143,30 +153,46 @@ same_id([i(Id,Vs,Cases,Hs)|Dists1], Id, [i(Id, Vs, Cases, Hs)|Extra], Rest) --> same_id(Dists1, Id, Extra, Rest). same_id(Dists, _, [], Dists) --> []. -estimate(state(Vars, _, _, Margs), LPs) :- - clpbn_run_solver(Margs, Vars, LPs). -maximise(state(_,_,DistInstances,_), Tables, LPs, Likelihood) :- - compute_parameters(DistInstances, Tables, LPs, 0.0, Likelihood). +compact_mvars([], []). +compact_mvars([X1,X2|MargVars], CMVars) :- X1 == X2, !, + compact_mvars([X2|MargVars], CMVars). +compact_mvars([X|MargVars], [X|CMVars]) :- !, + compact_mvars(MargVars, CMVars). -compute_parameters([], [], [], Lik, Lik). -compute_parameters([Id-Samples|Dists], [Id-NewTable|Tables], Ps, Lik0, Lik) :- +estimate(state(_, _, Margs, SolverState), LPs) :- + clpbn_flag(em_solver, Solver), + clpbn_run_solver(Solver, Margs, LPs, SolverState). + +maximise(state(_,DistInstances,MargVars,_), Tables, LPs, Likelihood) :- + rb_new(MDistTable0), + create_mdist_table(MargVars,LPs,MDistTable0,MDistTable), + compute_parameters(DistInstances, Tables, MDistTable, 0.0, Likelihood). + +create_mdist_table([],[],MDistTable,MDistTable). +create_mdist_table([Vs|MargVars],[Ps|LPs],MDistTable0,MDistTable) :- + rb_insert(MDistTable0, Vs, Ps, MDistTableI), + create_mdist_table(MargVars, LPs, MDistTableI ,MDistTable). + +compute_parameters([], [], _, Lik, Lik). +compute_parameters([Id-Samples|Dists], [Id-NewTable|Tables], MDistTable, Lik0, Lik) :- empty_dist(Id, Table0), - add_samples(Samples, Table0, Ps, MorePs), + add_samples(Samples, Table0, MDistTable), soften_sample(Table0, SoftenedTable), normalise_counts(SoftenedTable, NewTable), compute_likelihood(Table0, NewTable, DeltaLik), dist_new_table(Id, NewTable), NewLik is Lik0+DeltaLik, - compute_parameters(Dists, Tables, MorePs, NewLik, Lik). + compute_parameters(Dists, Tables, MDistTable, NewLik, Lik). -add_samples([], _, Ps, Ps). -add_samples([i(_,_,[Case],[])|Samples], Table, AllPs, RPs) :- !, +add_samples([], _, _). +add_samples([i(_,_,[Case],[])|Samples], Table, MDistTable) :- !, matrix_add(Table,Case,1.0), - add_samples(Samples, Table, AllPs, RPs). -add_samples([i(_,_,Cases,_)|Samples], Table, [Ps|AllPs], RPs) :- + add_samples(Samples, Table, MDistTable). +add_samples([i(_,_,Cases,Hiddens)|Samples], Table, MDistTable) :- + rb_lookup(Hiddens, Ps, MDistTable), run_sample(Cases, Ps, Table), - add_samples(Samples, Table, AllPs, RPs). + add_samples(Samples, Table, MDistTable). run_sample([], [], _). run_sample([C|Cases], [P|Ps], Table) :-