changes to support em: step 1
This commit is contained in:
parent
09ccb295c2
commit
3b811d0d70
@ -18,7 +18,8 @@
|
|||||||
get_evidence_from_position/3,
|
get_evidence_from_position/3,
|
||||||
dist_to_term/2,
|
dist_to_term/2,
|
||||||
empty_dist/2,
|
empty_dist/2,
|
||||||
dist_new_table/2
|
dist_new_table/2,
|
||||||
|
all_dist_ids/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
:- use_module(library(lists),[is_list/1,nth0/3]).
|
:- use_module(library(lists),[is_list/1,nth0/3]).
|
||||||
@ -229,4 +230,3 @@ dist_new_table(Id, NewMat) :-
|
|||||||
fail.
|
fail.
|
||||||
dist_new_table(_, _).
|
dist_new_table(_, _).
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,7 +8,8 @@
|
|||||||
%
|
%
|
||||||
|
|
||||||
:- module(gibbs, [gibbs/3,
|
:- module(gibbs, [gibbs/3,
|
||||||
check_if_gibbs_done/1]).
|
check_if_gibbs_done/1,
|
||||||
|
init_gibbs_solver/3]).
|
||||||
|
|
||||||
:- use_module(library(rbtrees),
|
:- use_module(library(rbtrees),
|
||||||
[rb_new/1,
|
[rb_new/1,
|
||||||
@ -49,9 +50,7 @@
|
|||||||
gibbs([],_,_) :- !.
|
gibbs([],_,_) :- !.
|
||||||
gibbs(LVs,Vs0,AllDiffs) :-
|
gibbs(LVs,Vs0,AllDiffs) :-
|
||||||
LVs = [_], !,
|
LVs = [_], !,
|
||||||
clean_up,
|
init_gibbs_solver(Vs0, LVs, Gibbs),
|
||||||
check_for_hidden_vars(Vs0, Vs0, Vs1),
|
|
||||||
sort(Vs1,Vs),
|
|
||||||
(clpbn:output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,Vs) ; true),
|
(clpbn:output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,Vs) ; true),
|
||||||
(clpbn:output(gviz(XBifStream)) -> clpbn2gviz(XBifStream,vel,Vs,LVs) ; true),
|
(clpbn:output(gviz(XBifStream)) -> clpbn2gviz(XBifStream,vel,Vs,LVs) ; true),
|
||||||
initialise(Vs, Graph, LVs, OutputVars, VarOrder),
|
initialise(Vs, Graph, LVs, OutputVars, VarOrder),
|
||||||
@ -64,6 +63,11 @@ gibbs(LVs,Vs0,AllDiffs) :-
|
|||||||
gibbs(LVs,_,_) :-
|
gibbs(LVs,_,_) :-
|
||||||
throw(error(domain_error(solver,LVs),solver(gibbs))).
|
throw(error(domain_error(solver,LVs),solver(gibbs))).
|
||||||
|
|
||||||
|
init_gibbs_solver(LVs, Vs0, Gibbs) :-
|
||||||
|
clean_up,
|
||||||
|
check_for_hidden_vars(Vs0, Vs0, Vs1),
|
||||||
|
sort(Vs1,Vs).
|
||||||
|
|
||||||
initialise(LVs, Graph, GVs, OutputVars, VarOrder) :-
|
initialise(LVs, Graph, GVs, OutputVars, VarOrder) :-
|
||||||
init_keys(Keys0),
|
init_keys(Keys0),
|
||||||
gen_keys(LVs, 0, VLen, Keys0, Keys),
|
gen_keys(LVs, 0, VLen, Keys0, Keys),
|
||||||
|
114
CLPBN/learning/em.yap
Normal file
114
CLPBN/learning/em.yap
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
%
|
||||||
|
% The world famous EM algorithm, in a nutshell
|
||||||
|
%
|
||||||
|
|
||||||
|
:- module(clpbn_em, [em/6]).
|
||||||
|
|
||||||
|
:- use_module(library(lists),
|
||||||
|
[append/3]).
|
||||||
|
|
||||||
|
:- use_module(library('clpbn/learning/learn_utils'),
|
||||||
|
[run_all/1,
|
||||||
|
clpbn_vars/2,
|
||||||
|
normalise_counts/2]).
|
||||||
|
|
||||||
|
em(Items, MaxError, MaxIts, Tables, Likelihood) :-
|
||||||
|
init_em(Items, State),
|
||||||
|
em_loop(0, 0.0, state(AllVars,AllDists), MaxError, MaxIts, Likelihood),
|
||||||
|
get_tables(State, Tables).
|
||||||
|
|
||||||
|
% This gets you an initial configuration. If there is a lot of evidence
|
||||||
|
% tables may be filled in close to optimal, otherwise they may be
|
||||||
|
% close to uniform.
|
||||||
|
% it also gets you a run for random variables
|
||||||
|
init_em(Items, state(AllVars, AllDists, AllDistInstances)) :-
|
||||||
|
run_all(Items),
|
||||||
|
different_dists(AllVars, AllDists, AllDistInstances).
|
||||||
|
|
||||||
|
% loop for as long as you want.
|
||||||
|
em_loop(MaxIts, Likelihood State, _, _ MaxIts, Likelihood) :- !.
|
||||||
|
em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF) :-
|
||||||
|
estimate(State),
|
||||||
|
maximise(State, Likelihood),
|
||||||
|
(
|
||||||
|
(
|
||||||
|
(Likelihood - Likelihood0)/Likelihood < MaxError
|
||||||
|
;
|
||||||
|
Its == MaxIts
|
||||||
|
)
|
||||||
|
->
|
||||||
|
LikelihoodF = Likelihood
|
||||||
|
;
|
||||||
|
Its1 is Its+1,
|
||||||
|
em_loop(Its1, Likelihood, State, MaxError, MaxIts, LikelihoodF)
|
||||||
|
).
|
||||||
|
|
||||||
|
% collect the different dists we are going to learn next.
|
||||||
|
different_dists(AllVars, AllDists, AllInfo) :-
|
||||||
|
all_dists(AllVars, Dists0, AllInfo),
|
||||||
|
sort(Dists0, Dists1),
|
||||||
|
group(Dists1, AllInfo).
|
||||||
|
|
||||||
|
group([], []) :-
|
||||||
|
group([i(Id,V,Ps)|Dists1], [Id-[[V|Ps]|Extra]|AllInfo]) :-
|
||||||
|
same_id(Dists1, Id, Extra, Rest),
|
||||||
|
group(Rest, AllInfo).
|
||||||
|
|
||||||
|
same_id([i(Id,V,Ps)|Dists1], Id, [[V|Ps]|Extra], Rest) :- !,
|
||||||
|
same_id(Dists1, Id, Extra, Rest).
|
||||||
|
same_id(Dists, _, [], Dists).
|
||||||
|
|
||||||
|
all_dists([], [], []).
|
||||||
|
all_dists([V|AllVars], Dists, [i(Id, AllInfo, Parents)|AllInfo]) :-
|
||||||
|
clpbn:get_atts(V, [dist(Id,_)]),
|
||||||
|
with_evidence(V, Id, Dists, Dists0), !,
|
||||||
|
all_dists(AllVars, Dists0, AllInfo).
|
||||||
|
|
||||||
|
with_evidence(V, Id) -->
|
||||||
|
{clpbn:get_atts(V, [evidence(Pos)]) }, !,
|
||||||
|
{ dist_pos2bin(Pos, Id, Bin) }.
|
||||||
|
with_evidence(V, Id) -->
|
||||||
|
[d(V,Id)].
|
||||||
|
|
||||||
|
estimate(state(Vars,Info,_)) :-
|
||||||
|
clpbn_solve_graph(Vars, OVars),
|
||||||
|
marg_vars(Info, Vars).
|
||||||
|
|
||||||
|
marg_vars([], _).
|
||||||
|
marg_vars([d(V,Id)|Vars], AllVs) :-
|
||||||
|
clpbn_marginalise_in_vars(V, AllVs),
|
||||||
|
marg_vars(Vars, AllVs).
|
||||||
|
|
||||||
|
maximise(state(_,_,DistInstances), Tables, Likelihood) :-
|
||||||
|
compute_parameters(DistInstances, Tables, 0.0, Likelihood).
|
||||||
|
|
||||||
|
compute_parameters([], [], Lik, Lik).
|
||||||
|
compute_parameters([Id-Samples|Dists], [Tab|Tables], Lik0, Lik) :-
|
||||||
|
empty_dist(Id, NewTable),
|
||||||
|
add_samples(Samples, NewTable).
|
||||||
|
normalise_table(Id, NewTable),
|
||||||
|
compute_parameters(Dists, Tables, Lik0, Lik).
|
||||||
|
|
||||||
|
add_samples([], _).
|
||||||
|
add_samples([S|Samples], Table) :-
|
||||||
|
run_sample(S, 1.0, Pos, Tot),
|
||||||
|
matrix_add(Table, Pos, Tot),
|
||||||
|
fail.
|
||||||
|
add_samples([_|Samples], Table) :-
|
||||||
|
add_samples(Samples, Table)
|
||||||
|
|
||||||
|
run_sample([], Tot, [], Tot).
|
||||||
|
run_sample([V|S], W0, [P|Pos], Tot) :-
|
||||||
|
{clpbn:get_atts(V, [evidence(P)]) }, !,
|
||||||
|
run_sample(S, W0, Pos, Tot).
|
||||||
|
run_sample([V|S], W0, [P|Pos], Tot) :-
|
||||||
|
{clpbn_display:get_atts(V, [posterior,(_,_,Ps,_)]) },
|
||||||
|
count_cases(Ps, 0, D0, P),
|
||||||
|
W1 is D0*W0,
|
||||||
|
run_sample(S, W1, Pos, Tot).
|
||||||
|
|
||||||
|
count_cases([D0|Ps], I0, D0, I0).
|
||||||
|
count_cases([_|Ps], I0, P, W1) :-
|
||||||
|
I is I0+1,
|
||||||
|
count_cases(Ps, I, P, W1).
|
||||||
|
|
Reference in New Issue
Block a user