more EM stuff

This commit is contained in:
Vítor Santos Costa
2012-09-29 11:50:00 +01:00
parent 78a08e1b87
commit 793907f710
11 changed files with 218 additions and 125 deletions

View File

@@ -8,10 +8,17 @@
[append/3,
delete/3]).
:- reexport(library(clpbn),
[
clpbn_flag/2,
clpbn_flag/3]).
:- use_module(library(clpbn),
[clpbn_init_graph/1,
clpbn_init_solver/5,
clpbn_run_solver/4,
pfl_init_solver/6,
pfl_run_solver/4,
clpbn_finalize_solver/1,
conditional_probability/3,
clpbn_flag/2]).
@@ -43,6 +50,8 @@
:- use_module(library(lists),
[member/2]).
:- use_module(library(maplist)).
:- use_module(library(matrix),
[matrix_add/3,
matrix_to_list/2]).
@@ -89,27 +98,22 @@ init_em(Items, State) :-
clpbn_flag(em_solver, Solver),
% only used for PCGs
clpbn_init_graph(Solver),
% create the ground network
call_run_all(Items),
% randomise_all_dists,
% set initial values for distributions
uniformise_all_dists,
setup_em_network(Solver, State).
setup_em_network(Items, Solver, State).
setup_em_network(Solver, state( AllDists, AllDistInstances, MargVars, SolverState)) :-
setup_em_network(Items, Solver, state( AllDists, AllDistInstances, MargKeys, SolverState)) :-
clpbn:use_parfactors(on), !,
% get all variables to marginalise
attributes:all_attvars(AllVars0),
% and order them
sort_vars_by_key(AllVars0,AllVars,[]),
% no, we are in trouble because we don't know the network yet.
% get the ground network
generate_network(AllVars, _, Keys, Factors, EList),
run_examples(Items, Keys, Factors, EList),
% get the EM CPT connections info from the factors
generate_dists(Factors, EList, AllDists, AllDistInstances, MargVars),
generate_dists(Factors, EList, AllDists, AllDistInstances, MargKeys),
% setup solver, if necessary
clpbn_init_solver(Solver, MargVars, _AllVars, ground(MargVars, Keys, Factors, EList), SolverState).
setup_em_network(Solver, state( AllDists, AllDistInstances, MargVars, SolverVars)) :-
pfl_init_solver(MargKeys, Keys, Factors, EList, SolverState, Solver).
setup_em_network(Items, Solver, state( AllDists, AllDistInstances, MargVars, SolverVars)) :-
% create the ground network
call_run_all(Items),
% get all variables to marginalise
attributes:all_attvars(AllVars0),
% and order them
@@ -119,6 +123,45 @@ setup_em_network(Solver, state( AllDists, AllDistInstances, MargVars, SolverVars
% setup solver by doing parameter independent work.
clpbn_init_solver(Solver, MargVars, AllVars, _, SolverVars).
run_examples(user:Exs, Keys, Factors, EList) :-
Exs = [_:_|_], !,
trace,
findall(ex(EKs, EFs, EEs), run_example(Exs, EKs, EFs, EEs),
VExs),
foldl4(join_example, VExs, [], Keys, [], Factors, [], EList, 0, _).
run_examples(Items, Keys, Factors, EList) :-
run_ex(Items, Keys, Factors, EList).
join_example( ex(EKs, EFs, EEs), Keys0, Keys, Factors0, Factors, EList0, EList, I0, I) :-
I is I0+1,
foldl(process_key(I0), EKs, Keys0, Keys),
foldl(process_factor(I0), EFs, Factors0, Factors),
foldl(process_ev(I0), EEs, EList0, EList).
process_key(I0, K, Keys0, [I0:K|Keys0]).
process_factor(I0, f(Type, Id, Keys), Keys0, [f(Type, Id, NKeys)|Keys0]) :-
maplist(update_key(I0), Keys, NKeys).
update_key(I0, K, I0:K).
process_ev(I0, K=V, Es0, [(I0:K)=V|Es0]).
run_example([_:Items|_], Keys, Factors, EList) :-
run_ex(user:Items, Keys, Factors, EList).
run_example([_|LItems], Keys, Factors, EList) :-
run_example(LItems, Keys, Factors, EList).
run_ex(Items, Keys, Factors, EList) :-
% create the ground network
call_run_all(Items),
attributes:all_attvars(AllVars0),
% and order them
sort_vars_by_key(AllVars0,AllVars,[]),
% no, we are in trouble because we don't know the network yet.
% get the ground network
generate_network(AllVars, _, Keys, Factors, EList).
% loop for as long as you want.
em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF, FTables) :-
estimate(State, LPs),
@@ -147,37 +190,31 @@ ltables([Id-T|Tables], [Key-LTable|FTables]) :-
generate_dists(Factors, EList, AllDists, AllInfo, MargVars) :-
b_hash_new(Ev0),
elist_to_hash(EList, Ev0, Ev),
process_factors(Factors, Ev, Dists0),
sort(Dists0, Dists1),
group(Dists1, AllDists, AllInfo, MargVars0, []),
sort(MargVars0, MargVars).
b_hash_new(Ev0),
foldl(elist_to_hash, EList, Ev0, Ev),
maplist(process_factor(Ev), Factors, Dists0),
sort(Dists0, Dists1),
group(Dists1, AllDists, AllInfo, MargVars0, []),
sort(MargVars0, MargVars).
elist_to_hash([], Ev, Ev).
elist_to_hash([K=V|EList], Ev0, Ev) :-
b_hash_insert(Ev0, K, V, Evi),
elist_to_hash(EList, Evi, Ev).
elist_to_hash(K=V, Ev0, Ev) :-
b_hash_insert(Ev0, K, V, Ev).
process_factors([], _Ev, []).
process_factors([f(bayes,Id,Ks)|Factors], Ev, [i(Id, Ks, Cases, NonEvs)|AllDistInstances]) :-
fetch_evidence(Ks, Ev, CompactCases, NonEvs),
uncompact_cases(CompactCases, Cases),
process_factors(Factors, Ev, AllDistInstances).
process_factor(Ev, f(bayes,Id,Ks), i(Id, Ks, Cases, NonEvs)) :-
foldl( fetch_evidence(Ev), Ks, CompactCases, [], NonEvs),
uncompact_cases(CompactCases, Cases).
fetch_evidence([], _Ev, [], []).
fetch_evidence([K|Ks], Ev, [E|CompactCases], NonEvs) :-
b_hash_lookup(K, E, Ev), !,
fetch_evidence(Ks, Ev, CompactCases, NonEvs).
fetch_evidence([K|Ks], Ev, [Ns|CompactCases], [K|NonEvs]) :-
fetch_evidence(Ev, K, E, NonEvs, NonEvs) :-
b_hash_lookup(K, E, Ev), !.
fetch_evidence(_Ev, _Id:K, Ns, NonEvs, [K|NonEvs]) :-
pfl:skolem(K,D), !,
foldl(domain_to_number, D, Ns, 0, _).
fetch_evidence(_Ev, K, Ns, NonEvs, [K|NonEvs]) :-
pfl:skolem(K,D),
domain_to_numbers(D,0,Ns),
fetch_evidence(Ks, Ev, CompactCases, NonEvs).
foldl(domain_to_number, D, Ns, 0, _).
domain_to_numbers([],_,[]).
domain_to_numbers([_|D],I0,[I0|Ns]) :-
I is I0+1,
domain_to_numbers(D,I,Ns).
domain_to_number(_, I0, I0, I) :-
I is I0+1.
% collect the different dists we are going to learn next.
@@ -213,24 +250,6 @@ all_dists([V|AllVars], AllVars0, [i(Id, [V|Parents], Cases, Hiddens)|Dists]) :-
uncompact_cases(CompactCases, Cases),
all_dists(AllVars, AllVars0, Dists).
find_variables([], _AllVars0, []).
find_variables([K|PKeys], AllVars0, [Parent|Parents]) :-
find_variable(K, AllVars0, Parent),
find_variables(PKeys, AllVars0, Parents).
%
% in clp(bn) the whole network is constructed when you evaluate EM. In
% pfl, we want to delay execution until as late as possible.
% we just create a new variable and hope for the best.
%
%
find_variable(K, [], Parent) :-
clpbn:put_atts(Parent, [key(K)]).
find_variable(K, [Parent|_AllVars0], Parent) :-
clpbn:get_atts(Parent, [key(K0)]), K0 =@= K, !.
find_variable(K, [_|AllVars0], Parent) :-
find_variable(K, AllVars0, Parent).
generate_hidden_cases([], [], []).
generate_hidden_cases([V|Parents], [P|Cases], Hiddens) :-
clpbn:get_atts(V, [evidence(P)]), !,
@@ -280,19 +299,21 @@ compact_mvars([X1,X2|MargVars], CMVars) :- X1 == X2, !,
compact_mvars([X|MargVars], [X|CMVars]) :- !,
compact_mvars(MargVars, CMVars).
estimate(state(_, _, Margs, SolverState), LPs) :-
clpbn:use_parfactors(on), !,
clpbn_flag(em_solver, Solver),
pfl_run_solver(Margs, LPs, SolverState, Solver).
estimate(state(_, _, Margs, SolverState), LPs) :-
clpbn_flag(em_solver, Solver),
clpbn_run_solver(Solver, Margs, LPs, SolverState).
maximise(state(_,DistInstances,MargVars,_), Tables, LPs, Likelihood) :-
rb_new(MDistTable0),
create_mdist_table(MargVars, LPs, MDistTable0, MDistTable),
foldl(create_mdist_table, MargVars, LPs, MDistTable0, MDistTable),
compute_parameters(DistInstances, Tables, MDistTable, 0.0, Likelihood, LPs:MargVars).
create_mdist_table([],[],MDistTable,MDistTable).
create_mdist_table([Vs|MargVars],[Ps|LPs],MDistTable0,MDistTable) :-
rb_insert(MDistTable0, Vs, Ps, MDistTableI),
create_mdist_table(MargVars, LPs, MDistTableI ,MDistTable).
create_mdist_table(Vs, Ps, MDistTable0, MDistTable) :-
rb_insert(MDistTable0, Vs, Ps, MDistTable).
compute_parameters([], [], _, Lik, Lik, _).
compute_parameters([Id-Samples|Dists], [Id-NewTable|Tables], MDistTable, Lik0, Lik, LPs:MargVars) :-