Introduce a lifted solver class

This commit is contained in:
Tiago Gomes 2012-11-14 21:55:51 +00:00
parent 6e7d0d1d0a
commit 64a27847cc
17 changed files with 53 additions and 197 deletions

View File

@ -12,7 +12,7 @@
#include "Horus.h"
BeliefProp::BeliefProp (const FactorGraph& fg) : Solver (fg)
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
{
runned_ = false;
}
@ -377,7 +377,8 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
Params
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
{
return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds);
return GroundSolver::getJointByConditioning (
GroundSolverType::BP, fg, jointVarIds);
}

View File

@ -5,7 +5,7 @@
#include <vector>
#include <sstream>
#include "Solver.h"
#include "GroundSolver.h"
#include "Factor.h"
#include "FactorGraph.h"
#include "Util.h"
@ -83,7 +83,7 @@ class SPNodeInfo
};
class BeliefProp : public Solver
class BeliefProp : public GroundSolver
{
public:
BeliefProp (const FactorGraph&);

View File

@ -6,7 +6,7 @@ bool CountingBp::checkForIdenticalFactors = true;
CountingBp::CountingBp (const FactorGraph& fg)
: Solver (fg), freeColor_(0)
: GroundSolver (fg), freeColor_(0)
{
findIdenticalFactors();
setInitialColors();
@ -74,8 +74,8 @@ CountingBp::solveQuery (VarIds queryVids)
cout << endl;
}
if (idx == facNodes.size()) {
res = Solver::getJointByConditioning (
GroundSolver::CBP, fg, queryVids);
res = GroundSolver::getJointByConditioning (
GroundSolverType::CBP, fg, queryVids);
} else {
VarIds reprArgs;
for (size_t i = 0; i < queryVids.size(); i++) {

View File

@ -3,7 +3,7 @@
#include <unordered_map>
#include "Solver.h"
#include "GroundSolver.h"
#include "FactorGraph.h"
#include "Util.h"
#include "Horus.h"
@ -102,7 +102,7 @@ class FacCluster
};
class CountingBp : public Solver
class CountingBp : public GroundSolver
{
public:
CountingBp (const FactorGraph& fg);

View File

@ -28,7 +28,7 @@ typedef vector<unsigned> Ranges;
typedef unsigned long long ullong;
enum LiftedSolver
enum LiftedSolverType
{
LVE, // generalized counting first-order variable elimination (GC-FOVE)
LBP, // lifted first-order belief propagation
@ -36,7 +36,7 @@ enum LiftedSolver
};
enum GroundSolver
enum GroundSolverType
{
VE, // variable elimination
BP, // belief propagation
@ -51,8 +51,8 @@ extern bool logDomain;
// level of debug information
extern unsigned verbosity;
extern LiftedSolver liftedSolver;
extern GroundSolver groundSolver;
extern LiftedSolverType liftedSolver;
extern GroundSolverType groundSolver;
};

View File

@ -160,15 +160,15 @@ readQueryAndEvidence (
void
runSolver (const FactorGraph& fg, const VarIds& queryIds)
{
Solver* solver = 0;
GroundSolver* solver = 0;
switch (Globals::groundSolver) {
case GroundSolver::VE:
case GroundSolverType::VE:
solver = new VarElim (fg);
break;
case GroundSolver::BP:
case GroundSolverType::BP:
solver = new BeliefProp (fg);
break;
case GroundSolver::CBP:
case GroundSolverType::CBP:
solver = new CountingBp (fg);
break;
default:

View File

@ -308,21 +308,21 @@ runLiftedSolver (void)
}
jointList = YAP_TailOfTerm (jointList);
}
if (Globals::liftedSolver == LiftedSolver::LVE) {
if (Globals::liftedSolver == LiftedSolverType::LVE) {
LiftedVe solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (queryVars));
} else if (Globals::liftedSolver == LiftedSolver::LBP) {
} else if (Globals::liftedSolver == LiftedSolverType::LBP) {
LiftedBp solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (queryVars));
} else if (Globals::liftedSolver == LiftedSolver::LKC) {
} else if (Globals::liftedSolver == LiftedSolverType::LKC) {
LiftedKc solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
@ -369,18 +369,18 @@ runGroundSolver (void)
for (size_t i = 0; i < tasks.size(); i++) {
Util::addToSet (vids, tasks[i]);
}
Solver* solver = 0;
GroundSolver* solver = 0;
FactorGraph* mfg = fg;
if (fg->bayesianFactors()) {
mfg = BayesBall::getMinimalFactorGraph (
*fg, VarIds (vids.begin(), vids.end()));
}
if (Globals::groundSolver == GroundSolver::VE) {
if (Globals::groundSolver == GroundSolverType::VE) {
solver = new VarElim (*mfg);
} else if (Globals::groundSolver == GroundSolver::BP) {
} else if (Globals::groundSolver == GroundSolverType::BP) {
solver = new BeliefProp (*mfg);
} else if (Globals::groundSolver == GroundSolver::CBP) {
} else if (Globals::groundSolver == GroundSolverType::CBP) {
CountingBp::checkForIdenticalFactors = false;
solver = new CountingBp (*mfg);
} else {

View File

@ -5,7 +5,7 @@
LiftedBp::LiftedBp (const ParfactorList& pfList)
: pfList_(pfList)
: LiftedSolver (pfList), pfList_(pfList)
{
refineParfactors();
createFactorGraph();

View File

@ -1,12 +1,13 @@
#ifndef HORUS_LIFTEDBP_H
#define HORUS_LIFTEDBP_H
#include "LiftedSolver.h"
#include "ParfactorList.h"
class FactorGraph;
class WeightedBp;
class LiftedBp
class LiftedBp : public LiftedSolver
{
public:
LiftedBp (const ParfactorList& pfList);

View File

@ -5,14 +5,6 @@
#include "Indexer.h"
LiftedKc::LiftedKc (const ParfactorList& pfList)
: pfList_(pfList)
{
}
LiftedKc::~LiftedKc (void)
{
delete lwcnf_;

View File

@ -1,15 +1,18 @@
#ifndef HORUS_LIFTEDKC_H
#define HORUS_LIFTEDKC_H
#include "LiftedSolver.h"
#include "ParfactorList.h"
class LiftedWCNF;
class LiftedCircuit;
class LiftedKc
class LiftedKc : public LiftedSolver
{
public:
LiftedKc (const ParfactorList& pfList);
LiftedKc (const ParfactorList& pfList)
: LiftedSolver(pfList), pfList_(pfList) { }
~LiftedKc (void);

View File

@ -1,7 +1,7 @@
#ifndef HORUS_LIFTEDVE_H
#define HORUS_LIFTEDVE_H
#include "LiftedSolver.h"
#include "ParfactorList.h"
@ -132,10 +132,11 @@ class GroundOperator : public LiftedOperator
class LiftedVe
class LiftedVe : public LiftedSolver
{
public:
LiftedVe (const ParfactorList& pfList) : pfList_(pfList) { }
LiftedVe (const ParfactorList& pfList)
: LiftedSolver(pfList), pfList_(pfList) { }
Params solveQuery (const Grounds&);

View File

@ -60,13 +60,14 @@ HEADERS = \
$(srcdir)/LiftedCircuit.h \
$(srcdir)/LiftedKc.h \
$(srcdir)/LiftedOperations.h \
$(srcdir)/LiftedSolver.h \
$(srcdir)/LiftedUtils.h \
$(srcdir)/LiftedVe.h \
$(srcdir)/LiftedWCNF.h \
$(srcdir)/Parfactor.h \
$(srcdir)/ParfactorList.h \
$(srcdir)/ProbFormula.h \
$(srcdir)/Solver.h \
$(srcdir)/GroundSolver.h \
$(srcdir)/TinySet.h \
$(srcdir)/Util.h \
$(srcdir)/Var.h \
@ -95,7 +96,7 @@ CPP_SOURCES = \
$(srcdir)/Parfactor.cpp \
$(srcdir)/ParfactorList.cpp \
$(srcdir)/ProbFormula.cpp \
$(srcdir)/Solver.cpp \
$(srcdir)/GroundSolver.cpp \
$(srcdir)/Util.cpp \
$(srcdir)/Var.cpp \
$(srcdir)/VarElim.cpp \
@ -122,7 +123,7 @@ OBJS = \
ProbFormula.o \
Parfactor.o \
ParfactorList.o \
Solver.o \
GroundSolver.o \
Util.o \
Var.o \
VarElim.o \
@ -137,7 +138,7 @@ HCLI_OBJS = \
Factor.o \
FactorGraph.o \
HorusCli.o \
Solver.o \
GroundSolver.o \
Util.o \
Var.o \
VarElim.o \

View File

@ -1,107 +0,0 @@
#include "Solver.h"
#include "Util.h"
#include "BeliefProp.h"
#include "CountingBp.h"
#include "VarElim.h"
void
Solver::printAnswer (const VarIds& vids)
{
Vars unobservedVars;
VarIds unobservedVids;
for (size_t i = 0; i < vids.size(); i++) {
VarNode* vn = fg.getVarNode (vids[i]);
if (vn->hasEvidence() == false) {
unobservedVars.push_back (vn);
unobservedVids.push_back (vids[i]);
}
}
if (unobservedVids.empty() == false) {
Params res = solveQuery (unobservedVids);
vector<string> stateLines = Util::getStateLines (unobservedVars);
for (size_t i = 0; i < res.size(); i++) {
cout << "P(" << stateLines[i] << ") = " ;
cout << std::setprecision (Constants::PRECISION) << res[i];
cout << endl;
}
cout << endl;
}
}
void
Solver::printAllPosterioris (void)
{
VarNodes vars = fg.varNodes();
std::sort (vars.begin(), vars.end(), sortByVarId());
for (size_t i = 0; i < vars.size(); i++) {
printAnswer ({vars[i]->varId()});
}
}
Params
Solver::getJointByConditioning (
GroundSolver solverType,
FactorGraph fg,
const VarIds& jointVarIds) const
{
VarNodes jointVars;
for (size_t i = 0; i < jointVarIds.size(); i++) {
assert (fg.getVarNode (jointVarIds[i]));
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
}
Solver* solver = 0;
switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break;
}
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
VarIds observedVids = {jointVars[0]->varId()};
for (size_t i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs;
Vars observedVars;
Ranges observedRanges;
for (size_t j = 0; j < observedVids.size(); j++) {
observedVars.push_back (fg.getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range());
}
Indexer indexer (observedRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]);
}
delete solver;
switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break;
}
Params beliefs = solver->solveQuery ({jointVarIds[i]});
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->range() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
}
delete solver;
return prevBeliefs;
}

View File

@ -1,36 +0,0 @@
#ifndef HORUS_SOLVER_H
#define HORUS_SOLVER_H
#include <iomanip>
#include "FactorGraph.h"
#include "Var.h"
#include "Horus.h"
using namespace std;
class Solver
{
public:
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
virtual ~Solver() { } // ensure that subclass destructor is called
virtual Params solveQuery (VarIds queryVids) = 0;
virtual void printSolverFlags (void) const = 0;
void printAnswer (const VarIds& vids);
void printAllPosterioris (void);
Params getJointByConditioning (GroundSolver,
FactorGraph, const VarIds& jointVarIds) const;
protected:
const FactorGraph& fg;
};
#endif // HORUS_SOLVER_H

View File

@ -13,9 +13,9 @@ bool logDomain = false;
unsigned verbosity = 0;
LiftedSolver liftedSolver = LiftedSolver::LVE;
LiftedSolverType liftedSolver = LiftedSolverType::LVE;
GroundSolver groundSolver = GroundSolver::VE;
GroundSolverType groundSolver = GroundSolverType::VE;
};
@ -211,11 +211,11 @@ setHorusFlag (string key, string value)
ss >> Globals::verbosity;
} else if (key == "lifted_solver") {
if ( value == "lve") {
Globals::liftedSolver = LiftedSolver::LVE;
Globals::liftedSolver = LiftedSolverType::LVE;
} else if (value == "lbp") {
Globals::liftedSolver = LiftedSolver::LBP;
Globals::liftedSolver = LiftedSolverType::LBP;
} else if (value == "lkc") {
Globals::liftedSolver = LiftedSolver::LKC;
Globals::liftedSolver = LiftedSolverType::LKC;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
@ -223,11 +223,11 @@ setHorusFlag (string key, string value)
}
} else if (key == "ground_solver") {
if ( value == "ve") {
Globals::groundSolver = GroundSolver::VE;
Globals::groundSolver = GroundSolverType::VE;
} else if (value == "bp") {
Globals::groundSolver = GroundSolver::BP;
Globals::groundSolver = GroundSolverType::BP;
} else if (value == "cbp") {
Globals::groundSolver = GroundSolver::CBP;
Globals::groundSolver = GroundSolverType::CBP;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;

View File

@ -3,7 +3,7 @@
#include "unordered_map"
#include "Solver.h"
#include "GroundSolver.h"
#include "FactorGraph.h"
#include "Horus.h"
@ -11,10 +11,10 @@
using namespace std;
class VarElim : public Solver
class VarElim : public GroundSolver
{
public:
VarElim (const FactorGraph& fg) : Solver (fg) { }
VarElim (const FactorGraph& fg) : GroundSolver (fg) { }
~VarElim (void);