%
% each variable is represented by a node in a binary tree.
% each node contains:
% key,
% current_value
% Markov Blanket
%

:- module(clpbn_gibbs,
		[gibbs/3,
		 check_if_gibbs_done/1,
		 init_gibbs_solver/4,
		 run_gibbs_solver/3
		]).

:- use_module(library(rbtrees),
		[rb_new/1,
		 rb_insert/4,
		 rb_lookup/3
		]).

:- use_module(library(lists),
		[member/2,
		 append/3,
		 delete/3,
		 max_list/2,
		 sum_list/2
		]).

:- use_module(library(maplist)).

:- use_module(library(ordsets),
		[ord_subtract/3]).

:- use_module(library('clpbn/matrix_cpt_utils'),
		[project_from_CPT/3,
		 reorder_CPT/5,
		 multiply_possibly_deterministic_factors/3,
		 column_from_possibly_deterministic_CPT/3,
		 normalise_possibly_deterministic_CPT/2,
		 list_from_CPT/2
		]).

:- use_module(library('clpbn/utils'),
		[check_for_hidden_vars/3]).

:- use_module(library('clpbn/dists'),
		[get_possibly_deterministic_dist_matrix/5,
		 get_dist_domain_size/2
		]).

:- use_module(library('clpbn/topsort'),
		[topsort/2]).

:- use_module(library('clpbn/display'),
		[clpbn_bind_vals/3]).

:- use_module(library('clpbn/connected'),
		[influences/3]).

:- dynamic gibbs_params/3.

:- dynamic explicit/1.

% arguments:
%
% list of output variables
% list of attributed variables
%
gibbs(LVs,Vs0,AllDiffs) :-
	init_gibbs_solver(LVs, Vs0, AllDiffs, Vs),
	run_gibbs_solver(LVs, LPs, Vs),
	clpbn_bind_vals(LVs,LPs,AllDiffs),
	clean_up.

init_gibbs_solver(GoalVs, Vs0, _, Vs) :-
	clean_up,
	term_variables(GoalVs, LVs),
	check_for_hidden_vars(Vs0, Vs0, Vs1),
	influences(Vs1, LVs, Vs2),
	sort(Vs2,Vs).

run_gibbs_solver(LVs, LPs, Vs) :-
	initialise(Vs, Graph, LVs, OutputVars, VarOrder),
	process(VarOrder, Graph, OutputVars, Estimates),
	sum_up_all(Estimates, LPs),
	clean_up.

initialise(LVs, Graph, GVs, OutputVars, VarOrder) :-
	init_keys(Keys0),
	foldl2(gen_key, LVs, 0, VLen, Keys0, Keys),
	functor(Graph,graph,VLen),
	graph_representation(LVs, Graph, 0, Keys, TGraph),
	compile_graph(Graph),
	topsort(TGraph, VarOrder),
%writeln(TGraph:VarOrder),
%	show_sorted(VarOrder, Graph),
	add_all_output_vars(GVs, Keys, OutputVars).

init_keys(Keys0) :-
	rb_new(Keys0).

gen_key(V, I0, I0, Keys0, Keys0) :-
	clpbn:get_atts(V,[evidence(_)]), !.
gen_key(V, I0, I, Keys0, Keys) :-
	I is I0+1,
	rb_insert(Keys0,V,I,Keys).

graph_representation([],_,_,_,[]).
graph_representation([V|Vs], Graph, I0, Keys, TGraph) :-
	clpbn:get_atts(V,[evidence(_)]), !,
	clpbn:get_atts(V, [dist(Id,Parents)]),
	get_possibly_deterministic_dist_matrix(Id, Parents, _, Vals, Table),
	maplist(get_size, Parents, Szs),
	length(Vals,Sz),
	project_evidence_out([V|Parents],[V|Parents],Table,[Sz|Szs],Variables,NewTable),
	% all variables are parents
	propagate2parents(Variables, NewTable, Variables, Graph, Keys),
	graph_representation(Vs, Graph, I0, Keys, TGraph).
