new version of CLP(BN) with EM learning

This commit is contained in:
Vítor Santos de Costa 2008-10-22 00:44:02 +01:00
parent b3cb7b1071
commit 592fe9e366
13 changed files with 2978 additions and 231 deletions

View File

@ -5,8 +5,8 @@
set_clpbn_flag/2, set_clpbn_flag/2,
clpbn_flag/3, clpbn_flag/3,
clpbn_key/2, clpbn_key/2,
clpbn_marginalise/2, clpbn_init_solver/3,
call_solver/2]). clpbn_run_solver/3]).
:- use_module(library(atts)). :- use_module(library(atts)).
:- use_module(library(lists)). :- use_module(library(lists)).
@ -42,7 +42,9 @@
:- use_module('clpbn/gibbs', :- use_module('clpbn/gibbs',
[gibbs/3, [gibbs/3,
check_if_gibbs_done/1 check_if_gibbs_done/1,
init_gibbs_solver/3,
run_gibbs_solver/3
]). ]).
:- use_module('clpbn/graphs', :- use_module('clpbn/graphs',
@ -52,7 +54,7 @@
:- use_module('clpbn/dists', :- use_module('clpbn/dists',
[ [
dist/3, dist/4,
get_dist/4, get_dist/4,
get_evidence_position/3, get_evidence_position/3,
get_evidence_from_position/3 get_evidence_from_position/3
@ -106,7 +108,7 @@ clpbn_flag(suppress_attribute_display,Before,After) :-
{Var = Key with Dist} :- {Var = Key with Dist} :-
put_atts(El,[key(Key),dist(DistInfo,Parents)]), put_atts(El,[key(Key),dist(DistInfo,Parents)]),
dist(Dist, DistInfo, Parents), dist(Dist, DistInfo, Key, Parents),
add_evidence(Var,DistInfo,El). add_evidence(Var,DistInfo,El).
check_constraint(Constraint, _, _, Constraint) :- var(Constraint), !. check_constraint(Constraint, _, _, Constraint) :- var(Constraint), !.
@ -158,7 +160,7 @@ call_solver(GVars, AVars) :-
clpbn_vars(AVars, DiffVars, AllVars), clpbn_vars(AVars, DiffVars, AllVars),
get_clpbn_vars(GVars,CLPBNGVars0), get_clpbn_vars(GVars,CLPBNGVars0),
simplify_query_vars(CLPBNGVars0, CLPBNGVars), simplify_query_vars(CLPBNGVars0, CLPBNGVars),
write_out(Solver,CLPBNGVars, AllVars, DiffVars). write_out(Solver,[CLPBNGVars], AllVars, DiffVars).
@ -332,3 +334,32 @@ user:term_expansion((A :- {}), ( :- true )) :- !, % evidence
clpbn_key(Var,Key) :- clpbn_key(Var,Key) :-
get_atts(Var, [key(Key)]). get_atts(Var, [key(Key)]).
%
% This is a routine to start a solver, called by the learning procedures (ie, em).
% LVs is a list of lists of variables one is interested in eventually marginalising out
% Vs0 gives the original graph
% AllDiffs gives variables that are not fully constrainted, ie, we don't fully know
% the key. In this case, we assume different instances will be bound to different
% values at the end of the day.
%
clpbn_init_solver(LVs,Vs0,VarsWithUnboundKeys) :-
solver(Solver),
clpbn_init_known_solver(Solver,LVs,Vs0,VarsWithUnboundKeys).
clpbn_init_known_solver(gibbs, LVs, Vs0, VarsWithUnboundKeys) :- !,
init_gibbs_solver(LVs, Vs0, VarsWithUnboundKeys).
clpbn_init_known_solver(_, _, _, _).
%
% LVs is the list of lists of variables to marginalise
% Vs is the full graph
% Ps are the probabilities on LVs.
%
%
clpbn_run_solver(LVs,Vs,LPs) :-
solver(Solver),
clpbn_run_known_solver(Solver,LVs,Vs,LPs).
clpbn_run_known_solver(gibbs,LVs, Vs, LPs) :- !,
run_gibbs_solver(LVs, Vs, LPs).

View File

@ -390,12 +390,12 @@ evidence_val(Ev,I0,[_|Domain],Val) :-
I1 is I0+1, I1 is I0+1,
evidence_val(Ev,I1,Domain,Val). evidence_val(Ev,I1,Domain,Val).
marginalize([V], _SortedVars,_NunmberedVars, Ps) :- !, marginalize([[V]], _SortedVars,_NunmberedVars, Ps) :- !,
v2number(V,Pos), v2number(V,Pos),
marg <-- marginal_nodes(engine_ev, Pos), marg <-- marginal_nodes(engine_ev, Pos),
matlab_get_variable( marg.'T', Ps). matlab_get_variable( marg.'T', Ps).
marginalize(Vs, SortedVars, NumberedVars,Ps) :- marginalize([Vs], SortedVars, NumberedVars,Ps) :-
bnt_solver(jtree),!, bnt_solver(jtree),!,
matlab_get_variable(loglik, Den), matlab_get_variable(loglik, Den),
clpbn_display:get_all_combs(Vs, Vals), clpbn_display:get_all_combs(Vs, Vals),

View File

@ -37,14 +37,19 @@ add_alldiffs([],Eqs,Eqs).
add_alldiffs(AllDiffs,Eqs,(Eqs/alldiff(AllDiffs))). add_alldiffs(AllDiffs,Eqs,(Eqs/alldiff(AllDiffs))).
clpbn_bind_vals([],_,_) :- !. clpbn_bind_vals([],[],_).
clpbn_bind_vals([Vs|MoreVs],[Ps|MorePs],AllDiffs) :-
clpbn_bind_vals2(Vs, Ps, AllDiffs),
clpbn_bind_vals(MoreVs,MorePs,AllDiffs).
clpbn_bind_vals2([],_,_) :- !.
% simple case, we want a distribution on a single variable. % simple case, we want a distribution on a single variable.
%bind_vals([V],Ps) :- !, %bind_vals([V],Ps) :- !,
% clpbn:get_atts(V, [dist(Vals,_,_)]), % clpbn:get_atts(V, [dist(Vals,_,_)]),
% put_atts(V, posterior([V], Vals, Ps)). % put_atts(V, posterior([V], Vals, Ps)).
% complex case, we want a joint distribution, do it on a leader. % complex case, we want a joint distribution, do it on a leader.
% should split on cliques ? % should split on cliques ?
clpbn_bind_vals(Vs,Ps,AllDiffs) :- clpbn_bind_vals2(Vs,Ps,AllDiffs) :-
get_all_combs(Vs, Vals), get_all_combs(Vs, Vals),
Vs = [V|_], Vs = [V|_],
put_atts(V, posterior(Vs, Vals, Ps, AllDiffs)). put_atts(V, posterior(Vs, Vals, Ps, AllDiffs)).

View File

@ -5,7 +5,7 @@
:- module(clpbn_dist, :- module(clpbn_dist,
[ [
dist/1, dist/1,
dist/3, dist/4,
dists/1, dists/1,
dist_new_table/2, dist_new_table/2,
get_dist/4, get_dist/4,
@ -51,22 +51,23 @@
/******************************************* /*******************************************
store stuff in a DB of the form: store stuff in a DB of the form:
db(Id, CPT, Type, Domain, CPTSize, DSize) db(Id, Key, CPT, Type, Domain, CPTSize, DSize)
where Id is the id, where Id is the id,
cptsize is the table size or -1, Key is a skeleton of the key(main functor only)
DSize is the domain size, cptsize is the table size or -1,
Type is DSize is the domain size,
tab for tabular Type is
trans for HMMs tab for tabular
continuous trans for HMMs
Domain is continuous
a list of values Domain is
bool for [t,f] a list of values
aminoacids for [a,c,d,e,f,g,h,i,k,l,m,n,p,q,r,s,t,v,w,y] bool for [t,f]
dna for [a,c,g,t] aminoacids for [a,c,d,e,f,g,h,i,k,l,m,n,p,q,r,s,t,v,w,y]
rna for [a,c,g,u] dna for [a,c,g,t]
reals rna for [a,c,g,u]
reals
********************************************/ ********************************************/
@ -82,13 +83,20 @@ new_id(Id) :-
dists(X) :- id(X1), X is X1-1. dists(X) :- id(X1), X is X1-1.
dist(V, Id, Parents) :- dist(V, Id, Key, Parents) :-
dist_unbound(V, Culprit), !, dist_unbound(V, Culprit), !,
when(Culprit, dist(V, Id, Parents)). when(Culprit, dist(V, Id, Key, Parents)).
dist(p(Type, CPT), Id, FParents) :- dist(V, Id, Key, Parents) :-
distribution(Type, CPT, Id, [], FParents). var(Key), !,
dist(p(Type, CPT, Parents), Id, FParents) :- when(Key, dist(V, Id, Key, Parents)).
distribution(Type, CPT, Id, Parents, FParents). dist(p(Type, CPT), Id, Key, FParents) :-
functor(Key, Na, Ar),
functor(Key0, Na, Ar),
distribution(Type, CPT, Id, Key0, [], FParents).
dist(p(Type, CPT, Parents), Id, Key, FParents) :-
functor(Key, Na, Ar),
functor(Key0, Na, Ar),
distribution(Type, CPT, Id, Key0, Parents, FParents).
dist_unbound(V, ground(V)) :- dist_unbound(V, ground(V)) :-
var(V), !. var(V), !.
@ -101,52 +109,52 @@ dist_unbound(p(Type,_,_), ground(Type)) :-
dist_unbound(p(_,CPT,_), ground(CPT)) :- dist_unbound(p(_,CPT,_), ground(CPT)) :-
\+ ground(CPT). \+ ground(CPT).
distribution(bool, trans(CPT), Id, Parents, FParents) :- distribution(bool, trans(CPT), Id, Key, Parents, FParents) :-
is_list(CPT), !, is_list(CPT), !,
compress_hmm_table(CPT, Parents, Tab, FParents), compress_hmm_table(CPT, Parents, Tab, FParents),
add_dist([t,f], trans, Tab, Parents, Id). add_dist([t,f], trans, Tab, Parents, Key, Id).
distribution(bool, CPT, Id, Parents, Parents) :- distribution(bool, CPT, Id, Key, Parents, Parents) :-
is_list(CPT), !, is_list(CPT), !,
add_dist([t,f], tab, CPT, Parents, Id). add_dist([t,f], tab, CPT, Parents, Key, Id).
distribution(aminoacids, trans(CPT), Id, Parents, FParents) :- distribution(aminoacids, trans(CPT), Id, Key, Parents, FParents) :-
is_list(CPT), !, is_list(CPT), !,
compress_hmm_table(CPT, Parents, Tab, FParents), compress_hmm_table(CPT, Parents, Tab, FParents),
add_dist([a,c,d,e,f,g,h,i,k,l,m,n,p,q,r,s,t,v,w,y], trans, Tab, FParents, Id). add_dist([a,c,d,e,f,g,h,i,k,l,m,n,p,q,r,s,t,v,w,y], trans, Tab, FParents, Key, Id).
distribution(aminoacids, CPT, Id, Parents, Parents) :- distribution(aminoacids, CPT, Id, Key, Parents, Parents) :-
is_list(CPT), !, is_list(CPT), !,
add_dist([a,c,d,e,f,g,h,i,k,l,m,n,p,q,r,s,t,v,w,y], tab, CPT, Parents, Id). add_dist([a,c,d,e,f,g,h,i,k,l,m,n,p,q,r,s,t,v,w,y], tab, CPT, Parents, Key, Id).
distribution(dna, trans(CPT), Id, Parents, FParents) :- distribution(dna, trans(CPT), Key, Id, Parents, FParents) :-
is_list(CPT), !, is_list(CPT), !,
compress_hmm_table(CPT, Parents, Tab, FParents), compress_hmm_table(CPT, Parents, Tab, FParents),
add_dist([a,c,g,t], trans, Tab, FParents, Id). add_dist([a,c,g,t], trans, Tab, FParents, Key, Id).
distribution(dna, CPT, Id, Parents, Parents) :- distribution(dna, CPT, Id, Key, Parents, Parents) :-
is_list(CPT), !, is_list(CPT), !,
add_dist([a,c,g,t], tab, CPT, Id). add_dist([a,c,g,t], tab, CPT, Key, Id).
distribution(rna, trans(CPT), Id, Parents, FParents) :- distribution(rna, trans(CPT), Id, Key, Parents, FParents) :-
is_list(CPT), !, is_list(CPT), !,
compress_hmm_table(CPT, Parents, Tab, FParents, FParents), compress_hmm_table(CPT, Parents, Tab, FParents, FParents),
add_dist([a,c,g,u], trans, Tab, Id). add_dist([a,c,g,u], trans, Tab, Key, Id).
distribution(rna, CPT, Id, Parents, Parents) :- distribution(rna, CPT, Id, Key, Parents, Parents) :-
is_list(CPT), !, is_list(CPT), !,
add_dist([a,c,g,u], tab, CPT, Parents, Id). add_dist([a,c,g,u], tab, CPT, Parents, Key, Id).
distribution(Domain, trans(CPT), Id, Parents, FParents) :- distribution(Domain, trans(CPT), Id, Key, Parents, FParents) :-
is_list(Domain), is_list(Domain),
is_list(CPT), !, is_list(CPT), !,
compress_hmm_table(CPT, Parents, Tab, FParents), compress_hmm_table(CPT, Parents, Tab, FParents),
add_dist(Domain, trans, Tab, FParents, Id). add_dist(Domain, trans, Tab, FParents, Key, Id).
distribution(Domain, CPT, Id, Parents, Parents) :- distribution(Domain, CPT, Id, Key, Parents, Parents) :-
is_list(Domain), is_list(Domain),
is_list(CPT), !, is_list(CPT), !,
add_dist(Domain, tab, CPT, Parents, Id). add_dist(Domain, tab, CPT, Parents, Key, Id).
add_dist(Domain, Type, CPT, _, Id) :- add_dist(Domain, Type, CPT, _, Key, Id) :-
recorded(clpbn_dist_db, db(Id, CPT, Type, Domain, _, _), _), !. recorded(clpbn_dist_db, db(Id, Key, CPT, Type, Domain, _, _), _), !.
add_dist(Domain, Type, CPT, Parents, Id) :- add_dist(Domain, Type, CPT, Parents, Key, Id) :-
length(CPT, CPTSize), length(CPT, CPTSize),
length(Domain, DSize), length(Domain, DSize),
new_id(Id), new_id(Id),
record_parent_sizes(Parents, Id, PSizes, [DSize|PSizes]), record_parent_sizes(Parents, Id, PSizes, [DSize|PSizes]),
recordz(clpbn_dist_db,db(Id, CPT, Type, Domain, CPTSize, DSize),_). recordz(clpbn_dist_db,db(Id, Key, CPT, Type, Domain, CPTSize, DSize),_).
record_parent_sizes([], Id, [], DSizes) :- record_parent_sizes([], Id, [], DSizes) :-
@ -167,13 +175,13 @@ compress_hmm_table([Prob|L],[P|Parents],[Prob|NL],[P|NParents]) :-
compress_hmm_table(L,Parents,NL,NParents). compress_hmm_table(L,Parents,NL,NParents).
dist(Id) :- dist(Id) :-
recorded(clpbn_dist_db, db(Id, _, _, _, _, _), _). recorded(clpbn_dist_db, db(Id, _, _, _, _, _, _), _).
get_dist(Id, Type, Domain, Tab) :- get_dist(Id, Type, Domain, Tab) :-
recorded(clpbn_dist_db, db(Id, Tab, Type, Domain, _, _), _). recorded(clpbn_dist_db, db(Id, _, Tab, Type, Domain, _, _), _).
get_dist_matrix(Id, Parents, Type, Domain, Mat) :- get_dist_matrix(Id, Parents, Type, Domain, Mat) :-
recorded(clpbn_dist_db, db(Id, Tab, Type, Domain, _, DomainSize), _), recorded(clpbn_dist_db, db(Id, _, Tab, Type, Domain, _, DomainSize), _),
get_dsizes(Parents, Sizes, []), get_dsizes(Parents, Sizes, []),
matrix_new(floats, [DomainSize|Sizes], Tab, Mat), matrix_new(floats, [DomainSize|Sizes], Tab, Mat),
matrix_to_logs(Mat). matrix_to_logs(Mat).
@ -185,31 +193,31 @@ get_dsizes([P|Parents], [Sz|Sizes], Sizes0) :-
get_dsizes(Parents, Sizes, Sizes0). get_dsizes(Parents, Sizes, Sizes0).
get_dist_params(Id, Parms) :- get_dist_params(Id, Parms) :-
recorded(clpbn_dist_db, db(Id, Parms, _, _, _, _), _). recorded(clpbn_dist_db, db(Id, _, Parms, _, _, _, _), _).
get_dist_domain_size(Id, DSize) :- get_dist_domain_size(Id, DSize) :-
recorded(clpbn_dist_db, db(Id, _, _, _, _, DSize), _). recorded(clpbn_dist_db, db(Id, _, _, _, _, _, DSize), _).
get_dist_domain(Id, Domain) :- get_dist_domain(Id, Domain) :-
recorded(clpbn_dist_db, db(Id, _, _, Domain, _, _), _). recorded(clpbn_dist_db, db(Id, _, _, _, Domain, _, _), _).
get_dist_nparams(Id, NParms) :- get_dist_nparams(Id, NParms) :-
recorded(clpbn_dist_db, db(Id, _, _, _, NParms, _), _). recorded(clpbn_dist_db, db(Id, _, _, _, _, NParms, _), _).
get_evidence_position(El, Id, Pos) :- get_evidence_position(El, Id, Pos) :-
recorded(clpbn_dist_db, db(Id, _, _, Domain, _, _), _), recorded(clpbn_dist_db, db(Id, _, _, _, Domain, _, _), _),
nth0(Pos, Domain, El), !. nth0(Pos, Domain, El), !.
get_evidence_position(El, Id, Pos) :- get_evidence_position(El, Id, Pos) :-
recorded(clpbn_dist_db, db(Id, _, _, _, _, _), _), !, recorded(clpbn_dist_db, db(Id, _, _, _, _, _, _), _), !,
throw(error(domain_error(evidence,Id),get_evidence_position(El, Id, Pos))). throw(error(domain_error(evidence,Id),get_evidence_position(El, Id, Pos))).
get_evidence_position(El, Id, Pos) :- get_evidence_position(El, Id, Pos) :-
throw(error(domain_error(no_distribution,Id),get_evidence_position(El, Id, Pos))). throw(error(domain_error(no_distribution,Id),get_evidence_position(El, Id, Pos))).
get_evidence_from_position(El, Id, Pos) :- get_evidence_from_position(El, Id, Pos) :-
recorded(clpbn_dist_db, db(Id, _, _, Domain, _, _), _), recorded(clpbn_dist_db, db(Id, _, _, _, Domain, _, _), _),
nth0(Pos, Domain, El), !. nth0(Pos, Domain, El), !.
get_evidence_from_position(El, Id, Pos) :- get_evidence_from_position(El, Id, Pos) :-
recorded(clpbn_dist_db, db(Id, _, _, _, _, _), _), !, recorded(clpbn_dist_db, db(Id, _, _, _, _, _, _), _), !,
throw(error(domain_error(evidence,Id),get_evidence_from_position(El, Id, Pos))). throw(error(domain_error(evidence,Id),get_evidence_from_position(El, Id, Pos))).
get_evidence_from_position(El, Id, Pos) :- get_evidence_from_position(El, Id, Pos) :-
throw(error(domain_error(no_distribution,Id),get_evidence_from_position(El, Id, Pos))). throw(error(domain_error(no_distribution,Id),get_evidence_from_position(El, Id, Pos))).
@ -224,9 +232,9 @@ empty_dist(Dist, TAB) :-
dist_new_table(Id, NewMat) :- dist_new_table(Id, NewMat) :-
matrix_to_list(NewMat, List), matrix_to_list(NewMat, List),
recorded(clpbn_dist_db, db(Id, _, A, B, C, D), R), recorded(clpbn_dist_db, db(Id, Key, _, A, B, C, D), R),
erase(R), erase(R),
recorda(clpbn_dist_db, db(Id, List, A, B, C, D), R), recorda(clpbn_dist_db, db(Id, Key, List, A, B, C, D), _),
fail. fail.
dist_new_table(_, _). dist_new_table(_, _).

View File

@ -7,9 +7,11 @@
% Markov Blanket % Markov Blanket
% %
:- module(gibbs, [gibbs/3, :- module(clpbn_gibbs,
check_if_gibbs_done/1, [gibbs/3,
init_gibbs_solver/3]). check_if_gibbs_done/1,
init_gibbs_solver/3,
run_gibbs_solver/3]).
:- use_module(library(rbtrees), :- use_module(library(rbtrees),
[rb_new/1, [rb_new/1,
@ -47,27 +49,35 @@
:- dynamic implicit/1. :- dynamic implicit/1.
gibbs([],_,_) :- !. % arguments:
%
% list of output variables
% list of attributed variables
%
gibbs([[]],_,_) :- !.
gibbs(LVs,Vs0,AllDiffs) :- gibbs(LVs,Vs0,AllDiffs) :-
LVs = [_], !, init_gibbs_solver(LVs, Vs0, Vs),
init_gibbs_solver(Vs0, LVs, Gibbs),
(clpbn:output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,Vs) ; true), (clpbn:output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,Vs) ; true),
(clpbn:output(gviz(XBifStream)) -> clpbn2gviz(XBifStream,vel,Vs,LVs) ; true), (clpbn:output(gviz(XBifStream)) -> clpbn2gviz(XBifStream,vel,Vs,LVs) ; true),
initialise(Vs, Graph, LVs, OutputVars, VarOrder), run_gibbs_solver(LVs, Vs, LPs),
% write(Graph),nl,
process(VarOrder, Graph, OutputVars, Estimates),
sum_up(Estimates, [LPs]),
% write(Estimates),nl,
clpbn_bind_vals(LVs,LPs,AllDiffs), clpbn_bind_vals(LVs,LPs,AllDiffs),
clean_up. clean_up.
gibbs(LVs,_,_) :-
throw(error(domain_error(solver,LVs),solver(gibbs))).
init_gibbs_solver(LVs, Vs0, Gibbs) :- init_gibbs_solver(_, Vs0, Vs) :-
clean_up, clean_up,
check_for_hidden_vars(Vs0, Vs0, Vs1), check_for_hidden_vars(Vs0, Vs0, Vs1),
sort(Vs1,Vs). sort(Vs1,Vs).
run_gibbs_solver(LVs, Vs, LPs) :-
initialise(Vs, Graph, LVs, OutputVars, VarOrder),
% writeln(Graph),
% write_pars(Vs),
process(VarOrder, Graph, OutputVars, Estimates),
% writeln(Estimates),
sum_up_all(Estimates, LPs),
clean_up.
% writeln(Estimates).
initialise(LVs, Graph, GVs, OutputVars, VarOrder) :- initialise(LVs, Graph, GVs, OutputVars, VarOrder) :-
init_keys(Keys0), init_keys(Keys0),
gen_keys(LVs, 0, VLen, Keys0, Keys), gen_keys(LVs, 0, VLen, Keys0, Keys),
@ -76,7 +86,7 @@ initialise(LVs, Graph, GVs, OutputVars, VarOrder) :-
compile_graph(Graph), compile_graph(Graph),
topsort(TGraph, VarOrder), topsort(TGraph, VarOrder),
% show_sorted(VarOrder, Graph), % show_sorted(VarOrder, Graph),
add_output_vars(GVs, Keys, OutputVars). add_all_output_vars(GVs, Keys, OutputVars).
init_keys(Keys0) :- init_keys(Keys0) :-
rb_new(Keys0). rb_new(Keys0).
@ -120,7 +130,7 @@ graph_representation([V|Vs], Graph, I0, Keys, [I-IParents|TGraph]) :-
write_pars([]). write_pars([]).
write_pars([V|Parents]) :- write_pars([V|Parents]) :-
clpbn:get_atts(V, [key(K)]),write(K),nl, clpbn:get_atts(V, [key(K),dist(I,_)]),write(K:I),nl,
write_pars(Parents). write_pars(Parents).
get_sizes([], []). get_sizes([], []).
@ -310,9 +320,12 @@ normalise_factors([F|Factors],S0,S,[P0|Probs],PF) :-
PF is P0-F/S. PF is P0-F/S.
store_mblanket(I,Values,Probs) :- store_mblanket(I,Values,Probs) :-
append(Values,Probs,Args), recordz(mblanket,m(I,Values,Probs),_).
Rule =.. [mblanket,I|Args],
assert(Rule). add_all_output_vars([], _, []).
add_all_output_vars([Vs|LVs], Keys, [Is|OutputVars]) :-
add_output_vars(Vs, Keys, Is),
add_all_output_vars(LVs, Keys, OutputVars).
add_output_vars([], _, []). add_output_vars([], _, []).
add_output_vars([V|LVs], Keys, [I|OutputVars]) :- add_output_vars([V|LVs], Keys, [I|OutputVars]) :-
@ -382,15 +395,29 @@ iparents_pos_sz([I|IPars], Chain, [P|IPos], Graph, [Sz|Sizes]) :-
init_estimates(0,_,_,[]) :- !. init_estimates(0,_,_,[]) :- !.
init_estimates(NChains,OutputVars,Graph,[Est|Est0]) :- init_estimates(NChains,OutputVars,Graph,[Est|Est0]) :-
NChainsI is NChains-1, NChainsI is NChains-1,
init_estimate(OutputVars,Graph,Est), init_estimate_all_outvs(OutputVars,Graph,Est),
init_estimates(NChainsI,OutputVars,Graph,Est0). init_estimates(NChainsI,OutputVars,Graph,Est0).
init_estimate([],_,[]). init_estimate_all_outvs([],_,[]).
init_estimate([V|OutputVars],Graph,[[I|E0L]|Est]) :- init_estimate_all_outvs([Vs|OutputVars],Graph,[E|Est]) :-
arg(V,Graph,var(_,I,_,_,Sz,_,_,_,_)), init_estimate(Vs, Graph, E),
gen_e0(Sz,E0L), init_estimate_all_outvs(OutputVars,Graph,Est).
init_estimate(OutputVars,Graph,Est).
init_estimate([],_,[]).
init_estimate([V],Graph,[I|E0L]) :- !,
arg(V,Graph,var(_,I,_,_,Sz,_,_,_,_)),
gen_e0(Sz,E0L).
init_estimate(Vs,Graph,me(Is,Mults,Es)) :-
generate_est_mults(Vs, Is, Graph, Mults, Sz),
gen_e0(Sz,Es).
generate_est_mults([], [], _, [], 1).
generate_est_mults([V|Vs], [I|Is], Graph, [M0|Mults], M) :-
arg(V,Graph,var(_,I,_,_,Sz,_,_,_,_)),
generate_est_mults(Vs, Is, Graph, Mults, M0),
M is M0*Sz.
gen_e0(0,[]) :- !. gen_e0(0,[]) :- !.
gen_e0(Sz,[0|E0L]) :- gen_e0(Sz,[0|E0L]) :-
Sz1 is Sz-1, Sz1 is Sz-1,
@ -398,8 +425,9 @@ gen_e0(Sz,[0|E0L]) :-
process_chains(0,_,F,F,_,_,Est,Est) :- !. process_chains(0,_,F,F,_,_,Est,Est) :- !.
process_chains(ToDo,VarOrder,End,Start,Graph,Len,Est0,Estf) :- process_chains(ToDo,VarOrder,End,Start,Graph,Len,Est0,Estf) :-
%format('ToDo = ~d~n',[ToDo]),
process_chains(Start,VarOrder,Int,Graph,Len,Est0,Esti), process_chains(Start,VarOrder,Int,Graph,Len,Est0,Esti),
% (ToDo mod 100 =:= 0 -> statistics,cvt2problist(Esti, Probs), Int =[S|_], format('did ~d: ~w~n ~w~n',[ToDo,Probs,S]) ; true), % (ToDo mod 100 =:= 1 -> statistics,cvt2problist(Esti, Probs), Int =[S|_], format('did ~d: ~w~n ~w~n',[ToDo,Probs,S]) ; true),
ToDo1 is ToDo-1, ToDo1 is ToDo-1,
process_chains(ToDo1,VarOrder,End,Int,Graph,Len,Esti,Estf). process_chains(ToDo1,VarOrder,End,Int,Graph,Len,Esti,Estf).
@ -408,8 +436,8 @@ process_chains([], _, [], _, _,[],[]).
process_chains([Sample0|Samples0], VarOrder, [Sample|Samples], Graph, SampLen,[E0|E0s],[Ef|Efs]) :- process_chains([Sample0|Samples0], VarOrder, [Sample|Samples], Graph, SampLen,[E0|E0s],[Ef|Efs]) :-
functor(Sample,sample,SampLen), functor(Sample,sample,SampLen),
do_sample(VarOrder,Sample,Sample0,Graph), do_sample(VarOrder,Sample,Sample0,Graph),
%format('Sample = ~w~n',[Sample]), % format('Sample = ~w~n',[Sample]),
update_estimate(E0,Sample,Ef), update_estimates(E0,Sample,Ef),
process_chains(Samples0, VarOrder, Samples, Graph, SampLen,E0s,Efs). process_chains(Samples0, VarOrder, Samples, Graph, SampLen,E0s,Efs).
do_sample([],_,_,_). do_sample([],_,_,_).
@ -418,15 +446,14 @@ do_sample([I|VarOrder],Sample,Sample0,Graph) :-
do_sample(VarOrder,Sample,Sample0,Graph). do_sample(VarOrder,Sample,Sample0,Graph).
do_var(I,Sample,Sample0,Graph) :- do_var(I,Sample,Sample0,Graph) :-
arg(I,Graph,var(_,I,_,_,Sz,CPTs,Parents,_,_)),
( implicit(I) -> ( implicit(I) ->
fetch_parents(Parents,I,Sample,Sample0,Bindings,[]), arg(I,Graph,var(_,_,_,_,Sz,CPTs,Parents,_,_)),
multiply_all_in_context(Parents,Bindings,CPTs,Sz,Graph,Vals) fetch_parents(Parents,I,Sample,Sample0,Bindings),
multiply_all_in_context(Parents,Bindings,CPTs,Sz,Graph,Vals)
; ;
length(Vals,Sz), arg(I,Graph,var(_,_,_,_,_,_,Parents,_,_)),
fetch_parents(Parents,I,Sample,Sample0,Args,Vals), fetch_parents(Parents,I,Sample,Sample0,Args),
Goal =.. [mblanket,I|Args], recorded(mblanket,m(I,Args,Vals),_)
call(Goal)
), ),
X is random, X is random,
pick_new_value(Vals,X,0,Val), pick_new_value(Vals,X,0,Val),
@ -444,31 +471,50 @@ set_pos([I|Is],[Pos|Args],Graph) :-
arg(I,Graph,var(_,I,Pos,_,_,_,_,_,_)), arg(I,Graph,var(_,I,Pos,_,_,_,_,_,_)),
set_pos(Is,Args,Graph). set_pos(Is,Args,Graph).
fetch_parents([],_,_,_,Args,Args). fetch_parents([],_,_,_,[]).
fetch_parents([P|Parents],I,Sample,Sample0,[VP|Args],Vals) :- fetch_parents([P|Parents],I,Sample,Sample0,[VP|Args]) :-
arg(P,Sample,VP), arg(P,Sample,VP),
nonvar(VP), !, nonvar(VP), !,
fetch_parents(Parents,I,Sample,Sample0,Args,Vals). fetch_parents(Parents,I,Sample,Sample0,Args).
fetch_parents([P|Parents],I,Sample,Sample0,[VP|Args],Vals) :- fetch_parents([P|Parents],I,Sample,Sample0,[VP|Args]) :-
arg(P,Sample0,VP), arg(P,Sample0,VP),
fetch_parents(Parents,I,Sample,Sample0,Args,Vals). fetch_parents(Parents,I,Sample,Sample0,Args).
pick_new_value([V|_],X,Val,Val) :- pick_new_value([V|Vals],X,I0,Val) :-
X < V, !. ( X < V ->
pick_new_value([_|Vals],X,I0,Val) :- Val = I0
I is I0+1, ;
pick_new_value(Vals,X,I,Val). I is I0+1,
pick_new_value(Vals,X,I,Val)
).
update_estimate([],_,[]). update_estimates([],_,[]).
update_estimate([[I|E]|E0],Sample,[[I|NE]|Ef]) :- update_estimates([Est|E0],Sample,[NEst|Ef]) :-
update_estimate(Est,Sample,NEst),
update_estimates(E0,Sample,Ef).
update_estimate([I|E],Sample,[I|NE]) :-
arg(I,Sample,V), arg(I,Sample,V),
update_estimate_for_var(V,E,NE), update_estimate_for_var(V,E,NE).
update_estimate(E0,Sample,Ef). update_estimate(me(Is,Mult,E),Sample,me(Is,Mult,NE)) :-
get_estimate_pos(Is, Sample, Mult, 0, V),
update_estimate_for_var(V,E,NE).
update_estimate_for_var(0,[X|T],[X1|T]) :- !, X1 is X+1. get_estimate_pos([], _, [], V, V).
update_estimate_for_var(V,[E|Es],[E|NEs]) :- get_estimate_pos([I|Is], Sample, [M|Mult], V0, V) :-
V1 is V-1, arg(I,Sample,VV),
update_estimate_for_var(V1,Es,NEs). VI is VV*M+V0,
get_estimate_pos(Is, Sample, Mult, VI, V).
update_estimate_for_var(V0,[X|T],[X1|NT]) :-
( V0 == 0 ->
X1 is X+1,
NT = T
;
V1 is V0-1,
X1 = X,
update_estimate_for_var(V1,T,NT)
).
@ -476,16 +522,14 @@ check_if_gibbs_done(Var) :-
get_atts(Var, [dist(_)]), !. get_atts(Var, [dist(_)]), !.
clean_up :- clean_up :-
current_predicate(mblanket,P), eraseall(mblanket),
retractall(P),
fail. fail.
clean_up :- clean_up :-
retractall(implicit(_)), retractall(implicit(_)),
fail. fail.
clean_up. clean_up.
gibbs_params(5,1000,10000).
gibbs_params(5,10000,100000).
cvt2problist([], []). cvt2problist([], []).
cvt2problist([[[_|E]]|Est0], [Ps|Probs]) :- cvt2problist([[[_|E]]|Est0], [Ps|Probs]) :-
@ -510,16 +554,32 @@ show_sorted([I|VarOrder], Graph) :-
format('~w ',[K]), format('~w ',[K]),
show_sorted(VarOrder, Graph). show_sorted(VarOrder, Graph).
sum_up([[]|_], []). sum_up_all([[]|_], []).
sum_up([[[Id|Counts]|More]|Chains], [Dist|Dists]) :- sum_up_all([[C|MoreC]|Chains], [Dist|Dists]) :-
add_up(Counts,Chains, Id, Add,RChains), extract_sums(Chains, CurrentChains, LeftChains),
normalise(Add, Dist), sum_up([C|CurrentChains], Dist),
sum_up([More|RChains], Dists). sum_up_all([MoreC|LeftChains], Dists).
add_up(Counts,[],_,Counts,[]). extract_sums([], [], []).
add_up(Counts,[[[Id|Cs]|MoreVars]|Chains],Id, Add, [MoreVars|RChains]) :- extract_sums([[C|Chains]|MoreChains], [C|CurrentChains], [Chains|LeftChains]) :-
extract_sums(MoreChains, CurrentChains, LeftChains).
sum_up([[_|Counts]|Chains], Dist) :-
add_up(Counts,Chains, Add),
normalise(Add, Dist).
sum_up([me(_,_,Counts)|Chains], Dist) :-
add_up_mes(Counts,Chains, Add),
normalise(Add, Dist).
add_up(Counts,[],Counts).
add_up(Counts,[[_|Cs]|Chains], Add) :-
sum_lists(Counts, Cs, NCounts), sum_lists(Counts, Cs, NCounts),
add_up(NCounts, Chains, Id, Add, RChains). add_up(NCounts, Chains, Add).
add_up_mes(Counts,[],Counts).
add_up_mes(Counts,[me(_,_,Cs)|Chains], Add) :-
sum_lists(Counts, Cs, NCounts),
add_up_mes(NCounts, Chains, Add).
sum_lists([],[],[]). sum_lists([],[],[]).
sum_lists([Count|Counts], [C|Cs], [NC|NCounts]) :- sum_lists([Count|Counts], [C|Cs], [NC|NCounts]) :-

View File

@ -72,6 +72,7 @@
:- use_module(library('clpbn/display'), [ :- use_module(library('clpbn/display'), [
clpbn_bind_vals/3]). clpbn_bind_vals/3]).
jt([[]],_,_) :- !.
jt(LVs,Vs0,AllDiffs) :- jt(LVs,Vs0,AllDiffs) :-
get_graph(Vs0, BayesNet, CPTs, Evidence), get_graph(Vs0, BayesNet, CPTs, Evidence),
build_jt(BayesNet, CPTs, JTree), build_jt(BayesNet, CPTs, JTree),
@ -81,7 +82,7 @@ jt(LVs,Vs0,AllDiffs) :-
% write_tree(NewTree,0), % write_tree(NewTree,0),
propagate_evidence(Evidence, NewTree, EvTree), propagate_evidence(Evidence, NewTree, EvTree),
message_passing(EvTree, MTree), message_passing(EvTree, MTree),
get_margin(MTree, LVs, LPs), get_margins(MTree, LVs, LPs),
clpbn_bind_vals(LVs,LPs,AllDiffs). clpbn_bind_vals(LVs,LPs,AllDiffs).
@ -461,7 +462,12 @@ downward([tree(Clique1-(Dist1,Msg1),DistKids)|Kids], Clique, Tab, [tree(Clique1-
downward(Kids, Clique, Tab, NKids). downward(Kids, Clique, Tab, NKids).
get_margin(NewTree, LVs0, LPs) :- get_margins(_, [], []).
get_margins(NewTree, [Vs|LVs], [LP|LPs]) :-
get_margin(NewTree, Vs, LP),
get_margins(NewTree, LVs, LPs).
get_margin(NewTree, LVs, LPs) :-
sort(LVs0, LVs), sort(LVs0, LVs),
find_clique(NewTree, LVs, Clique, Dist), find_clique(NewTree, LVs, Clique, Dist),
sum_out_from_CPT(LVs, Dist, Clique, tab(TAB,_,_)), sum_out_from_CPT(LVs, Dist, Clique, tab(TAB,_,_)),

View File

@ -37,7 +37,6 @@ topsort(Graph0, Sort0, Found0, Sort) :-
add_nodes([], Found, Sort, [], Found, Sort). add_nodes([], Found, Sort, [], Found, Sort).
add_nodes([N-Ns|Graph0], Found0, SortI, NewGraph, Found, NSort) :- add_nodes([N-Ns|Graph0], Found0, SortI, NewGraph, Found, NSort) :-
(N=1600 -> write(Ns), nl ; true),
delete_nodes(Ns, Found0, NNs), delete_nodes(Ns, Found0, NNs),
( NNs == [] -> ( NNs == [] ->
NewGraph = IGraph, NewGraph = IGraph,

View File

@ -55,8 +55,8 @@ check_if_vel_done(Var) :-
% %
% implementation of the well known variable elimination algorithm % implementation of the well known variable elimination algorithm
% %
vel([],_,_) :- !. vel([[]],_,_) :- !.
vel(LVs,Vs0,AllDiffs) :- vel([LVs],Vs0,AllDiffs) :-
check_for_hidden_vars(Vs0, Vs0, Vs1), check_for_hidden_vars(Vs0, Vs0, Vs1),
sort(Vs1,Vs), sort(Vs1,Vs),
% LVi will have a list of CLPBN variables % LVi will have a list of CLPBN variables

View File

@ -2,113 +2,166 @@
% The world famous EM algorithm, in a nutshell % The world famous EM algorithm, in a nutshell
% %
:- module(clpbn_em, [em/6]). :- module(clpbn_em, [em/5]).
:- use_module(library(lists), :- use_module(library(lists),
[append/3]). [append/3]).
:- use_module(library(clpbn),
[clpbn_init_solver/3,
clpbn_run_solver/3]).
:- use_module(library('clpbn/dists'),
[get_dist_domain_size/2,
empty_dist/2,
dist_new_table/2]).
:- use_module(library('clpbn/learning/learn_utils'), :- use_module(library('clpbn/learning/learn_utils'),
[run_all/1, [run_all/1,
clpbn_vars/2, clpbn_vars/2,
normalise_counts/2]). normalise_counts/2,
compute_likelihood/3]).
:- use_module(library(lists),
[member/2]).
:- use_module(library(matrix),
[matrix_add/3,
matrix_to_list/2]).
:- use_module(library('clpbn/utils'), [
check_for_hidden_vars/3]).
:- meta_predicate em(:,+,+,-,-), init_em(:,-).
em(Items, MaxError, MaxIts, Tables, Likelihood) :- em(Items, MaxError, MaxIts, Tables, Likelihood) :-
init_em(Items, State), init_em(Items, State),
em_loop(0, 0.0, state(AllVars,AllDists), MaxError, MaxIts, Likelihood), em_loop(0, 0.0, State, MaxError, MaxIts, Likelihood, Tables).
get_tables(State, Tables).
% This gets you an initial configuration. If there is a lot of evidence % This gets you an initial configuration. If there is a lot of evidence
% tables may be filled in close to optimal, otherwise they may be % tables may be filled in close to optimal, otherwise they may be
% close to uniform. % close to uniform.
% it also gets you a run for random variables % it also gets you a run for random variables
init_em(Items, state(AllVars, AllDists, AllDistInstances)) :-
% state collects all Info we need for the EM algorithm
% it includes the list of variables without evidence,
% the list of distributions for which we want to compute parameters,
% and more detailed info on distributions, namely with a list of all instances for the distribution.
init_em(Items, state(AllVars, AllDists, AllDistInstances, MargVars)) :-
run_all(Items), run_all(Items),
different_dists(AllVars, AllDists, AllDistInstances). attributes:all_attvars(AllVars0),
% remove variables that do not have to do with this query.
check_for_hidden_vars(AllVars0, AllVars0, AllVars),
different_dists(AllVars, AllDists, AllDistInstances, MargVars),
clpbn_init_solver(MargVars, AllVars, _).
% loop for as long as you want. % loop for as long as you want.
em_loop(MaxIts, Likelihood, State, _, _, MaxIts, Likelihood) :- !. em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF, FTables) :-
em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF) :- estimate(State, LPs),
estimate(State), maximise(State, Tables, LPs, Likelihood),
maximise(State, Likelihood),
( (
( (
(Likelihood - Likelihood0)/Likelihood < MaxError (Likelihood - Likelihood0)/Likelihood < MaxError
; ;
Its == MaxIts Its == MaxIts
) )
-> ->
ltables(Tables, FTables),
LikelihoodF = Likelihood LikelihoodF = Likelihood
; ;
Its1 is Its+1, Its1 is Its+1,
em_loop(Its1, Likelihood, State, MaxError, MaxIts, LikelihoodF) em_loop(Its1, Likelihood, State, MaxError, MaxIts, LikelihoodF, FTables)
). ).
ltables([], []).
ltables([Id-T|Tables], [Id-LTable|FTables]) :-
matrix_to_list(T,LTable),
ltables(Tables, FTables).
% collect the different dists we are going to learn next. % collect the different dists we are going to learn next.
different_dists(AllVars, AllDists, AllInfo) :- different_dists(AllVars, AllDists, AllInfo, MargVars) :-
all_dists(AllVars, Dists0, AllInfo), all_dists(AllVars, Dists0),
sort(Dists0, Dists1), sort(Dists0, Dists1),
group(Dists1, AllInfo). group(Dists1, AllDists, AllInfo, MargVars, []).
group([], []). all_dists([], []).
group([i(Id,V,Ps)|Dists1], [Id-[[V|Ps]|Extra]|AllInfo]) :- all_dists([V|AllVars], [i(Id, [V|Parents], Cases, Hiddens)|Dists]) :-
same_id(Dists1, Id, Extra, Rest), clpbn:get_atts(V, [dist(Id,Parents)]),
group(Rest, AllInfo). generate_hidden_cases([V|Parents], CompactCases, Hiddens),
uncompact_cases(CompactCases, Cases),
all_dists(AllVars, Dists).
same_id([i(Id,V,Ps)|Dists1], Id, [[V|Ps]|Extra], Rest) :- !, generate_hidden_cases([], [], []).
same_id(Dists1, Id, Extra, Rest). generate_hidden_cases([V|Parents], [P|Cases], Hiddens) :-
same_id(Dists, _, [], Dists). clpbn:get_atts(V, [evidence(P)]), !,
generate_hidden_cases(Parents, Cases, Hiddens).
all_dists([], [], []). generate_hidden_cases([V|Parents], [Cases|MoreCases], [V|Hiddens]) :-
all_dists([V|AllVars], Dists, [i(Id, AllInfo, Parents)|AllInfo]) :-
clpbn:get_atts(V, [dist(Id,_)]), clpbn:get_atts(V, [dist(Id,_)]),
with_evidence(V, Id, Dists, Dists0), !, get_dist_domain_size(Id, Sz),
all_dists(AllVars, Dists0, AllInfo). gen_cases(0, Sz, Cases),
generate_hidden_cases(Parents, MoreCases, Hiddens).
with_evidence(V, Id) -->
{clpbn:get_atts(V, [evidence(Pos)]) }, !,
{ dist_pos2bin(Pos, Id, Bin) }.
with_evidence(V, Id) -->
[d(V,Id)].
estimate(state(Vars,Info,_)) :-
clpbn_solve_graph(Vars, OVars),
marg_vars(Info, Vars).
marg_vars([], _).
marg_vars([d(V,Id)|Vars], AllVs) :-
clpbn_marginalise_in_vars(V, AllVs),
marg_vars(Vars, AllVs).
maximise(state(_,_,DistInstances), Tables, Likelihood) :-
compute_parameters(DistInstances, Tables, 0.0, Likelihood).
compute_parameters([], [], Lik, Lik).
compute_parameters([Id-Samples|Dists], [Tab|Tables], Lik0, Lik) :-
empty_dist(Id, NewTable),
add_samples(Samples, NewTable),
normalise_table(Id, NewTable),
compute_parameters(Dists, Tables, Lik0, Lik).
add_samples([], _).
add_samples([S|Samples], Table) :-
run_sample(S, 1.0, Pos, Tot),
matrix_add(Table, Pos, Tot),
fail.
add_samples([_|Samples], Table) :-
add_samples(Samples, Table).
run_sample([], Tot, [], Tot). gen_cases(Sz, Sz, []) :- !.
run_sample([V|S], W0, [P|Pos], Tot) :- gen_cases(I, Sz, [I|Cases]) :-
{clpbn:get_atts(V, [evidence(P)]) }, !, I1 is I+1,
run_sample(S, W0, Pos, Tot). gen_cases(I1, Sz, Cases).
run_sample([V|S], W0, [P|Pos], Tot) :-
{clpbn_display:get_atts(V, [posterior,(_,_,Ps,_)]) }, uncompact_cases(CompactCases, Cases) :-
count_cases(Ps, 0, D0, P), findall(Case, is_case(CompactCases, Case), Cases).
W1 is D0*W0,
run_sample(S, W1, Pos, Tot). is_case([], []).
is_case([A|CompactCases], [A|Case]) :-
count_cases([D0|Ps], I0, D0, I0). integer(A), !,
count_cases([_|Ps], I0, P, W1) :- is_case(CompactCases, Case).
I is I0+1, is_case([L|CompactCases], [C|Case]) :-
count_cases(Ps, I, P, W1). member(C, L),
is_case(CompactCases, Case).
group([], [], []) --> [].
group([i(Id,Ps,Cs,[])|Dists1], [Id|Ids], [Id-[i(Id,Ps,Cs,[])|Extra]|AllInfo]) --> !,
same_id(Dists1, Id, Extra, Rest),
group(Rest, Ids, AllInfo).
group([i(Id,Ps,Cs,Hs)|Dists1], [Id|Ids], [Id-[i(Id,Ps,Cs,Hs)|Extra]|AllInfo]) -->
[Hs],
same_id(Dists1, Id, Extra, Rest),
group(Rest, Ids, AllInfo).
same_id([i(Id,Vs,Cases,[])|Dists1], Id, [i(Id, Vs, Cases, [])|Extra], Rest) --> !,
same_id(Dists1, Id, Extra, Rest).
same_id([i(Id,Vs,Cases,Hs)|Dists1], Id, [i(Id, Vs, Cases, Hs)|Extra], Rest) --> !,
[Hs],
same_id(Dists1, Id, Extra, Rest).
same_id(Dists, _, [], Dists) --> [].
estimate(state(Vars, _, _, Margs), LPs) :-
clpbn_run_solver(Margs, Vars, LPs).
maximise(state(_,_,DistInstances,_), Tables, LPs, Likelihood) :-
compute_parameters(DistInstances, Tables, LPs, 0.0, Likelihood).
compute_parameters([], [], [], Lik, Lik).
compute_parameters([Id-Samples|Dists], [Id-NewTable|Tables], Ps, Lik0, Lik) :-
empty_dist(Id, Table0),
add_samples(Samples, Table0, Ps, MorePs),
normalise_counts(Table0, NewTable),
compute_likelihood(Table0, NewTable, DeltaLik),
dist_new_table(Id, NewTable),
NewLik is Lik0+DeltaLik,
compute_parameters(Dists, Tables, MorePs, NewLik, Lik).
add_samples([], _, Ps, Ps).
add_samples([i(_,_,[Case],[])|Samples], Table, AllPs, RPs) :- !,
matrix_add(Table,Case,1.0),
add_samples(Samples, Table, AllPs, RPs).
add_samples([i(_,_,Cases,_)|Samples], Table, [Ps|AllPs], RPs) :-
run_sample(Cases, Ps, Table),
add_samples(Samples, Table, AllPs, RPs).
run_sample([], [], _).
run_sample([C|Cases], [P|Ps], Table) :-
matrix_add(Table, C, P),
run_sample(Cases, Ps, Table).

View File

@ -0,0 +1,34 @@
% learn distribution for school database.
:- [pos:train].
:- ['~/Yap/work/CLPBN/clpbn/examples/School/school_32'].
:- ['~/Yap/work/CLPBN/learning/em'].
main :-
findall(X,goal(X),L),
em(L,0.1,10,CPTs,Lik),
writeln(Lik:CPTs).
% miss 30% of the examples.
goal(professor_ability(P,V)) :-
pos:professor_ability(P,V1),
( random > 0.3 -> V = V1 ; true).
% miss 10% of the examples.
goal(professor_popularity(P,V)) :-
pos:professor_popularity(P,V1),
( random > 0.3 -> V = V1 ; true).
goal(registration_grade(P,V)) :-
pos:registration_grade(P,V1),
( random > 0.1 -> V = V1 ; true).
goal(student_intelligence(P,V)) :-
pos:student_intelligence(P,V1),
( random > 0.1 -> V = V1 ; true).
goal(course_difficulty(P,V)) :-
pos:course_difficulty(P,V1),
( random > 0.1 -> V = V1 ; true).
/*
goal(registration_satisfaction(P,V)) :-
pos:registration_satisfaction(P,V).
*/

File diff suppressed because it is too large Load Diff

View File

@ -2,13 +2,31 @@
% Utilities for learning % Utilities for learning
% %
:- module(bnt_learn_utils, [run_all/1, :- module(clpbn_learn_utils, [run_all/1,
clpbn_vars/2]). clpbn_vars/2,
normalise_counts/2,
compute_likelihood/3]).
:- use_module(library(matrix),
[matrix_agg_lines/3,
matrix_op_to_lines/4,
matrix_to_logs/2,
matrix_op/4,
matrix_sum/2]).
:- meta_predicate run_all(:).
run_all([]). run_all([]).
run_all([G|Gs]) :- run_all([G|Gs]) :-
call(user:G), call(G),
run_all(Gs). run_all(Gs).
run_all(M:Gs) :-
run_all(Gs,M).
run_all([],_).
run_all([G|Gs],M) :-
call(M:G),
run_all(Gs,M).
clpbn_vars(Vs,BVars) :- clpbn_vars(Vs,BVars) :-
get_clpbn_vars(Vs,CVs), get_clpbn_vars(Vs,CVs),
@ -31,5 +49,15 @@ get_var_has_same_key([K-V|KVs],K,V,KVs0) :- !,
get_var_has_same_key(KVs,K,V,KVs0). get_var_has_same_key(KVs,K,V,KVs0).
get_var_has_same_key(KVs,_,_,KVs). get_var_has_same_key(KVs,_,_,KVs).
normalise_counts(MAT,NMAT) :-
matrix_agg_lines(MAT, +, Sum),
matrix_op_to_lines(MAT, Sum, /, NMAT).
compute_likelihood(Table0, NewTable, DeltaLik) :-
matrix:matrix_to_list(Table0,L0), writeln(L0),
matrix:matrix_to_list(NewTable,L1), writeln(L1),
matrix_to_logs(NewTable, Logs),
matrix_op(Table0, Logs, *, Logs),
matrix_sum(Logs, DeltaLik).

View File

@ -1,12 +1,32 @@
% %
% Maximum likelihood estimator and friends.
%
%
% This assumes we have a single big example. % This assumes we have a single big example.
% %
:- use_module(library('clpbn_learning/utils'), :- module(clpbn_mle, [learn_parameters/2,
[run_all/1, learn_parameters/3,
clpbn_vars/2]). parameters_from_evidence/3]).
:- module(bnt_mle, [learn_parameters/2]). :- use_module(library('clpbn')).
:- use_module(library('clpbn/learning/learn_utils'),
[run_all/1,
clpbn_vars/2,
normalise_counts/2]).
:- use_module(library('clpbn/dists'),
[empty_dist/2,
dist_new_table/2]).
:- use_module(library(matrix),
[matrix_inc/2,
matrix_op_to_all/4]).
learn_parameters(Items, Tables) :-
learn_parameters(Items, Tables, []).
% %
% full evidence learning % full evidence learning
@ -14,16 +34,26 @@
learn_parameters(Items, Tables, Extras) :- learn_parameters(Items, Tables, Extras) :-
run_all(Items), run_all(Items),
attributes:all_attvars(AVars), attributes:all_attvars(AVars),
% sort and incorporte evidence % sort and incorporate evidence
clpbn_vars(AVars, AllVars), clpbn_vars(AVars, AllVars),
mk_sample(AllVars, Sample), mk_sample(AllVars, Sample),
compute_tables(Extras, Sample, Tables). compute_tables(Extras, Sample, Tables).
mk_sample(AllVars, NVars, LL) :- parameters_from_evidence(AllVars, Sample, Extras) :-
add2sample(AllVars, Sample), mk_sample_from_evidence(AllVars, Sample),
msort(Sample, AddL), compute_tables(Extras, Sample, Tables).
compute_params(AddL, Parms).
mk_sample_from_evidence(AllVars, SortedSample) :-
add_evidence2sample(AllVars, Sample),
msort(Sample, SortedSample).
mk_sample(AllVars, SortedSample) :-
add2sample(AllVars, Sample),
msort(Sample, SortedSample).
%
% assumes we have full data, meaning evidence for every variable
%
add2sample([], []). add2sample([], []).
add2sample([V|Vs],[val(Id,[Ev|EParents])|Vals]) :- add2sample([V|Vs],[val(Id,[Ev|EParents])|Vals]) :-
clpbn:get_atts(V, [evidence(Ev),dist(Id,Parents)]), clpbn:get_atts(V, [evidence(Ev),dist(Id,Parents)]),
@ -31,12 +61,72 @@ add2sample([V|Vs],[val(Id,[Ev|EParents])|Vals]) :-
add2sample(Vs, Vals). add2sample(Vs, Vals).
get_eparents([P|Parents], [E|EParents]) :- get_eparents([P|Parents], [E|EParents]) :-
clpbn:get_atts(V, [evidence(Ev)]), clpbn:get_atts(P, [evidence(E)]),
get_eparents(Parents, EParents). get_eparents(Parents, EParents).
get_eparents([], []). get_eparents([], []).
compute_tables([], Sample, Tables) :- %
mle(Sample, Tables). % assumes we ignore variables without evidence or without evidence
compute_tables([laplace|_], Sample, Tables) :- % on a parent!
laplace(Sample, Tables). %
add_evidence2sample([], []).
add_evidence2sample([V|Vs],[val(Id,[Ev|EParents])|Vals]) :-
clpbn:get_atts(V, [evidence(Ev),dist(Id,Parents)]),
get_eveparents(Parents, EParents), !,
add_evidence2sample(Vs, Vals).
add_evidence2sample([_|Vs],Vals) :-
add_evidence2sample(Vs, Vals).
get_eveparents([P|Parents], [E|EParents]) :-
clpbn:get_atts(P, [evidence(E)]),
get_eparents(Parents, EParents).
get_eveparents([], []).
compute_tables(Parameters, Sample, NewTables) :-
estimator(Sample, Tables),
add_priors(Parameters, Tables, NewTables).
estimator([], []).
estimator([val(Id,Sample)|Samples], [NewTable|Tables]) :-
empty_dist(Id, NewTable),
id_samples(Id, Samples, IdSamples, MoreSamples),
mle([Sample|IdSamples], NewTable),
% replace matrix in distribution
dist_new_table(Id, NewTable),
estimator(MoreSamples, Tables).
id_samples(_, [], [], []).
id_samples(Id, [val(Id,Sample)|Samples], [Sample|IdSamples], MoreSamples) :- !,
id_samples(Id, Samples, IdSamples, MoreSamples).
id_samples(_, Samples, [], Samples).
mle([Sample|IdSamples], Table) :-
matrix_inc(Table, Sample),
mle(IdSamples, Table).
mle([], _).
add_priors([], Tables, NewTables) :-
normalise(Tables, NewTables).
add_priors([laplace|_], Tables, NewTables) :- !,
laplace(Tables, TablesI),
normalise(TablesI, NewTables).
add_priors([m_estimate(M)|_], Tables, NewTables) :- !,
add_mestimate(Tables, M, TablesI),
normalise(TablesI, NewTables).
add_priors([_|Parms], Tables, NewTables) :-
add_priors(Parms, Tables, NewTables).
normalise([], []).
normalise([T0|TablesI], [T|NewTables]) :-
normalise_counts(T0, T),
normalise(TablesI, NewTables).
laplace([], []).
laplace([T0|TablesI], [T|NewTables]) :-
matrix_op_to_all(T0, +, 1, T),
laplace(TablesI, NewTables).