122 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
		
		
			
		
	
	
			122 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
|   | % | ||
|  | % Learn parameters using the BNT toolkit | ||
|  | % | ||
|  | 
 | ||
|  | :- yap_flag(unknown,error). | ||
|  | 
 | ||
|  | :- style_check(all). | ||
|  | 
 | ||
|  | :- module(bnt_parameters, [learn_parameters/2]). | ||
|  | 
 | ||
|  | :- use_module(library('clpbn'), [ | ||
|  | 	clpbn_flag/3]). | ||
|  | 
 | ||
|  | :- use_module(library('clpbn/bnt'), [ | ||
|  | 	create_bnt_graph/2]). | ||
|  | 
 | ||
|  | :- use_module(library('clpbn/display'), [ | ||
|  | 	clpbn_bind_vals/3]). | ||
|  | 
 | ||
|  | :- use_module(library('clpbn/dists'), [ | ||
|  | 				       get_dist_domain/2 | ||
|  | 				      ]). | ||
|  | 
 | ||
|  | :- use_module(library(matlab), [matlab_initialized_cells/4, | ||
|  | 				matlab_call/2, | ||
|  | 				matlab_get_variable/2 | ||
|  | 			      ]). | ||
|  | 
 | ||
|  | :- dynamic bnt_em_max_iter/1. | ||
|  | bnt_em_max_iter(10). | ||
|  | 
 | ||
|  | 
 | ||
|  | % syntactic sugar for matlab_call. | ||
|  | :- op(800,yfx,<--). | ||
|  | 
 | ||
|  | G <-- Y :- | ||
|  | 	matlab_call(Y,G). | ||
|  | 
 | ||
|  | 
 | ||
|  | learn_parameters(Items, Tables) :- | ||
|  | 	run_all(Items), | ||
|  | 	clpbn_flag(solver, OldSolver, bnt), | ||
|  | 	clpbn_flag(bnt_model, Old, tied), | ||
|  | 	attributes:all_attvars(AVars), | ||
|  | 	% sort and incorporte evidence | ||
|  | 	clpbn_vars(AVars, AllVars), | ||
|  | 	length(AllVars,NVars), | ||
|  | 	create_bnt_graph(AllVars, Reps), | ||
|  | 	mk_sample(AllVars,NVars,EvVars), | ||
|  | 	bnt_learn_parameters(NVars,EvVars), | ||
|  | 	get_parameters(Reps, Tables), | ||
|  | 	clpbn_flag(solver, bnt, OldSolver), | ||
|  | 	clpbn_flag(bnt_model, tied, Old). | ||
|  | 
 | ||
|  | run_all([]). | ||
|  | run_all([G|Gs]) :- | ||
|  | 	call(user:G), | ||
|  | 	run_all(Gs). | ||
|  | 
 | ||
|  | clpbn_vars(Vs,BVars) :- | ||
|  | 	get_clpbn_vars(Vs,CVs), | ||
|  | 	keysort(CVs,KVs), | ||
|  | 	merge_vars(KVs,BVars). | ||
|  | 	 | ||
|  | get_clpbn_vars([],[]). | ||
|  | get_clpbn_vars([V|GVars],[K-V|CLPBNGVars]) :- | ||
|  | 	clpbn:get_atts(V, [key(K)]), !, | ||
|  | 	get_clpbn_vars(GVars,CLPBNGVars). | ||
|  | get_clpbn_vars([_|GVars],CLPBNGVars) :- | ||
|  | 	get_clpbn_vars(GVars,CLPBNGVars). | ||
|  | 
 | ||
|  | merge_vars([],[]). | ||
|  | merge_vars([K-V|KVs],[V|BVars]) :- | ||
|  | 	get_var_has_same_key(KVs,K,V,KVs0), | ||
|  | 	merge_vars(KVs0,BVars). | ||
|  | 	 | ||
|  | get_var_has_same_key([K-V|KVs],K,V,KVs0)  :- !, | ||
|  | 	get_var_has_same_key(KVs,K,V,KVs0). | ||
|  | get_var_has_same_key(KVs,_,_,KVs). | ||
|  | 
 | ||
|  | 
 | ||
|  | mk_sample(AllVars,NVars, LL) :- | ||
|  | 	add2sample(AllVars, LN), | ||
|  | 	length(LN,LL), | ||
|  | 	matlab_initialized_cells( NVars, 1, LN, sample). | ||
|  | 
 | ||
|  | add2sample([],  []). | ||
|  | add2sample([V|Vs],[val(VId,1,Val)|Vals]) :- | ||
|  | 	clpbn:get_atts(V, [evidence(Ev),dist(Id,_)]), !, | ||
|  | 	bnt:get_atts(V,[bnt_id(VId)]), | ||
|  | 	get_dist_domain(Id, Domain), | ||
|  | 	evidence_val(Ev,1,Domain,Val), | ||
|  | 	add2sample(Vs, Vals). | ||
|  | add2sample([_V|Vs],Vals) :- | ||
|  | 	add2sample(Vs, Vals). | ||
|  | 
 | ||
|  | evidence_val(Ev,Val,[Ev|_],Val) :- !. | ||
|  | evidence_val(Ev,I0,[_|Domain],Val) :- | ||
|  | 	I1 is I0+1, | ||
|  | 	evidence_val(Ev,I1,Domain,Val). | ||
|  | 
 | ||
|  | bnt_learn_parameters(_,_) :- | ||
|  | 	engine <-- jtree_inf_engine(bnet), | ||
|  | %	engine <-- var_elim_inf_engine(bnet), | ||
|  | %	engine <-- gibbs_sampling_inf_engine(bnet), | ||
|  | %	engine <-- belprop_inf_engine(bnet), | ||
|  | %	engine <-- pearl_inf_engine(bnet), | ||
|  | 	bnt_em_max_iter(MaxIters), | ||
|  | 	[new_bnet, trace] <-- learn_params_em(engine, sample, MaxIters). | ||
|  | 
 | ||
|  | 
 | ||
|  | get_parameters([],[]). | ||
|  | get_parameters([Rep-v(_,_,_)|Reps],[CPT|CPTs]) :- | ||
|  | 	get_new_table(Rep,CPT), | ||
|  | 	get_parameters(Reps,CPTs). | ||
|  | 	 | ||
|  | get_new_table(Rep,CPT) :- | ||
|  | 	s <-- struct(new_bnet.'CPD'({Rep})), | ||
|  | 	matlab_get_variable( s.'CPT', CPT). | ||
|  | 	 | ||
|  | 	 |