graph_representation([V|Vs], Graph, I0, Keys, [I-IParents|TGraph]) :-
	I is I0+1,
	clpbn:get_atts(V, [dist(Id,Parents)]),
	get_possibly_deterministic_dist_matrix(Id, Parents, _, Vals, Table),
	maplist( get_size, Parents, Szs),
	length(Vals,Sz),
	project_evidence_out([V|Parents],[V|Parents],Table,[Sz|Szs],Variables,NewTable),
	Variables = [V|NewParents],
	sort_according_to_indices(NewParents,Keys,SortedNVs,SortedIndices),
	reorder_CPT(Variables,NewTable,[V|SortedNVs],NewTable2,_),
	add2graph(V, Vals, NewTable2, SortedIndices, Graph, Keys),
	propagate2parents(NewParents, NewTable, Variables, Graph,Keys),
	maplist(parent_index(Keys), NewParents, IVariables0),
	sort(IVariables0, IParents),
	arg(I, Graph, var(_,_,_,_,_,_,_,NewTable2,SortedIndices)),
	graph_representation(Vs, Graph, I, Keys, TGraph).

write_pars([]).
write_pars([V|Parents]) :-
	clpbn:get_atts(V, [key(K),dist(I,_)]),write(K:I),nl,
	write_pars(Parents).

get_size(V, Sz) :-
	clpbn:get_atts(V, [dist(Id,_)]),
	get_dist_domain_size(Id, Sz).

parent_index(Keys, V, I) :-
	rb_lookup(V, I, Keys).

%
% first, remove nodes that have evidence from tables.
%
project_evidence_out([],Deps,Table,_,Deps,Table).
project_evidence_out([V|Parents],Deps,Table,Szs,NewDeps,NewTable) :-
	clpbn:get_atts(V,[evidence(_)]), !,
	project_from_CPT(V,tab(Table,Deps,Szs),tab(ITable,IDeps,ISzs)),
	project_evidence_out(Parents,IDeps,ITable,ISzs,NewDeps,NewTable).
project_evidence_out([_Par|Parents],Deps,Table,Szs,NewDeps,NewTable) :-
	project_evidence_out(Parents,Deps,Table,Szs,NewDeps,NewTable).

propagate2parents([], _, _, _, _).
propagate2parents([V|NewParents], Table, Variables, Graph, Keys) :-
	delete(Variables,V,NVs),
	sort_according_to_indices(NVs,Keys,SortedNVs,SortedIndices),
	reorder_CPT(Variables,Table,[V|SortedNVs],NewTable,_),
	add2graph(V, _, NewTable, SortedIndices, Graph, Keys),
	propagate2parents(NewParents,Table, Variables, Graph, Keys).

add2graph(V, Vals, Table, IParents, Graph, Keys) :-
	rb_lookup(V, Index, Keys),
	(var(Vals) -> true ; length(Vals,Sz)),
	arg(Index, Graph, var(V,Index,_,Vals,Sz,VarSlot,_,_,_)),
	member(tabular(Table,Index,IParents), VarSlot), !.

sort_according_to_indices(NVs,Keys,SortedNVs,SortedIndices) :-
	maplist(var2index(Keys), NVs, ToSort),
	keysort(ToSort, Sorted),
	maplist(split_parent, Sorted, SortedNVs,SortedIndices).

split_parent(I-V, V, I).

var2index(Keys, V, I-V) :-
	rb_lookup(V, I, Keys).

%
% This is the really cool bit.
%
compile_graph(Graph) :-
	Graph =.. [_|VarsInfo],
	maplist( compile_var(Graph), VarsInfo).

compile_var(Graph, var(_,I,_,Vals,Sz,VarSlot,Parents,_,_)) :-
	foldl2( fetch_parent(Graph), VarSlot, [], Parents, [], Sizes),
	foldl( mult_list, Sizes,1,TotSize),
	compile_var(TotSize,I,Vals,Sz,VarSlot,Parents,Sizes,Graph).

fetch_parent(Graph, tabular(_,_,Ps), Parents0, ParentsF, Sizes0, SizesF) :-
	foldl2( merge_these_parents(Graph), Ps, Parents0, ParentsF, Sizes0, SizesF).

merge_these_parents(_Graph, I,Parents0,Parents0,Sizes0,Sizes0) :-
	member(I,Parents0), !.
merge_these_parents(Graph, I, Parents0,ParentsF,Sizes0,SizesF) :-
	arg(I,Graph,var(_,I,_,Vals,_,_,_,_,_)),
	length(Vals, Sz),
	add_parent(Parents0,I,ParentsF,Sizes0,Sz,SizesF).

