%
% Learn parameters using the BNT toolkit
%

:- yap_flag(unknown,error).

:- style_check(all).

:- module(bnt_parameters, [learn_parameters/2]).

:- use_module(library('clpbn'), [
	clpbn_flag/3]).

:- use_module(library('clpbn/bnt'), [
	create_bnt_graph/2]).

:- use_module(library('clpbn/display'), [
	clpbn_bind_vals/3]).

:- use_module(library('clpbn/dists'), [
				       get_dist_domain/2
				      ]).

:- use_module(library(matlab), [matlab_initialized_cells/4,
				matlab_call/2,
				matlab_get_variable/2
			      ]).

:- dynamic bnt_em_max_iter/1.
bnt_em_max_iter(10).


% syntactic sugar for matlab_call.
:- op(800,yfx,<--).

G <-- Y :-
	matlab_call(Y,G).


learn_parameters(Items, Tables) :-
	run_all(Items),
	clpbn_flag(solver, OldSolver, bnt),
	clpbn_flag(bnt_model, Old, tied),
	attributes:all_attvars(AVars),
	% sort and incorporte evidence
	clpbn_vars(AVars, AllVars),
	length(AllVars,NVars),
	create_bnt_graph(AllVars, Reps),
	mk_sample(AllVars,NVars,EvVars),
	bnt_learn_parameters(NVars,EvVars),
	get_parameters(Reps, Tables),
	clpbn_flag(solver, bnt, OldSolver),
	clpbn_flag(bnt_model, tied, Old).

run_all([]).
run_all([G|Gs]) :-
	call(user:G),
	run_all(Gs).

clpbn_vars(Vs,BVars) :-
	get_clpbn_vars(Vs,CVs),
	keysort(CVs,KVs),
	merge_vars(KVs,BVars).
	
get_clpbn_vars([],[]).
get_clpbn_vars([V|GVars],[K-V|CLPBNGVars]) :-
	clpbn:get_atts(V, [key(K)]), !,
	get_clpbn_vars(GVars,CLPBNGVars).
get_clpbn_vars([_|GVars],CLPBNGVars) :-
	get_clpbn_vars(GVars,CLPBNGVars).

merge_vars([],[]).
merge_vars([K-V|KVs],[V|BVars]) :-
	get_var_has_same_key(KVs,K,V,KVs0),
	merge_vars(KVs0,BVars).
	
get_var_has_same_key([K-V|KVs],K,V,KVs0)  :- !,
	get_var_has_same_key(KVs,K,V,KVs0).
get_var_has_same_key(KVs,_,_,KVs).


mk_sample(AllVars,NVars, LL) :-
	add2sample(AllVars, LN),
	length(LN,LL),
	matlab_initialized_cells( NVars, 1, LN, sample).

add2sample([],  []).
add2sample([V|Vs],[val(VId,1,Val)|Vals]) :-
	clpbn:get_atts(V, [evidence(Ev),dist(Id,_)]), !,
	bnt:get_atts(V,[bnt_id(VId)]),
	get_dist_domain(Id, Domain),
	evidence_val(Ev,1,Domain,Val),
	add2sample(Vs, Vals).
add2sample([_V|Vs],Vals) :-
	add2sample(Vs, Vals).

evidence_val(Ev,Val,[Ev|_],Val) :- !.
evidence_val(Ev,I0,[_|Domain],Val) :-
	I1 is I0+1,
	evidence_val(Ev,I1,Domain,Val).

bnt_learn_parameters(_,_) :-
	engine <-- jtree_inf_engine(bnet),
%	engine <-- var_elim_inf_engine(bnet),
%	engine <-- gibbs_sampling_inf_engine(bnet),
%	engine <-- belprop_inf_engine(bnet),
%	engine <-- pearl_inf_engine(bnet),
	bnt_em_max_iter(MaxIters),
	[new_bnet, trace] <-- learn_params_em(engine, sample, MaxIters).


get_parameters([],[]).
get_parameters([Rep-v(_,_,_)|Reps],[CPT|CPTs]) :-
	get_new_table(Rep,CPT),
	get_parameters(Reps,CPTs).
	
get_new_table(Rep,CPT) :-
	s <-- struct(new_bnet.'CPD'({Rep})),
	matlab_get_variable( s.'CPT', CPT).