| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | % | 
					
						
							|  |  |  | % The world famous EM algorithm, in a nutshell | 
					
						
							|  |  |  | % | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | :- module(clpbn_em, [em/5]). | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | :- reexport(library(clpbn), | 
					
						
							| 
									
										
										
										
											2012-12-17 14:50:12 +00:00
										 |  |  | 		[clpbn_flag/2, | 
					
						
							|  |  |  | 		 clpbn_flag/3 | 
					
						
							|  |  |  | 		]). | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | :- use_module(library(clpbn), | 
					
						
							| 
									
										
										
										
											2012-12-17 14:50:12 +00:00
										 |  |  | 		[clpbn_init_graph/1, | 
					
						
							|  |  |  | 		 clpbn_init_solver/4, | 
					
						
							|  |  |  | 		 clpbn_run_solver/3, | 
					
						
							|  |  |  | 		 pfl_init_solver/5, | 
					
						
							|  |  |  | 		 pfl_run_solver/3, | 
					
						
							| 
									
										
										
										
											2012-12-17 23:59:52 +00:00
										 |  |  | 		 pfl_end_solver/1, | 
					
						
							| 
									
										
										
										
											2012-12-17 14:50:12 +00:00
										 |  |  | 		 conditional_probability/3, | 
					
						
							|  |  |  | 		 clpbn_flag/2 | 
					
						
							|  |  |  | 		]). | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | :- use_module(library('clpbn/dists'), | 
					
						
							| 
									
										
										
										
											2012-12-17 14:50:12 +00:00
										 |  |  | 		[get_dist_domain_size/2, | 
					
						
							|  |  |  | 		 empty_dist/2, | 
					
						
							|  |  |  | 		 dist_new_table/2, | 
					
						
							|  |  |  | 		 get_dist_key/2, | 
					
						
							|  |  |  | 		 randomise_all_dists/0, | 
					
						
							|  |  |  | 		 uniformise_all_dists/0 | 
					
						
							|  |  |  | 		]). | 
					
						
							| 
									
										
										
										
											2012-12-17 11:53:57 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | :- use_module(library('clpbn/ground_factors'), | 
					
						
							| 
									
										
										
										
											2012-12-17 14:50:12 +00:00
										 |  |  | 		[generate_network/5, | 
					
						
							|  |  |  | 		 f/3 | 
					
						
							|  |  |  | 		]). | 
					
						
							| 
									
										
										
										
											2012-12-20 23:19:10 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-12-17 11:53:57 +00:00
										 |  |  | :- use_module(library('clpbn/utils'), | 
					
						
							| 
									
										
										
										
											2012-12-17 14:50:12 +00:00
										 |  |  | 		[check_for_hidden_vars/3, | 
					
						
							|  |  |  | 		 sort_vars_by_key/3 | 
					
						
							|  |  |  | 		]). | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | :- use_module(library('clpbn/learning/learn_utils'), | 
					
						
							| 
									
										
										
										
											2012-12-17 14:50:12 +00:00
										 |  |  | 		[run_all/1, | 
					
						
							|  |  |  | 		 clpbn_vars/2, | 
					
						
							|  |  |  | 		 normalise_counts/2, | 
					
						
							|  |  |  | 		 compute_likelihood/3, | 
					
						
							|  |  |  | 		 soften_sample/2 | 
					
						
							|  |  |  | 		]). | 
					
						
							| 
									
										
										
										
											2012-12-20 23:19:10 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-12-17 11:53:57 +00:00
										 |  |  | :- use_module(library(bhash), | 
					
						
							| 
									
										
										
										
											2012-12-17 14:50:12 +00:00
										 |  |  | 		[b_hash_new/1, | 
					
						
							|  |  |  | 		 b_hash_lookup/3, | 
					
						
							|  |  |  | 		 b_hash_insert/4 | 
					
						
							|  |  |  | 		]). | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-12-17 11:53:57 +00:00
										 |  |  | :- use_module(library(matrix), | 
					
						
							| 
									
										
										
										
											2012-12-17 14:50:12 +00:00
										 |  |  | 		[matrix_add/3, | 
					
						
							|  |  |  | 		 matrix_to_list/2 | 
					
						
							|  |  |  | 		]). | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | :- use_module(library(lists), | 
					
						
							| 
									
										
										
										
											2012-12-17 14:50:12 +00:00
										 |  |  | 		[member/2]). | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-12-17 11:53:57 +00:00
										 |  |  | :- use_module(library(rbtrees), | 
					
						
							| 
									
										
										
										
											2012-12-17 14:50:12 +00:00
										 |  |  | 		[rb_new/1, | 
					
						
							|  |  |  | 		 rb_insert/4, | 
					
						
							|  |  |  | 		 rb_lookup/3 | 
					
						
							|  |  |  | 		]). | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | :- use_module(library(maplist)). | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | :- meta_predicate em(:,+,+,-,-), init_em(:,-). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | em(Items, MaxError, MaxIts, Tables, Likelihood) :- | 
					
						
							|  |  |  | 	catch(init_em(Items, State),Error,handle_em(Error)), | 
					
						
							|  |  |  | 	em_loop(0, 0.0, State, MaxError, MaxIts, Likelihood, Tables), | 
					
						
							| 
									
										
										
										
											2012-12-17 23:59:52 +00:00
										 |  |  | 	end_em(State), | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	assert(em_found(Tables, Likelihood)), | 
					
						
							|  |  |  | 	fail. | 
					
						
							|  |  |  | % get rid of new random variables the easy way :) | 
					
						
							|  |  |  | em(_, _, _, Tables, Likelihood) :- | 
					
						
							|  |  |  | 	retract(em_found(Tables, Likelihood)). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | handle_em(error(repeated_parents)) :- !, | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	assert(em_found(_, -inf)), | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	fail. | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | handle_em(Error) :- | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	throw(Error). | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-12-17 23:59:52 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | end_em(state(_AllDists, _AllDistInstances, _MargKeys, SolverState)) :- | 
					
						
							|  |  |  | 	clpbn:use_parfactors(on), !, | 
					
						
							|  |  |  | 	pfl_end_solver(SolverState). | 
					
						
							|  |  |  | end_em(_). | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | % This gets you an initial configuration. If there is a lot of evidence | 
					
						
							|  |  |  | % tables may be filled in close to optimal, otherwise they may be | 
					
						
							|  |  |  | % close to uniform. | 
					
						
							|  |  |  | % it also gets you a run for random variables | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | % state collects all Info we need for the EM algorithm | 
					
						
							|  |  |  | % it includes the list of variables without evidence, | 
					
						
							|  |  |  | % the list of distributions for which we want to compute parameters, | 
					
						
							|  |  |  | % and more detailed info on distributions, namely with a list of all instances for the distribution. | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | init_em(Items, State) :- | 
					
						
							| 
									
										
										
										
											2009-05-26 10:49:04 -05:00
										 |  |  | 	clpbn_flag(em_solver, Solver), | 
					
						
							| 
									
										
										
										
											2012-06-22 19:00:12 +01:00
										 |  |  | 	% only used for PCGs | 
					
						
							| 
									
										
										
										
											2009-05-26 10:49:04 -05:00
										 |  |  | 	clpbn_init_graph(Solver), | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | %	randomise_all_dists, | 
					
						
							| 
									
										
										
										
											2012-06-22 19:00:12 +01:00
										 |  |  | 	% set initial values for distributions | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	uniformise_all_dists, | 
					
						
							| 
									
										
										
										
											2012-12-17 11:53:57 +00:00
										 |  |  | 	setup_em_network(Items, State). | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-12-17 11:53:57 +00:00
										 |  |  | setup_em_network(Items, state(AllDists, AllDistInstances, MargKeys, SolverState)) :- | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | 	clpbn:use_parfactors(on), !, | 
					
						
							|  |  |  | 	% get all variables to marginalise | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | 	run_examples(Items, Keys, Factors, EList), | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | 	% get the EM CPT connections info from the factors | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | 	generate_dists(Factors, EList, AllDists, AllDistInstances, MargKeys), | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | 	% setup solver, if necessary | 
					
						
							| 
									
										
										
										
											2012-12-15 16:11:03 +00:00
										 |  |  | 	pfl_init_solver(MargKeys, Keys, Factors, EList, SolverState). | 
					
						
							| 
									
										
										
										
											2012-12-17 11:53:57 +00:00
										 |  |  | setup_em_network(Items, state(AllDists, AllDistInstances, MargVars, SolverState)) :- | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | 	% create the ground network | 
					
						
							|  |  |  | 	call_run_all(Items), | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | 	% get all variables to marginalise | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	attributes:all_attvars(AllVars0), | 
					
						
							| 
									
										
										
										
											2012-06-22 19:00:12 +01:00
										 |  |  | 	% and order them | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	sort_vars_by_key(AllVars0,AllVars,[]), | 
					
						
							|  |  |  | 	% remove variables that do not have to do with this query. | 
					
						
							|  |  |  | 	different_dists(AllVars, AllDists, AllDistInstances, MargVars), | 
					
						
							| 
									
										
										
										
											2012-06-22 19:00:12 +01:00
										 |  |  | 	% setup solver by doing parameter independent work. | 
					
						
							| 
									
										
										
										
											2012-12-17 11:53:57 +00:00
										 |  |  | 	clpbn_init_solver(MargVars, AllVars, _, SolverState). | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | run_examples(user:Exs, Keys, Factors, EList) :- | 
					
						
							| 
									
										
										
										
											2013-01-10 13:45:24 +00:00
										 |  |  | 	Exs = [[_|_]|_], !, | 
					
						
							|  |  |  | 	foldl(add_key, Exs, KExs, 1, _), | 
					
						
							|  |  |  | 	findall(ex(EKs, EFs, EEs), run_example(KExs, EKs, EFs, EEs), VExs), | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	foldl4(join_example, VExs, [], Keys, [], Factors, [], EList, 0, _). | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | run_examples(Items, Keys, Factors, EList) :- | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	run_ex(Items, Keys, Factors, EList). | 
					
						
							| 
									
										
										
										
											2013-01-09 18:34:19 +00:00
										 |  |  |   | 
					
						
							| 
									
										
										
										
											2013-01-10 13:45:24 +00:00
										 |  |  | add_key(Ex, I:Ex, I, I1) :- | 
					
						
							|  |  |  | 	I1 is I+1. | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | join_example( ex(EKs, EFs, EEs), Keys0, Keys, Factors0, Factors, EList0, EList, I0, I) :- | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	I is I0+1, | 
					
						
							|  |  |  | 	foldl(process_key(I0), EKs, Keys0, Keys), | 
					
						
							|  |  |  | 	foldl(process_factor(I0), EFs, Factors0, Factors), | 
					
						
							|  |  |  | 	foldl(process_ev(I0), EEs, EList0, EList). | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | process_key(I0, K, Keys0, [I0:K|Keys0]). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | process_factor(I0, f(Type, Id, Keys), Keys0, [f(Type, Id, NKeys)|Keys0]) :- | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	maplist(update_key(I0), Keys, NKeys). | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | update_key(I0, K, I0:K). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | process_ev(I0, K=V, Es0, [(I0:K)=V|Es0]). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | run_example([_:Items|_], Keys, Factors, EList) :- | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	run_ex(user:Items, Keys, Factors, EList). | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | run_example([_|LItems], Keys, Factors, EList) :- | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	run_example(LItems, Keys, Factors, EList). | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | run_ex(Items, Keys, Factors, EList) :- | 
					
						
							|  |  |  | 	% create the ground network | 
					
						
							|  |  |  | 	call_run_all(Items), | 
					
						
							|  |  |  | 	attributes:all_attvars(AllVars0), | 
					
						
							|  |  |  | 	% and order them | 
					
						
							|  |  |  | 	sort_vars_by_key(AllVars0,AllVars,[]), | 
					
						
							|  |  |  | 	% no, we are in trouble because we don't know the network yet. | 
					
						
							|  |  |  | 	% get the ground network | 
					
						
							|  |  |  | 	generate_network(AllVars, _, Keys, Factors, EList). | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | % loop for as long as you want. | 
					
						
							|  |  |  | em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF, FTables) :- | 
					
						
							|  |  |  | 	estimate(State, LPs), | 
					
						
							|  |  |  | 	maximise(State, Tables, LPs, Likelihood), | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | 	ltables(Tables, F0Tables), | 
					
						
							| 
									
										
										
										
											2012-12-15 00:41:00 +00:00
										 |  |  | 	%writeln(iteration:Its:Likelihood:Its:Likelihood0:F0Tables), | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	( | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	  ( | 
					
						
							|  |  |  | 	    abs((Likelihood - Likelihood0)/Likelihood) < MaxError | 
					
						
							|  |  |  | 	  ; | 
					
						
							|  |  |  | 	    Its == MaxIts | 
					
						
							|  |  |  | 	  ) | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	-> | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	  ltables(Tables, FTables), | 
					
						
							|  |  |  | 	  LikelihoodF = Likelihood | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	; | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	  Its1 is Its+1, | 
					
						
							|  |  |  | 	  em_loop(Its1, Likelihood, State, MaxError, MaxIts, LikelihoodF, FTables) | 
					
						
							| 
									
										
										
										
											2013-01-10 13:45:24 +00:00
										 |  |  | 	). | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | ltables([], []). | 
					
						
							|  |  |  | ltables([Id-T|Tables], [Key-LTable|FTables]) :- | 
					
						
							|  |  |  | 	matrix_to_list(T,LTable), | 
					
						
							|  |  |  | 	get_dist_key(Id, Key), | 
					
						
							|  |  |  | 	ltables(Tables, FTables). | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | generate_dists(Factors, EList, AllDists, AllInfo, MargVars) :- | 
					
						
							| 
									
										
										
										
											2012-12-20 23:19:10 +00:00
										 |  |  | 	b_hash_new(Ev0), | 
					
						
							| 
									
										
										
										
											2012-12-17 11:56:15 +00:00
										 |  |  | 	foldl(elist_to_hash, EList, Ev0, Ev), | 
					
						
							|  |  |  | 	maplist(process_factor(Ev), Factors, Dists0), | 
					
						
							|  |  |  | 	sort(Dists0, Dists1), | 
					
						
							|  |  |  | 	group(Dists1, AllDists, AllInfo, MargVars0, []), | 
					
						
							|  |  |  | 	sort(MargVars0, MargVars). | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | elist_to_hash(K=V, Ev0, Ev) :- | 
					
						
							|  |  |  | 	b_hash_insert(Ev0, K, V, Ev). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | process_factor(Ev, f(bayes,Id,Ks), i(Id, Ks, Cases, NonEvs)) :- | 
					
						
							|  |  |  | 	foldl( fetch_evidence(Ev), Ks, CompactCases, [], NonEvs), | 
					
						
							|  |  |  | 	uncompact_cases(CompactCases, Cases). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | fetch_evidence(Ev, K, E, NonEvs, NonEvs) :- | 
					
						
							|  |  |  | 	b_hash_lookup(K, E, Ev), !. | 
					
						
							|  |  |  | fetch_evidence(_Ev, K, Ns, NonEvs, [K|NonEvs]) :- | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | 	pfl:skolem(K,D), | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | 	foldl(domain_to_number, D, Ns, 0, _). | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | domain_to_number(_, I0, I0, I) :- | 
					
						
							|  |  |  | 	I is I0+1. | 
					
						
							| 
									
										
										
										
											2012-12-20 23:19:10 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | % collect the different dists we are going to learn next. | 
					
						
							|  |  |  | different_dists(AllVars, AllDists, AllInfo, MargVars) :- | 
					
						
							| 
									
										
										
										
											2012-06-22 19:00:12 +01:00
										 |  |  | 	all_dists(AllVars, AllVars, Dists0), | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	sort(Dists0, Dists1), | 
					
						
							|  |  |  | 	group(Dists1, AllDists, AllInfo, MargVars0, []), | 
					
						
							|  |  |  | 	sort(MargVars0, MargVars). | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-06-22 19:00:12 +01:00
										 |  |  | % | 
					
						
							|  |  |  | % V -> to Id defining V. We get: | 
					
						
							|  |  |  | % the random variables that are parents | 
					
						
							| 
									
										
										
										
											2012-12-20 23:19:10 +00:00
										 |  |  | % the cases that can happen, eg if we have A <- B, C | 
					
						
							| 
									
										
										
										
											2012-06-22 19:00:12 +01:00
										 |  |  | % A and B are boolean w/o evidence, and C is f, the cases could be | 
					
						
							| 
									
										
										
										
											2012-12-20 23:19:10 +00:00
										 |  |  | % [0,0,1], [0,1,1], [1,0,0], [1,1,0], | 
					
						
							| 
									
										
										
										
											2012-06-22 19:00:12 +01:00
										 |  |  | % Hiddens will be C | 
					
						
							|  |  |  | % | 
					
						
							|  |  |  | all_dists([], _, []). | 
					
						
							|  |  |  | all_dists([V|AllVars], AllVars0, [i(Id, [V|Parents], Cases, Hiddens)|Dists]) :- | 
					
						
							|  |  |  | 	% V is an instance of Id | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	clpbn:get_atts(V, [dist(Id,Parents)]), | 
					
						
							| 
									
										
										
										
											2012-12-17 11:53:57 +00:00
										 |  |  | 	sort([V|Parents], Sorted), | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	length(Sorted, LengSorted), | 
					
						
							| 
									
										
										
										
											2012-12-17 11:53:57 +00:00
										 |  |  | 	length(Parents, LengParents), | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	( | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	  LengParents+1 =:= LengSorted | 
					
						
							|  |  |  | 	-> | 
					
						
							|  |  |  | 	  true | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	; | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	  throw(error(repeated_parents)) | 
					
						
							| 
									
										
										
										
											2012-06-22 19:00:12 +01:00
										 |  |  | 	), | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	generate_hidden_cases([V|Parents], CompactCases, Hiddens), | 
					
						
							|  |  |  | 	uncompact_cases(CompactCases, Cases), | 
					
						
							| 
									
										
										
										
											2012-06-22 19:00:12 +01:00
										 |  |  | 	all_dists(AllVars, AllVars0, Dists). | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | generate_hidden_cases([], [], []). | 
					
						
							|  |  |  | generate_hidden_cases([V|Parents], [P|Cases], Hiddens) :- | 
					
						
							|  |  |  | 	clpbn:get_atts(V, [evidence(P)]), !, | 
					
						
							|  |  |  | 	generate_hidden_cases(Parents, Cases, Hiddens). | 
					
						
							|  |  |  | generate_hidden_cases([V|Parents], [Cases|MoreCases], [V|Hiddens]) :- | 
					
						
							|  |  |  | 	clpbn:get_atts(V, [dist(Id,_)]), | 
					
						
							|  |  |  | 	get_dist_domain_size(Id, Sz), | 
					
						
							|  |  |  | 	gen_cases(0, Sz, Cases), | 
					
						
							|  |  |  | 	generate_hidden_cases(Parents, MoreCases, Hiddens). | 
					
						
							| 
									
										
										
										
											2012-08-15 16:01:45 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | gen_cases(Sz, Sz, []) :- !. | 
					
						
							|  |  |  | gen_cases(I, Sz, [I|Cases]) :- | 
					
						
							|  |  |  | 	I1 is I+1, | 
					
						
							|  |  |  | 	gen_cases(I1, Sz, Cases). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | uncompact_cases(CompactCases, Cases) :- | 
					
						
							|  |  |  | 	findall(Case, is_case(CompactCases, Case), Cases). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | is_case([], []). | 
					
						
							|  |  |  | is_case([A|CompactCases], [A|Case]) :- | 
					
						
							|  |  |  | 	integer(A), !, | 
					
						
							|  |  |  | 	is_case(CompactCases, Case). | 
					
						
							|  |  |  | is_case([L|CompactCases], [C|Case]) :- | 
					
						
							|  |  |  | 	member(C, L), | 
					
						
							|  |  |  | 	is_case(CompactCases, Case). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | group([], [], []) --> []. | 
					
						
							|  |  |  | group([i(Id,Ps,Cs,[])|Dists1], [Id|Ids], [Id-[i(Id,Ps,Cs,[])|Extra]|AllInfo]) --> !, | 
					
						
							|  |  |  | 	same_id(Dists1, Id, Extra, Rest), | 
					
						
							|  |  |  | 	group(Rest, Ids, AllInfo). | 
					
						
							|  |  |  | group([i(Id,Ps,Cs,Hs)|Dists1], [Id|Ids], [Id-[i(Id,Ps,Cs,Hs)|Extra]|AllInfo]) --> | 
					
						
							|  |  |  | 	[Hs], | 
					
						
							|  |  |  | 	same_id(Dists1, Id, Extra, Rest), | 
					
						
							|  |  |  | 	group(Rest, Ids, AllInfo). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | same_id([i(Id,Vs,Cases,[])|Dists1], Id, [i(Id, Vs, Cases, [])|Extra], Rest) --> !, | 
					
						
							|  |  |  | 	same_id(Dists1, Id, Extra, Rest). | 
					
						
							|  |  |  | same_id([i(Id,Vs,Cases,Hs)|Dists1], Id, [i(Id, Vs, Cases, Hs)|Extra], Rest) --> !, | 
					
						
							|  |  |  | 	[Hs], | 
					
						
							|  |  |  | 	same_id(Dists1, Id, Extra, Rest). | 
					
						
							|  |  |  | same_id(Dists, _, [], Dists) --> []. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | compact_mvars([], []). | 
					
						
							|  |  |  | compact_mvars([X1,X2|MargVars], CMVars) :- X1 == X2, !, | 
					
						
							|  |  |  | 	compact_mvars([X2|MargVars], CMVars). | 
					
						
							|  |  |  | compact_mvars([X|MargVars], [X|CMVars]) :- !, | 
					
						
							|  |  |  | 	compact_mvars(MargVars, CMVars). | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | estimate(state(_, _, Margs, SolverState), LPs) :- | 
					
						
							|  |  |  | 	clpbn:use_parfactors(on), !, | 
					
						
							| 
									
										
										
										
											2012-12-15 16:11:03 +00:00
										 |  |  | 	pfl_run_solver(Margs, LPs, SolverState). | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | estimate(state(_, _, Margs, SolverState), LPs) :- | 
					
						
							| 
									
										
										
										
											2012-12-15 16:11:03 +00:00
										 |  |  | 	clpbn_run_solver(Margs, LPs, SolverState). | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | maximise(state(_,DistInstances,MargVars,_), Tables, LPs, Likelihood) :- | 
					
						
							|  |  |  | 	rb_new(MDistTable0), | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | 	foldl(create_mdist_table, MargVars, LPs, MDistTable0, MDistTable), | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	compute_parameters(DistInstances, Tables, MDistTable, 0.0, Likelihood, LPs:MargVars). | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-09-29 11:50:00 +01:00
										 |  |  | create_mdist_table(Vs, Ps, MDistTable0, MDistTable) :- | 
					
						
							|  |  |  | 	rb_insert(MDistTable0, Vs, Ps, MDistTable). | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | compute_parameters([], [], _, Lik, Lik, _). | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | compute_parameters([Id-Samples|Dists], [Id-NewTable|Tables], MDistTable, Lik0, Lik, LPs:MargVars) :- | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	empty_dist(Id, Table0), | 
					
						
							|  |  |  | 	add_samples(Samples, Table0, MDistTable), | 
					
						
							| 
									
										
										
										
											2011-11-30 13:04:13 +00:00
										 |  |  | %matrix_to_list(Table0,Mat), lists:sumlist(Mat, Sum), format(user_error, 'FINAL ~d ~w ~w~n', [Id,Sum,Mat]), | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	soften_sample(Table0, SoftenedTable), | 
					
						
							| 
									
										
										
										
											2011-05-20 23:56:12 +01:00
										 |  |  | %	matrix:matrix_sum(Table0,TotM), | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	normalise_counts(SoftenedTable, NewTable), | 
					
						
							|  |  |  | 	compute_likelihood(Table0, NewTable, DeltaLik), | 
					
						
							|  |  |  | 	dist_new_table(Id, NewTable), | 
					
						
							|  |  |  | 	NewLik is Lik0+DeltaLik, | 
					
						
							| 
									
										
										
										
											2012-12-17 17:57:00 +00:00
										 |  |  | 	compute_parameters(Dists, Tables, MDistTable, NewLik, Lik, LPs:MargVars). | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | add_samples([], _, _). | 
					
						
							|  |  |  | add_samples([i(_,_,[Case],[])|Samples], Table, MDistTable) :- !, | 
					
						
							|  |  |  | 	matrix_add(Table,Case,1.0), | 
					
						
							|  |  |  | 	add_samples(Samples, Table, MDistTable). | 
					
						
							|  |  |  | add_samples([i(_,_,Cases,Hiddens)|Samples], Table, MDistTable) :- | 
					
						
							|  |  |  | 	rb_lookup(Hiddens, Ps, MDistTable), | 
					
						
							|  |  |  | 	run_sample(Cases, Ps, Table), | 
					
						
							| 
									
										
										
										
											2011-09-24 21:39:37 +01:00
										 |  |  | %matrix_to_list(Table,M), format(user_error, '~w ~w~n', [Cases,Ps]), | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 	add_samples(Samples, Table, MDistTable). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | run_sample([], [], _). | 
					
						
							|  |  |  | run_sample([C|Cases], [P|Ps], Table) :- | 
					
						
							|  |  |  | 	matrix_add(Table, C, P), | 
					
						
							|  |  |  | 	run_sample(Cases, Ps, Table). | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-05-26 10:49:04 -05:00
										 |  |  | call_run_all(Mod:Items) :- | 
					
						
							| 
									
										
										
										
											2009-10-21 00:05:23 +01:00
										 |  |  | 	clpbn_flag(em_solver, pcg), !, | 
					
						
							| 
									
										
										
										
											2009-05-26 10:49:04 -05:00
										 |  |  | 	backtrack_run_all(Items, Mod). | 
					
						
							| 
									
										
										
										
											2009-10-21 00:05:23 +01:00
										 |  |  | call_run_all(Mod:Items) :- | 
					
						
							|  |  |  | 	run_all(Mod:Items). | 
					
						
							| 
									
										
										
										
											2009-02-16 12:23:29 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-05-26 10:49:04 -05:00
										 |  |  | backtrack_run_all([Item|_], Mod) :- | 
					
						
							|  |  |  | 	call(Mod:Item), | 
					
						
							|  |  |  | 	fail. | 
					
						
							|  |  |  | backtrack_run_all([_|Items], Mod) :- | 
					
						
							|  |  |  | 	backtrack_run_all(Items, Mod). | 
					
						
							|  |  |  | backtrack_run_all([], _). | 
					
						
							| 
									
										
										
										
											2011-11-30 13:04:13 +00:00
										 |  |  | 
 |