add_parent([],I,[I],[],Sz,[Sz]).
add_parent([P|Parents0],I,[I,P|Parents0],Sizes0,Sz,[Sz|Sizes0]) :-
	P > I, !.
add_parent([P|Parents0],I,[P|ParentsI],[S|Sizes0],Sz,[S|SizesI]) :-
	add_parent(Parents0,I,ParentsI,Sizes0,Sz,SizesI).


mult_list(Sz,Mult0,Mult) :-
	Mult is Sz*Mult0.

% compile node as set of facts, faster execution
compile_var(TotSize,I,_Vals,Sz,CPTs,Parents,_Sizes,Graph) :-
	TotSize < 1024*64, TotSize > 0, !,
	multiply_all(I,Parents,CPTs,Sz,Graph).
% do it dynamically
compile_var(_,_,_,_,_,_,_,_).

multiply_all(I,Parents,CPTs,Sz,Graph) :-
	maplist( markov_blanket_instance(Graph), Parents, Values),
	(
	  multiply_all(CPTs,Graph,Probs)
	->
	  store_mblanket(I,Values,Probs)
	;
	  throw(error(domain_error(bayesian_domain),gibbs_cpt(I,Parents,Values,Sz)))
	),
	fail.
multiply_all(I,_,_,_,_) :-
	assert(explicit(I)).

% note: what matters is how this predicate instantiates the temp
% slot in the graph!
markov_blanket_instance(Graph, I, Pos) :-
	arg(I, Graph, var(_,I,Pos,Vals,_,_,_,_,_)),
	fetch_val(Vals, 0, Pos).

% backtrack through every value in domain
%
fetch_val([_|_],Pos,Pos).
fetch_val([_|Vals],I0,Pos) :-
	I is I0+1,
	fetch_val(Vals,I,Pos).

multiply_all([tabular(Table,_,Parents)|CPTs],Graph,Probs) :-
	maplist( fetch_parent(Graph), Parents, Vals),
	column_from_possibly_deterministic_CPT(Table,Vals,Probs0),
	multiply_more(CPTs,Graph,Probs0,Probs).

fetch_parent(Graph, P, Val) :-
	arg(P,Graph,var(_,_,Val,_,_,_,_,_,_)).

multiply_more([],_,Probs0,LProbs) :-
	normalise_possibly_deterministic_CPT(Probs0, Probs),
	list_from_CPT(Probs, LProbs0),
	accumulate_up_list(LProbs0, 0.0, LProbs).
multiply_more([tabular(Table,_,Parents)|CPTs],Graph,Probs0,Probs) :-
	maplist( fetch_parent(Graph), Parents, Vals),
	column_from_possibly_deterministic_CPT(Table, Vals, P0),
	multiply_possibly_deterministic_factors(Probs0, P0, ProbsI),
	multiply_more(CPTs,Graph,ProbsI,Probs).

accumulate_up_list([], _, []).
accumulate_up_list([P|LProbs], P0, [P1|L]) :-
	P1 is P0+P,
	accumulate_up_list(LProbs, P1, L).


store_mblanket(I,Values,Probs) :-
	recordz(mblanket,m(I,Values,Probs),_).

add_all_output_vars([], _, []).
add_all_output_vars([Vs|LVs], Keys, [Is|OutputVars]) :-
	add_output_vars(Vs, Keys, Is),
	add_all_output_vars(LVs, Keys, OutputVars).

add_output_vars([], _, []).
add_output_vars([V|LVs], Keys, [I|OutputVars]) :-
	rb_lookup(V, I, Keys),
	add_output_vars(LVs, Keys, OutputVars).

process(VarOrder, Graph, OutputVars, Estimates) :-
	gibbs_params(NChains,BurnIn,NSamples),
	functor(Graph,_,Len),
	init_chains(NChains,VarOrder,Len,Graph,Chains0),
	init_estimates(NChains,OutputVars,Graph,Est0),
	process_chains(BurnIn,VarOrder,BurnedIn,Chains0,Graph,Len,Est0,_),
	process_chains(NSamples,VarOrder,_,BurnedIn,Graph,Len,Est0,Estimates).

