This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/horus_lifted.yap

117 lines
2.9 KiB
Plaintext
Raw Normal View History

2012-03-31 23:27:37 +01:00
/*******************************************************
2012-03-22 11:29:46 +00:00
2012-05-23 19:15:23 +01:00
Interface to Horus Lifted Solvers. Used by:
- Generalized Counting First-Order Variable Elimination (GC-FOVE)
- Lifted First-Order Belief Propagation
- Lifted First-Order Knowledge Compilation
2012-11-22 16:33:22 +00:00
2012-03-31 23:27:37 +01:00
********************************************************/
2012-03-22 11:29:46 +00:00
2012-05-23 19:15:23 +01:00
:- module(clpbn_horus_lifted,
[call_horus_lifted_solver/3,
check_if_horus_lifted_solver_done/1,
init_horus_lifted_solver/4,
run_horus_lifted_solver/3,
2012-12-18 12:11:45 +00:00
end_horus_lifted_solver/1
]).
2012-01-10 17:01:06 +00:00
2012-05-23 19:15:23 +01:00
:- use_module(horus,
[cpp_create_lifted_network/3,
2012-12-18 22:47:43 +00:00
cpp_set_parfactors_params/3,
cpp_run_lifted_solver/3,
cpp_free_lifted_network/1
]).
2012-01-10 17:01:06 +00:00
2012-03-22 11:29:46 +00:00
:- use_module(library('clpbn/display'),
[clpbn_bind_vals/3]).
2012-01-10 17:01:06 +00:00
2012-03-22 11:29:46 +00:00
:- use_module(library(pfl),
[factor/6,
skolem/2,
get_pfl_parameters/3
]).
2012-03-22 11:29:46 +00:00
2012-12-18 12:11:45 +00:00
:- use_module(library(maplist)).
2012-03-22 11:29:46 +00:00
2012-11-22 16:33:22 +00:00
call_horus_lifted_solver(QueryVars, AllVars, Output) :-
init_horus_lifted_solver(_, AllVars, _, State),
run_horus_lifted_solver(QueryVars, Solutions, State),
clpbn_bind_vals(QueryVars, Solutions, Output),
2012-12-18 12:11:45 +00:00
end_horus_lifted_solver(State).
2012-03-31 23:27:37 +01:00
2012-12-18 12:11:45 +00:00
init_horus_lifted_solver(_, AllVars, _, state(Network, DistIds)) :-
get_parfactors(Parfactors),
2012-12-18 12:11:45 +00:00
get_observed_keys(AllVars, ObservedKeys),
2013-01-10 17:23:09 +00:00
%writeln(network:(parfactors=Parfactors, evidence=ObservedKeys)), nl,
2012-12-18 22:47:43 +00:00
cpp_create_lifted_network(Parfactors, ObservedKeys, Network),
maplist(get_dist_id, Parfactors, DistIds0),
sort(DistIds0, DistIds).
2012-11-22 16:33:22 +00:00
run_horus_lifted_solver(QueryVars, Solutions, state(Network, _DistIds)) :-
2012-12-18 12:11:45 +00:00
maplist(get_query_keys, QueryVars, QueryKeys),
%maplist(get_pfl_parameters, DistIds, _, DistsParams),
2012-12-18 22:47:43 +00:00
%cpp_set_parfactors_params(Network, DistIds, DistsParams),
2012-12-18 12:11:45 +00:00
cpp_run_lifted_solver(Network, QueryKeys, Solutions).
2012-11-22 16:33:22 +00:00
2012-01-10 17:01:06 +00:00
2012-12-18 12:11:45 +00:00
end_horus_lifted_solver(state(Network, _)) :-
cpp_free_lifted_network(Network).
2012-03-31 23:27:37 +01:00
2012-01-10 17:01:06 +00:00
%
2012-12-18 12:11:45 +00:00
% Enumerate all parfactors and enumerate their domain as tuples.
2012-01-10 17:01:06 +00:00
%
2012-12-18 12:11:45 +00:00
:- table get_parfactors/1.
2012-03-22 11:29:46 +00:00
get_parfactors(Factors) :-
findall(F, is_factor(F), Factors).
2012-03-31 23:27:37 +01:00
2012-01-10 17:01:06 +00:00
2012-03-22 11:29:46 +00:00
is_factor(pf(Id, Ks, Rs, Phi, Tuples)) :-
factor(_Type, Id, Ks, Vs, Table, Constraints),
2012-12-18 12:11:45 +00:00
maplist(get_range, Ks, Rs),
Table \= avg,
gen_table(Table, Phi),
2013-01-10 17:23:09 +00:00
all_tuples(Constraints, Vs, Tuples).
2012-01-10 17:01:06 +00:00
2012-03-22 11:29:46 +00:00
2012-12-18 12:11:45 +00:00
get_range(K, Range) :-
2012-12-18 22:47:43 +00:00
skolem(K, Domain),
length(Domain, Range).
2012-03-22 11:29:46 +00:00
2012-01-10 17:01:06 +00:00
gen_table(Table, Phi) :-
2012-12-18 12:11:45 +00:00
( is_list(Table) -> Phi = Table ; call(user:Table, Phi) ).
2012-03-31 23:27:37 +01:00
2012-03-22 11:29:46 +00:00
2012-01-10 17:01:06 +00:00
all_tuples(Constraints, Tuple, Tuples) :-
findall(Tuple, run(Constraints), Tuples0),
sort(Tuples0, Tuples).
2012-01-10 17:01:06 +00:00
2012-03-22 11:29:46 +00:00
2012-01-10 17:01:06 +00:00
run([]).
run(Goal.Constraints) :-
user:Goal,
run(Constraints).
2012-01-10 17:01:06 +00:00
2012-03-22 11:29:46 +00:00
2012-12-18 22:47:43 +00:00
get_dist_id(pf(DistId, _, _, _, _), DistId).
2012-03-22 11:29:46 +00:00
2012-12-18 12:11:45 +00:00
get_observed_keys([], []).
get_observed_keys(V.AllAttVars, [K:E|ObservedKeys]) :-
clpbn:get_atts(V,[key(K)]),
( clpbn:get_atts(V,[evidence(E)]) ; pfl:evidence(K,E) ), !,
2012-12-18 12:11:45 +00:00
get_observed_keys(AllAttVars, ObservedKeys).
2012-12-18 22:47:43 +00:00
get_observed_keys(_V.AllAttVars, ObservedKeys) :-
2012-12-18 12:11:45 +00:00
get_observed_keys(AllAttVars, ObservedKeys).
2012-03-22 11:29:46 +00:00
2012-11-22 16:33:22 +00:00
get_query_keys([], []).
2012-12-18 12:11:45 +00:00
get_query_keys(V.AttVars, K.Ks) :-
clpbn:get_atts(V,[key(K)]), !,
get_query_keys(AttVars, Ks).
2012-03-22 11:29:46 +00:00