More cleanups
This commit is contained in:
parent
4b0acbf8c1
commit
0a661b0462
@ -8,8 +8,8 @@
|
|||||||
[set_horus_flag/1,
|
[set_horus_flag/1,
|
||||||
cpp_create_lifted_network/3,
|
cpp_create_lifted_network/3,
|
||||||
cpp_create_ground_network/4,
|
cpp_create_ground_network/4,
|
||||||
cpp_set_parfactors_params/2,
|
cpp_set_parfactors_params/3,
|
||||||
cpp_set_factors_params/2,
|
cpp_set_factors_params/3,
|
||||||
cpp_run_lifted_solver/3,
|
cpp_run_lifted_solver/3,
|
||||||
cpp_run_ground_solver/3,
|
cpp_run_ground_solver/3,
|
||||||
cpp_set_vars_information/2,
|
cpp_set_vars_information/2,
|
||||||
@ -18,8 +18,9 @@
|
|||||||
cpp_free_ground_network/1
|
cpp_free_ground_network/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
:- use_module(library(clpbn),
|
|
||||||
[set_clpbn_flag/2]).
|
:- catch(load_foreign_files([horus], [], init_predicates), _, patch_things_up)
|
||||||
|
-> true ; warning.
|
||||||
|
|
||||||
|
|
||||||
patch_things_up :-
|
patch_things_up :-
|
||||||
@ -27,11 +28,7 @@ patch_things_up :-
|
|||||||
|
|
||||||
|
|
||||||
warning :-
|
warning :-
|
||||||
format(user_error,"Horus library not installed: cannot use bp, fove~n.",[]).
|
format(user_error,"Horus library not installed: cannot use hve, bp, cbp, lve, lkc and lbp~n.",[]).
|
||||||
|
|
||||||
|
|
||||||
:- catch(load_foreign_files([horus], [], init_predicates), _, patch_things_up)
|
|
||||||
-> true ; warning.
|
|
||||||
|
|
||||||
|
|
||||||
set_horus_flag(K,V) :- cpp_set_horus_flag(K,V).
|
set_horus_flag(K,V) :- cpp_set_horus_flag(K,V).
|
||||||
|
@ -17,19 +17,24 @@
|
|||||||
|
|
||||||
:- use_module(horus,
|
:- use_module(horus,
|
||||||
[cpp_create_ground_network/4,
|
[cpp_create_ground_network/4,
|
||||||
cpp_set_factors_params/2,
|
cpp_set_factors_params/3,
|
||||||
cpp_run_ground_solver/3,
|
cpp_run_ground_solver/3,
|
||||||
cpp_free_ground_network/1,
|
cpp_free_ground_network/1,
|
||||||
cpp_set_vars_information/2
|
cpp_set_vars_information/2
|
||||||
]).
|
]).
|
||||||
|
|
||||||
:- use_module(library('clpbn/numbers')).
|
:- use_module(library('clpbn/numbers'),
|
||||||
|
[lists_of_keys_to_ids/6,
|
||||||
|
keys_to_numbers/7
|
||||||
|
]).
|
||||||
|
|
||||||
:- use_module(library('clpbn/display'),
|
:- use_module(library('clpbn/display'),
|
||||||
[clpbn_bind_vals/3]).
|
[clpbn_bind_vals/3]).
|
||||||
|
|
||||||
:- use_module(library(pfl),
|
:- use_module(library(pfl),
|
||||||
[skolem/2]).
|
[get_pfl_parameters/2,
|
||||||
|
skolem/2
|
||||||
|
]).
|
||||||
|
|
||||||
:- use_module(library(charsio),
|
:- use_module(library(charsio),
|
||||||
[term_to_atom/2]).
|
[term_to_atom/2]).
|
||||||
@ -37,40 +42,48 @@
|
|||||||
:- use_module(library(maplist)).
|
:- use_module(library(maplist)).
|
||||||
|
|
||||||
|
|
||||||
call_horus_ground_solver(QueryVars, QueryKeys, AllKeys, Factors, Evidence, Output) :-
|
call_horus_ground_solver(QueryVars, QueryKeys, AllKeys, Factors, Evidence,
|
||||||
|
Output) :-
|
||||||
init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence, State),
|
init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence, State),
|
||||||
run_horus_ground_solver([QueryKeys], Solutions, State),
|
run_horus_ground_solver([QueryKeys], Solutions, State),
|
||||||
clpbn_bind_vals([QueryVars], Solutions, Output),
|
clpbn_bind_vals([QueryVars], Solutions, Output),
|
||||||
end_horus_ground_solver(State).
|
end_horus_ground_solver(State).
|
||||||
|
|
||||||
|
|
||||||
run_horus_ground_solver(QueryKeys, Solutions, state(Network,Hash,Id)) :-
|
init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence,
|
||||||
%maplist(get_dists_parameters, DistIds, DistsParams),
|
state(Network,Hash,Id,DistIds)) :-
|
||||||
%cpp_set_factors_params(Network, DistsParams),
|
factors_type(Factors, Type),
|
||||||
|
keys_to_numbers(AllKeys, Factors, Evidence, Hash, Id, FacIds, EvIds),
|
||||||
|
%writeln(network:(type=Type, factors=FacIds, evidence=EvIds)), nl,
|
||||||
|
cpp_create_ground_network(Type, FacIds, EvIds, Network),
|
||||||
|
%maplist(term_to_atom, AllKeys, VarNames),
|
||||||
|
%maplist(get_domain, AllKeys, Domains),
|
||||||
|
%cpp_set_vars_information(VarNames, Domains),
|
||||||
|
maplist(get_dist_id, FacIds, DistIds0),
|
||||||
|
sort(DistIds0, DistIds).
|
||||||
|
|
||||||
|
|
||||||
|
run_horus_ground_solver(QueryKeys, Solutions,
|
||||||
|
state(Network,Hash,Id, DistIds)) :-
|
||||||
lists_of_keys_to_ids(QueryKeys, QueryIds, Hash, _, Id, _),
|
lists_of_keys_to_ids(QueryKeys, QueryIds, Hash, _, Id, _),
|
||||||
|
%maplist(get_pfl_parameters, DistIds, DistParams),
|
||||||
|
%cpp_set_factors_params(Network, DistIds, DistParams),
|
||||||
cpp_run_ground_solver(Network, QueryIds, Solutions).
|
cpp_run_ground_solver(Network, QueryIds, Solutions).
|
||||||
|
|
||||||
|
|
||||||
init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence, state(Network,Hash4,Id4)) :-
|
end_horus_ground_solver(state(Network,_Hash,_Id, _DistIds)) :-
|
||||||
get_factors_type(Factors, Type),
|
|
||||||
keys_to_numbers(AllKeys, Factors, Evidence, Hash4, Id4, FactorIds, EvidenceIds),
|
|
||||||
cpp_create_ground_network(Type, FactorIds, EvidenceIds, Network),
|
|
||||||
%writeln(network:(Type, FactorIds, EvidenceIds, Network)), writeln(''),
|
|
||||||
maplist(get_var_information, AllKeys, StatesNames),
|
|
||||||
maplist(term_to_atom, AllKeys, KeysAtoms),
|
|
||||||
cpp_set_vars_information(KeysAtoms, StatesNames).
|
|
||||||
|
|
||||||
|
|
||||||
end_horus_ground_solver(state(Network,_Hash,_Id)) :-
|
|
||||||
cpp_free_ground_network(Network).
|
cpp_free_ground_network(Network).
|
||||||
|
|
||||||
|
|
||||||
get_factors_type([f(bayes, _, _)|_], bayes) :- ! .
|
factors_type([f(bayes, _, _)|_], bayes) :- ! .
|
||||||
get_factors_type([f(markov, _, _)|_], markov) :- ! .
|
factors_type([f(markov, _, _)|_], markov) :- ! .
|
||||||
|
|
||||||
|
|
||||||
get_var_information(_:Key, Domain) :- !,
|
get_dist_id(f(_, _, _, DistId), DistId).
|
||||||
|
|
||||||
|
|
||||||
|
get_domain(_:Key, Domain) :- !,
|
||||||
skolem(Key, Domain).
|
skolem(Key, Domain).
|
||||||
get_var_information(Key, Domain) :-
|
get_domain(Key, Domain) :-
|
||||||
skolem(Key, Domain).
|
skolem(Key, Domain).
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
:- use_module(horus,
|
:- use_module(horus,
|
||||||
[cpp_create_lifted_network/3,
|
[cpp_create_lifted_network/3,
|
||||||
cpp_set_parfactors_params/2,
|
cpp_set_parfactors_params/3,
|
||||||
cpp_run_lifted_solver/3,
|
cpp_run_lifted_solver/3,
|
||||||
cpp_free_lifted_network/1
|
cpp_free_lifted_network/1
|
||||||
]).
|
]).
|
||||||
@ -43,22 +43,17 @@ call_horus_lifted_solver(QueryVars, AllVars, Output) :-
|
|||||||
|
|
||||||
init_horus_lifted_solver(_, AllVars, _, state(Network, DistIds)) :-
|
init_horus_lifted_solver(_, AllVars, _, state(Network, DistIds)) :-
|
||||||
get_parfactors(Parfactors),
|
get_parfactors(Parfactors),
|
||||||
get_dist_ids(Parfactors, DistIds0),
|
|
||||||
sort(DistIds0, DistIds),
|
|
||||||
get_observed_keys(AllVars, ObservedKeys),
|
get_observed_keys(AllVars, ObservedKeys),
|
||||||
%writeln(parfactors:Parfactors:'\n'),
|
%writeln(network:(parfactors=Parfactors, evidence=ObservedKeys)), nl,
|
||||||
%writeln(evidence:ObservedKeys:'\n'),
|
cpp_create_lifted_network(Parfactors, ObservedKeys, Network),
|
||||||
cpp_create_lifted_network(Parfactors, ObservedKeys, Network).
|
maplist(get_dist_id, Parfactors, DistIds0),
|
||||||
|
sort(DistIds0, DistIds).
|
||||||
|
|
||||||
|
|
||||||
run_horus_lifted_solver(QueryVars, Solutions, state(Network, DistIds)) :-
|
run_horus_lifted_solver(QueryVars, Solutions, state(Network, DistIds)) :-
|
||||||
maplist(get_query_keys, QueryVars, QueryKeys),
|
maplist(get_query_keys, QueryVars, QueryKeys),
|
||||||
get_dists_parameters(DistIds, DistsParams),
|
%maplist(get_pfl_parameters, DistIds,DistsParams),
|
||||||
%writeln(distparams1:DistsParams),
|
%cpp_set_parfactors_params(Network, DistIds, DistsParams),
|
||||||
%maplist(get_pfl_parameters, DistIds,DistsParams2),
|
|
||||||
%writeln(distparams1:DistsParams2),
|
|
||||||
%writeln(dists:DistsParams), writeln(''),
|
|
||||||
cpp_set_parfactors_params(Network, DistsParams),
|
|
||||||
cpp_run_lifted_solver(Network, QueryKeys, Solutions).
|
cpp_run_lifted_solver(Network, QueryKeys, Solutions).
|
||||||
|
|
||||||
|
|
||||||
@ -68,13 +63,6 @@ end_horus_lifted_solver(state(Network, _)) :-
|
|||||||
%
|
%
|
||||||
% Enumerate all parfactors and enumerate their domain as tuples.
|
% Enumerate all parfactors and enumerate their domain as tuples.
|
||||||
%
|
%
|
||||||
% output is list of pf(
|
|
||||||
% Id: an unique number
|
|
||||||
% Ks: a list of keys, also known as the pf formula [a(X),b(Y),c(X,Y)]
|
|
||||||
% Vs: the list of free variables [X,Y]
|
|
||||||
% Phi: the table following usual CLP(BN) convention
|
|
||||||
% Tuples: ground bindings for variables in Vs, of the form [fv(x,y)]
|
|
||||||
%
|
|
||||||
:- table get_parfactors/1.
|
:- table get_parfactors/1.
|
||||||
|
|
||||||
get_parfactors(Factors) :-
|
get_parfactors(Factors) :-
|
||||||
@ -90,8 +78,8 @@ is_factor(pf(Id, Ks, Rs, Phi, Tuples)) :-
|
|||||||
|
|
||||||
|
|
||||||
get_range(K, Range) :-
|
get_range(K, Range) :-
|
||||||
skolem(K,Domain),
|
skolem(K, Domain),
|
||||||
length(Domain,Range).
|
length(Domain, Range).
|
||||||
|
|
||||||
|
|
||||||
gen_table(Table, Phi) :-
|
gen_table(Table, Phi) :-
|
||||||
@ -108,9 +96,7 @@ run(Goal.Constraints) :-
|
|||||||
run(Constraints).
|
run(Constraints).
|
||||||
|
|
||||||
|
|
||||||
get_dist_ids([], []).
|
get_dist_id(pf(DistId, _, _, _, _), DistId).
|
||||||
get_dist_ids(pf(Id, _, _, _, _).Parfactors, Id.DistIds) :-
|
|
||||||
get_dist_ids(Parfactors, DistIds).
|
|
||||||
|
|
||||||
|
|
||||||
get_observed_keys([], []).
|
get_observed_keys([], []).
|
||||||
@ -118,8 +104,7 @@ get_observed_keys(V.AllAttVars, [K:E|ObservedKeys]) :-
|
|||||||
clpbn:get_atts(V,[key(K)]),
|
clpbn:get_atts(V,[key(K)]),
|
||||||
( clpbn:get_atts(V,[evidence(E)]) ; pfl:evidence(K,E) ), !,
|
( clpbn:get_atts(V,[evidence(E)]) ; pfl:evidence(K,E) ), !,
|
||||||
get_observed_keys(AllAttVars, ObservedKeys).
|
get_observed_keys(AllAttVars, ObservedKeys).
|
||||||
get_observed_keys(V.AllAttVars, ObservedKeys) :-
|
get_observed_keys(_V.AllAttVars, ObservedKeys) :-
|
||||||
clpbn:get_atts(V,[key(_K)]), !,
|
|
||||||
get_observed_keys(AllAttVars, ObservedKeys).
|
get_observed_keys(AllAttVars, ObservedKeys).
|
||||||
|
|
||||||
|
|
||||||
@ -128,9 +113,3 @@ get_query_keys(V.AttVars, K.Ks) :-
|
|||||||
clpbn:get_atts(V,[key(K)]), !,
|
clpbn:get_atts(V,[key(K)]), !,
|
||||||
get_query_keys(AttVars, Ks).
|
get_query_keys(AttVars, Ks).
|
||||||
|
|
||||||
|
|
||||||
get_dists_parameters([], []).
|
|
||||||
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
|
||||||
get_pfl_parameters(Id, Params),
|
|
||||||
get_dists_parameters(Ids, DistsInfo).
|
|
||||||
|
|
||||||
|
@ -233,19 +233,21 @@ setParfactorsParams (void)
|
|||||||
{
|
{
|
||||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
ParfactorList* pfList = network->first;
|
ParfactorList* pfList = network->first;
|
||||||
YAP_Term distList = YAP_ARG2;
|
YAP_Term distIdsList = YAP_ARG2;
|
||||||
|
YAP_Term paramsList = YAP_ARG3;
|
||||||
unordered_map<unsigned, Params> paramsMap;
|
unordered_map<unsigned, Params> paramsMap;
|
||||||
while (distList != YAP_TermNil()) {
|
while (distIdsList != YAP_TermNil()) {
|
||||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
unsigned distId = (unsigned) YAP_IntOfTerm (
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
YAP_HeadOfTerm (distIdsList));
|
||||||
assert (Util::contains (paramsMap, distId) == false);
|
assert (Util::contains (paramsMap, distId) == false);
|
||||||
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
|
paramsMap[distId] = readParameters (YAP_HeadOfTerm (paramsList));
|
||||||
distList = YAP_TailOfTerm (distList);
|
distIdsList = YAP_TailOfTerm (distIdsList);
|
||||||
|
paramsList = YAP_TailOfTerm (paramsList);
|
||||||
}
|
}
|
||||||
ParfactorList::iterator it = pfList->begin();
|
ParfactorList::iterator it = pfList->begin();
|
||||||
while (it != pfList->end()) {
|
while (it != pfList->end()) {
|
||||||
assert (Util::contains (paramsMap, (*it)->distId()));
|
assert (Util::contains (paramsMap, (*it)->distId()));
|
||||||
// (*it)->setParams (paramsMap[(*it)->distId()]);
|
(*it)->setParams (paramsMap[(*it)->distId()]);
|
||||||
++ it;
|
++ it;
|
||||||
}
|
}
|
||||||
return TRUE;
|
return TRUE;
|
||||||
@ -256,16 +258,17 @@ setParfactorsParams (void)
|
|||||||
int
|
int
|
||||||
setFactorsParams (void)
|
setFactorsParams (void)
|
||||||
{
|
{
|
||||||
return TRUE; // TODO
|
|
||||||
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
YAP_Term distList = YAP_ARG2;
|
YAP_Term distIdsList = YAP_ARG2;
|
||||||
|
YAP_Term paramsList = YAP_ARG3;
|
||||||
unordered_map<unsigned, Params> paramsMap;
|
unordered_map<unsigned, Params> paramsMap;
|
||||||
while (distList != YAP_TermNil()) {
|
while (distIdsList != YAP_TermNil()) {
|
||||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
unsigned distId = (unsigned) YAP_IntOfTerm (
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
YAP_HeadOfTerm (distIdsList));
|
||||||
assert (Util::contains (paramsMap, distId) == false);
|
assert (Util::contains (paramsMap, distId) == false);
|
||||||
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
|
paramsMap[distId] = readParameters (YAP_HeadOfTerm (paramsList));
|
||||||
distList = YAP_TailOfTerm (distList);
|
distIdsList = YAP_TailOfTerm (distIdsList);
|
||||||
|
paramsList = YAP_TailOfTerm (paramsList);
|
||||||
}
|
}
|
||||||
const FacNodes& facNodes = fg->facNodes();
|
const FacNodes& facNodes = fg->facNodes();
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
@ -534,15 +537,34 @@ fillAnswersPrologList (vector<Params>& results)
|
|||||||
extern "C" void
|
extern "C" void
|
||||||
init_predicates (void)
|
init_predicates (void)
|
||||||
{
|
{
|
||||||
YAP_UserCPredicate ("cpp_create_lifted_network", createLiftedNetwork, 3);
|
YAP_UserCPredicate ("cpp_create_lifted_network",
|
||||||
YAP_UserCPredicate ("cpp_create_ground_network", createGroundNetwork, 4);
|
createLiftedNetwork, 3);
|
||||||
YAP_UserCPredicate ("cpp_run_lifted_solver", runLiftedSolver, 3);
|
|
||||||
YAP_UserCPredicate ("cpp_run_ground_solver", runGroundSolver, 3);
|
YAP_UserCPredicate ("cpp_create_ground_network",
|
||||||
YAP_UserCPredicate ("cpp_set_parfactors_params", setParfactorsParams, 2);
|
createGroundNetwork, 4);
|
||||||
YAP_UserCPredicate ("cpp_cpp_set_factors_params", setFactorsParams, 2);
|
|
||||||
YAP_UserCPredicate ("cpp_set_vars_information", setVarsInformation, 2);
|
YAP_UserCPredicate ("cpp_run_lifted_solver",
|
||||||
YAP_UserCPredicate ("cpp_set_horus_flag", setHorusFlag, 2);
|
runLiftedSolver, 3);
|
||||||
YAP_UserCPredicate ("cpp_free_lifted_network", freeLiftedNetwork, 1);
|
|
||||||
YAP_UserCPredicate ("cpp_free_ground_network", freeGroundNetwork, 1);
|
YAP_UserCPredicate ("cpp_run_ground_solver",
|
||||||
|
runGroundSolver, 3);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_set_parfactors_params",
|
||||||
|
setParfactorsParams, 3);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_set_factors_params",
|
||||||
|
setFactorsParams, 3);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_set_vars_information",
|
||||||
|
setVarsInformation, 2);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_set_horus_flag",
|
||||||
|
setHorusFlag, 2);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_free_lifted_network",
|
||||||
|
freeLiftedNetwork, 1);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_free_ground_network",
|
||||||
|
freeGroundNetwork, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user