%
% I use an uniform distribution to generate the initial sample.
%
init_chains(0,_,_,_,[]) :- !.
init_chains(I,VarOrder,Len,Graph,[Chain|Chains]) :-
	init_chain(VarOrder,Len,Graph,Chain),
	I1 is I-1,
	init_chains(I1,VarOrder,Len,Graph,Chains).


init_chain(VarOrder,Len,Graph,Chain) :-
	functor(Chain,sample,Len),
	gen_sample(VarOrder,Graph,Chain).

gen_sample([],_,_) :- !.
gen_sample([I|Vs],Graph,Chain) :-
	arg(I,Graph,var(_,I,_,_,Sz,_,_,_,_)),
	Pos is integer(random*Sz),
	arg(I,Chain,Pos),
	gen_sample(Vs,Graph,Chain).


init_estimates(0,_,_,[]) :- !.
init_estimates(NChains,OutputVars,Graph,[Est|Est0]) :-
	NChainsI is NChains-1,
	init_estimate_all_outvs(OutputVars,Graph,Est),
	init_estimates(NChainsI,OutputVars,Graph,Est0).

init_estimate_all_outvs([],_,[]).
init_estimate_all_outvs([Vs|OutputVars],Graph,[E|Est]) :-
	init_estimate(Vs, Graph, E),
	init_estimate_all_outvs(OutputVars,Graph,Est).

init_estimate([],_,[]).
init_estimate([V],Graph,[I|E0L]) :- !,
	arg(V,Graph,var(_,I,_,_,Sz,_,_,_,_)),
	gen_e0(Sz,E0L).
init_estimate(Vs,Graph,me(Is,Mults,Es)) :-
	generate_est_mults(Vs, Is, Graph, Mults, Sz),
	gen_e0(Sz,Es).


generate_est_mults([], [], _, [], 1).
generate_est_mults([V|Vs], [I|Is], Graph, [M0|Mults], M) :-
	arg(V,Graph,var(_,I,_,_,Sz,_,_,_,_)),
	generate_est_mults(Vs, Is, Graph, Mults, M0),
	M is M0*Sz.

gen_e0(0,[]) :- !.
gen_e0(Sz,[0|E0L]) :-
	Sz1 is Sz-1,
	gen_e0(Sz1,E0L).

process_chains(0,_,F,F,_,_,Est,Est) :- !.
process_chains(ToDo,VarOrder,End,Start,Graph,Len,Est0,Estf) :-
%format('ToDo = ~d~n',[ToDo]),
	process_chains(Start,VarOrder,Int,Graph,Len,Est0,Esti),
% (ToDo mod 100 =:= 1 -> statistics,maplist(cvt2prob, Esti, Probs), Int =[S|_], format('did ~d: ~w~n ~w~n',[ToDo,Probs,S]) ; true),
	ToDo1 is ToDo-1,
	process_chains(ToDo1,VarOrder,End,Int,Graph,Len,Esti,Estf).


process_chains([], _, [], _, _,[],[]).
process_chains([Sample0|Samples0], VarOrder, [Sample|Samples], Graph, SampLen,[E0|E0s],[Ef|Efs]) :-
	functor(Sample,sample,SampLen),
	do_sample(VarOrder,Sample,Sample0,Graph),
% format('Sample = ~w~n',[Sample]),
	maplist(update_estimate(Sample), E0, Ef),
	process_chains(Samples0, VarOrder, Samples, Graph, SampLen,E0s,Efs).

do_sample([],_,_,_).
do_sample([I|VarOrder],Sample,Sample0,Graph) :-
	do_var(I,Sample,Sample0,Graph),
	do_sample(VarOrder,Sample,Sample0,Graph).

do_var(I,Sample,Sample0,Graph) :-
	( explicit(I) ->
	  arg(I,Graph,var(_,_,_,_,_,_,Parents,_,_)),
	  fetch_parents(Parents,I,Sample,Sample0,Args),
	  recorded(mblanket,m(I,Args,Vals),_)
	;
	  arg(I,Graph,var(_,_,_,_,_,CPTs,Parents,_,_)),
	  fetch_parents(Parents,I,Sample,Sample0,Bindings),
	  multiply_all_in_context(Parents,Bindings,CPTs,Graph,Vals)
	),
	X is random,
	pick_new_value(Vals,X,0,Val),
	arg(I,Sample,Val).

