minor improvements
This commit is contained in:
parent
ee5b8e693a
commit
57edd6adb9
@ -15,7 +15,7 @@
|
||||
cpp_run_ground_solver/3,
|
||||
cpp_set_vars_information/2,
|
||||
cpp_set_horus_flag/2,
|
||||
cpp_free_parfactors/1,
|
||||
cpp_free_lifted_network/1,
|
||||
cpp_free_ground_network/1
|
||||
]).
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
[cpp_create_lifted_network/3,
|
||||
cpp_set_parfactors_params/2,
|
||||
cpp_run_lifted_solver/3,
|
||||
cpp_free_parfactors/1
|
||||
cpp_free_lifted_network/1
|
||||
]).
|
||||
|
||||
:- use_module(library('clpbn/display'),
|
||||
@ -144,5 +144,5 @@ run_horus_lifted_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds))
|
||||
|
||||
|
||||
finalize_horus_lifted_solver(fove(ParfactorList, _)) :-
|
||||
cpp_free_parfactors(ParfactorList).
|
||||
cpp_free_lifted_network(ParfactorList).
|
||||
|
||||
|
@ -9,12 +9,8 @@
|
||||
#include "Var.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class Var;
|
||||
|
||||
class BBNode : public Var
|
||||
{
|
||||
public:
|
||||
|
@ -20,10 +20,8 @@
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||
|
||||
|
||||
Params readParameters (YAP_Term);
|
||||
|
||||
vector<unsigned> readUnsignedList (YAP_Term);
|
||||
@ -32,14 +30,6 @@ void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
||||
|
||||
Parfactor* readParfactor (YAP_Term);
|
||||
|
||||
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
|
||||
void runBeliefProp (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
|
||||
|
||||
|
||||
|
||||
vector<unsigned>
|
||||
readUnsignedList (YAP_Term list)
|
||||
@ -54,7 +44,8 @@ readUnsignedList (YAP_Term list)
|
||||
|
||||
|
||||
|
||||
int createLiftedNetwork (void)
|
||||
int
|
||||
createLiftedNetwork (void)
|
||||
{
|
||||
Parfactors parfactors;
|
||||
YAP_Term parfactorList = YAP_ARG1;
|
||||
@ -91,7 +82,8 @@ int createLiftedNetwork (void)
|
||||
|
||||
|
||||
|
||||
Parfactor* readParfactor (YAP_Term pfTerm)
|
||||
Parfactor*
|
||||
readParfactor (YAP_Term pfTerm)
|
||||
{
|
||||
// read dist id
|
||||
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
||||
@ -171,7 +163,8 @@ Parfactor* readParfactor (YAP_Term pfTerm)
|
||||
|
||||
|
||||
|
||||
void readLiftedEvidence (
|
||||
void
|
||||
readLiftedEvidence (
|
||||
YAP_Term observedList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
@ -360,11 +353,42 @@ runGroundSolver (void)
|
||||
taskList = YAP_TailOfTerm (taskList);
|
||||
}
|
||||
|
||||
vector<Params> results;
|
||||
std::set<VarId> vids;
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
Util::addToSet (vids, tasks[i]);
|
||||
}
|
||||
Solver* solver = 0;
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->bayesianFactors()) {
|
||||
mfg = BayesBall::getMinimalFactorGraph (
|
||||
*fg, VarIds (vids.begin(), vids.end()));
|
||||
}
|
||||
|
||||
if (Globals::groundSolver == GroundSolver::VE) {
|
||||
runVeSolver (fg, tasks, results);
|
||||
solver = new VarElim (*mfg);
|
||||
} else if (Globals::groundSolver == GroundSolver::BP) {
|
||||
solver = new BeliefProp (*mfg);
|
||||
} else if (Globals::groundSolver == GroundSolver::CBP) {
|
||||
CountingBp::checkForIdenticalFactors = false;
|
||||
solver = new CountingBp (*mfg);
|
||||
} else {
|
||||
runBeliefProp (fg, tasks, results);
|
||||
assert (false);
|
||||
}
|
||||
|
||||
if (Globals::verbosity > 0) {
|
||||
solver->printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
vector<Params> results;
|
||||
results.reserve (tasks.size());
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
results.push_back (solver->solveQuery (tasks[i]));
|
||||
}
|
||||
|
||||
delete solver;
|
||||
if (fg->bayesianFactors()) {
|
||||
delete mfg;
|
||||
}
|
||||
|
||||
YAP_Term list = YAP_TermNil();
|
||||
@ -385,71 +409,6 @@ runGroundSolver (void)
|
||||
|
||||
|
||||
|
||||
void runVeSolver (
|
||||
FactorGraph* fg,
|
||||
const vector<VarIds>& tasks,
|
||||
vector<Params>& results)
|
||||
{
|
||||
results.reserve (tasks.size());
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->bayesianFactors()) {
|
||||
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
||||
}
|
||||
VarElim solver (*mfg);
|
||||
if (Globals::verbosity > 0 && i == 0) {
|
||||
solver.printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
results.push_back (solver.solveQuery (tasks[i]));
|
||||
if (fg->bayesianFactors()) {
|
||||
delete mfg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void runBeliefProp (
|
||||
FactorGraph* fg,
|
||||
const vector<VarIds>& tasks,
|
||||
vector<Params>& results)
|
||||
{
|
||||
std::set<VarId> vids;
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
Util::addToSet (vids, tasks[i]);
|
||||
}
|
||||
Solver* solver = 0;
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->bayesianFactors()) {
|
||||
mfg = BayesBall::getMinimalFactorGraph (
|
||||
*fg, VarIds (vids.begin(),vids.end()));
|
||||
}
|
||||
if (Globals::groundSolver == GroundSolver::BP) {
|
||||
solver = new BeliefProp (*mfg);
|
||||
} else if (Globals::groundSolver == GroundSolver::CBP) {
|
||||
CountingBp::checkForIdenticalFactors = false;
|
||||
solver = new CountingBp (*mfg);
|
||||
} else {
|
||||
cerr << "error: unknow solver" << endl;
|
||||
abort();
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
solver->printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
results.reserve (tasks.size());
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
results.push_back (solver->solveQuery (tasks[i]));
|
||||
}
|
||||
if (fg->bayesianFactors()) {
|
||||
delete mfg;
|
||||
}
|
||||
delete solver;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setParfactorsParams (void)
|
||||
{
|
||||
@ -565,7 +524,7 @@ freeGroundNetwork (void)
|
||||
|
||||
|
||||
int
|
||||
freeParfactors (void)
|
||||
freeLiftedNetwork (void)
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
delete network->first;
|
||||
@ -587,7 +546,7 @@ init_predicates (void)
|
||||
YAP_UserCPredicate ("cpp_cpp_set_factors_params", setFactorsParams, 2);
|
||||
YAP_UserCPredicate ("cpp_set_vars_information", setVarsInformation, 2);
|
||||
YAP_UserCPredicate ("cpp_set_horus_flag", setHorusFlag, 2);
|
||||
YAP_UserCPredicate ("cpp_free_parfactors", freeParfactors, 1);
|
||||
YAP_UserCPredicate ("cpp_free_lifted_network", freeLiftedNetwork, 1);
|
||||
YAP_UserCPredicate ("cpp_free_ground_network", freeGroundNetwork, 1);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user