new version of CLP(BN) with EM learning

This commit is contained in:
Vítor Santos de Costa 2008-10-22 00:44:02 +01:00
parent b3cb7b1071
commit 592fe9e366
13 changed files with 2978 additions and 231 deletions

View File

@ -5,8 +5,8 @@
set_clpbn_flag/2,
clpbn_flag/3,
clpbn_key/2,
clpbn_marginalise/2,
call_solver/2]).
clpbn_init_solver/3,
clpbn_run_solver/3]).
:- use_module(library(atts)).
:- use_module(library(lists)).
@ -42,7 +42,9 @@
:- use_module('clpbn/gibbs',
[gibbs/3,
check_if_gibbs_done/1
check_if_gibbs_done/1,
init_gibbs_solver/3,
run_gibbs_solver/3
]).
:- use_module('clpbn/graphs',
@ -52,7 +54,7 @@
:- use_module('clpbn/dists',
[
dist/3,
dist/4,
get_dist/4,
get_evidence_position/3,
get_evidence_from_position/3
@ -106,7 +108,7 @@ clpbn_flag(suppress_attribute_display,Before,After) :-
{Var = Key with Dist} :-
put_atts(El,[key(Key),dist(DistInfo,Parents)]),
dist(Dist, DistInfo, Parents),
dist(Dist, DistInfo, Key, Parents),
add_evidence(Var,DistInfo,El).
check_constraint(Constraint, _, _, Constraint) :- var(Constraint), !.
@ -158,7 +160,7 @@ call_solver(GVars, AVars) :-
clpbn_vars(AVars, DiffVars, AllVars),
get_clpbn_vars(GVars,CLPBNGVars0),
simplify_query_vars(CLPBNGVars0, CLPBNGVars),
write_out(Solver,CLPBNGVars, AllVars, DiffVars).
write_out(Solver,[CLPBNGVars], AllVars, DiffVars).
@ -332,3 +334,32 @@ user:term_expansion((A :- {}), ( :- true )) :- !, % evidence
clpbn_key(Var,Key) :-
get_atts(Var, [key(Key)]).
%
% This is a routine to start a solver, called by the learning procedures (ie, em).
% LVs is a list of lists of variables one is interested in eventually marginalising out
% Vs0 gives the original graph
% AllDiffs gives variables that are not fully constrainted, ie, we don't fully know
% the key. In this case, we assume different instances will be bound to different
% values at the end of the day.
%
clpbn_init_solver(LVs,Vs0,VarsWithUnboundKeys) :-
solver(Solver),
clpbn_init_known_solver(Solver,LVs,Vs0,VarsWithUnboundKeys).
clpbn_init_known_solver(gibbs, LVs, Vs0, VarsWithUnboundKeys) :- !,
init_gibbs_solver(LVs, Vs0, VarsWithUnboundKeys).
clpbn_init_known_solver(_, _, _, _).
%
% LVs is the list of lists of variables to marginalise
% Vs is the full graph
% Ps are the probabilities on LVs.
%
%
clpbn_run_solver(LVs,Vs,LPs) :-
solver(Solver),
clpbn_run_known_solver(Solver,LVs,Vs,LPs).
clpbn_run_known_solver(gibbs,LVs, Vs, LPs) :- !,
run_gibbs_solver(LVs, Vs, LPs).

View File

@ -390,12 +390,12 @@ evidence_val(Ev,I0,[_|Domain],Val) :-
I1 is I0+1,
evidence_val(Ev,I1,Domain,Val).
marginalize([V], _SortedVars,_NunmberedVars, Ps) :- !,
marginalize([[V]], _SortedVars,_NunmberedVars, Ps) :- !,
v2number(V,Pos),
marg <-- marginal_nodes(engine_ev, Pos),
matlab_get_variable( marg.'T', Ps).
marginalize(Vs, SortedVars, NumberedVars,Ps) :-
marginalize([Vs], SortedVars, NumberedVars,Ps) :-
bnt_solver(jtree),!,
matlab_get_variable(loglik, Den),
clpbn_display:get_all_combs(Vs, Vals),

View File

@ -37,14 +37,19 @@ add_alldiffs([],Eqs,Eqs).
add_alldiffs(AllDiffs,Eqs,(Eqs/alldiff(AllDiffs))).
clpbn_bind_vals([],_,_) :- !.
clpbn_bind_vals([],[],_).
clpbn_bind_vals([Vs|MoreVs],[Ps|MorePs],AllDiffs) :-
clpbn_bind_vals2(Vs, Ps, AllDiffs),
clpbn_bind_vals(MoreVs,MorePs,AllDiffs).
clpbn_bind_vals2([],_,_) :- !.
% simple case, we want a distribution on a single variable.
%bind_vals([V],Ps) :- !,
% clpbn:get_atts(V, [dist(Vals,_,_)]),
% put_atts(V, posterior([V], Vals, Ps)).
% complex case, we want a joint distribution, do it on a leader.
% should split on cliques ?
clpbn_bind_vals(Vs,Ps,AllDiffs) :-
clpbn_bind_vals2(Vs,Ps,AllDiffs) :-
get_all_combs(Vs, Vals),
Vs = [V|_],
put_atts(V, posterior(Vs, Vals, Ps, AllDiffs)).

View File

@ -5,7 +5,7 @@
:- module(clpbn_dist,
[
dist/1,
dist/3,
dist/4,
dists/1,
dist_new_table/2,
get_dist/4,
@ -51,22 +51,23 @@
/*******************************************
store stuff in a DB of the form:
db(Id, CPT, Type, Domain, CPTSize, DSize)
db(Id, Key, CPT, Type, Domain, CPTSize, DSize)
where Id is the id,
cptsize is the table size or -1,
DSize is the domain size,
Type is
tab for tabular
trans for HMMs
continuous
Domain is
a list of values
bool for [t,f]
aminoacids for [a,c,d,e,f,g,h,i,k,l,m,n,p,q,r,s,t,v,w,y]
dna for [a,c,g,t]
rna for [a,c,g,u]
reals
Key is a skeleton of the key(main functor only)
cptsize is the table size or -1,
DSize is the domain size,
Type is
tab for tabular
trans for HMMs
continuous
Domain is
a list of values
bool for [t,f]
aminoacids for [a,c,d,e,f,g,h,i,k,l,m,n,p,q,r,s,t,v,w,y]
dna for [a,c,g,t]
rna for [a,c,g,u]
reals
********************************************/
@ -82,13 +83,20 @@ new_id(Id) :-
dists(X) :- id(X1), X is X1-1.
dist(V, Id, Parents) :-
dist(V, Id, Key, Parents) :-
dist_unbound(V, Culprit), !,
when(Culprit, dist(V, Id, Parents)).
dist(p(Type, CPT), Id, FParents) :-
distribution(Type, CPT, Id, [], FParents).
dist(p(Type, CPT, Parents), Id, FParents) :-
distribution(Type, CPT, Id, Parents, FParents).
when(Culprit, dist(V, Id, Key, Parents)).
dist(V, Id, Key, Parents) :-
var(Key), !,
when(Key, dist(V, Id, Key, Parents)).
dist(p(Type, CPT), Id, Key, FParents) :-
functor(Key, Na, Ar),
functor(Key0, Na, Ar),
distribution(Type, CPT, Id, Key0, [], FParents).
dist(p(Type, CPT, Parents), Id, Key, FParents) :-
functor(Key, Na, Ar),
functor(Key0, Na, Ar),
distribution(Type, CPT, Id, Key0, Parents, FParents).
dist_unbound(V, ground(V)) :-
var(V), !.
@ -101,52 +109,52 @@ dist_unbound(p(Type,_,_), ground(Type)) :-
dist_unbound(p(_,CPT,_), ground(CPT)) :-
\+ ground(CPT).
distribution(bool, trans(CPT), Id, Parents, FParents) :-
distribution(bool, trans(CPT), Id, Key, Parents, FParents) :-
is_list(CPT), !,
compress_hmm_table(CPT, Parents, Tab, FParents),
add_dist([t,f], trans, Tab, Parents, Id).
distribution(bool, CPT, Id, Parents, Parents) :-
add_dist([t,f], trans, Tab, Parents, Key, Id).
distribution(bool, CPT, Id, Key, Parents, Parents) :-
is_list(CPT), !,
add_dist([t,f], tab, CPT, Parents, Id).
distribution(aminoacids, trans(CPT), Id, Parents, FParents) :-
add_dist([t,f], tab, CPT, Parents, Key, Id).
distribution(aminoacids, trans(CPT), Id, Key, Parents, FParents) :-
is_list(CPT), !,
compress_hmm_table(CPT, Parents, Tab, FParents),
add_dist([a,c,d,e,f,g,h,i,k,l,m,n,p,q,r,s,t,v,w,y], trans, Tab, FParents, Id).
distribution(aminoacids, CPT, Id, Parents, Parents) :-
add_dist([a,c,d,e,f,g,h,i,k,l,m,n,p,q,r,s,t,v,w,y], trans, Tab, FParents, Key, Id).
distribution(aminoacids, CPT, Id, Key, Parents, Parents) :-
is_list(CPT), !,
add_dist([a,c,d,e,f,g,h,i,k,l,m,n,p,q,r,s,t,v,w,y], tab, CPT, Parents, Id).
distribution(dna, trans(CPT), Id, Parents, FParents) :-
add_dist([a,c,d,e,f,g,h,i,k,l,m,n,p,q,r,s,t,v,w,y], tab, CPT, Parents, Key, Id).
distribution(dna, trans(CPT), Key, Id, Parents, FParents) :-
is_list(CPT), !,
compress_hmm_table(CPT, Parents, Tab, FParents),
add_dist([a,c,g,t], trans, Tab, FParents, Id).
distribution(dna, CPT, Id, Parents, Parents) :-
add_dist([a,c,g,t], trans, Tab, FParents, Key, Id).
distribution(dna, CPT, Id, Key, Parents, Parents) :-
is_list(CPT), !,
add_dist([a,c,g,t], tab, CPT, Id).
distribution(rna, trans(CPT), Id, Parents, FParents) :-
add_dist([a,c,g,t], tab, CPT, Key, Id).
distribution(rna, trans(CPT), Id, Key, Parents, FParents) :-
is_list(CPT), !,
compress_hmm_table(CPT, Parents, Tab, FParents, FParents),
add_dist([a,c,g,u], trans, Tab, Id).
distribution(rna, CPT, Id, Parents, Parents) :-
add_dist([a,c,g,u], trans, Tab, Key, Id).
distribution(rna, CPT, Id, Key, Parents, Parents) :-
is_list(CPT), !,
add_dist([a,c,g,u], tab, CPT, Parents, Id).
distribution(Domain, trans(CPT), Id, Parents, FParents) :-
add_dist([a,c,g,u], tab, CPT, Parents, Key, Id).
distribution(Domain, trans(CPT), Id, Key, Parents, FParents) :-
is_list(Domain),
is_list(CPT), !,
compress_hmm_table(CPT, Parents, Tab, FParents),
add_dist(Domain, trans, Tab, FParents, Id).
distribution(Domain, CPT, Id, Parents, Parents) :-
add_dist(Domain, trans, Tab, FParents, Key, Id).
distribution(Domain, CPT, Id, Key, Parents, Parents) :-
is_list(Domain),
is_list(CPT), !,
add_dist(Domain, tab, CPT, Parents, Id).
add_dist(Domain, tab, CPT, Parents, Key, Id).
add_dist(Domain, Type, CPT, _, Id) :-
recorded(clpbn_dist_db, db(Id, CPT, Type, Domain, _, _), _), !.
add_dist(Domain, Type, CPT, Parents, Id) :-
add_dist(Domain, Type, CPT, _, Key, Id) :-
recorded(clpbn_dist_db, db(Id, Key, CPT, Type, Domain, _, _), _), !.
add_dist(Domain, Type, CPT, Parents, Key, Id) :-
length(CPT, CPTSize),
length(Domain, DSize),
new_id(Id),
record_parent_sizes(Parents, Id, PSizes, [DSize|PSizes]),
recordz(clpbn_dist_db,db(Id, CPT, Type, Domain, CPTSize, DSize),_).
recordz(clpbn_dist_db,db(Id, Key, CPT, Type, Domain, CPTSize, DSize),_).
record_parent_sizes([], Id, [], DSizes) :-
@ -167,13 +175,13 @@ compress_hmm_table([Prob|L],[P|Parents],[Prob|NL],[P|NParents]) :-
compress_hmm_table(L,Parents,NL,NParents).
dist(Id) :-
recorded(clpbn_dist_db, db(Id, _, _, _, _, _), _).
recorded(clpbn_dist_db, db(Id, _, _, _, _, _, _), _).
get_dist(Id, Type, Domain, Tab) :-
recorded(clpbn_dist_db, db(Id, Tab, Type, Domain, _, _), _).
recorded(clpbn_dist_db, db(Id, _, Tab, Type, Domain, _, _), _).
get_dist_matrix(Id, Parents, Type, Domain, Mat) :-
recorded(clpbn_dist_db, db(Id, Tab, Type, Domain, _, DomainSize), _),
recorded(clpbn_dist_db, db(Id, _, Tab, Type, Domain, _, DomainSize), _),
get_dsizes(Parents, Sizes, []),
matrix_new(floats, [DomainSize|Sizes], Tab, Mat),
matrix_to_logs(Mat).
@ -185,31 +193,31 @@ get_dsizes([P|Parents], [Sz|Sizes], Sizes0) :-
get_dsizes(Parents, Sizes, Sizes0).
get_dist_params(Id, Parms) :-
recorded(clpbn_dist_db, db(Id, Parms, _, _, _, _), _).
recorded(clpbn_dist_db, db(Id, _, Parms, _, _, _, _), _).
get_dist_domain_size(Id, DSize) :-
recorded(clpbn_dist_db, db(Id, _, _, _, _, DSize), _).
recorded(clpbn_dist_db, db(Id, _, _, _, _, _, DSize), _).
get_dist_domain(Id, Domain) :-
recorded(clpbn_dist_db, db(Id, _, _, Domain, _, _), _).
recorded(clpbn_dist_db, db(Id, _, _, _, Domain, _, _), _).
get_dist_nparams(Id, NParms) :-
recorded(clpbn_dist_db, db(Id, _, _, _, NParms, _), _).
recorded(clpbn_dist_db, db(Id, _, _, _, _, NParms, _), _).
get_evidence_position(El, Id, Pos) :-
recorded(clpbn_dist_db, db(Id, _, _, Domain, _, _), _),
recorded(clpbn_dist_db, db(Id, _, _, _, Domain, _, _), _),
nth0(Pos, Domain, El), !.
get_evidence_position(El, Id, Pos) :-
recorded(clpbn_dist_db, db(Id, _, _, _, _, _), _), !,
recorded(clpbn_dist_db, db(Id, _, _, _, _, _, _), _), !,
throw(error(domain_error(evidence,Id),get_evidence_position(El, Id, Pos))).
get_evidence_position(El, Id, Pos) :-
throw(error(domain_error(no_distribution,Id),get_evidence_position(El, Id, Pos))).
get_evidence_from_position(El, Id, Pos) :-
recorded(clpbn_dist_db, db(Id, _, _, Domain, _, _), _),
recorded(clpbn_dist_db, db(Id, _, _, _, Domain, _, _), _),
nth0(Pos, Domain, El), !.
get_evidence_from_position(El, Id, Pos) :-
recorded(clpbn_dist_db, db(Id, _, _, _, _, _), _), !,
recorded(clpbn_dist_db, db(Id, _, _, _, _, _, _), _), !,
throw(error(domain_error(evidence,Id),get_evidence_from_position(El, Id, Pos))).
get_evidence_from_position(El, Id, Pos) :-
throw(error(domain_error(no_distribution,Id),get_evidence_from_position(El, Id, Pos))).
@ -224,9 +232,9 @@ empty_dist(Dist, TAB) :-
dist_new_table(Id, NewMat) :-
matrix_to_list(NewMat, List),
recorded(clpbn_dist_db, db(Id, _, A, B, C, D), R),
recorded(clpbn_dist_db, db(Id, Key, _, A, B, C, D), R),
erase(R),
recorda(clpbn_dist_db, db(Id, List, A, B, C, D), R),
recorda(clpbn_dist_db, db(Id, Key, List, A, B, C, D), _),
fail.
dist_new_table(_, _).

View File

@ -7,9 +7,11 @@
% Markov Blanket
%
:- module(gibbs, [gibbs/3,
check_if_gibbs_done/1,
init_gibbs_solver/3]).
:- module(clpbn_gibbs,
[gibbs/3,
check_if_gibbs_done/1,
init_gibbs_solver/3,
run_gibbs_solver/3]).
:- use_module(library(rbtrees),
[rb_new/1,
@ -47,27 +49,35 @@
:- dynamic implicit/1.
gibbs([],_,_) :- !.
% arguments:
%
% list of output variables
% list of attributed variables
%
gibbs([[]],_,_) :- !.
gibbs(LVs,Vs0,AllDiffs) :-
LVs = [_], !,
init_gibbs_solver(Vs0, LVs, Gibbs),
init_gibbs_solver(LVs, Vs0, Vs),
(clpbn:output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,Vs) ; true),
(clpbn:output(gviz(XBifStream)) -> clpbn2gviz(XBifStream,vel,Vs,LVs) ; true),
initialise(Vs, Graph, LVs, OutputVars, VarOrder),
% write(Graph),nl,
process(VarOrder, Graph, OutputVars, Estimates),
sum_up(Estimates, [LPs]),
% write(Estimates),nl,
run_gibbs_solver(LVs, Vs, LPs),
clpbn_bind_vals(LVs,LPs,AllDiffs),
clean_up.
gibbs(LVs,_,_) :-
throw(error(domain_error(solver,LVs),solver(gibbs))).
init_gibbs_solver(LVs, Vs0, Gibbs) :-
init_gibbs_solver(_, Vs0, Vs) :-
clean_up,
check_for_hidden_vars(Vs0, Vs0, Vs1),
sort(Vs1,Vs).
run_gibbs_solver(LVs, Vs, LPs) :-
initialise(Vs, Graph, LVs, OutputVars, VarOrder),
% writeln(Graph),
% write_pars(Vs),
process(VarOrder, Graph, OutputVars, Estimates),
% writeln(Estimates),
sum_up_all(Estimates, LPs),
clean_up.
% writeln(Estimates).
initialise(LVs, Graph, GVs, OutputVars, VarOrder) :-
init_keys(Keys0),
gen_keys(LVs, 0, VLen, Keys0, Keys),
@ -76,7 +86,7 @@ initialise(LVs, Graph, GVs, OutputVars, VarOrder) :-
compile_graph(Graph),
topsort(TGraph, VarOrder),
% show_sorted(VarOrder, Graph),
add_output_vars(GVs, Keys, OutputVars).
add_all_output_vars(GVs, Keys, OutputVars).
init_keys(Keys0) :-
rb_new(Keys0).
@ -120,7 +130,7 @@ graph_representation([V|Vs], Graph, I0, Keys, [I-IParents|TGraph]) :-
write_pars([]).
write_pars([V|Parents]) :-
clpbn:get_atts(V, [key(K)]),write(K),nl,
clpbn:get_atts(V, [key(K),dist(I,_)]),write(K:I),nl,
write_pars(Parents).
get_sizes([], []).
@ -310,9 +320,12 @@ normalise_factors([F|Factors],S0,S,[P0|Probs],PF) :-
PF is P0-F/S.
store_mblanket(I,Values,Probs) :-
append(Values,Probs,Args),
Rule =.. [mblanket,I|Args],
assert(Rule).
recordz(mblanket,m(I,Values,Probs),_).
add_all_output_vars([], _, []).
add_all_output_vars([Vs|LVs], Keys, [Is|OutputVars]) :-
add_output_vars(Vs, Keys, Is),
add_all_output_vars(LVs, Keys, OutputVars).
add_output_vars([], _, []).
add_output_vars([V|LVs], Keys, [I|OutputVars]) :-
@ -382,15 +395,29 @@ iparents_pos_sz([I|IPars], Chain, [P|IPos], Graph, [Sz|Sizes]) :-
init_estimates(0,_,_,[]) :- !.
init_estimates(NChains,OutputVars,Graph,[Est|Est0]) :-
NChainsI is NChains-1,
init_estimate(OutputVars,Graph,Est),
init_estimate_all_outvs(OutputVars,Graph,Est),
init_estimates(NChainsI,OutputVars,Graph,Est0).
init_estimate([],_,[]).
init_estimate([V|OutputVars],Graph,[[I|E0L]|Est]) :-
arg(V,Graph,var(_,I,_,_,Sz,_,_,_,_)),
gen_e0(Sz,E0L),
init_estimate(OutputVars,Graph,Est).
init_estimate_all_outvs([],_,[]).
init_estimate_all_outvs([Vs|OutputVars],Graph,[E|Est]) :-
init_estimate(Vs, Graph, E),
init_estimate_all_outvs(OutputVars,Graph,Est).
init_estimate([],_,[]).
init_estimate([V],Graph,[I|E0L]) :- !,
arg(V,Graph,var(_,I,_,_,Sz,_,_,_,_)),
gen_e0(Sz,E0L).
init_estimate(Vs,Graph,me(Is,Mults,Es)) :-
generate_est_mults(Vs, Is, Graph, Mults, Sz),
gen_e0(Sz,Es).
generate_est_mults([], [], _, [], 1).
generate_est_mults([V|Vs], [I|Is], Graph, [M0|Mults], M) :-
arg(V,Graph,var(_,I,_,_,Sz,_,_,_,_)),
generate_est_mults(Vs, Is, Graph, Mults, M0),
M is M0*Sz.
gen_e0(0,[]) :- !.
gen_e0(Sz,[0|E0L]) :-
Sz1 is Sz-1,
@ -398,8 +425,9 @@ gen_e0(Sz,[0|E0L]) :-
process_chains(0,_,F,F,_,_,Est,Est) :- !.
process_chains(ToDo,VarOrder,End,Start,Graph,Len,Est0,Estf) :-
%format('ToDo = ~d~n',[ToDo]),
process_chains(Start,VarOrder,Int,Graph,Len,Est0,Esti),
% (ToDo mod 100 =:= 0 -> statistics,cvt2problist(Esti, Probs), Int =[S|_], format('did ~d: ~w~n ~w~n',[ToDo,Probs,S]) ; true),
% (ToDo mod 100 =:= 1 -> statistics,cvt2problist(Esti, Probs), Int =[S|_], format('did ~d: ~w~n ~w~n',[ToDo,Probs,S]) ; true),
ToDo1 is ToDo-1,
process_chains(ToDo1,VarOrder,End,Int,Graph,Len,Esti,Estf).
@ -408,8 +436,8 @@ process_chains([], _, [], _, _,[],[]).
process_chains([Sample0|Samples0], VarOrder, [Sample|Samples], Graph, SampLen,[E0|E0s],[Ef|Efs]) :-
functor(Sample,sample,SampLen),
do_sample(VarOrder,Sample,Sample0,Graph),
%format('Sample = ~w~n',[Sample]),
update_estimate(E0,Sample,Ef),
% format('Sample = ~w~n',[Sample]),
update_estimates(E0,Sample,Ef),
process_chains(Samples0, VarOrder, Samples, Graph, SampLen,E0s,Efs).
do_sample([],_,_,_).
@ -418,15 +446,14 @@ do_sample([I|VarOrder],Sample,Sample0,Graph) :-
do_sample(VarOrder,Sample,Sample0,Graph).
do_var(I,Sample,Sample0,Graph) :-
arg(I,Graph,var(_,I,_,_,Sz,CPTs,Parents,_,_)),
( implicit(I) ->
fetch_parents(Parents,I,Sample,Sample0,Bindings,[]),
multiply_all_in_context(Parents,Bindings,CPTs,Sz,Graph,Vals)
arg(I,Graph,var(_,_,_,_,Sz,CPTs,Parents,_,_)),
fetch_parents(Parents,I,Sample,Sample0,Bindings),
multiply_all_in_context(Parents,Bindings,CPTs,Sz,Graph,Vals)
;
length(Vals,Sz),
fetch_parents(Parents,I,Sample,Sample0,Args,Vals),
Goal =.. [mblanket,I|Args],
call(Goal)
arg(I,Graph,var(_,_,_,_,_,_,Parents,_,_)),
fetch_parents(Parents,I,Sample,Sample0,Args),
recorded(mblanket,m(I,Args,Vals),_)
),
X is random,
pick_new_value(Vals,X,0,Val),
@ -444,31 +471,50 @@ set_pos([I|Is],[Pos|Args],Graph) :-
arg(I,Graph,var(_,I,Pos,_,_,_,_,_,_)),
set_pos(Is,Args,Graph).
fetch_parents([],_,_,_,Args,Args).
fetch_parents([P|Parents],I,Sample,Sample0,[VP|Args],Vals) :-
fetch_parents([],_,_,_,[]).
fetch_parents([P|Parents],I,Sample,Sample0,[VP|Args]) :-
arg(P,Sample,VP),
nonvar(VP), !,
fetch_parents(Parents,I,Sample,Sample0,Args,Vals).
fetch_parents([P|Parents],I,Sample,Sample0,[VP|Args],Vals) :-
fetch_parents(Parents,I,Sample,Sample0,Args).
fetch_parents([P|Parents],I,Sample,Sample0,[VP|Args]) :-
arg(P,Sample0,VP),
fetch_parents(Parents,I,Sample,Sample0,Args,Vals).
fetch_parents(Parents,I,Sample,Sample0,Args).
pick_new_value([V|_],X,Val,Val) :-
X < V, !.
pick_new_value([_|Vals],X,I0,Val) :-
I is I0+1,
pick_new_value(Vals,X,I,Val).
pick_new_value([V|Vals],X,I0,Val) :-
( X < V ->
Val = I0
;
I is I0+1,
pick_new_value(Vals,X,I,Val)
).
update_estimate([],_,[]).
update_estimate([[I|E]|E0],Sample,[[I|NE]|Ef]) :-
update_estimates([],_,[]).
update_estimates([Est|E0],Sample,[NEst|Ef]) :-
update_estimate(Est,Sample,NEst),
update_estimates(E0,Sample,Ef).
update_estimate([I|E],Sample,[I|NE]) :-
arg(I,Sample,V),
update_estimate_for_var(V,E,NE),
update_estimate(E0,Sample,Ef).
update_estimate_for_var(V,E,NE).
update_estimate(me(Is,Mult,E),Sample,me(Is,Mult,NE)) :-
get_estimate_pos(Is, Sample, Mult, 0, V),
update_estimate_for_var(V,E,NE).
update_estimate_for_var(0,[X|T],[X1|T]) :- !, X1 is X+1.
update_estimate_for_var(V,[E|Es],[E|NEs]) :-
V1 is V-1,
update_estimate_for_var(V1,Es,NEs).
get_estimate_pos([], _, [], V, V).
get_estimate_pos([I|Is], Sample, [M|Mult], V0, V) :-
arg(I,Sample,VV),
VI is VV*M+V0,
get_estimate_pos(Is, Sample, Mult, VI, V).
update_estimate_for_var(V0,[X|T],[X1|NT]) :-
( V0 == 0 ->
X1 is X+1,
NT = T
;
V1 is V0-1,
X1 = X,
update_estimate_for_var(V1,T,NT)
).
@ -476,16 +522,14 @@ check_if_gibbs_done(Var) :-
get_atts(Var, [dist(_)]), !.
clean_up :-
current_predicate(mblanket,P),
retractall(P),
eraseall(mblanket),
fail.
clean_up :-
retractall(implicit(_)),
fail.
clean_up.
gibbs_params(5,10000,100000).
gibbs_params(5,1000,10000).
cvt2problist([], []).
cvt2problist([[[_|E]]|Est0], [Ps|Probs]) :-
@ -510,16 +554,32 @@ show_sorted([I|VarOrder], Graph) :-
format('~w ',[K]),
show_sorted(VarOrder, Graph).
sum_up([[]|_], []).
sum_up([[[Id|Counts]|More]|Chains], [Dist|Dists]) :-
add_up(Counts,Chains, Id, Add,RChains),
normalise(Add, Dist),
sum_up([More|RChains], Dists).
sum_up_all([[]|_], []).
sum_up_all([[C|MoreC]|Chains], [Dist|Dists]) :-
extract_sums(Chains, CurrentChains, LeftChains),
sum_up([C|CurrentChains], Dist),
sum_up_all([MoreC|LeftChains], Dists).
add_up(Counts,[],_,Counts,[]).
add_up(Counts,[[[Id|Cs]|MoreVars]|Chains],Id, Add, [MoreVars|RChains]) :-
extract_sums([], [], []).
extract_sums([[C|Chains]|MoreChains], [C|CurrentChains], [Chains|LeftChains]) :-
extract_sums(MoreChains, CurrentChains, LeftChains).
sum_up([[_|Counts]|Chains], Dist) :-
add_up(Counts,Chains, Add),
normalise(Add, Dist).
sum_up([me(_,_,Counts)|Chains], Dist) :-
add_up_mes(Counts,Chains, Add),
normalise(Add, Dist).
add_up(Counts,[],Counts).
add_up(Counts,[[_|Cs]|Chains], Add) :-
sum_lists(Counts, Cs, NCounts),
add_up(NCounts, Chains, Id, Add, RChains).
add_up(NCounts, Chains, Add).
add_up_mes(Counts,[],Counts).
add_up_mes(Counts,[me(_,_,Cs)|Chains], Add) :-
sum_lists(Counts, Cs, NCounts),
add_up_mes(NCounts, Chains, Add).
sum_lists([],[],[]).
sum_lists([Count|Counts], [C|Cs], [NC|NCounts]) :-

View File

@ -72,6 +72,7 @@
:- use_module(library('clpbn/display'), [
clpbn_bind_vals/3]).
jt([[]],_,_) :- !.
jt(LVs,Vs0,AllDiffs) :-
get_graph(Vs0, BayesNet, CPTs, Evidence),
build_jt(BayesNet, CPTs, JTree),
@ -81,7 +82,7 @@ jt(LVs,Vs0,AllDiffs) :-
% write_tree(NewTree,0),
propagate_evidence(Evidence, NewTree, EvTree),
message_passing(EvTree, MTree),
get_margin(MTree, LVs, LPs),
get_margins(MTree, LVs, LPs),
clpbn_bind_vals(LVs,LPs,AllDiffs).
@ -461,7 +462,12 @@ downward([tree(Clique1-(Dist1,Msg1),DistKids)|Kids], Clique, Tab, [tree(Clique1-
downward(Kids, Clique, Tab, NKids).
get_margin(NewTree, LVs0, LPs) :-
get_margins(_, [], []).
get_margins(NewTree, [Vs|LVs], [LP|LPs]) :-
get_margin(NewTree, Vs, LP),
get_margins(NewTree, LVs, LPs).
get_margin(NewTree, LVs, LPs) :-
sort(LVs0, LVs),
find_clique(NewTree, LVs, Clique, Dist),
sum_out_from_CPT(LVs, Dist, Clique, tab(TAB,_,_)),

View File

@ -37,7 +37,6 @@ topsort(Graph0, Sort0, Found0, Sort) :-
add_nodes([], Found, Sort, [], Found, Sort).
add_nodes([N-Ns|Graph0], Found0, SortI, NewGraph, Found, NSort) :-
(N=1600 -> write(Ns), nl ; true),
delete_nodes(Ns, Found0, NNs),
( NNs == [] ->
NewGraph = IGraph,

View File

@ -55,8 +55,8 @@ check_if_vel_done(Var) :-
%
% implementation of the well known variable elimination algorithm
%
vel([],_,_) :- !.
vel(LVs,Vs0,AllDiffs) :-
vel([[]],_,_) :- !.
vel([LVs],Vs0,AllDiffs) :-
check_for_hidden_vars(Vs0, Vs0, Vs1),
sort(Vs1,Vs),
% LVi will have a list of CLPBN variables

View File

@ -2,113 +2,166 @@
% The world famous EM algorithm, in a nutshell
%
:- module(clpbn_em, [em/6]).
:- module(clpbn_em, [em/5]).
:- use_module(library(lists),
[append/3]).
:- use_module(library(clpbn),
[clpbn_init_solver/3,
clpbn_run_solver/3]).
:- use_module(library('clpbn/dists'),
[get_dist_domain_size/2,
empty_dist/2,
dist_new_table/2]).
:- use_module(library('clpbn/learning/learn_utils'),
[run_all/1,
clpbn_vars/2,
normalise_counts/2]).
normalise_counts/2,
compute_likelihood/3]).
:- use_module(library(lists),
[member/2]).
:- use_module(library(matrix),
[matrix_add/3,
matrix_to_list/2]).
:- use_module(library('clpbn/utils'), [
check_for_hidden_vars/3]).
:- meta_predicate em(:,+,+,-,-), init_em(:,-).
em(Items, MaxError, MaxIts, Tables, Likelihood) :-
init_em(Items, State),
em_loop(0, 0.0, state(AllVars,AllDists), MaxError, MaxIts, Likelihood),
get_tables(State, Tables).
em_loop(0, 0.0, State, MaxError, MaxIts, Likelihood, Tables).
% This gets you an initial configuration. If there is a lot of evidence
% tables may be filled in close to optimal, otherwise they may be
% close to uniform.
% it also gets you a run for random variables
init_em(Items, state(AllVars, AllDists, AllDistInstances)) :-
% state collects all Info we need for the EM algorithm
% it includes the list of variables without evidence,
% the list of distributions for which we want to compute parameters,
% and more detailed info on distributions, namely with a list of all instances for the distribution.
init_em(Items, state(AllVars, AllDists, AllDistInstances, MargVars)) :-
run_all(Items),
different_dists(AllVars, AllDists, AllDistInstances).
attributes:all_attvars(AllVars0),
% remove variables that do not have to do with this query.
check_for_hidden_vars(AllVars0, AllVars0, AllVars),
different_dists(AllVars, AllDists, AllDistInstances, MargVars),
clpbn_init_solver(MargVars, AllVars, _).
% loop for as long as you want.
em_loop(MaxIts, Likelihood, State, _, _, MaxIts, Likelihood) :- !.
em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF) :-
estimate(State),
maximise(State, Likelihood),
em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF, FTables) :-
estimate(State, LPs),
maximise(State, Tables, LPs, Likelihood),
(
(
(
(Likelihood - Likelihood0)/Likelihood < MaxError
;
;
Its == MaxIts
)
)
->
ltables(Tables, FTables),
LikelihoodF = Likelihood
;
Its1 is Its+1,
em_loop(Its1, Likelihood, State, MaxError, MaxIts, LikelihoodF)
em_loop(Its1, Likelihood, State, MaxError, MaxIts, LikelihoodF, FTables)
).
ltables([], []).
ltables([Id-T|Tables], [Id-LTable|FTables]) :-
matrix_to_list(T,LTable),
ltables(Tables, FTables).
% collect the different dists we are going to learn next.
different_dists(AllVars, AllDists, AllInfo) :-
all_dists(AllVars, Dists0, AllInfo),
different_dists(AllVars, AllDists, AllInfo, MargVars) :-
all_dists(AllVars, Dists0),
sort(Dists0, Dists1),
group(Dists1, AllInfo).
group(Dists1, AllDists, AllInfo, MargVars, []).
group([], []).
group([i(Id,V,Ps)|Dists1], [Id-[[V|Ps]|Extra]|AllInfo]) :-
same_id(Dists1, Id, Extra, Rest),
group(Rest, AllInfo).
all_dists([], []).
all_dists([V|AllVars], [i(Id, [V|Parents], Cases, Hiddens)|Dists]) :-
clpbn:get_atts(V, [dist(Id,Parents)]),
generate_hidden_cases([V|Parents], CompactCases, Hiddens),
uncompact_cases(CompactCases, Cases),
all_dists(AllVars, Dists).
same_id([i(Id,V,Ps)|Dists1], Id, [[V|Ps]|Extra], Rest) :- !,
same_id(Dists1, Id, Extra, Rest).
same_id(Dists, _, [], Dists).
all_dists([], [], []).
all_dists([V|AllVars], Dists, [i(Id, AllInfo, Parents)|AllInfo]) :-
generate_hidden_cases([], [], []).
generate_hidden_cases([V|Parents], [P|Cases], Hiddens) :-
clpbn:get_atts(V, [evidence(P)]), !,
generate_hidden_cases(Parents, Cases, Hiddens).
generate_hidden_cases([V|Parents], [Cases|MoreCases], [V|Hiddens]) :-
clpbn:get_atts(V, [dist(Id,_)]),
with_evidence(V, Id, Dists, Dists0), !,
all_dists(AllVars, Dists0, AllInfo).
with_evidence(V, Id) -->
{clpbn:get_atts(V, [evidence(Pos)]) }, !,
{ dist_pos2bin(Pos, Id, Bin) }.
with_evidence(V, Id) -->
[d(V,Id)].
estimate(state(Vars,Info,_)) :-
clpbn_solve_graph(Vars, OVars),
marg_vars(Info, Vars).
marg_vars([], _).
marg_vars([d(V,Id)|Vars], AllVs) :-
clpbn_marginalise_in_vars(V, AllVs),
marg_vars(Vars, AllVs).
maximise(state(_,_,DistInstances), Tables, Likelihood) :-
compute_parameters(DistInstances, Tables, 0.0, Likelihood).
compute_parameters([], [], Lik, Lik).
compute_parameters([Id-Samples|Dists], [Tab|Tables], Lik0, Lik) :-
empty_dist(Id, NewTable),
add_samples(Samples, NewTable),
normalise_table(Id, NewTable),
compute_parameters(Dists, Tables, Lik0, Lik).
add_samples([], _).
add_samples([S|Samples], Table) :-
run_sample(S, 1.0, Pos, Tot),
matrix_add(Table, Pos, Tot),
fail.
add_samples([_|Samples], Table) :-
add_samples(Samples, Table).
get_dist_domain_size(Id, Sz),
gen_cases(0, Sz, Cases),
generate_hidden_cases(Parents, MoreCases, Hiddens).
run_sample([], Tot, [], Tot).
run_sample([V|S], W0, [P|Pos], Tot) :-
{clpbn:get_atts(V, [evidence(P)]) }, !,
run_sample(S, W0, Pos, Tot).
run_sample([V|S], W0, [P|Pos], Tot) :-
{clpbn_display:get_atts(V, [posterior,(_,_,Ps,_)]) },
count_cases(Ps, 0, D0, P),
W1 is D0*W0,
run_sample(S, W1, Pos, Tot).
count_cases([D0|Ps], I0, D0, I0).
count_cases([_|Ps], I0, P, W1) :-
I is I0+1,
count_cases(Ps, I, P, W1).
gen_cases(Sz, Sz, []) :- !.
gen_cases(I, Sz, [I|Cases]) :-
I1 is I+1,
gen_cases(I1, Sz, Cases).
uncompact_cases(CompactCases, Cases) :-
findall(Case, is_case(CompactCases, Case), Cases).
is_case([], []).
is_case([A|CompactCases], [A|Case]) :-
integer(A), !,
is_case(CompactCases, Case).
is_case([L|CompactCases], [C|Case]) :-
member(C, L),
is_case(CompactCases, Case).
group([], [], []) --> [].
group([i(Id,Ps,Cs,[])|Dists1], [Id|Ids], [Id-[i(Id,Ps,Cs,[])|Extra]|AllInfo]) --> !,
same_id(Dists1, Id, Extra, Rest),
group(Rest, Ids, AllInfo).
group([i(Id,Ps,Cs,Hs)|Dists1], [Id|Ids], [Id-[i(Id,Ps,Cs,Hs)|Extra]|AllInfo]) -->
[Hs],
same_id(Dists1, Id, Extra, Rest),
group(Rest, Ids, AllInfo).
same_id([i(Id,Vs,Cases,[])|Dists1], Id, [i(Id, Vs, Cases, [])|Extra], Rest) --> !,
same_id(Dists1, Id, Extra, Rest).
same_id([i(Id,Vs,Cases,Hs)|Dists1], Id, [i(Id, Vs, Cases, Hs)|Extra], Rest) --> !,
[Hs],
same_id(Dists1, Id, Extra, Rest).
same_id(Dists, _, [], Dists) --> [].
estimate(state(Vars, _, _, Margs), LPs) :-
clpbn_run_solver(Margs, Vars, LPs).
maximise(state(_,_,DistInstances,_), Tables, LPs, Likelihood) :-
compute_parameters(DistInstances, Tables, LPs, 0.0, Likelihood).
compute_parameters([], [], [], Lik, Lik).
compute_parameters([Id-Samples|Dists], [Id-NewTable|Tables], Ps, Lik0, Lik) :-
empty_dist(Id, Table0),
add_samples(Samples, Table0, Ps, MorePs),
normalise_counts(Table0, NewTable),
compute_likelihood(Table0, NewTable, DeltaLik),
dist_new_table(Id, NewTable),
NewLik is Lik0+DeltaLik,
compute_parameters(Dists, Tables, MorePs, NewLik, Lik).
add_samples([], _, Ps, Ps).
add_samples([i(_,_,[Case],[])|Samples], Table, AllPs, RPs) :- !,
matrix_add(Table,Case,1.0),
add_samples(Samples, Table, AllPs, RPs).
add_samples([i(_,_,Cases,_)|Samples], Table, [Ps|AllPs], RPs) :-
run_sample(Cases, Ps, Table),
add_samples(Samples, Table, AllPs, RPs).
run_sample([], [], _).
run_sample([C|Cases], [P|Ps], Table) :-
matrix_add(Table, C, P),
run_sample(Cases, Ps, Table).

View File

@ -0,0 +1,34 @@
% learn distribution for school database.
:- [pos:train].
:- ['~/Yap/work/CLPBN/clpbn/examples/School/school_32'].
:- ['~/Yap/work/CLPBN/learning/em'].
main :-
findall(X,goal(X),L),
em(L,0.1,10,CPTs,Lik),
writeln(Lik:CPTs).
% miss 30% of the examples.
goal(professor_ability(P,V)) :-
pos:professor_ability(P,V1),
( random > 0.3 -> V = V1 ; true).
% miss 10% of the examples.
goal(professor_popularity(P,V)) :-
pos:professor_popularity(P,V1),
( random > 0.3 -> V = V1 ; true).
goal(registration_grade(P,V)) :-
pos:registration_grade(P,V1),
( random > 0.1 -> V = V1 ; true).
goal(student_intelligence(P,V)) :-
pos:student_intelligence(P,V1),
( random > 0.1 -> V = V1 ; true).
goal(course_difficulty(P,V)) :-
pos:course_difficulty(P,V1),
( random > 0.1 -> V = V1 ; true).
/*
goal(registration_satisfaction(P,V)) :-
pos:registration_satisfaction(P,V).
*/

File diff suppressed because it is too large Load Diff

View File

@ -2,13 +2,31 @@
% Utilities for learning
%
:- module(bnt_learn_utils, [run_all/1,
clpbn_vars/2]).
:- module(clpbn_learn_utils, [run_all/1,
clpbn_vars/2,
normalise_counts/2,
compute_likelihood/3]).
:- use_module(library(matrix),
[matrix_agg_lines/3,
matrix_op_to_lines/4,
matrix_to_logs/2,
matrix_op/4,
matrix_sum/2]).
:- meta_predicate run_all(:).
run_all([]).
run_all([G|Gs]) :-
call(user:G),
call(G),
run_all(Gs).
run_all(M:Gs) :-
run_all(Gs,M).
run_all([],_).
run_all([G|Gs],M) :-
call(M:G),
run_all(Gs,M).
clpbn_vars(Vs,BVars) :-
get_clpbn_vars(Vs,CVs),
@ -31,5 +49,15 @@ get_var_has_same_key([K-V|KVs],K,V,KVs0) :- !,
get_var_has_same_key(KVs,K,V,KVs0).
get_var_has_same_key(KVs,_,_,KVs).
normalise_counts(MAT,NMAT) :-
matrix_agg_lines(MAT, +, Sum),
matrix_op_to_lines(MAT, Sum, /, NMAT).
compute_likelihood(Table0, NewTable, DeltaLik) :-
matrix:matrix_to_list(Table0,L0), writeln(L0),
matrix:matrix_to_list(NewTable,L1), writeln(L1),
matrix_to_logs(NewTable, Logs),
matrix_op(Table0, Logs, *, Logs),
matrix_sum(Logs, DeltaLik).

View File

@ -1,12 +1,32 @@
%
% Maximum likelihood estimator and friends.
%
%
% This assumes we have a single big example.
%
:- use_module(library('clpbn_learning/utils'),
[run_all/1,
clpbn_vars/2]).
:- module(clpbn_mle, [learn_parameters/2,
learn_parameters/3,
parameters_from_evidence/3]).
:- module(bnt_mle, [learn_parameters/2]).
:- use_module(library('clpbn')).
:- use_module(library('clpbn/learning/learn_utils'),
[run_all/1,
clpbn_vars/2,
normalise_counts/2]).
:- use_module(library('clpbn/dists'),
[empty_dist/2,
dist_new_table/2]).
:- use_module(library(matrix),
[matrix_inc/2,
matrix_op_to_all/4]).
learn_parameters(Items, Tables) :-
learn_parameters(Items, Tables, []).
%
% full evidence learning
@ -14,16 +34,26 @@
learn_parameters(Items, Tables, Extras) :-
run_all(Items),
attributes:all_attvars(AVars),
% sort and incorporte evidence
% sort and incorporate evidence
clpbn_vars(AVars, AllVars),
mk_sample(AllVars, Sample),
compute_tables(Extras, Sample, Tables).
mk_sample(AllVars, NVars, LL) :-
add2sample(AllVars, Sample),
msort(Sample, AddL),
compute_params(AddL, Parms).
parameters_from_evidence(AllVars, Sample, Extras) :-
mk_sample_from_evidence(AllVars, Sample),
compute_tables(Extras, Sample, Tables).
mk_sample_from_evidence(AllVars, SortedSample) :-
add_evidence2sample(AllVars, Sample),
msort(Sample, SortedSample).
mk_sample(AllVars, SortedSample) :-
add2sample(AllVars, Sample),
msort(Sample, SortedSample).
%
% assumes we have full data, meaning evidence for every variable
%
add2sample([], []).
add2sample([V|Vs],[val(Id,[Ev|EParents])|Vals]) :-
clpbn:get_atts(V, [evidence(Ev),dist(Id,Parents)]),
@ -31,12 +61,72 @@ add2sample([V|Vs],[val(Id,[Ev|EParents])|Vals]) :-
add2sample(Vs, Vals).
get_eparents([P|Parents], [E|EParents]) :-
clpbn:get_atts(V, [evidence(Ev)]),
clpbn:get_atts(P, [evidence(E)]),
get_eparents(Parents, EParents).
get_eparents([], []).
compute_tables([], Sample, Tables) :-
mle(Sample, Tables).
compute_tables([laplace|_], Sample, Tables) :-
laplace(Sample, Tables).
%
% assumes we ignore variables without evidence or without evidence
% on a parent!
%
add_evidence2sample([], []).
add_evidence2sample([V|Vs],[val(Id,[Ev|EParents])|Vals]) :-
clpbn:get_atts(V, [evidence(Ev),dist(Id,Parents)]),
get_eveparents(Parents, EParents), !,
add_evidence2sample(Vs, Vals).
add_evidence2sample([_|Vs],Vals) :-
add_evidence2sample(Vs, Vals).
get_eveparents([P|Parents], [E|EParents]) :-
clpbn:get_atts(P, [evidence(E)]),
get_eparents(Parents, EParents).
get_eveparents([], []).
compute_tables(Parameters, Sample, NewTables) :-
estimator(Sample, Tables),
add_priors(Parameters, Tables, NewTables).
estimator([], []).
estimator([val(Id,Sample)|Samples], [NewTable|Tables]) :-
empty_dist(Id, NewTable),
id_samples(Id, Samples, IdSamples, MoreSamples),
mle([Sample|IdSamples], NewTable),
% replace matrix in distribution
dist_new_table(Id, NewTable),
estimator(MoreSamples, Tables).
id_samples(_, [], [], []).
id_samples(Id, [val(Id,Sample)|Samples], [Sample|IdSamples], MoreSamples) :- !,
id_samples(Id, Samples, IdSamples, MoreSamples).
id_samples(_, Samples, [], Samples).
mle([Sample|IdSamples], Table) :-
matrix_inc(Table, Sample),
mle(IdSamples, Table).
mle([], _).
add_priors([], Tables, NewTables) :-
normalise(Tables, NewTables).
add_priors([laplace|_], Tables, NewTables) :- !,
laplace(Tables, TablesI),
normalise(TablesI, NewTables).
add_priors([m_estimate(M)|_], Tables, NewTables) :- !,
add_mestimate(Tables, M, TablesI),
normalise(TablesI, NewTables).
add_priors([_|Parms], Tables, NewTables) :-
add_priors(Parms, Tables, NewTables).
normalise([], []).
normalise([T0|TablesI], [T|NewTables]) :-
normalise_counts(T0, T),
normalise(TablesI, NewTables).
laplace([], []).
laplace([T0|TablesI], [T|NewTables]) :-
matrix_op_to_all(T0, +, 1, T),
laplace(TablesI, NewTables).