122 lines
2.7 KiB
Prolog
122 lines
2.7 KiB
Prolog
%
|
|
% Learn parameters using the BNT toolkit
|
|
%
|
|
|
|
:- yap_flag(unknown,error).
|
|
|
|
:- style_check(all).
|
|
|
|
:- module(bnt_parameters, [learn_parameters/2]).
|
|
|
|
:- use_module(library('clpbn'),
|
|
[clpbn_flag/3]).
|
|
|
|
:- use_module(library('clpbn/bnt'),
|
|
[create_bnt_graph/2]).
|
|
|
|
:- use_module(library('clpbn/display'),
|
|
[clpbn_bind_vals/3]).
|
|
|
|
:- use_module(library('clpbn/dists'),
|
|
[get_dist_domain/2]).
|
|
|
|
:- use_module(library(matlab),
|
|
[matlab_initialized_cells/4,
|
|
matlab_call/2,
|
|
matlab_get_variable/2
|
|
]).
|
|
|
|
:- dynamic bnt_em_max_iter/1.
|
|
bnt_em_max_iter(10).
|
|
|
|
|
|
% syntactic sugar for matlab_call.
|
|
:- op(800,yfx,<--).
|
|
|
|
G <-- Y :-
|
|
matlab_call(Y,G).
|
|
|
|
|
|
learn_parameters(Items, Tables) :-
|
|
run_all(Items),
|
|
clpbn_flag(solver, OldSolver, bnt),
|
|
clpbn_flag(bnt_model, Old, tied),
|
|
attributes:all_attvars(AVars),
|
|
% sort and incorporte evidence
|
|
clpbn_vars(AVars, AllVars),
|
|
length(AllVars,NVars),
|
|
create_bnt_graph(AllVars, Reps),
|
|
mk_sample(AllVars,NVars,EvVars),
|
|
bnt_learn_parameters(NVars,EvVars),
|
|
get_parameters(Reps, Tables),
|
|
clpbn_flag(solver, bnt, OldSolver),
|
|
clpbn_flag(bnt_model, tied, Old).
|
|
|
|
run_all([]).
|
|
run_all([G|Gs]) :-
|
|
call(user:G),
|
|
run_all(Gs).
|
|
|
|
clpbn_vars(Vs,BVars) :-
|
|
get_clpbn_vars(Vs,CVs),
|
|
keysort(CVs,KVs),
|
|
merge_vars(KVs,BVars).
|
|
|
|
get_clpbn_vars([],[]).
|
|
get_clpbn_vars([V|GVars],[K-V|CLPBNGVars]) :-
|
|
clpbn:get_atts(V, [key(K)]), !,
|
|
get_clpbn_vars(GVars,CLPBNGVars).
|
|
get_clpbn_vars([_|GVars],CLPBNGVars) :-
|
|
get_clpbn_vars(GVars,CLPBNGVars).
|
|
|
|
merge_vars([],[]).
|
|
merge_vars([K-V|KVs],[V|BVars]) :-
|
|
get_var_has_same_key(KVs,K,V,KVs0),
|
|
merge_vars(KVs0,BVars).
|
|
|
|
get_var_has_same_key([K-V|KVs],K,V,KVs0) :- !,
|
|
get_var_has_same_key(KVs,K,V,KVs0).
|
|
get_var_has_same_key(KVs,_,_,KVs).
|
|
|
|
|
|
mk_sample(AllVars,NVars, LL) :-
|
|
add2sample(AllVars, LN),
|
|
length(LN,LL),
|
|
matlab_initialized_cells( NVars, 1, LN, sample).
|
|
|
|
add2sample([], []).
|
|
add2sample([V|Vs],[val(VId,1,Val)|Vals]) :-
|
|
clpbn:get_atts(V, [evidence(Ev),dist(Id,_)]), !,
|
|
bnt:get_atts(V,[bnt_id(VId)]),
|
|
get_dist_domain(Id, Domain),
|
|
evidence_val(Ev,1,Domain,Val),
|
|
add2sample(Vs, Vals).
|
|
add2sample([_V|Vs],Vals) :-
|
|
add2sample(Vs, Vals).
|
|
|
|
evidence_val(Ev,Val,[Ev|_],Val) :- !.
|
|
evidence_val(Ev,I0,[_|Domain],Val) :-
|
|
I1 is I0+1,
|
|
evidence_val(Ev,I1,Domain,Val).
|
|
|
|
bnt_learn_parameters(_,_) :-
|
|
engine <-- jtree_inf_engine(bnet),
|
|
% engine <-- var_elim_inf_engine(bnet),
|
|
% engine <-- gibbs_sampling_inf_engine(bnet),
|
|
% engine <-- belprop_inf_engine(bnet),
|
|
% engine <-- pearl_inf_engine(bnet),
|
|
bnt_em_max_iter(MaxIters),
|
|
[new_bnet, trace] <-- learn_params_em(engine, sample, MaxIters).
|
|
|
|
|
|
get_parameters([],[]).
|
|
get_parameters([Rep-v(_,_,_)|Reps],[CPT|CPTs]) :-
|
|
get_new_table(Rep,CPT),
|
|
get_parameters(Reps,CPTs).
|
|
|
|
get_new_table(Rep,CPT) :-
|
|
s <-- struct(new_bnet.'CPD'({Rep})),
|
|
matlab_get_variable( s.'CPT', CPT).
|
|
|
|
|