Introduce a lifted solver class

This commit is contained in:
Tiago Gomes 2012-11-14 21:55:51 +00:00
parent 6e7d0d1d0a
commit 64a27847cc
17 changed files with 53 additions and 197 deletions

View File

@ -12,7 +12,7 @@
#include "Horus.h" #include "Horus.h"
BeliefProp::BeliefProp (const FactorGraph& fg) : Solver (fg) BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
{ {
runned_ = false; runned_ = false;
} }
@ -377,7 +377,8 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
Params Params
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
{ {
return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds); return GroundSolver::getJointByConditioning (
GroundSolverType::BP, fg, jointVarIds);
} }

View File

@ -5,7 +5,7 @@
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include "Solver.h" #include "GroundSolver.h"
#include "Factor.h" #include "Factor.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Util.h" #include "Util.h"
@ -83,7 +83,7 @@ class SPNodeInfo
}; };
class BeliefProp : public Solver class BeliefProp : public GroundSolver
{ {
public: public:
BeliefProp (const FactorGraph&); BeliefProp (const FactorGraph&);

View File

@ -6,7 +6,7 @@ bool CountingBp::checkForIdenticalFactors = true;
CountingBp::CountingBp (const FactorGraph& fg) CountingBp::CountingBp (const FactorGraph& fg)
: Solver (fg), freeColor_(0) : GroundSolver (fg), freeColor_(0)
{ {
findIdenticalFactors(); findIdenticalFactors();
setInitialColors(); setInitialColors();
@ -74,8 +74,8 @@ CountingBp::solveQuery (VarIds queryVids)
cout << endl; cout << endl;
} }
if (idx == facNodes.size()) { if (idx == facNodes.size()) {
res = Solver::getJointByConditioning ( res = GroundSolver::getJointByConditioning (
GroundSolver::CBP, fg, queryVids); GroundSolverType::CBP, fg, queryVids);
} else { } else {
VarIds reprArgs; VarIds reprArgs;
for (size_t i = 0; i < queryVids.size(); i++) { for (size_t i = 0; i < queryVids.size(); i++) {

View File

@ -3,7 +3,7 @@
#include <unordered_map> #include <unordered_map>
#include "Solver.h" #include "GroundSolver.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Util.h" #include "Util.h"
#include "Horus.h" #include "Horus.h"
@ -102,7 +102,7 @@ class FacCluster
}; };
class CountingBp : public Solver class CountingBp : public GroundSolver
{ {
public: public:
CountingBp (const FactorGraph& fg); CountingBp (const FactorGraph& fg);

View File

@ -28,7 +28,7 @@ typedef vector<unsigned> Ranges;
typedef unsigned long long ullong; typedef unsigned long long ullong;
enum LiftedSolver enum LiftedSolverType
{ {
LVE, // generalized counting first-order variable elimination (GC-FOVE) LVE, // generalized counting first-order variable elimination (GC-FOVE)
LBP, // lifted first-order belief propagation LBP, // lifted first-order belief propagation
@ -36,7 +36,7 @@ enum LiftedSolver
}; };
enum GroundSolver enum GroundSolverType
{ {
VE, // variable elimination VE, // variable elimination
BP, // belief propagation BP, // belief propagation
@ -51,8 +51,8 @@ extern bool logDomain;
// level of debug information // level of debug information
extern unsigned verbosity; extern unsigned verbosity;
extern LiftedSolver liftedSolver; extern LiftedSolverType liftedSolver;
extern GroundSolver groundSolver; extern GroundSolverType groundSolver;
}; };

View File

@ -160,15 +160,15 @@ readQueryAndEvidence (
void void
runSolver (const FactorGraph& fg, const VarIds& queryIds) runSolver (const FactorGraph& fg, const VarIds& queryIds)
{ {
Solver* solver = 0; GroundSolver* solver = 0;
switch (Globals::groundSolver) { switch (Globals::groundSolver) {
case GroundSolver::VE: case GroundSolverType::VE:
solver = new VarElim (fg); solver = new VarElim (fg);
break; break;
case GroundSolver::BP: case GroundSolverType::BP:
solver = new BeliefProp (fg); solver = new BeliefProp (fg);
break; break;
case GroundSolver::CBP: case GroundSolverType::CBP:
solver = new CountingBp (fg); solver = new CountingBp (fg);
break; break;
default: default:

View File

@ -308,21 +308,21 @@ runLiftedSolver (void)
} }
jointList = YAP_TailOfTerm (jointList); jointList = YAP_TailOfTerm (jointList);
} }
if (Globals::liftedSolver == LiftedSolver::LVE) { if (Globals::liftedSolver == LiftedSolverType::LVE) {
LiftedVe solver (pfListCopy); LiftedVe solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) { if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags(); solver.printSolverFlags();
cout << endl; cout << endl;
} }
results.push_back (solver.solveQuery (queryVars)); results.push_back (solver.solveQuery (queryVars));
} else if (Globals::liftedSolver == LiftedSolver::LBP) { } else if (Globals::liftedSolver == LiftedSolverType::LBP) {
LiftedBp solver (pfListCopy); LiftedBp solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) { if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags(); solver.printSolverFlags();
cout << endl; cout << endl;
} }
results.push_back (solver.solveQuery (queryVars)); results.push_back (solver.solveQuery (queryVars));
} else if (Globals::liftedSolver == LiftedSolver::LKC) { } else if (Globals::liftedSolver == LiftedSolverType::LKC) {
LiftedKc solver (pfListCopy); LiftedKc solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) { if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags(); solver.printSolverFlags();
@ -369,18 +369,18 @@ runGroundSolver (void)
for (size_t i = 0; i < tasks.size(); i++) { for (size_t i = 0; i < tasks.size(); i++) {
Util::addToSet (vids, tasks[i]); Util::addToSet (vids, tasks[i]);
} }
Solver* solver = 0; GroundSolver* solver = 0;
FactorGraph* mfg = fg; FactorGraph* mfg = fg;
if (fg->bayesianFactors()) { if (fg->bayesianFactors()) {
mfg = BayesBall::getMinimalFactorGraph ( mfg = BayesBall::getMinimalFactorGraph (
*fg, VarIds (vids.begin(), vids.end())); *fg, VarIds (vids.begin(), vids.end()));
} }
if (Globals::groundSolver == GroundSolver::VE) { if (Globals::groundSolver == GroundSolverType::VE) {
solver = new VarElim (*mfg); solver = new VarElim (*mfg);
} else if (Globals::groundSolver == GroundSolver::BP) { } else if (Globals::groundSolver == GroundSolverType::BP) {
solver = new BeliefProp (*mfg); solver = new BeliefProp (*mfg);
} else if (Globals::groundSolver == GroundSolver::CBP) { } else if (Globals::groundSolver == GroundSolverType::CBP) {
CountingBp::checkForIdenticalFactors = false; CountingBp::checkForIdenticalFactors = false;
solver = new CountingBp (*mfg); solver = new CountingBp (*mfg);
} else { } else {

View File

@ -5,7 +5,7 @@
LiftedBp::LiftedBp (const ParfactorList& pfList) LiftedBp::LiftedBp (const ParfactorList& pfList)
: pfList_(pfList) : LiftedSolver (pfList), pfList_(pfList)
{ {
refineParfactors(); refineParfactors();
createFactorGraph(); createFactorGraph();

View File

@ -1,12 +1,13 @@
#ifndef HORUS_LIFTEDBP_H #ifndef HORUS_LIFTEDBP_H
#define HORUS_LIFTEDBP_H #define HORUS_LIFTEDBP_H
#include "LiftedSolver.h"
#include "ParfactorList.h" #include "ParfactorList.h"
class FactorGraph; class FactorGraph;
class WeightedBp; class WeightedBp;
class LiftedBp class LiftedBp : public LiftedSolver
{ {
public: public:
LiftedBp (const ParfactorList& pfList); LiftedBp (const ParfactorList& pfList);

View File

@ -5,14 +5,6 @@
#include "Indexer.h" #include "Indexer.h"
LiftedKc::LiftedKc (const ParfactorList& pfList)
: pfList_(pfList)
{
}
LiftedKc::~LiftedKc (void) LiftedKc::~LiftedKc (void)
{ {
delete lwcnf_; delete lwcnf_;

View File

@ -1,15 +1,18 @@
#ifndef HORUS_LIFTEDKC_H #ifndef HORUS_LIFTEDKC_H
#define HORUS_LIFTEDKC_H #define HORUS_LIFTEDKC_H
#include "LiftedSolver.h"
#include "ParfactorList.h" #include "ParfactorList.h"
class LiftedWCNF; class LiftedWCNF;
class LiftedCircuit; class LiftedCircuit;
class LiftedKc
class LiftedKc : public LiftedSolver
{ {
public: public:
LiftedKc (const ParfactorList& pfList); LiftedKc (const ParfactorList& pfList)
: LiftedSolver(pfList), pfList_(pfList) { }
~LiftedKc (void); ~LiftedKc (void);

View File

@ -1,7 +1,7 @@
#ifndef HORUS_LIFTEDVE_H #ifndef HORUS_LIFTEDVE_H
#define HORUS_LIFTEDVE_H #define HORUS_LIFTEDVE_H
#include "LiftedSolver.h"
#include "ParfactorList.h" #include "ParfactorList.h"
@ -132,10 +132,11 @@ class GroundOperator : public LiftedOperator
class LiftedVe class LiftedVe : public LiftedSolver
{ {
public: public:
LiftedVe (const ParfactorList& pfList) : pfList_(pfList) { } LiftedVe (const ParfactorList& pfList)
: LiftedSolver(pfList), pfList_(pfList) { }
Params solveQuery (const Grounds&); Params solveQuery (const Grounds&);

View File

@ -60,13 +60,14 @@ HEADERS = \
$(srcdir)/LiftedCircuit.h \ $(srcdir)/LiftedCircuit.h \
$(srcdir)/LiftedKc.h \ $(srcdir)/LiftedKc.h \
$(srcdir)/LiftedOperations.h \ $(srcdir)/LiftedOperations.h \
$(srcdir)/LiftedSolver.h \
$(srcdir)/LiftedUtils.h \ $(srcdir)/LiftedUtils.h \
$(srcdir)/LiftedVe.h \ $(srcdir)/LiftedVe.h \
$(srcdir)/LiftedWCNF.h \ $(srcdir)/LiftedWCNF.h \
$(srcdir)/Parfactor.h \ $(srcdir)/Parfactor.h \
$(srcdir)/ParfactorList.h \ $(srcdir)/ParfactorList.h \
$(srcdir)/ProbFormula.h \ $(srcdir)/ProbFormula.h \
$(srcdir)/Solver.h \ $(srcdir)/GroundSolver.h \
$(srcdir)/TinySet.h \ $(srcdir)/TinySet.h \
$(srcdir)/Util.h \ $(srcdir)/Util.h \
$(srcdir)/Var.h \ $(srcdir)/Var.h \
@ -95,7 +96,7 @@ CPP_SOURCES = \
$(srcdir)/Parfactor.cpp \ $(srcdir)/Parfactor.cpp \
$(srcdir)/ParfactorList.cpp \ $(srcdir)/ParfactorList.cpp \
$(srcdir)/ProbFormula.cpp \ $(srcdir)/ProbFormula.cpp \
$(srcdir)/Solver.cpp \ $(srcdir)/GroundSolver.cpp \
$(srcdir)/Util.cpp \ $(srcdir)/Util.cpp \
$(srcdir)/Var.cpp \ $(srcdir)/Var.cpp \
$(srcdir)/VarElim.cpp \ $(srcdir)/VarElim.cpp \
@ -122,7 +123,7 @@ OBJS = \
ProbFormula.o \ ProbFormula.o \
Parfactor.o \ Parfactor.o \
ParfactorList.o \ ParfactorList.o \
Solver.o \ GroundSolver.o \
Util.o \ Util.o \
Var.o \ Var.o \
VarElim.o \ VarElim.o \
@ -137,7 +138,7 @@ HCLI_OBJS = \
Factor.o \ Factor.o \
FactorGraph.o \ FactorGraph.o \
HorusCli.o \ HorusCli.o \
Solver.o \ GroundSolver.o \
Util.o \ Util.o \
Var.o \ Var.o \
VarElim.o \ VarElim.o \

View File

@ -1,107 +0,0 @@
#include "Solver.h"
#include "Util.h"
#include "BeliefProp.h"
#include "CountingBp.h"
#include "VarElim.h"
void
Solver::printAnswer (const VarIds& vids)
{
Vars unobservedVars;
VarIds unobservedVids;
for (size_t i = 0; i < vids.size(); i++) {
VarNode* vn = fg.getVarNode (vids[i]);
if (vn->hasEvidence() == false) {
unobservedVars.push_back (vn);
unobservedVids.push_back (vids[i]);
}
}
if (unobservedVids.empty() == false) {
Params res = solveQuery (unobservedVids);
vector<string> stateLines = Util::getStateLines (unobservedVars);
for (size_t i = 0; i < res.size(); i++) {
cout << "P(" << stateLines[i] << ") = " ;
cout << std::setprecision (Constants::PRECISION) << res[i];
cout << endl;
}
cout << endl;
}
}
void
Solver::printAllPosterioris (void)
{
VarNodes vars = fg.varNodes();
std::sort (vars.begin(), vars.end(), sortByVarId());
for (size_t i = 0; i < vars.size(); i++) {
printAnswer ({vars[i]->varId()});
}
}
Params
Solver::getJointByConditioning (
GroundSolver solverType,
FactorGraph fg,
const VarIds& jointVarIds) const
{
VarNodes jointVars;
for (size_t i = 0; i < jointVarIds.size(); i++) {
assert (fg.getVarNode (jointVarIds[i]));
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
}
Solver* solver = 0;
switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break;
}
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
VarIds observedVids = {jointVars[0]->varId()};
for (size_t i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs;
Vars observedVars;
Ranges observedRanges;
for (size_t j = 0; j < observedVids.size(); j++) {
observedVars.push_back (fg.getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range());
}
Indexer indexer (observedRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]);
}
delete solver;
switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break;
}
Params beliefs = solver->solveQuery ({jointVarIds[i]});
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->range() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
}
delete solver;
return prevBeliefs;
}

View File

@ -1,36 +0,0 @@
#ifndef HORUS_SOLVER_H
#define HORUS_SOLVER_H
#include <iomanip>
#include "FactorGraph.h"
#include "Var.h"
#include "Horus.h"
using namespace std;
class Solver
{
public:
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
virtual ~Solver() { } // ensure that subclass destructor is called
virtual Params solveQuery (VarIds queryVids) = 0;
virtual void printSolverFlags (void) const = 0;
void printAnswer (const VarIds& vids);
void printAllPosterioris (void);
Params getJointByConditioning (GroundSolver,
FactorGraph, const VarIds& jointVarIds) const;
protected:
const FactorGraph& fg;
};
#endif // HORUS_SOLVER_H

View File

@ -13,9 +13,9 @@ bool logDomain = false;
unsigned verbosity = 0; unsigned verbosity = 0;
LiftedSolver liftedSolver = LiftedSolver::LVE; LiftedSolverType liftedSolver = LiftedSolverType::LVE;
GroundSolver groundSolver = GroundSolver::VE; GroundSolverType groundSolver = GroundSolverType::VE;
}; };
@ -211,11 +211,11 @@ setHorusFlag (string key, string value)
ss >> Globals::verbosity; ss >> Globals::verbosity;
} else if (key == "lifted_solver") { } else if (key == "lifted_solver") {
if ( value == "lve") { if ( value == "lve") {
Globals::liftedSolver = LiftedSolver::LVE; Globals::liftedSolver = LiftedSolverType::LVE;
} else if (value == "lbp") { } else if (value == "lbp") {
Globals::liftedSolver = LiftedSolver::LBP; Globals::liftedSolver = LiftedSolverType::LBP;
} else if (value == "lkc") { } else if (value == "lkc") {
Globals::liftedSolver = LiftedSolver::LKC; Globals::liftedSolver = LiftedSolverType::LKC;
} else { } else {
cerr << "warning: invalid value `" << value << "' " ; cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;
@ -223,11 +223,11 @@ setHorusFlag (string key, string value)
} }
} else if (key == "ground_solver") { } else if (key == "ground_solver") {
if ( value == "ve") { if ( value == "ve") {
Globals::groundSolver = GroundSolver::VE; Globals::groundSolver = GroundSolverType::VE;
} else if (value == "bp") { } else if (value == "bp") {
Globals::groundSolver = GroundSolver::BP; Globals::groundSolver = GroundSolverType::BP;
} else if (value == "cbp") { } else if (value == "cbp") {
Globals::groundSolver = GroundSolver::CBP; Globals::groundSolver = GroundSolverType::CBP;
} else { } else {
cerr << "warning: invalid value `" << value << "' " ; cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;

View File

@ -3,7 +3,7 @@
#include "unordered_map" #include "unordered_map"
#include "Solver.h" #include "GroundSolver.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Horus.h" #include "Horus.h"
@ -11,10 +11,10 @@
using namespace std; using namespace std;
class VarElim : public Solver class VarElim : public GroundSolver
{ {
public: public:
VarElim (const FactorGraph& fg) : Solver (fg) { } VarElim (const FactorGraph& fg) : GroundSolver (fg) { }
~VarElim (void); ~VarElim (void);