multiply_all_in_context(Parents,Args,CPTs,Graph,Vals) :-
	set_pos(Parents,Args,Graph),
	multiply_all(CPTs,Graph,Vals),
	assert(mall(Vals)), fail.
multiply_all_in_context(_,_,_,_,Vals) :-
	retract(mall(Vals)).

set_pos([],[],_).
set_pos([I|Is],[Pos|Args],Graph) :-
	arg(I,Graph,var(_,I,Pos,_,_,_,_,_,_)),
	set_pos(Is,Args,Graph).

fetch_parents([],_,_,_,[]).
fetch_parents([P|Parents],I,Sample,Sample0,[VP|Args]) :-
	arg(P,Sample,VP),
	nonvar(VP), !,
	fetch_parents(Parents,I,Sample,Sample0,Args).
fetch_parents([P|Parents],I,Sample,Sample0,[VP|Args]) :-
	arg(P,Sample0,VP),
	fetch_parents(Parents,I,Sample,Sample0,Args).

pick_new_value([V|Vals],X,I0,Val) :-
	( X < V ->
	  Val = I0
	;
	  I is I0+1,
	  pick_new_value(Vals,X,I,Val)
	).

update_estimate(Sample, [I|E],[I|NE]) :-
	arg(I,Sample,V),
	update_estimate_for_var(V,E,NE).
update_estimate(Sample,me(Is,Mult,E),me(Is,Mult,NE)) :-
	get_estimate_pos(Is, Sample, Mult, 0, V),
	update_estimate_for_var(V,E,NE).

get_estimate_pos([], _, [], V, V).
get_estimate_pos([I|Is], Sample, [M|Mult], V0, V) :-
	arg(I,Sample,VV),
	VI is VV*M+V0,
	get_estimate_pos(Is, Sample, Mult, VI, V).

update_estimate_for_var(V0,[X|T],[X1|NT]) :-
	(V0 == 0 ->
	  X1 is X+1,
	  NT = T
	;
	  V1 is V0-1,
	  X1 = X,
	  update_estimate_for_var(V1,T,NT)
	).


check_if_gibbs_done(Var) :-
	get_atts(Var, [dist(_)]), !.

clean_up :-
	eraseall(mblanket),
	fail.
clean_up :-
	retractall(explicit(_)),
	fail.
clean_up.

gibbs_params(5,100,1000).

cvt2prob([[_|E]], Ps) :-
	foldl(sum_all, E, 0, Sum),
	maplist( do_prob(Sum), E, Ps).

sum_all(E, S0, Sum) :-
	Sum is S0+E.

do_prob(Sum, E, P) :-
	P is E/Sum.

show_sorted([], _) :- nl.
show_sorted([I|VarOrder], Graph) :-
	arg(I,Graph,var(V,I,_,_,_,_,_,_,_)),
	clpbn:get_atts(V,[key(K)]),
	format('~w ',[K]),
	show_sorted(VarOrder, Graph).

sum_up_all([[]|_], []).
sum_up_all([[C|MoreC]|Chains], [Dist|Dists]) :-
	maplist( extract_sum, Chains, CurrentChains, LeftChains),
	sum_up([C|CurrentChains], Dist),
	sum_up_all([MoreC|LeftChains], Dists).

extract_sum([C|Chains], C, Chains).

sum_up([[_|Counts]|Chains], Dist) :-
	add_up(Counts,Chains, Add),
	normalise(Add, Dist).
sum_up([me(_,_,Counts)|Chains], Dist) :-
	add_up_mes(Counts,Chains, Add),
	normalise(Add, Dist).

add_up(Counts,[],Counts).
add_up(Counts,[[_|Cs]|Chains], Add) :-
	maplist(sum, Counts, Cs, NCounts),
	add_up(NCounts, Chains, Add).

add_up_mes(Counts,[],Counts).
add_up_mes(Counts,[me(_,_,Cs)|Chains], Add) :-
	maplist( sum_list, Counts, Cs, NCounts),
	add_up_mes(NCounts, Chains, Add).

sum(Count, C, NC) :-
	NC is Count+C.

normalise(Add, Dist) :-
	sum_list(Add, Sum),
	maplist(divide(Sum), Add, Dist).

divide(Sum, C, P) :-
	P is C/Sum.