/*******************************************************

 Interface to Horus Lifted Solvers. Used by:
	- Generalized Counting First-Order Variable Elimination (GC-FOVE)
	- Lifted First-Order Belief Propagation
	- Lifted First-Order Knowledge Compilation

********************************************************/

:- module(clpbn_horus_lifted,
		[call_horus_lifted_solver/3,
		 check_if_horus_lifted_solver_done/1,
		 init_horus_lifted_solver/4,
		 run_horus_lifted_solver/3,
		 end_horus_lifted_solver/1
		]).

:- use_module(horus,
		[cpp_create_lifted_network/3,
		 cpp_set_parfactors_params/3,
		 cpp_run_lifted_solver/3,
		 cpp_free_lifted_network/1
		]).

:- use_module(library('clpbn/display'),
		[clpbn_bind_vals/3]).

:- use_module(library(pfl),
		[factor/6,
		 skolem/2,
		 get_pfl_parameters/3
		]).

:- use_module(library(maplist)).


call_horus_lifted_solver(QueryVars, AllVars, Output) :-
	init_horus_lifted_solver(_, AllVars, _, State),
	run_horus_lifted_solver(QueryVars, Solutions, State),
	clpbn_bind_vals(QueryVars, Solutions, Output),
	end_horus_lifted_solver(State).


init_horus_lifted_solver(_, AllVars, _, state(Network, DistIds)) :-
	get_parfactors(Parfactors),
	get_observed_keys(AllVars, ObservedKeys),
	%writeln(network:(parfactors=Parfactors, evidence=ObservedKeys)), nl,
	cpp_create_lifted_network(Parfactors, ObservedKeys, Network),
	maplist(get_dist_id, Parfactors, DistIds0),
	sort(DistIds0, DistIds).


run_horus_lifted_solver(QueryVars, Solutions, state(Network, _DistIds)) :-
	maplist(get_query_keys, QueryVars, QueryKeys),
	%maplist(get_pfl_parameters, DistIds, _, DistsParams),
	%cpp_set_parfactors_params(Network, DistIds, DistsParams),
	cpp_run_lifted_solver(Network, QueryKeys, Solutions).


end_horus_lifted_solver(state(Network, _)) :-
	cpp_free_lifted_network(Network).

%
% Enumerate all parfactors and enumerate their domain as tuples.
%
:- table get_parfactors/1.

get_parfactors(Factors) :-
	findall(F, is_factor(F), Factors).


is_factor(pf(Id, Ks, Rs, Phi, Tuples)) :-
	factor(_Type, Id, Ks, Vs, Table, Constraints),
	maplist(get_range, Ks, Rs),
	Table \= avg,
	gen_table(Table, Phi),
	all_tuples(Constraints, Vs, Tuples).


get_range(K, Range) :-
	skolem(K, Domain),
	length(Domain, Range).


gen_table(Table, Phi) :-
	( is_list(Table) -> Phi = Table ; call(user:Table, Phi) ).


all_tuples(Constraints, Tuple, Tuples) :-
	findall(Tuple, run(Constraints), Tuples0),
	sort(Tuples0, Tuples).


run([]).
run([Goal|Constraints]) :-
	user:Goal,
	run(Constraints).


get_dist_id(pf(DistId, _, _, _, _), DistId).


get_observed_keys([], []).
get_observed_keys(V.AllAttVars, [K:E|ObservedKeys]) :-
	clpbn:get_atts(V,[key(K)]),
	( clpbn:get_atts(V,[evidence(E)]) ; pfl:evidence(K,E) ), !,
	get_observed_keys(AllAttVars, ObservedKeys).
get_observed_keys(_V.AllAttVars, ObservedKeys) :-
	get_observed_keys(AllAttVars, ObservedKeys).


get_query_keys([], []).
get_query_keys(V.AttVars, K.Ks) :-
	clpbn:get_atts(V,[key(K)]), !,
	get_query_keys(AttVars, Ks).