Introduce a lifted solver class
This commit is contained in:
parent
6e7d0d1d0a
commit
64a27847cc
@ -12,7 +12,7 @@
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
BeliefProp::BeliefProp (const FactorGraph& fg) : Solver (fg)
|
||||
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
|
||||
{
|
||||
runned_ = false;
|
||||
}
|
||||
@ -377,7 +377,8 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||
Params
|
||||
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds);
|
||||
return GroundSolver::getJointByConditioning (
|
||||
GroundSolverType::BP, fg, jointVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "GroundSolver.h"
|
||||
#include "Factor.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Util.h"
|
||||
@ -83,7 +83,7 @@ class SPNodeInfo
|
||||
};
|
||||
|
||||
|
||||
class BeliefProp : public Solver
|
||||
class BeliefProp : public GroundSolver
|
||||
{
|
||||
public:
|
||||
BeliefProp (const FactorGraph&);
|
||||
|
@ -6,7 +6,7 @@ bool CountingBp::checkForIdenticalFactors = true;
|
||||
|
||||
|
||||
CountingBp::CountingBp (const FactorGraph& fg)
|
||||
: Solver (fg), freeColor_(0)
|
||||
: GroundSolver (fg), freeColor_(0)
|
||||
{
|
||||
findIdenticalFactors();
|
||||
setInitialColors();
|
||||
@ -74,8 +74,8 @@ CountingBp::solveQuery (VarIds queryVids)
|
||||
cout << endl;
|
||||
}
|
||||
if (idx == facNodes.size()) {
|
||||
res = Solver::getJointByConditioning (
|
||||
GroundSolver::CBP, fg, queryVids);
|
||||
res = GroundSolver::getJointByConditioning (
|
||||
GroundSolverType::CBP, fg, queryVids);
|
||||
} else {
|
||||
VarIds reprArgs;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "GroundSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Util.h"
|
||||
#include "Horus.h"
|
||||
@ -102,7 +102,7 @@ class FacCluster
|
||||
};
|
||||
|
||||
|
||||
class CountingBp : public Solver
|
||||
class CountingBp : public GroundSolver
|
||||
{
|
||||
public:
|
||||
CountingBp (const FactorGraph& fg);
|
||||
|
@ -28,7 +28,7 @@ typedef vector<unsigned> Ranges;
|
||||
typedef unsigned long long ullong;
|
||||
|
||||
|
||||
enum LiftedSolver
|
||||
enum LiftedSolverType
|
||||
{
|
||||
LVE, // generalized counting first-order variable elimination (GC-FOVE)
|
||||
LBP, // lifted first-order belief propagation
|
||||
@ -36,7 +36,7 @@ enum LiftedSolver
|
||||
};
|
||||
|
||||
|
||||
enum GroundSolver
|
||||
enum GroundSolverType
|
||||
{
|
||||
VE, // variable elimination
|
||||
BP, // belief propagation
|
||||
@ -51,8 +51,8 @@ extern bool logDomain;
|
||||
// level of debug information
|
||||
extern unsigned verbosity;
|
||||
|
||||
extern LiftedSolver liftedSolver;
|
||||
extern GroundSolver groundSolver;
|
||||
extern LiftedSolverType liftedSolver;
|
||||
extern GroundSolverType groundSolver;
|
||||
|
||||
};
|
||||
|
||||
|
@ -160,15 +160,15 @@ readQueryAndEvidence (
|
||||
void
|
||||
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||
{
|
||||
Solver* solver = 0;
|
||||
GroundSolver* solver = 0;
|
||||
switch (Globals::groundSolver) {
|
||||
case GroundSolver::VE:
|
||||
case GroundSolverType::VE:
|
||||
solver = new VarElim (fg);
|
||||
break;
|
||||
case GroundSolver::BP:
|
||||
case GroundSolverType::BP:
|
||||
solver = new BeliefProp (fg);
|
||||
break;
|
||||
case GroundSolver::CBP:
|
||||
case GroundSolverType::CBP:
|
||||
solver = new CountingBp (fg);
|
||||
break;
|
||||
default:
|
||||
|
@ -308,21 +308,21 @@ runLiftedSolver (void)
|
||||
}
|
||||
jointList = YAP_TailOfTerm (jointList);
|
||||
}
|
||||
if (Globals::liftedSolver == LiftedSolver::LVE) {
|
||||
if (Globals::liftedSolver == LiftedSolverType::LVE) {
|
||||
LiftedVe solver (pfListCopy);
|
||||
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||
solver.printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
results.push_back (solver.solveQuery (queryVars));
|
||||
} else if (Globals::liftedSolver == LiftedSolver::LBP) {
|
||||
} else if (Globals::liftedSolver == LiftedSolverType::LBP) {
|
||||
LiftedBp solver (pfListCopy);
|
||||
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||
solver.printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
results.push_back (solver.solveQuery (queryVars));
|
||||
} else if (Globals::liftedSolver == LiftedSolver::LKC) {
|
||||
} else if (Globals::liftedSolver == LiftedSolverType::LKC) {
|
||||
LiftedKc solver (pfListCopy);
|
||||
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||
solver.printSolverFlags();
|
||||
@ -369,18 +369,18 @@ runGroundSolver (void)
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
Util::addToSet (vids, tasks[i]);
|
||||
}
|
||||
Solver* solver = 0;
|
||||
GroundSolver* solver = 0;
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->bayesianFactors()) {
|
||||
mfg = BayesBall::getMinimalFactorGraph (
|
||||
*fg, VarIds (vids.begin(), vids.end()));
|
||||
}
|
||||
|
||||
if (Globals::groundSolver == GroundSolver::VE) {
|
||||
if (Globals::groundSolver == GroundSolverType::VE) {
|
||||
solver = new VarElim (*mfg);
|
||||
} else if (Globals::groundSolver == GroundSolver::BP) {
|
||||
} else if (Globals::groundSolver == GroundSolverType::BP) {
|
||||
solver = new BeliefProp (*mfg);
|
||||
} else if (Globals::groundSolver == GroundSolver::CBP) {
|
||||
} else if (Globals::groundSolver == GroundSolverType::CBP) {
|
||||
CountingBp::checkForIdenticalFactors = false;
|
||||
solver = new CountingBp (*mfg);
|
||||
} else {
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
|
||||
LiftedBp::LiftedBp (const ParfactorList& pfList)
|
||||
: pfList_(pfList)
|
||||
: LiftedSolver (pfList), pfList_(pfList)
|
||||
{
|
||||
refineParfactors();
|
||||
createFactorGraph();
|
||||
|
@ -1,12 +1,13 @@
|
||||
#ifndef HORUS_LIFTEDBP_H
|
||||
#define HORUS_LIFTEDBP_H
|
||||
|
||||
#include "LiftedSolver.h"
|
||||
#include "ParfactorList.h"
|
||||
|
||||
class FactorGraph;
|
||||
class WeightedBp;
|
||||
|
||||
class LiftedBp
|
||||
class LiftedBp : public LiftedSolver
|
||||
{
|
||||
public:
|
||||
LiftedBp (const ParfactorList& pfList);
|
||||
|
@ -5,14 +5,6 @@
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
LiftedKc::LiftedKc (const ParfactorList& pfList)
|
||||
: pfList_(pfList)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedKc::~LiftedKc (void)
|
||||
{
|
||||
delete lwcnf_;
|
||||
|
@ -1,15 +1,18 @@
|
||||
#ifndef HORUS_LIFTEDKC_H
|
||||
#define HORUS_LIFTEDKC_H
|
||||
|
||||
#include "LiftedSolver.h"
|
||||
#include "ParfactorList.h"
|
||||
|
||||
class LiftedWCNF;
|
||||
class LiftedCircuit;
|
||||
|
||||
class LiftedKc
|
||||
|
||||
class LiftedKc : public LiftedSolver
|
||||
{
|
||||
public:
|
||||
LiftedKc (const ParfactorList& pfList);
|
||||
LiftedKc (const ParfactorList& pfList)
|
||||
: LiftedSolver(pfList), pfList_(pfList) { }
|
||||
|
||||
~LiftedKc (void);
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
#ifndef HORUS_LIFTEDVE_H
|
||||
#define HORUS_LIFTEDVE_H
|
||||
|
||||
|
||||
#include "LiftedSolver.h"
|
||||
#include "ParfactorList.h"
|
||||
|
||||
|
||||
@ -132,10 +132,11 @@ class GroundOperator : public LiftedOperator
|
||||
|
||||
|
||||
|
||||
class LiftedVe
|
||||
class LiftedVe : public LiftedSolver
|
||||
{
|
||||
public:
|
||||
LiftedVe (const ParfactorList& pfList) : pfList_(pfList) { }
|
||||
LiftedVe (const ParfactorList& pfList)
|
||||
: LiftedSolver(pfList), pfList_(pfList) { }
|
||||
|
||||
Params solveQuery (const Grounds&);
|
||||
|
||||
|
@ -60,13 +60,14 @@ HEADERS = \
|
||||
$(srcdir)/LiftedCircuit.h \
|
||||
$(srcdir)/LiftedKc.h \
|
||||
$(srcdir)/LiftedOperations.h \
|
||||
$(srcdir)/LiftedSolver.h \
|
||||
$(srcdir)/LiftedUtils.h \
|
||||
$(srcdir)/LiftedVe.h \
|
||||
$(srcdir)/LiftedWCNF.h \
|
||||
$(srcdir)/Parfactor.h \
|
||||
$(srcdir)/ParfactorList.h \
|
||||
$(srcdir)/ProbFormula.h \
|
||||
$(srcdir)/Solver.h \
|
||||
$(srcdir)/GroundSolver.h \
|
||||
$(srcdir)/TinySet.h \
|
||||
$(srcdir)/Util.h \
|
||||
$(srcdir)/Var.h \
|
||||
@ -95,7 +96,7 @@ CPP_SOURCES = \
|
||||
$(srcdir)/Parfactor.cpp \
|
||||
$(srcdir)/ParfactorList.cpp \
|
||||
$(srcdir)/ProbFormula.cpp \
|
||||
$(srcdir)/Solver.cpp \
|
||||
$(srcdir)/GroundSolver.cpp \
|
||||
$(srcdir)/Util.cpp \
|
||||
$(srcdir)/Var.cpp \
|
||||
$(srcdir)/VarElim.cpp \
|
||||
@ -122,7 +123,7 @@ OBJS = \
|
||||
ProbFormula.o \
|
||||
Parfactor.o \
|
||||
ParfactorList.o \
|
||||
Solver.o \
|
||||
GroundSolver.o \
|
||||
Util.o \
|
||||
Var.o \
|
||||
VarElim.o \
|
||||
@ -137,7 +138,7 @@ HCLI_OBJS = \
|
||||
Factor.o \
|
||||
FactorGraph.o \
|
||||
HorusCli.o \
|
||||
Solver.o \
|
||||
GroundSolver.o \
|
||||
Util.o \
|
||||
Var.o \
|
||||
VarElim.o \
|
||||
|
@ -1,107 +0,0 @@
|
||||
#include "Solver.h"
|
||||
#include "Util.h"
|
||||
#include "BeliefProp.h"
|
||||
#include "CountingBp.h"
|
||||
#include "VarElim.h"
|
||||
|
||||
|
||||
void
|
||||
Solver::printAnswer (const VarIds& vids)
|
||||
{
|
||||
Vars unobservedVars;
|
||||
VarIds unobservedVids;
|
||||
for (size_t i = 0; i < vids.size(); i++) {
|
||||
VarNode* vn = fg.getVarNode (vids[i]);
|
||||
if (vn->hasEvidence() == false) {
|
||||
unobservedVars.push_back (vn);
|
||||
unobservedVids.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
if (unobservedVids.empty() == false) {
|
||||
Params res = solveQuery (unobservedVids);
|
||||
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
||||
for (size_t i = 0; i < res.size(); i++) {
|
||||
cout << "P(" << stateLines[i] << ") = " ;
|
||||
cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Solver::printAllPosterioris (void)
|
||||
{
|
||||
VarNodes vars = fg.varNodes();
|
||||
std::sort (vars.begin(), vars.end(), sortByVarId());
|
||||
for (size_t i = 0; i < vars.size(); i++) {
|
||||
printAnswer ({vars[i]->varId()});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
Solver::getJointByConditioning (
|
||||
GroundSolver solverType,
|
||||
FactorGraph fg,
|
||||
const VarIds& jointVarIds) const
|
||||
{
|
||||
VarNodes jointVars;
|
||||
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
||||
assert (fg.getVarNode (jointVarIds[i]));
|
||||
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
Solver* solver = 0;
|
||||
switch (solverType) {
|
||||
case GroundSolver::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolver::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolver::VE: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
|
||||
for (size_t i = 1; i < jointVarIds.size(); i++) {
|
||||
assert (jointVars[i]->hasEvidence() == false);
|
||||
Params newBeliefs;
|
||||
Vars observedVars;
|
||||
Ranges observedRanges;
|
||||
for (size_t j = 0; j < observedVids.size(); j++) {
|
||||
observedVars.push_back (fg.getVarNode (observedVids[j]));
|
||||
observedRanges.push_back (observedVars.back()->range());
|
||||
}
|
||||
Indexer indexer (observedRanges, false);
|
||||
while (indexer.valid()) {
|
||||
for (size_t j = 0; j < observedVars.size(); j++) {
|
||||
observedVars[j]->setEvidence (indexer[j]);
|
||||
}
|
||||
delete solver;
|
||||
switch (solverType) {
|
||||
case GroundSolver::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolver::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolver::VE: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
|
||||
int count = -1;
|
||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->range() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
observedVids.push_back (jointVars[i]->varId());
|
||||
}
|
||||
delete solver;
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
@ -1,36 +0,0 @@
|
||||
#ifndef HORUS_SOLVER_H
|
||||
#define HORUS_SOLVER_H
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "Var.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
class Solver
|
||||
{
|
||||
public:
|
||||
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
|
||||
|
||||
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||
|
||||
virtual Params solveQuery (VarIds queryVids) = 0;
|
||||
|
||||
virtual void printSolverFlags (void) const = 0;
|
||||
|
||||
void printAnswer (const VarIds& vids);
|
||||
|
||||
void printAllPosterioris (void);
|
||||
|
||||
Params getJointByConditioning (GroundSolver,
|
||||
FactorGraph, const VarIds& jointVarIds) const;
|
||||
|
||||
protected:
|
||||
const FactorGraph& fg;
|
||||
};
|
||||
|
||||
#endif // HORUS_SOLVER_H
|
||||
|
@ -13,9 +13,9 @@ bool logDomain = false;
|
||||
|
||||
unsigned verbosity = 0;
|
||||
|
||||
LiftedSolver liftedSolver = LiftedSolver::LVE;
|
||||
LiftedSolverType liftedSolver = LiftedSolverType::LVE;
|
||||
|
||||
GroundSolver groundSolver = GroundSolver::VE;
|
||||
GroundSolverType groundSolver = GroundSolverType::VE;
|
||||
|
||||
};
|
||||
|
||||
@ -211,11 +211,11 @@ setHorusFlag (string key, string value)
|
||||
ss >> Globals::verbosity;
|
||||
} else if (key == "lifted_solver") {
|
||||
if ( value == "lve") {
|
||||
Globals::liftedSolver = LiftedSolver::LVE;
|
||||
Globals::liftedSolver = LiftedSolverType::LVE;
|
||||
} else if (value == "lbp") {
|
||||
Globals::liftedSolver = LiftedSolver::LBP;
|
||||
Globals::liftedSolver = LiftedSolverType::LBP;
|
||||
} else if (value == "lkc") {
|
||||
Globals::liftedSolver = LiftedSolver::LKC;
|
||||
Globals::liftedSolver = LiftedSolverType::LKC;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
@ -223,11 +223,11 @@ setHorusFlag (string key, string value)
|
||||
}
|
||||
} else if (key == "ground_solver") {
|
||||
if ( value == "ve") {
|
||||
Globals::groundSolver = GroundSolver::VE;
|
||||
Globals::groundSolver = GroundSolverType::VE;
|
||||
} else if (value == "bp") {
|
||||
Globals::groundSolver = GroundSolver::BP;
|
||||
Globals::groundSolver = GroundSolverType::BP;
|
||||
} else if (value == "cbp") {
|
||||
Globals::groundSolver = GroundSolver::CBP;
|
||||
Globals::groundSolver = GroundSolverType::CBP;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
#include "unordered_map"
|
||||
|
||||
#include "Solver.h"
|
||||
#include "GroundSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
@ -11,10 +11,10 @@
|
||||
using namespace std;
|
||||
|
||||
|
||||
class VarElim : public Solver
|
||||
class VarElim : public GroundSolver
|
||||
{
|
||||
public:
|
||||
VarElim (const FactorGraph& fg) : Solver (fg) { }
|
||||
VarElim (const FactorGraph& fg) : GroundSolver (fg) { }
|
||||
|
||||
~VarElim (void);
|
||||
|
||||
|
Reference in New Issue
Block a user