minor improvements
This commit is contained in:
parent
ee5b8e693a
commit
57edd6adb9
@ -15,7 +15,7 @@
|
|||||||
cpp_run_ground_solver/3,
|
cpp_run_ground_solver/3,
|
||||||
cpp_set_vars_information/2,
|
cpp_set_vars_information/2,
|
||||||
cpp_set_horus_flag/2,
|
cpp_set_horus_flag/2,
|
||||||
cpp_free_parfactors/1,
|
cpp_free_lifted_network/1,
|
||||||
cpp_free_ground_network/1
|
cpp_free_ground_network/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
[cpp_create_lifted_network/3,
|
[cpp_create_lifted_network/3,
|
||||||
cpp_set_parfactors_params/2,
|
cpp_set_parfactors_params/2,
|
||||||
cpp_run_lifted_solver/3,
|
cpp_run_lifted_solver/3,
|
||||||
cpp_free_parfactors/1
|
cpp_free_lifted_network/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
:- use_module(library('clpbn/display'),
|
:- use_module(library('clpbn/display'),
|
||||||
@ -144,5 +144,5 @@ run_horus_lifted_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds))
|
|||||||
|
|
||||||
|
|
||||||
finalize_horus_lifted_solver(fove(ParfactorList, _)) :-
|
finalize_horus_lifted_solver(fove(ParfactorList, _)) :-
|
||||||
cpp_free_parfactors(ParfactorList).
|
cpp_free_lifted_network(ParfactorList).
|
||||||
|
|
||||||
|
@ -9,12 +9,8 @@
|
|||||||
#include "Var.h"
|
#include "Var.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
class Var;
|
|
||||||
|
|
||||||
class BBNode : public Var
|
class BBNode : public Var
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
@ -20,10 +20,8 @@
|
|||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||||
|
|
||||||
|
|
||||||
Params readParameters (YAP_Term);
|
Params readParameters (YAP_Term);
|
||||||
|
|
||||||
vector<unsigned> readUnsignedList (YAP_Term);
|
vector<unsigned> readUnsignedList (YAP_Term);
|
||||||
@ -32,14 +30,6 @@ void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
|||||||
|
|
||||||
Parfactor* readParfactor (YAP_Term);
|
Parfactor* readParfactor (YAP_Term);
|
||||||
|
|
||||||
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
|
||||||
vector<Params>& results);
|
|
||||||
|
|
||||||
void runBeliefProp (FactorGraph* fg, const vector<VarIds>& tasks,
|
|
||||||
vector<Params>& results);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<unsigned>
|
vector<unsigned>
|
||||||
readUnsignedList (YAP_Term list)
|
readUnsignedList (YAP_Term list)
|
||||||
@ -54,7 +44,8 @@ readUnsignedList (YAP_Term list)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
int createLiftedNetwork (void)
|
int
|
||||||
|
createLiftedNetwork (void)
|
||||||
{
|
{
|
||||||
Parfactors parfactors;
|
Parfactors parfactors;
|
||||||
YAP_Term parfactorList = YAP_ARG1;
|
YAP_Term parfactorList = YAP_ARG1;
|
||||||
@ -91,7 +82,8 @@ int createLiftedNetwork (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor* readParfactor (YAP_Term pfTerm)
|
Parfactor*
|
||||||
|
readParfactor (YAP_Term pfTerm)
|
||||||
{
|
{
|
||||||
// read dist id
|
// read dist id
|
||||||
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
||||||
@ -171,7 +163,8 @@ Parfactor* readParfactor (YAP_Term pfTerm)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void readLiftedEvidence (
|
void
|
||||||
|
readLiftedEvidence (
|
||||||
YAP_Term observedList,
|
YAP_Term observedList,
|
||||||
ObservedFormulas& obsFormulas)
|
ObservedFormulas& obsFormulas)
|
||||||
{
|
{
|
||||||
@ -360,11 +353,42 @@ runGroundSolver (void)
|
|||||||
taskList = YAP_TailOfTerm (taskList);
|
taskList = YAP_TailOfTerm (taskList);
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<Params> results;
|
std::set<VarId> vids;
|
||||||
|
for (size_t i = 0; i < tasks.size(); i++) {
|
||||||
|
Util::addToSet (vids, tasks[i]);
|
||||||
|
}
|
||||||
|
Solver* solver = 0;
|
||||||
|
FactorGraph* mfg = fg;
|
||||||
|
if (fg->bayesianFactors()) {
|
||||||
|
mfg = BayesBall::getMinimalFactorGraph (
|
||||||
|
*fg, VarIds (vids.begin(), vids.end()));
|
||||||
|
}
|
||||||
|
|
||||||
if (Globals::groundSolver == GroundSolver::VE) {
|
if (Globals::groundSolver == GroundSolver::VE) {
|
||||||
runVeSolver (fg, tasks, results);
|
solver = new VarElim (*mfg);
|
||||||
|
} else if (Globals::groundSolver == GroundSolver::BP) {
|
||||||
|
solver = new BeliefProp (*mfg);
|
||||||
|
} else if (Globals::groundSolver == GroundSolver::CBP) {
|
||||||
|
CountingBp::checkForIdenticalFactors = false;
|
||||||
|
solver = new CountingBp (*mfg);
|
||||||
} else {
|
} else {
|
||||||
runBeliefProp (fg, tasks, results);
|
assert (false);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
solver->printSolverFlags();
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<Params> results;
|
||||||
|
results.reserve (tasks.size());
|
||||||
|
for (size_t i = 0; i < tasks.size(); i++) {
|
||||||
|
results.push_back (solver->solveQuery (tasks[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
delete solver;
|
||||||
|
if (fg->bayesianFactors()) {
|
||||||
|
delete mfg;
|
||||||
}
|
}
|
||||||
|
|
||||||
YAP_Term list = YAP_TermNil();
|
YAP_Term list = YAP_TermNil();
|
||||||
@ -385,71 +409,6 @@ runGroundSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void runVeSolver (
|
|
||||||
FactorGraph* fg,
|
|
||||||
const vector<VarIds>& tasks,
|
|
||||||
vector<Params>& results)
|
|
||||||
{
|
|
||||||
results.reserve (tasks.size());
|
|
||||||
for (size_t i = 0; i < tasks.size(); i++) {
|
|
||||||
FactorGraph* mfg = fg;
|
|
||||||
if (fg->bayesianFactors()) {
|
|
||||||
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
|
||||||
}
|
|
||||||
VarElim solver (*mfg);
|
|
||||||
if (Globals::verbosity > 0 && i == 0) {
|
|
||||||
solver.printSolverFlags();
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
results.push_back (solver.solveQuery (tasks[i]));
|
|
||||||
if (fg->bayesianFactors()) {
|
|
||||||
delete mfg;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void runBeliefProp (
|
|
||||||
FactorGraph* fg,
|
|
||||||
const vector<VarIds>& tasks,
|
|
||||||
vector<Params>& results)
|
|
||||||
{
|
|
||||||
std::set<VarId> vids;
|
|
||||||
for (size_t i = 0; i < tasks.size(); i++) {
|
|
||||||
Util::addToSet (vids, tasks[i]);
|
|
||||||
}
|
|
||||||
Solver* solver = 0;
|
|
||||||
FactorGraph* mfg = fg;
|
|
||||||
if (fg->bayesianFactors()) {
|
|
||||||
mfg = BayesBall::getMinimalFactorGraph (
|
|
||||||
*fg, VarIds (vids.begin(),vids.end()));
|
|
||||||
}
|
|
||||||
if (Globals::groundSolver == GroundSolver::BP) {
|
|
||||||
solver = new BeliefProp (*mfg);
|
|
||||||
} else if (Globals::groundSolver == GroundSolver::CBP) {
|
|
||||||
CountingBp::checkForIdenticalFactors = false;
|
|
||||||
solver = new CountingBp (*mfg);
|
|
||||||
} else {
|
|
||||||
cerr << "error: unknow solver" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
if (Globals::verbosity > 0) {
|
|
||||||
solver->printSolverFlags();
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
results.reserve (tasks.size());
|
|
||||||
for (size_t i = 0; i < tasks.size(); i++) {
|
|
||||||
results.push_back (solver->solveQuery (tasks[i]));
|
|
||||||
}
|
|
||||||
if (fg->bayesianFactors()) {
|
|
||||||
delete mfg;
|
|
||||||
}
|
|
||||||
delete solver;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
setParfactorsParams (void)
|
setParfactorsParams (void)
|
||||||
{
|
{
|
||||||
@ -565,7 +524,7 @@ freeGroundNetwork (void)
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
freeParfactors (void)
|
freeLiftedNetwork (void)
|
||||||
{
|
{
|
||||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
delete network->first;
|
delete network->first;
|
||||||
@ -587,7 +546,7 @@ init_predicates (void)
|
|||||||
YAP_UserCPredicate ("cpp_cpp_set_factors_params", setFactorsParams, 2);
|
YAP_UserCPredicate ("cpp_cpp_set_factors_params", setFactorsParams, 2);
|
||||||
YAP_UserCPredicate ("cpp_set_vars_information", setVarsInformation, 2);
|
YAP_UserCPredicate ("cpp_set_vars_information", setVarsInformation, 2);
|
||||||
YAP_UserCPredicate ("cpp_set_horus_flag", setHorusFlag, 2);
|
YAP_UserCPredicate ("cpp_set_horus_flag", setHorusFlag, 2);
|
||||||
YAP_UserCPredicate ("cpp_free_parfactors", freeParfactors, 1);
|
YAP_UserCPredicate ("cpp_free_lifted_network", freeLiftedNetwork, 1);
|
||||||
YAP_UserCPredicate ("cpp_free_ground_network", freeGroundNetwork, 1);
|
YAP_UserCPredicate ("cpp_free_ground_network", freeGroundNetwork, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user