From 64a27847cc0a9df49be59a389e4200032035d422 Mon Sep 17 00:00:00 2001 From: Tiago Gomes Date: Wed, 14 Nov 2012 21:55:51 +0000 Subject: [PATCH] Introduce a lifted solver class --- packages/CLPBN/horus/BeliefProp.cpp | 5 +- packages/CLPBN/horus/BeliefProp.h | 4 +- packages/CLPBN/horus/CountingBp.cpp | 6 +- packages/CLPBN/horus/CountingBp.h | 4 +- packages/CLPBN/horus/Horus.h | 8 +-- packages/CLPBN/horus/HorusCli.cpp | 8 +-- packages/CLPBN/horus/HorusYap.cpp | 14 ++-- packages/CLPBN/horus/LiftedBp.cpp | 2 +- packages/CLPBN/horus/LiftedBp.h | 3 +- packages/CLPBN/horus/LiftedKc.cpp | 8 --- packages/CLPBN/horus/LiftedKc.h | 7 +- packages/CLPBN/horus/LiftedVe.h | 7 +- packages/CLPBN/horus/Makefile.in | 9 +-- packages/CLPBN/horus/Solver.cpp | 107 ---------------------------- packages/CLPBN/horus/Solver.h | 36 ---------- packages/CLPBN/horus/Util.cpp | 16 ++--- packages/CLPBN/horus/VarElim.h | 6 +- 17 files changed, 53 insertions(+), 197 deletions(-) delete mode 100644 packages/CLPBN/horus/Solver.cpp delete mode 100644 packages/CLPBN/horus/Solver.h diff --git a/packages/CLPBN/horus/BeliefProp.cpp b/packages/CLPBN/horus/BeliefProp.cpp index 314f4a6c5..d96384cfd 100644 --- a/packages/CLPBN/horus/BeliefProp.cpp +++ b/packages/CLPBN/horus/BeliefProp.cpp @@ -12,7 +12,7 @@ #include "Horus.h" -BeliefProp::BeliefProp (const FactorGraph& fg) : Solver (fg) +BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg) { runned_ = false; } @@ -377,7 +377,8 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const Params BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const { - return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds); + return GroundSolver::getJointByConditioning ( + GroundSolverType::BP, fg, jointVarIds); } diff --git a/packages/CLPBN/horus/BeliefProp.h b/packages/CLPBN/horus/BeliefProp.h index 1545abfc8..6c1d5c46b 100644 --- a/packages/CLPBN/horus/BeliefProp.h +++ b/packages/CLPBN/horus/BeliefProp.h @@ -5,7 +5,7 @@ #include #include -#include "Solver.h" +#include "GroundSolver.h" #include "Factor.h" #include "FactorGraph.h" #include "Util.h" @@ -83,7 +83,7 @@ class SPNodeInfo }; -class BeliefProp : public Solver +class BeliefProp : public GroundSolver { public: BeliefProp (const FactorGraph&); diff --git a/packages/CLPBN/horus/CountingBp.cpp b/packages/CLPBN/horus/CountingBp.cpp index 365ff7098..d248c602c 100644 --- a/packages/CLPBN/horus/CountingBp.cpp +++ b/packages/CLPBN/horus/CountingBp.cpp @@ -6,7 +6,7 @@ bool CountingBp::checkForIdenticalFactors = true; CountingBp::CountingBp (const FactorGraph& fg) - : Solver (fg), freeColor_(0) + : GroundSolver (fg), freeColor_(0) { findIdenticalFactors(); setInitialColors(); @@ -74,8 +74,8 @@ CountingBp::solveQuery (VarIds queryVids) cout << endl; } if (idx == facNodes.size()) { - res = Solver::getJointByConditioning ( - GroundSolver::CBP, fg, queryVids); + res = GroundSolver::getJointByConditioning ( + GroundSolverType::CBP, fg, queryVids); } else { VarIds reprArgs; for (size_t i = 0; i < queryVids.size(); i++) { diff --git a/packages/CLPBN/horus/CountingBp.h b/packages/CLPBN/horus/CountingBp.h index 7bc45c632..a553e9307 100644 --- a/packages/CLPBN/horus/CountingBp.h +++ b/packages/CLPBN/horus/CountingBp.h @@ -3,7 +3,7 @@ #include -#include "Solver.h" +#include "GroundSolver.h" #include "FactorGraph.h" #include "Util.h" #include "Horus.h" @@ -102,7 +102,7 @@ class FacCluster }; -class CountingBp : public Solver +class CountingBp : public GroundSolver { public: CountingBp (const FactorGraph& fg); diff --git a/packages/CLPBN/horus/Horus.h b/packages/CLPBN/horus/Horus.h index 2c8d20e1e..7e5f12c8e 100644 --- a/packages/CLPBN/horus/Horus.h +++ b/packages/CLPBN/horus/Horus.h @@ -28,7 +28,7 @@ typedef vector Ranges; typedef unsigned long long ullong; -enum LiftedSolver +enum LiftedSolverType { LVE, // generalized counting first-order variable elimination (GC-FOVE) LBP, // lifted first-order belief propagation @@ -36,7 +36,7 @@ enum LiftedSolver }; -enum GroundSolver +enum GroundSolverType { VE, // variable elimination BP, // belief propagation @@ -51,8 +51,8 @@ extern bool logDomain; // level of debug information extern unsigned verbosity; -extern LiftedSolver liftedSolver; -extern GroundSolver groundSolver; +extern LiftedSolverType liftedSolver; +extern GroundSolverType groundSolver; }; diff --git a/packages/CLPBN/horus/HorusCli.cpp b/packages/CLPBN/horus/HorusCli.cpp index 4c3f8e7fc..639b91739 100644 --- a/packages/CLPBN/horus/HorusCli.cpp +++ b/packages/CLPBN/horus/HorusCli.cpp @@ -160,15 +160,15 @@ readQueryAndEvidence ( void runSolver (const FactorGraph& fg, const VarIds& queryIds) { - Solver* solver = 0; + GroundSolver* solver = 0; switch (Globals::groundSolver) { - case GroundSolver::VE: + case GroundSolverType::VE: solver = new VarElim (fg); break; - case GroundSolver::BP: + case GroundSolverType::BP: solver = new BeliefProp (fg); break; - case GroundSolver::CBP: + case GroundSolverType::CBP: solver = new CountingBp (fg); break; default: diff --git a/packages/CLPBN/horus/HorusYap.cpp b/packages/CLPBN/horus/HorusYap.cpp index cd31c0612..2fa0008fb 100644 --- a/packages/CLPBN/horus/HorusYap.cpp +++ b/packages/CLPBN/horus/HorusYap.cpp @@ -308,21 +308,21 @@ runLiftedSolver (void) } jointList = YAP_TailOfTerm (jointList); } - if (Globals::liftedSolver == LiftedSolver::LVE) { + if (Globals::liftedSolver == LiftedSolverType::LVE) { LiftedVe solver (pfListCopy); if (Globals::verbosity > 0 && taskList == YAP_ARG2) { solver.printSolverFlags(); cout << endl; } results.push_back (solver.solveQuery (queryVars)); - } else if (Globals::liftedSolver == LiftedSolver::LBP) { + } else if (Globals::liftedSolver == LiftedSolverType::LBP) { LiftedBp solver (pfListCopy); if (Globals::verbosity > 0 && taskList == YAP_ARG2) { solver.printSolverFlags(); cout << endl; } results.push_back (solver.solveQuery (queryVars)); - } else if (Globals::liftedSolver == LiftedSolver::LKC) { + } else if (Globals::liftedSolver == LiftedSolverType::LKC) { LiftedKc solver (pfListCopy); if (Globals::verbosity > 0 && taskList == YAP_ARG2) { solver.printSolverFlags(); @@ -369,18 +369,18 @@ runGroundSolver (void) for (size_t i = 0; i < tasks.size(); i++) { Util::addToSet (vids, tasks[i]); } - Solver* solver = 0; + GroundSolver* solver = 0; FactorGraph* mfg = fg; if (fg->bayesianFactors()) { mfg = BayesBall::getMinimalFactorGraph ( *fg, VarIds (vids.begin(), vids.end())); } - if (Globals::groundSolver == GroundSolver::VE) { + if (Globals::groundSolver == GroundSolverType::VE) { solver = new VarElim (*mfg); - } else if (Globals::groundSolver == GroundSolver::BP) { + } else if (Globals::groundSolver == GroundSolverType::BP) { solver = new BeliefProp (*mfg); - } else if (Globals::groundSolver == GroundSolver::CBP) { + } else if (Globals::groundSolver == GroundSolverType::CBP) { CountingBp::checkForIdenticalFactors = false; solver = new CountingBp (*mfg); } else { diff --git a/packages/CLPBN/horus/LiftedBp.cpp b/packages/CLPBN/horus/LiftedBp.cpp index 468e79d72..3fada048d 100644 --- a/packages/CLPBN/horus/LiftedBp.cpp +++ b/packages/CLPBN/horus/LiftedBp.cpp @@ -5,7 +5,7 @@ LiftedBp::LiftedBp (const ParfactorList& pfList) - : pfList_(pfList) + : LiftedSolver (pfList), pfList_(pfList) { refineParfactors(); createFactorGraph(); diff --git a/packages/CLPBN/horus/LiftedBp.h b/packages/CLPBN/horus/LiftedBp.h index 29edf0ac8..cb6e9f3a4 100644 --- a/packages/CLPBN/horus/LiftedBp.h +++ b/packages/CLPBN/horus/LiftedBp.h @@ -1,12 +1,13 @@ #ifndef HORUS_LIFTEDBP_H #define HORUS_LIFTEDBP_H +#include "LiftedSolver.h" #include "ParfactorList.h" class FactorGraph; class WeightedBp; -class LiftedBp +class LiftedBp : public LiftedSolver { public: LiftedBp (const ParfactorList& pfList); diff --git a/packages/CLPBN/horus/LiftedKc.cpp b/packages/CLPBN/horus/LiftedKc.cpp index 69aed3f33..64e651379 100644 --- a/packages/CLPBN/horus/LiftedKc.cpp +++ b/packages/CLPBN/horus/LiftedKc.cpp @@ -5,14 +5,6 @@ #include "Indexer.h" -LiftedKc::LiftedKc (const ParfactorList& pfList) - : pfList_(pfList) -{ - -} - - - LiftedKc::~LiftedKc (void) { delete lwcnf_; diff --git a/packages/CLPBN/horus/LiftedKc.h b/packages/CLPBN/horus/LiftedKc.h index 4b3065c1d..52138c985 100644 --- a/packages/CLPBN/horus/LiftedKc.h +++ b/packages/CLPBN/horus/LiftedKc.h @@ -1,15 +1,18 @@ #ifndef HORUS_LIFTEDKC_H #define HORUS_LIFTEDKC_H +#include "LiftedSolver.h" #include "ParfactorList.h" class LiftedWCNF; class LiftedCircuit; -class LiftedKc + +class LiftedKc : public LiftedSolver { public: - LiftedKc (const ParfactorList& pfList); + LiftedKc (const ParfactorList& pfList) + : LiftedSolver(pfList), pfList_(pfList) { } ~LiftedKc (void); diff --git a/packages/CLPBN/horus/LiftedVe.h b/packages/CLPBN/horus/LiftedVe.h index 9a464a348..cdaf39823 100644 --- a/packages/CLPBN/horus/LiftedVe.h +++ b/packages/CLPBN/horus/LiftedVe.h @@ -1,7 +1,7 @@ #ifndef HORUS_LIFTEDVE_H #define HORUS_LIFTEDVE_H - +#include "LiftedSolver.h" #include "ParfactorList.h" @@ -132,10 +132,11 @@ class GroundOperator : public LiftedOperator -class LiftedVe +class LiftedVe : public LiftedSolver { public: - LiftedVe (const ParfactorList& pfList) : pfList_(pfList) { } + LiftedVe (const ParfactorList& pfList) + : LiftedSolver(pfList), pfList_(pfList) { } Params solveQuery (const Grounds&); diff --git a/packages/CLPBN/horus/Makefile.in b/packages/CLPBN/horus/Makefile.in index a87b3574e..59936c776 100644 --- a/packages/CLPBN/horus/Makefile.in +++ b/packages/CLPBN/horus/Makefile.in @@ -60,13 +60,14 @@ HEADERS = \ $(srcdir)/LiftedCircuit.h \ $(srcdir)/LiftedKc.h \ $(srcdir)/LiftedOperations.h \ + $(srcdir)/LiftedSolver.h \ $(srcdir)/LiftedUtils.h \ $(srcdir)/LiftedVe.h \ $(srcdir)/LiftedWCNF.h \ $(srcdir)/Parfactor.h \ $(srcdir)/ParfactorList.h \ $(srcdir)/ProbFormula.h \ - $(srcdir)/Solver.h \ + $(srcdir)/GroundSolver.h \ $(srcdir)/TinySet.h \ $(srcdir)/Util.h \ $(srcdir)/Var.h \ @@ -95,7 +96,7 @@ CPP_SOURCES = \ $(srcdir)/Parfactor.cpp \ $(srcdir)/ParfactorList.cpp \ $(srcdir)/ProbFormula.cpp \ - $(srcdir)/Solver.cpp \ + $(srcdir)/GroundSolver.cpp \ $(srcdir)/Util.cpp \ $(srcdir)/Var.cpp \ $(srcdir)/VarElim.cpp \ @@ -122,7 +123,7 @@ OBJS = \ ProbFormula.o \ Parfactor.o \ ParfactorList.o \ - Solver.o \ + GroundSolver.o \ Util.o \ Var.o \ VarElim.o \ @@ -137,7 +138,7 @@ HCLI_OBJS = \ Factor.o \ FactorGraph.o \ HorusCli.o \ - Solver.o \ + GroundSolver.o \ Util.o \ Var.o \ VarElim.o \ diff --git a/packages/CLPBN/horus/Solver.cpp b/packages/CLPBN/horus/Solver.cpp deleted file mode 100644 index 4cb3b6768..000000000 --- a/packages/CLPBN/horus/Solver.cpp +++ /dev/null @@ -1,107 +0,0 @@ -#include "Solver.h" -#include "Util.h" -#include "BeliefProp.h" -#include "CountingBp.h" -#include "VarElim.h" - - -void -Solver::printAnswer (const VarIds& vids) -{ - Vars unobservedVars; - VarIds unobservedVids; - for (size_t i = 0; i < vids.size(); i++) { - VarNode* vn = fg.getVarNode (vids[i]); - if (vn->hasEvidence() == false) { - unobservedVars.push_back (vn); - unobservedVids.push_back (vids[i]); - } - } - if (unobservedVids.empty() == false) { - Params res = solveQuery (unobservedVids); - vector stateLines = Util::getStateLines (unobservedVars); - for (size_t i = 0; i < res.size(); i++) { - cout << "P(" << stateLines[i] << ") = " ; - cout << std::setprecision (Constants::PRECISION) << res[i]; - cout << endl; - } - cout << endl; - } -} - - - -void -Solver::printAllPosterioris (void) -{ - VarNodes vars = fg.varNodes(); - std::sort (vars.begin(), vars.end(), sortByVarId()); - for (size_t i = 0; i < vars.size(); i++) { - printAnswer ({vars[i]->varId()}); - } -} - - - -Params -Solver::getJointByConditioning ( - GroundSolver solverType, - FactorGraph fg, - const VarIds& jointVarIds) const -{ - VarNodes jointVars; - for (size_t i = 0; i < jointVarIds.size(); i++) { - assert (fg.getVarNode (jointVarIds[i])); - jointVars.push_back (fg.getVarNode (jointVarIds[i])); - } - - Solver* solver = 0; - switch (solverType) { - case GroundSolver::BP: solver = new BeliefProp (fg); break; - case GroundSolver::CBP: solver = new CountingBp (fg); break; - case GroundSolver::VE: solver = new VarElim (fg); break; - } - Params prevBeliefs = solver->solveQuery ({jointVarIds[0]}); - VarIds observedVids = {jointVars[0]->varId()}; - - for (size_t i = 1; i < jointVarIds.size(); i++) { - assert (jointVars[i]->hasEvidence() == false); - Params newBeliefs; - Vars observedVars; - Ranges observedRanges; - for (size_t j = 0; j < observedVids.size(); j++) { - observedVars.push_back (fg.getVarNode (observedVids[j])); - observedRanges.push_back (observedVars.back()->range()); - } - Indexer indexer (observedRanges, false); - while (indexer.valid()) { - for (size_t j = 0; j < observedVars.size(); j++) { - observedVars[j]->setEvidence (indexer[j]); - } - delete solver; - switch (solverType) { - case GroundSolver::BP: solver = new BeliefProp (fg); break; - case GroundSolver::CBP: solver = new CountingBp (fg); break; - case GroundSolver::VE: solver = new VarElim (fg); break; - } - Params beliefs = solver->solveQuery ({jointVarIds[i]}); - for (size_t k = 0; k < beliefs.size(); k++) { - newBeliefs.push_back (beliefs[k]); - } - ++ indexer; - } - - int count = -1; - for (size_t j = 0; j < newBeliefs.size(); j++) { - if (j % jointVars[i]->range() == 0) { - count ++; - } - newBeliefs[j] *= prevBeliefs[count]; - } - prevBeliefs = newBeliefs; - observedVids.push_back (jointVars[i]->varId()); - } - delete solver; - return prevBeliefs; -} - diff --git a/packages/CLPBN/horus/Solver.h b/packages/CLPBN/horus/Solver.h deleted file mode 100644 index a378b2419..000000000 --- a/packages/CLPBN/horus/Solver.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef HORUS_SOLVER_H -#define HORUS_SOLVER_H - -#include - -#include "FactorGraph.h" -#include "Var.h" -#include "Horus.h" - - -using namespace std; - -class Solver -{ - public: - Solver (const FactorGraph& factorGraph) : fg(factorGraph) { } - - virtual ~Solver() { } // ensure that subclass destructor is called - - virtual Params solveQuery (VarIds queryVids) = 0; - - virtual void printSolverFlags (void) const = 0; - - void printAnswer (const VarIds& vids); - - void printAllPosterioris (void); - - Params getJointByConditioning (GroundSolver, - FactorGraph, const VarIds& jointVarIds) const; - - protected: - const FactorGraph& fg; -}; - -#endif // HORUS_SOLVER_H - diff --git a/packages/CLPBN/horus/Util.cpp b/packages/CLPBN/horus/Util.cpp index bef3414c6..d3dbd588d 100644 --- a/packages/CLPBN/horus/Util.cpp +++ b/packages/CLPBN/horus/Util.cpp @@ -13,9 +13,9 @@ bool logDomain = false; unsigned verbosity = 0; -LiftedSolver liftedSolver = LiftedSolver::LVE; +LiftedSolverType liftedSolver = LiftedSolverType::LVE; -GroundSolver groundSolver = GroundSolver::VE; +GroundSolverType groundSolver = GroundSolverType::VE; }; @@ -211,11 +211,11 @@ setHorusFlag (string key, string value) ss >> Globals::verbosity; } else if (key == "lifted_solver") { if ( value == "lve") { - Globals::liftedSolver = LiftedSolver::LVE; + Globals::liftedSolver = LiftedSolverType::LVE; } else if (value == "lbp") { - Globals::liftedSolver = LiftedSolver::LBP; + Globals::liftedSolver = LiftedSolverType::LBP; } else if (value == "lkc") { - Globals::liftedSolver = LiftedSolver::LKC; + Globals::liftedSolver = LiftedSolverType::LKC; } else { cerr << "warning: invalid value `" << value << "' " ; cerr << "for `" << key << "'" << endl; @@ -223,11 +223,11 @@ setHorusFlag (string key, string value) } } else if (key == "ground_solver") { if ( value == "ve") { - Globals::groundSolver = GroundSolver::VE; + Globals::groundSolver = GroundSolverType::VE; } else if (value == "bp") { - Globals::groundSolver = GroundSolver::BP; + Globals::groundSolver = GroundSolverType::BP; } else if (value == "cbp") { - Globals::groundSolver = GroundSolver::CBP; + Globals::groundSolver = GroundSolverType::CBP; } else { cerr << "warning: invalid value `" << value << "' " ; cerr << "for `" << key << "'" << endl; diff --git a/packages/CLPBN/horus/VarElim.h b/packages/CLPBN/horus/VarElim.h index 6fbaded8c..fe1327fc0 100644 --- a/packages/CLPBN/horus/VarElim.h +++ b/packages/CLPBN/horus/VarElim.h @@ -3,7 +3,7 @@ #include "unordered_map" -#include "Solver.h" +#include "GroundSolver.h" #include "FactorGraph.h" #include "Horus.h" @@ -11,10 +11,10 @@ using namespace std; -class VarElim : public Solver +class VarElim : public GroundSolver { public: - VarElim (const FactorGraph& fg) : Solver (fg) { } + VarElim (const FactorGraph& fg) : GroundSolver (fg) { } ~VarElim (void);