fixes to work with gibbs and vel.

This commit is contained in:
Vitor Santos Costa 2008-11-01 11:52:54 +00:00
parent de79f30e45
commit 11e02be540

View File

@ -8,13 +8,15 @@
[append/3]). [append/3]).
:- use_module(library(clpbn), :- use_module(library(clpbn),
[clpbn_init_solver/3, [clpbn_init_solver/5,
clpbn_run_solver/3]). clpbn_run_solver/4,
clpbn_flag/2]).
:- use_module(library('clpbn/dists'), :- use_module(library('clpbn/dists'),
[get_dist_domain_size/2, [get_dist_domain_size/2,
empty_dist/2, empty_dist/2,
dist_new_table/2]). dist_new_table/2,
get_dist_key/2]).
:- use_module(library('clpbn/connected'), :- use_module(library('clpbn/connected'),
[clpbn_subgraphs/2]). [clpbn_subgraphs/2]).
@ -33,6 +35,11 @@
[matrix_add/3, [matrix_add/3,
matrix_to_list/2]). matrix_to_list/2]).
:- use_module(library(rbtrees),
[rb_new/1,
rb_insert/4,
rb_lookup/3]).
:- use_module(library('clpbn/utils'), :- use_module(library('clpbn/utils'),
[ [
check_for_hidden_vars/3, check_for_hidden_vars/3,
@ -53,14 +60,15 @@ em(Items, MaxError, MaxIts, Tables, Likelihood) :-
% it includes the list of variables without evidence, % it includes the list of variables without evidence,
% the list of distributions for which we want to compute parameters, % the list of distributions for which we want to compute parameters,
% and more detailed info on distributions, namely with a list of all instances for the distribution. % and more detailed info on distributions, namely with a list of all instances for the distribution.
init_em(Items, state(AllVars, AllDists, AllDistInstances, MargVars)) :- init_em(Items, state( AllDists, AllDistInstances, MargVars, SolverVars)) :-
run_all(Items), run_all(Items),
attributes:all_attvars(AllVars0), attributes:all_attvars(AllVars0),
sort_vars_by_key(AllVars0,AllVars1,[]), sort_vars_by_key(AllVars0,AllVars1,[]),
% remove variables that do not have to do with this query. % remove variables that do not have to do with this query.
check_for_hidden_vars(AllVars1, AllVars1, AllVars), check_for_hidden_vars(AllVars1, AllVars1, AllVars),
different_dists(AllVars, AllDists, AllDistInstances, MargVars), different_dists(AllVars, AllDists, AllDistInstances, MargVars),
clpbn_init_solver(MargVars, AllVars, _). clpbn_flag(em_solver, Solver),
clpbn_init_solver(Solver, MargVars, AllVars, _, SolverVars).
% loop for as long as you want. % loop for as long as you want.
em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF, FTables) :- em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF, FTables) :-
@ -82,8 +90,9 @@ em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF, FTables) :-
). ).
ltables([], []). ltables([], []).
ltables([Id-T|Tables], [Id-LTable|FTables]) :- ltables([Id-T|Tables], [Key-LTable|FTables]) :-
matrix_to_list(T,LTable), matrix_to_list(T,LTable),
get_dist_key(Id, Key),
ltables(Tables, FTables). ltables(Tables, FTables).
@ -92,7 +101,8 @@ ltables([Id-T|Tables], [Id-LTable|FTables]) :-
different_dists(AllVars, AllDists, AllInfo, MargVars) :- different_dists(AllVars, AllDists, AllInfo, MargVars) :-
all_dists(AllVars, Dists0), all_dists(AllVars, Dists0),
sort(Dists0, Dists1), sort(Dists0, Dists1),
group(Dists1, AllDists, AllInfo, MargVars, []). group(Dists1, AllDists, AllInfo, MargVars0, []),
sort(MargVars0, MargVars).
all_dists([], []). all_dists([], []).
all_dists([V|AllVars], [i(Id, [V|Parents], Cases, Hiddens)|Dists]) :- all_dists([V|AllVars], [i(Id, [V|Parents], Cases, Hiddens)|Dists]) :-
@ -143,30 +153,46 @@ same_id([i(Id,Vs,Cases,Hs)|Dists1], Id, [i(Id, Vs, Cases, Hs)|Extra], Rest) -->
same_id(Dists1, Id, Extra, Rest). same_id(Dists1, Id, Extra, Rest).
same_id(Dists, _, [], Dists) --> []. same_id(Dists, _, [], Dists) --> [].
estimate(state(Vars, _, _, Margs), LPs) :-
clpbn_run_solver(Margs, Vars, LPs).
maximise(state(_,_,DistInstances,_), Tables, LPs, Likelihood) :- compact_mvars([], []).
compute_parameters(DistInstances, Tables, LPs, 0.0, Likelihood). compact_mvars([X1,X2|MargVars], CMVars) :- X1 == X2, !,
compact_mvars([X2|MargVars], CMVars).
compact_mvars([X|MargVars], [X|CMVars]) :- !,
compact_mvars(MargVars, CMVars).
compute_parameters([], [], [], Lik, Lik). estimate(state(_, _, Margs, SolverState), LPs) :-
compute_parameters([Id-Samples|Dists], [Id-NewTable|Tables], Ps, Lik0, Lik) :- clpbn_flag(em_solver, Solver),
clpbn_run_solver(Solver, Margs, LPs, SolverState).
maximise(state(_,DistInstances,MargVars,_), Tables, LPs, Likelihood) :-
rb_new(MDistTable0),
create_mdist_table(MargVars,LPs,MDistTable0,MDistTable),
compute_parameters(DistInstances, Tables, MDistTable, 0.0, Likelihood).
create_mdist_table([],[],MDistTable,MDistTable).
create_mdist_table([Vs|MargVars],[Ps|LPs],MDistTable0,MDistTable) :-
rb_insert(MDistTable0, Vs, Ps, MDistTableI),
create_mdist_table(MargVars, LPs, MDistTableI ,MDistTable).
compute_parameters([], [], _, Lik, Lik).
compute_parameters([Id-Samples|Dists], [Id-NewTable|Tables], MDistTable, Lik0, Lik) :-
empty_dist(Id, Table0), empty_dist(Id, Table0),
add_samples(Samples, Table0, Ps, MorePs), add_samples(Samples, Table0, MDistTable),
soften_sample(Table0, SoftenedTable), soften_sample(Table0, SoftenedTable),
normalise_counts(SoftenedTable, NewTable), normalise_counts(SoftenedTable, NewTable),
compute_likelihood(Table0, NewTable, DeltaLik), compute_likelihood(Table0, NewTable, DeltaLik),
dist_new_table(Id, NewTable), dist_new_table(Id, NewTable),
NewLik is Lik0+DeltaLik, NewLik is Lik0+DeltaLik,
compute_parameters(Dists, Tables, MorePs, NewLik, Lik). compute_parameters(Dists, Tables, MDistTable, NewLik, Lik).
add_samples([], _, Ps, Ps). add_samples([], _, _).
add_samples([i(_,_,[Case],[])|Samples], Table, AllPs, RPs) :- !, add_samples([i(_,_,[Case],[])|Samples], Table, MDistTable) :- !,
matrix_add(Table,Case,1.0), matrix_add(Table,Case,1.0),
add_samples(Samples, Table, AllPs, RPs). add_samples(Samples, Table, MDistTable).
add_samples([i(_,_,Cases,_)|Samples], Table, [Ps|AllPs], RPs) :- add_samples([i(_,_,Cases,Hiddens)|Samples], Table, MDistTable) :-
rb_lookup(Hiddens, Ps, MDistTable),
run_sample(Cases, Ps, Table), run_sample(Cases, Ps, Table),
add_samples(Samples, Table, AllPs, RPs). add_samples(Samples, Table, MDistTable).
run_sample([], [], _). run_sample([], [], _).
run_sample([C|Cases], [P|Ps], Table) :- run_sample([C|Cases], [P|Ps], Table) :-