2009-02-16 12:23:29 +00:00
|
|
|
%
|
|
|
|
% The world famous EM algorithm, in a nutshell
|
|
|
|
%
|
|
|
|
|
|
|
|
:- module(clpbn_em, [em/5]).
|
|
|
|
|
|
|
|
:- use_module(library(lists),
|
2011-11-30 13:04:13 +00:00
|
|
|
[append/3,
|
|
|
|
delete/3]).
|
2009-02-16 12:23:29 +00:00
|
|
|
|
2012-09-29 11:50:00 +01:00
|
|
|
:- reexport(library(clpbn),
|
|
|
|
[
|
|
|
|
clpbn_flag/2,
|
|
|
|
clpbn_flag/3]).
|
|
|
|
|
2009-02-16 12:23:29 +00:00
|
|
|
:- use_module(library(clpbn),
|
2009-05-26 16:49:04 +01:00
|
|
|
[clpbn_init_graph/1,
|
|
|
|
clpbn_init_solver/5,
|
2009-02-16 12:23:29 +00:00
|
|
|
clpbn_run_solver/4,
|
2012-09-29 11:50:00 +01:00
|
|
|
pfl_init_solver/6,
|
|
|
|
pfl_run_solver/4,
|
2011-05-20 23:56:12 +01:00
|
|
|
clpbn_finalize_solver/1,
|
2011-11-30 13:04:13 +00:00
|
|
|
conditional_probability/3,
|
2009-02-16 12:23:29 +00:00
|
|
|
clpbn_flag/2]).
|
|
|
|
|
|
|
|
:- use_module(library('clpbn/dists'),
|
|
|
|
[get_dist_domain_size/2,
|
|
|
|
empty_dist/2,
|
|
|
|
dist_new_table/2,
|
|
|
|
get_dist_key/2,
|
|
|
|
randomise_all_dists/0,
|
|
|
|
uniformise_all_dists/0]).
|
|
|
|
|
2012-08-15 22:01:45 +01:00
|
|
|
:- use_module(library(clpbn/ground_factors),
|
|
|
|
[generate_network/5,
|
|
|
|
f/3]).
|
|
|
|
|
|
|
|
:- use_module(library(bhash), [
|
|
|
|
b_hash_new/1,
|
|
|
|
b_hash_lookup/3,
|
|
|
|
b_hash_insert/4]).
|
|
|
|
|
2009-02-16 12:23:29 +00:00
|
|
|
:- use_module(library('clpbn/learning/learn_utils'),
|
|
|
|
[run_all/1,
|
|
|
|
clpbn_vars/2,
|
|
|
|
normalise_counts/2,
|
|
|
|
compute_likelihood/3,
|
|
|
|
soften_sample/2]).
|
|
|
|
|
|
|
|
:- use_module(library(lists),
|
|
|
|
[member/2]).
|
|
|
|
|
2012-09-29 11:50:00 +01:00
|
|
|
:- use_module(library(maplist)).
|
|
|
|
|
2009-02-16 12:23:29 +00:00
|
|
|
:- use_module(library(matrix),
|
|
|
|
[matrix_add/3,
|
|
|
|
matrix_to_list/2]).
|
|
|
|
|
|
|
|
:- use_module(library(rbtrees),
|
|
|
|
[rb_new/1,
|
|
|
|
rb_insert/4,
|
|
|
|
rb_lookup/3]).
|
|
|
|
|
|
|
|
:- use_module(library('clpbn/utils'),
|
|
|
|
[
|
|
|
|
check_for_hidden_vars/3,
|
|
|
|
sort_vars_by_key/3]).
|
|
|
|
|
|
|
|
:- meta_predicate em(:,+,+,-,-), init_em(:,-).
|
|
|
|
|
|
|
|
em(Items, MaxError, MaxIts, Tables, Likelihood) :-
|
|
|
|
catch(init_em(Items, State),Error,handle_em(Error)),
|
|
|
|
em_loop(0, 0.0, State, MaxError, MaxIts, Likelihood, Tables),
|
2011-05-17 12:00:33 +01:00
|
|
|
clpbn_finalize_solver(State),
|
2009-02-16 12:23:29 +00:00
|
|
|
assert(em_found(Tables, Likelihood)),
|
|
|
|
fail.
|
|
|
|
% get rid of new random variables the easy way :)
|
|
|
|
em(_, _, _, Tables, Likelihood) :-
|
|
|
|
retract(em_found(Tables, Likelihood)).
|
|
|
|
|
|
|
|
|
2012-08-15 22:01:45 +01:00
|
|
|
handle_em(error(repeated_parents)) :- !,
|
2009-02-16 12:23:29 +00:00
|
|
|
assert(em_found(_, -inf)),
|
2009-05-26 16:49:04 +01:00
|
|
|
fail.
|
2012-08-15 22:01:45 +01:00
|
|
|
handle_em(Error) :-
|
|
|
|
throw(Error).
|
2009-02-16 12:23:29 +00:00
|
|
|
|
|
|
|
% This gets you an initial configuration. If there is a lot of evidence
|
|
|
|
% tables may be filled in close to optimal, otherwise they may be
|
|
|
|
% close to uniform.
|
|
|
|
% it also gets you a run for random variables
|
|
|
|
|
|
|
|
% state collects all Info we need for the EM algorithm
|
|
|
|
% it includes the list of variables without evidence,
|
|
|
|
% the list of distributions for which we want to compute parameters,
|
|
|
|
% and more detailed info on distributions, namely with a list of all instances for the distribution.
|
2012-08-15 22:01:45 +01:00
|
|
|
init_em(Items, State) :-
|
2009-05-26 16:49:04 +01:00
|
|
|
clpbn_flag(em_solver, Solver),
|
2012-06-22 19:00:12 +01:00
|
|
|
% only used for PCGs
|
2009-05-26 16:49:04 +01:00
|
|
|
clpbn_init_graph(Solver),
|
2009-02-16 12:23:29 +00:00
|
|
|
% randomise_all_dists,
|
2012-06-22 19:00:12 +01:00
|
|
|
% set initial values for distributions
|
2009-02-16 12:23:29 +00:00
|
|
|
uniformise_all_dists,
|
2012-09-29 11:50:00 +01:00
|
|
|
setup_em_network(Items, Solver, State).
|
2012-08-15 22:01:45 +01:00
|
|
|
|
2012-09-29 11:50:00 +01:00
|
|
|
setup_em_network(Items, Solver, state( AllDists, AllDistInstances, MargKeys, SolverState)) :-
|
2012-08-15 22:01:45 +01:00
|
|
|
clpbn:use_parfactors(on), !,
|
|
|
|
% get all variables to marginalise
|
2012-09-29 11:50:00 +01:00
|
|
|
run_examples(Items, Keys, Factors, EList),
|
2012-08-15 22:01:45 +01:00
|
|
|
% get the EM CPT connections info from the factors
|
2012-09-29 11:50:00 +01:00
|
|
|
generate_dists(Factors, EList, AllDists, AllDistInstances, MargKeys),
|
2012-08-15 22:01:45 +01:00
|
|
|
% setup solver, if necessary
|
2012-09-29 11:50:00 +01:00
|
|
|
pfl_init_solver(MargKeys, Keys, Factors, EList, SolverState, Solver).
|
|
|
|
setup_em_network(Items, Solver, state( AllDists, AllDistInstances, MargVars, SolverVars)) :-
|
|
|
|
% create the ground network
|
|
|
|
call_run_all(Items),
|
2012-08-15 22:01:45 +01:00
|
|
|
% get all variables to marginalise
|
2009-02-16 12:23:29 +00:00
|
|
|
attributes:all_attvars(AllVars0),
|
2012-06-22 19:00:12 +01:00
|
|
|
% and order them
|
2009-02-16 12:23:29 +00:00
|
|
|
sort_vars_by_key(AllVars0,AllVars,[]),
|
|
|
|
% remove variables that do not have to do with this query.
|
|
|
|
different_dists(AllVars, AllDists, AllDistInstances, MargVars),
|
2012-06-22 19:00:12 +01:00
|
|
|
% setup solver by doing parameter independent work.
|
2009-02-16 12:23:29 +00:00
|
|
|
clpbn_init_solver(Solver, MargVars, AllVars, _, SolverVars).
|
|
|
|
|
2012-09-29 11:50:00 +01:00
|
|
|
run_examples(user:Exs, Keys, Factors, EList) :-
|
|
|
|
Exs = [_:_|_], !,
|
|
|
|
trace,
|
|
|
|
findall(ex(EKs, EFs, EEs), run_example(Exs, EKs, EFs, EEs),
|
|
|
|
VExs),
|
|
|
|
foldl4(join_example, VExs, [], Keys, [], Factors, [], EList, 0, _).
|
|
|
|
run_examples(Items, Keys, Factors, EList) :-
|
|
|
|
run_ex(Items, Keys, Factors, EList).
|
|
|
|
|
|
|
|
join_example( ex(EKs, EFs, EEs), Keys0, Keys, Factors0, Factors, EList0, EList, I0, I) :-
|
|
|
|
I is I0+1,
|
|
|
|
foldl(process_key(I0), EKs, Keys0, Keys),
|
|
|
|
foldl(process_factor(I0), EFs, Factors0, Factors),
|
|
|
|
foldl(process_ev(I0), EEs, EList0, EList).
|
|
|
|
|
|
|
|
process_key(I0, K, Keys0, [I0:K|Keys0]).
|
|
|
|
|
|
|
|
process_factor(I0, f(Type, Id, Keys), Keys0, [f(Type, Id, NKeys)|Keys0]) :-
|
|
|
|
maplist(update_key(I0), Keys, NKeys).
|
|
|
|
|
|
|
|
update_key(I0, K, I0:K).
|
|
|
|
|
|
|
|
process_ev(I0, K=V, Es0, [(I0:K)=V|Es0]).
|
|
|
|
|
|
|
|
run_example([_:Items|_], Keys, Factors, EList) :-
|
|
|
|
run_ex(user:Items, Keys, Factors, EList).
|
|
|
|
run_example([_|LItems], Keys, Factors, EList) :-
|
|
|
|
run_example(LItems, Keys, Factors, EList).
|
|
|
|
|
|
|
|
run_ex(Items, Keys, Factors, EList) :-
|
|
|
|
% create the ground network
|
|
|
|
call_run_all(Items),
|
|
|
|
attributes:all_attvars(AllVars0),
|
|
|
|
% and order them
|
|
|
|
sort_vars_by_key(AllVars0,AllVars,[]),
|
|
|
|
% no, we are in trouble because we don't know the network yet.
|
|
|
|
% get the ground network
|
|
|
|
generate_network(AllVars, _, Keys, Factors, EList).
|
|
|
|
|
2009-02-16 12:23:29 +00:00
|
|
|
% loop for as long as you want.
|
|
|
|
em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF, FTables) :-
|
|
|
|
estimate(State, LPs),
|
|
|
|
maximise(State, Tables, LPs, Likelihood),
|
2012-08-15 22:01:45 +01:00
|
|
|
ltables(Tables, F0Tables),
|
|
|
|
writeln(iteration:Its:Likelihood:Its:Likelihood0:F0Tables),
|
2009-02-16 12:23:29 +00:00
|
|
|
(
|
|
|
|
(
|
|
|
|
abs((Likelihood - Likelihood0)/Likelihood) < MaxError
|
|
|
|
;
|
|
|
|
Its == MaxIts
|
|
|
|
)
|
|
|
|
->
|
|
|
|
ltables(Tables, FTables),
|
|
|
|
LikelihoodF = Likelihood
|
|
|
|
;
|
|
|
|
Its1 is Its+1,
|
|
|
|
em_loop(Its1, Likelihood, State, MaxError, MaxIts, LikelihoodF, FTables)
|
|
|
|
).
|
|
|
|
|
|
|
|
ltables([], []).
|
|
|
|
ltables([Id-T|Tables], [Key-LTable|FTables]) :-
|
|
|
|
matrix_to_list(T,LTable),
|
|
|
|
get_dist_key(Id, Key),
|
|
|
|
ltables(Tables, FTables).
|
|
|
|
|
2012-08-15 22:01:45 +01:00
|
|
|
|
|
|
|
generate_dists(Factors, EList, AllDists, AllInfo, MargVars) :-
|
2012-09-29 11:50:00 +01:00
|
|
|
b_hash_new(Ev0),
|
|
|
|
foldl(elist_to_hash, EList, Ev0, Ev),
|
|
|
|
maplist(process_factor(Ev), Factors, Dists0),
|
|
|
|
sort(Dists0, Dists1),
|
|
|
|
group(Dists1, AllDists, AllInfo, MargVars0, []),
|
|
|
|
sort(MargVars0, MargVars).
|
|
|
|
|
|
|
|
elist_to_hash(K=V, Ev0, Ev) :-
|
|
|
|
b_hash_insert(Ev0, K, V, Ev).
|
|
|
|
|
|
|
|
process_factor(Ev, f(bayes,Id,Ks), i(Id, Ks, Cases, NonEvs)) :-
|
|
|
|
foldl( fetch_evidence(Ev), Ks, CompactCases, [], NonEvs),
|
|
|
|
uncompact_cases(CompactCases, Cases).
|
|
|
|
|
|
|
|
fetch_evidence(Ev, K, E, NonEvs, NonEvs) :-
|
|
|
|
b_hash_lookup(K, E, Ev), !.
|
|
|
|
fetch_evidence(_Ev, _Id:K, Ns, NonEvs, [K|NonEvs]) :-
|
|
|
|
pfl:skolem(K,D), !,
|
|
|
|
foldl(domain_to_number, D, Ns, 0, _).
|
|
|
|
fetch_evidence(_Ev, K, Ns, NonEvs, [K|NonEvs]) :-
|
2012-08-15 22:01:45 +01:00
|
|
|
pfl:skolem(K,D),
|
2012-09-29 11:50:00 +01:00
|
|
|
foldl(domain_to_number, D, Ns, 0, _).
|
2012-08-15 22:01:45 +01:00
|
|
|
|
2012-09-29 11:50:00 +01:00
|
|
|
domain_to_number(_, I0, I0, I) :-
|
|
|
|
I is I0+1.
|
2012-08-15 22:01:45 +01:00
|
|
|
|
|
|
|
|
2009-02-16 12:23:29 +00:00
|
|
|
% collect the different dists we are going to learn next.
|
|
|
|
different_dists(AllVars, AllDists, AllInfo, MargVars) :-
|
2012-06-22 19:00:12 +01:00
|
|
|
all_dists(AllVars, AllVars, Dists0),
|
2009-02-16 12:23:29 +00:00
|
|
|
sort(Dists0, Dists1),
|
|
|
|
group(Dists1, AllDists, AllInfo, MargVars0, []),
|
|
|
|
sort(MargVars0, MargVars).
|
|
|
|
|
2012-06-22 19:00:12 +01:00
|
|
|
%
|
|
|
|
% V -> to Id defining V. We get:
|
|
|
|
% the random variables that are parents
|
|
|
|
% the cases that can happen, eg if we have A <- B, C
|
|
|
|
% A and B are boolean w/o evidence, and C is f, the cases could be
|
|
|
|
% [0,0,1], [0,1,1], [1,0,0], [1,1,0],
|
|
|
|
% Hiddens will be C
|
|
|
|
%
|
|
|
|
all_dists([], _, []).
|
|
|
|
all_dists([V|AllVars], AllVars0, [i(Id, [V|Parents], Cases, Hiddens)|Dists]) :-
|
|
|
|
% V is an instance of Id
|
2009-02-16 12:23:29 +00:00
|
|
|
clpbn:get_atts(V, [dist(Id,Parents)]),
|
2012-06-22 19:00:12 +01:00
|
|
|
sort([V|Parents], Sorted),
|
2009-02-16 12:23:29 +00:00
|
|
|
length(Sorted, LengSorted),
|
|
|
|
length(Parents, LengParents),
|
|
|
|
(
|
|
|
|
LengParents+1 =:= LengSorted
|
|
|
|
->
|
|
|
|
true
|
|
|
|
;
|
|
|
|
throw(error(repeated_parents))
|
2012-06-22 19:00:12 +01:00
|
|
|
),
|
2009-02-16 12:23:29 +00:00
|
|
|
generate_hidden_cases([V|Parents], CompactCases, Hiddens),
|
|
|
|
uncompact_cases(CompactCases, Cases),
|
2012-06-22 19:00:12 +01:00
|
|
|
all_dists(AllVars, AllVars0, Dists).
|
|
|
|
|
2009-02-16 12:23:29 +00:00
|
|
|
generate_hidden_cases([], [], []).
|
|
|
|
generate_hidden_cases([V|Parents], [P|Cases], Hiddens) :-
|
|
|
|
clpbn:get_atts(V, [evidence(P)]), !,
|
|
|
|
generate_hidden_cases(Parents, Cases, Hiddens).
|
|
|
|
generate_hidden_cases([V|Parents], [Cases|MoreCases], [V|Hiddens]) :-
|
|
|
|
clpbn:get_atts(V, [dist(Id,_)]),
|
|
|
|
get_dist_domain_size(Id, Sz),
|
|
|
|
gen_cases(0, Sz, Cases),
|
|
|
|
generate_hidden_cases(Parents, MoreCases, Hiddens).
|
2012-08-15 22:01:45 +01:00
|
|
|
|
2009-02-16 12:23:29 +00:00
|
|
|
gen_cases(Sz, Sz, []) :- !.
|
|
|
|
gen_cases(I, Sz, [I|Cases]) :-
|
|
|
|
I1 is I+1,
|
|
|
|
gen_cases(I1, Sz, Cases).
|
|
|
|
|
|
|
|
uncompact_cases(CompactCases, Cases) :-
|
|
|
|
findall(Case, is_case(CompactCases, Case), Cases).
|
|
|
|
|
|
|
|
is_case([], []).
|
|
|
|
is_case([A|CompactCases], [A|Case]) :-
|
|
|
|
integer(A), !,
|
|
|
|
is_case(CompactCases, Case).
|
|
|
|
is_case([L|CompactCases], [C|Case]) :-
|
|
|
|
member(C, L),
|
|
|
|
is_case(CompactCases, Case).
|
|
|
|
|
|
|
|
group([], [], []) --> [].
|
|
|
|
group([i(Id,Ps,Cs,[])|Dists1], [Id|Ids], [Id-[i(Id,Ps,Cs,[])|Extra]|AllInfo]) --> !,
|
|
|
|
same_id(Dists1, Id, Extra, Rest),
|
|
|
|
group(Rest, Ids, AllInfo).
|
|
|
|
group([i(Id,Ps,Cs,Hs)|Dists1], [Id|Ids], [Id-[i(Id,Ps,Cs,Hs)|Extra]|AllInfo]) -->
|
|
|
|
[Hs],
|
|
|
|
same_id(Dists1, Id, Extra, Rest),
|
|
|
|
group(Rest, Ids, AllInfo).
|
|
|
|
|
|
|
|
same_id([i(Id,Vs,Cases,[])|Dists1], Id, [i(Id, Vs, Cases, [])|Extra], Rest) --> !,
|
|
|
|
same_id(Dists1, Id, Extra, Rest).
|
|
|
|
same_id([i(Id,Vs,Cases,Hs)|Dists1], Id, [i(Id, Vs, Cases, Hs)|Extra], Rest) --> !,
|
|
|
|
[Hs],
|
|
|
|
same_id(Dists1, Id, Extra, Rest).
|
|
|
|
same_id(Dists, _, [], Dists) --> [].
|
|
|
|
|
|
|
|
|
|
|
|
compact_mvars([], []).
|
|
|
|
compact_mvars([X1,X2|MargVars], CMVars) :- X1 == X2, !,
|
|
|
|
compact_mvars([X2|MargVars], CMVars).
|
|
|
|
compact_mvars([X|MargVars], [X|CMVars]) :- !,
|
|
|
|
compact_mvars(MargVars, CMVars).
|
|
|
|
|
2012-09-29 11:50:00 +01:00
|
|
|
estimate(state(_, _, Margs, SolverState), LPs) :-
|
|
|
|
clpbn:use_parfactors(on), !,
|
|
|
|
clpbn_flag(em_solver, Solver),
|
|
|
|
pfl_run_solver(Margs, LPs, SolverState, Solver).
|
2009-02-16 12:23:29 +00:00
|
|
|
estimate(state(_, _, Margs, SolverState), LPs) :-
|
|
|
|
clpbn_flag(em_solver, Solver),
|
|
|
|
clpbn_run_solver(Solver, Margs, LPs, SolverState).
|
|
|
|
|
|
|
|
maximise(state(_,DistInstances,MargVars,_), Tables, LPs, Likelihood) :-
|
|
|
|
rb_new(MDistTable0),
|
2012-09-29 11:50:00 +01:00
|
|
|
foldl(create_mdist_table, MargVars, LPs, MDistTable0, MDistTable),
|
2009-02-16 12:23:29 +00:00
|
|
|
compute_parameters(DistInstances, Tables, MDistTable, 0.0, Likelihood, LPs:MargVars).
|
|
|
|
|
2012-09-29 11:50:00 +01:00
|
|
|
create_mdist_table(Vs, Ps, MDistTable0, MDistTable) :-
|
|
|
|
rb_insert(MDistTable0, Vs, Ps, MDistTable).
|
2009-02-16 12:23:29 +00:00
|
|
|
|
|
|
|
compute_parameters([], [], _, Lik, Lik, _).
|
|
|
|
compute_parameters([Id-Samples|Dists], [Id-NewTable|Tables], MDistTable, Lik0, Lik, LPs:MargVars) :-
|
|
|
|
empty_dist(Id, Table0),
|
|
|
|
add_samples(Samples, Table0, MDistTable),
|
2011-11-30 13:04:13 +00:00
|
|
|
%matrix_to_list(Table0,Mat), lists:sumlist(Mat, Sum), format(user_error, 'FINAL ~d ~w ~w~n', [Id,Sum,Mat]),
|
2009-02-16 12:23:29 +00:00
|
|
|
soften_sample(Table0, SoftenedTable),
|
2011-05-20 23:56:12 +01:00
|
|
|
% matrix:matrix_sum(Table0,TotM),
|
2009-02-16 12:23:29 +00:00
|
|
|
normalise_counts(SoftenedTable, NewTable),
|
|
|
|
compute_likelihood(Table0, NewTable, DeltaLik),
|
|
|
|
dist_new_table(Id, NewTable),
|
|
|
|
NewLik is Lik0+DeltaLik,
|
|
|
|
compute_parameters(Dists, Tables, MDistTable, NewLik, Lik, LPs:MargVars).
|
|
|
|
|
|
|
|
add_samples([], _, _).
|
|
|
|
add_samples([i(_,_,[Case],[])|Samples], Table, MDistTable) :- !,
|
|
|
|
matrix_add(Table,Case,1.0),
|
|
|
|
add_samples(Samples, Table, MDistTable).
|
|
|
|
add_samples([i(_,_,Cases,Hiddens)|Samples], Table, MDistTable) :-
|
|
|
|
rb_lookup(Hiddens, Ps, MDistTable),
|
|
|
|
run_sample(Cases, Ps, Table),
|
2011-09-24 21:39:37 +01:00
|
|
|
%matrix_to_list(Table,M), format(user_error, '~w ~w~n', [Cases,Ps]),
|
2009-02-16 12:23:29 +00:00
|
|
|
add_samples(Samples, Table, MDistTable).
|
|
|
|
|
|
|
|
run_sample([], [], _).
|
|
|
|
run_sample([C|Cases], [P|Ps], Table) :-
|
|
|
|
matrix_add(Table, C, P),
|
|
|
|
run_sample(Cases, Ps, Table).
|
|
|
|
|
2009-05-26 16:49:04 +01:00
|
|
|
call_run_all(Mod:Items) :-
|
2009-10-21 00:05:23 +01:00
|
|
|
clpbn_flag(em_solver, pcg), !,
|
2009-05-26 16:49:04 +01:00
|
|
|
backtrack_run_all(Items, Mod).
|
2009-10-21 00:05:23 +01:00
|
|
|
call_run_all(Mod:Items) :-
|
|
|
|
run_all(Mod:Items).
|
2009-02-16 12:23:29 +00:00
|
|
|
|
2009-05-26 16:49:04 +01:00
|
|
|
backtrack_run_all([Item|_], Mod) :-
|
|
|
|
call(Mod:Item),
|
|
|
|
fail.
|
|
|
|
backtrack_run_all([_|Items], Mod) :-
|
|
|
|
backtrack_run_all(Items, Mod).
|
|
|
|
backtrack_run_all([], _).
|
2011-11-30 13:04:13 +00:00
|
|
|
|