From 78e86a6330a6d5d676da3ce00368752a7f03f22d Mon Sep 17 00:00:00 2001 From: Tiago Gomes Date: Tue, 10 Apr 2012 15:00:18 +0100 Subject: [PATCH] refactor ground solver interface --- packages/CLPBN/clpbn/bp/BpSolver.cpp | 229 ++++++++++++---------- packages/CLPBN/clpbn/bp/BpSolver.h | 4 +- packages/CLPBN/clpbn/bp/CbpSolver.cpp | 5 +- packages/CLPBN/clpbn/bp/CbpSolver.h | 6 +- packages/CLPBN/clpbn/bp/ElimGraph.cpp | 12 +- packages/CLPBN/clpbn/bp/ElimGraph.h | 2 - packages/CLPBN/clpbn/bp/Factor.cpp | 2 +- packages/CLPBN/clpbn/bp/HorusCli.cpp | 31 ++- packages/CLPBN/clpbn/bp/HorusYap.cpp | 13 +- packages/CLPBN/clpbn/bp/Solver.cpp | 66 +++---- packages/CLPBN/clpbn/bp/Solver.h | 12 +- packages/CLPBN/clpbn/bp/Util.cpp | 2 +- packages/CLPBN/clpbn/bp/Util.h | 2 +- packages/CLPBN/clpbn/bp/VarElimSolver.cpp | 38 +--- packages/CLPBN/clpbn/bp/VarElimSolver.h | 9 +- 15 files changed, 191 insertions(+), 242 deletions(-) diff --git a/packages/CLPBN/clpbn/bp/BpSolver.cpp b/packages/CLPBN/clpbn/bp/BpSolver.cpp index ab05205c9..d645dbce3 100644 --- a/packages/CLPBN/clpbn/bp/BpSolver.cpp +++ b/packages/CLPBN/clpbn/bp/BpSolver.cpp @@ -15,6 +15,7 @@ BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg) { factorGraph_ = &fg; + runned_ = false; } @@ -34,31 +35,14 @@ BpSolver::~BpSolver (void) -void -BpSolver::runSolver (void) +Params +BpSolver::solveQuery (VarIds queryVids) { - clock_t start; - if (Constants::COLLECT_STATS) { - start = clock(); - } - runLoopySolver(); - if (Constants::DEBUG >= 2) { - cout << endl; - if (nIters_ < BpOptions::maxIter) { - cout << "Sum-Product converged in " ; - cout << nIters_ << " iterations" << endl; - } else { - cout << "The maximum number of iterations was hit, terminating..." ; - cout << endl; - } - } - unsigned size = factorGraph_->varNodes().size(); - if (Constants::COLLECT_STATS) { - unsigned nIters = 0; - bool loopy = factorGraph_->isTree() == false; - if (loopy) nIters = nIters_; - double time = (double (clock() - start)) / CLOCKS_PER_SEC; - Statistics::updateStatistics (size, loopy, nIters, time); + assert (queryVids.empty() == false); + if (queryVids.size() == 1) { + return getPosterioriOf (queryVids[0]); + } else { + return getJointDistributionOf (queryVids); } } @@ -67,6 +51,9 @@ BpSolver::runSolver (void) Params BpSolver::getPosterioriOf (VarId vid) { + if (runned_ == false) { + runSolver(); + } assert (factorGraph_->getVarNode (vid)); VarNode* var = factorGraph_->getVarNode (vid); Params probs; @@ -97,6 +84,9 @@ BpSolver::getPosterioriOf (VarId vid) Params BpSolver::getJointDistributionOf (const VarIds& jointVarIds) { + if (runned_ == false) { + runSolver(); + } int idx = -1; VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]); const FacNodes& facNodes = vn->neighbors(); @@ -130,52 +120,6 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds) -void -BpSolver::runLoopySolver (void) -{ - initializeSolver(); - nIters_ = 0; - - while (!converged() && nIters_ < BpOptions::maxIter) { - - nIters_ ++; - if (Constants::DEBUG >= 2) { - Util::printHeader (" Iteration " + nIters_); - cout << endl; - } - - switch (BpOptions::schedule) { - case BpOptions::Schedule::SEQ_RANDOM: - random_shuffle (links_.begin(), links_.end()); - // no break - - case BpOptions::Schedule::SEQ_FIXED: - for (unsigned i = 0; i < links_.size(); i++) { - calculateAndUpdateMessage (links_[i]); - } - break; - - case BpOptions::Schedule::PARALLEL: - for (unsigned i = 0; i < links_.size(); i++) { - calculateMessage (links_[i]); - } - for (unsigned i = 0; i < links_.size(); i++) { - updateMessage(links_[i]); - } - break; - - case BpOptions::Schedule::MAX_RESIDUAL: - maxResidualSchedule(); - break; - } - if (Constants::DEBUG >= 2) { - cout << endl; - } - } -} - - - void BpSolver::initializeSolver (void) { @@ -226,40 +170,6 @@ BpSolver::createLinks (void) -bool -BpSolver::converged (void) -{ - if (links_.size() == 0) { - return true; - } - if (nIters_ == 0 || nIters_ == 1) { - return false; - } - bool converged = true; - if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) { - double maxResidual = (*(sortedOrder_.begin()))->getResidual(); - if (maxResidual > BpOptions::accuracy) { - converged = false; - } else { - converged = true; - } - } else { - for (unsigned i = 0; i < links_.size(); i++) { - double residual = links_[i]->getResidual(); - if (Constants::DEBUG >= 2) { - cout << links_[i]->toString() + " residual = " << residual << endl; - } - if (residual > BpOptions::accuracy) { - converged = false; - if (Constants::DEBUG == 0) break; - } - } - } - return converged; -} - - - void BpSolver::maxResidualSchedule (void) { @@ -493,3 +403,114 @@ BpSolver::printLinkInformation (void) const } } + + +void +BpSolver::runSolver (void) +{ + clock_t start; + if (Constants::COLLECT_STATS) { + start = clock(); + } + runLoopySolver(); + if (Constants::DEBUG >= 2) { + cout << endl; + if (nIters_ < BpOptions::maxIter) { + cout << "Sum-Product converged in " ; + cout << nIters_ << " iterations" << endl; + } else { + cout << "The maximum number of iterations was hit, terminating..." ; + cout << endl; + } + } + unsigned size = factorGraph_->varNodes().size(); + if (Constants::COLLECT_STATS) { + unsigned nIters = 0; + bool loopy = factorGraph_->isTree() == false; + if (loopy) nIters = nIters_; + double time = (double (clock() - start)) / CLOCKS_PER_SEC; + Statistics::updateStatistics (size, loopy, nIters, time); + } + runned_ = true; +} + + + +void +BpSolver::runLoopySolver (void) +{ + initializeSolver(); + nIters_ = 0; + + while (!converged() && nIters_ < BpOptions::maxIter) { + + nIters_ ++; + if (Constants::DEBUG >= 2) { + Util::printHeader (" Iteration " + nIters_); + cout << endl; + } + + switch (BpOptions::schedule) { + case BpOptions::Schedule::SEQ_RANDOM: + random_shuffle (links_.begin(), links_.end()); + // no break + + case BpOptions::Schedule::SEQ_FIXED: + for (unsigned i = 0; i < links_.size(); i++) { + calculateAndUpdateMessage (links_[i]); + } + break; + + case BpOptions::Schedule::PARALLEL: + for (unsigned i = 0; i < links_.size(); i++) { + calculateMessage (links_[i]); + } + for (unsigned i = 0; i < links_.size(); i++) { + updateMessage(links_[i]); + } + break; + + case BpOptions::Schedule::MAX_RESIDUAL: + maxResidualSchedule(); + break; + } + if (Constants::DEBUG >= 2) { + cout << endl; + } + } +} + + + +bool +BpSolver::converged (void) +{ + if (links_.size() == 0) { + return true; + } + if (nIters_ == 0 || nIters_ == 1) { + return false; + } + bool converged = true; + if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) { + double maxResidual = (*(sortedOrder_.begin()))->getResidual(); + if (maxResidual > BpOptions::accuracy) { + converged = false; + } else { + converged = true; + } + } else { + for (unsigned i = 0; i < links_.size(); i++) { + double residual = links_[i]->getResidual(); + if (Constants::DEBUG >= 2) { + cout << links_[i]->toString() + " residual = " << residual << endl; + } + if (residual > BpOptions::accuracy) { + converged = false; + if (Constants::DEBUG == 0) break; + } + } + } + return converged; +} + diff --git a/packages/CLPBN/clpbn/bp/BpSolver.h b/packages/CLPBN/clpbn/bp/BpSolver.h index f7b77e6d9..ebd5c0c93 100644 --- a/packages/CLPBN/clpbn/bp/BpSolver.h +++ b/packages/CLPBN/clpbn/bp/BpSolver.h @@ -95,7 +95,7 @@ class BpSolver : public Solver virtual ~BpSolver (void); - void runSolver (void); + Params solveQuery (VarIds); virtual Params getPosterioriOf (VarId); @@ -169,6 +169,7 @@ class BpSolver : public Solver unsigned nIters_; vector varsI_; vector facsI_; + bool runned_; const FactorGraph* factorGraph_; typedef multiset SortedOrder; @@ -178,6 +179,7 @@ class BpSolver : public Solver SpLinkMap linkMap_; private: + void runSolver (void); void runLoopySolver (void); bool converged (void); }; diff --git a/packages/CLPBN/clpbn/bp/CbpSolver.cpp b/packages/CLPBN/clpbn/bp/CbpSolver.cpp index e5adbf9cc..1437638f8 100644 --- a/packages/CLPBN/clpbn/bp/CbpSolver.cpp +++ b/packages/CLPBN/clpbn/bp/CbpSolver.cpp @@ -85,9 +85,8 @@ CbpSolver::initializeSolver (void) if (Constants::COLLECT_STATS) { unsigned nClusterVars = factorGraph_->varNodes().size(); unsigned nClusterFacs = factorGraph_->facNodes().size(); - Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs, - nClusterVars, nClusterFacs, - nWithoutNeighs); + Statistics::updateCompressingStatistics (nGroundVars, + nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs); } // cout << "Compressed Factor Graph" << endl; diff --git a/packages/CLPBN/clpbn/bp/CbpSolver.h b/packages/CLPBN/clpbn/bp/CbpSolver.h index 7118f4443..c270e9d4b 100644 --- a/packages/CLPBN/clpbn/bp/CbpSolver.h +++ b/packages/CLPBN/clpbn/bp/CbpSolver.h @@ -37,7 +37,7 @@ class CbpSolverLink : public SpLink class CbpSolver : public BpSolver { public: - CbpSolver (FactorGraph& fg) : BpSolver (fg) { } + CbpSolver (const FactorGraph& fg) : BpSolver (fg) { } ~CbpSolver (void); @@ -47,13 +47,17 @@ class CbpSolver : public BpSolver private: void initializeSolver (void); + void createLinks (void); void maxResidualSchedule (void); + Params getVar2FactorMsg (const SpLink*) const; + void printLinkInformation (void) const; CFactorGraph* lfg_; + FactorGraph* factorGraph_; }; #endif // HORUS_CBP_H diff --git a/packages/CLPBN/clpbn/bp/ElimGraph.cpp b/packages/CLPBN/clpbn/bp/ElimGraph.cpp index cbf5eae82..9ede9677e 100644 --- a/packages/CLPBN/clpbn/bp/ElimGraph.cpp +++ b/packages/CLPBN/clpbn/bp/ElimGraph.cpp @@ -34,7 +34,6 @@ ElimGraph::ElimGraph (const vector& factors) } } } - setIndexes(); } @@ -148,6 +147,7 @@ void ElimGraph::addNode (EgNode* n) { nodes_.push_back (n); + n->setIndex (nodes_.size() - 1); varMap_.insert (make_pair (n->varId(), n)); } @@ -301,13 +301,3 @@ ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const return false; } - - -void -ElimGraph::setIndexes (void) -{ - for (unsigned i = 0; i < nodes_.size(); i++) { - nodes_[i]->setIndex (i); - } -} - diff --git a/packages/CLPBN/clpbn/bp/ElimGraph.h b/packages/CLPBN/clpbn/bp/ElimGraph.h index 05e17c8d4..564ddf278 100644 --- a/packages/CLPBN/clpbn/bp/ElimGraph.h +++ b/packages/CLPBN/clpbn/bp/ElimGraph.h @@ -78,8 +78,6 @@ class ElimGraph bool neighbors (const EgNode*, const EgNode*) const; - void setIndexes (void); - vector nodes_; vector marked_; unordered_map varMap_; diff --git a/packages/CLPBN/clpbn/bp/Factor.cpp b/packages/CLPBN/clpbn/bp/Factor.cpp index 1755876ae..2b5d22068 100644 --- a/packages/CLPBN/clpbn/bp/Factor.cpp +++ b/packages/CLPBN/clpbn/bp/Factor.cpp @@ -241,7 +241,7 @@ Factor::print (void) const for (unsigned i = 0; i < args_.size(); i++) { vars.push_back (new Var (args_[i], ranges_[i])); } - vector jointStrings = Util::getJointStateStrings (vars); + vector jointStrings = Util::getStateLines (vars); for (unsigned i = 0; i < params_.size(); i++) { cout << "f(" << jointStrings[i] << ")" ; cout << " = " << params_[i] << endl; diff --git a/packages/CLPBN/clpbn/bp/HorusCli.cpp b/packages/CLPBN/clpbn/bp/HorusCli.cpp index 75f25215b..4b6153a77 100644 --- a/packages/CLPBN/clpbn/bp/HorusCli.cpp +++ b/packages/CLPBN/clpbn/bp/HorusCli.cpp @@ -11,7 +11,7 @@ using namespace std; void processArguments (FactorGraph&, int, const char* []); -void runSolver (Solver*, const VarIds&); +void runSolver (const FactorGraph&, const VarIds&); const string USAGE = "usage: \ ./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ; @@ -25,8 +25,8 @@ main (int argc, const char* argv[]) cerr << USAGE << endl; exit (0); } - const string& fileName = argv[1]; - const string& extension = fileName.substr ( + string fileName = argv[1]; + string extension = fileName.substr ( fileName.find_last_of ('.') + 1); FactorGraph fg; if (extension == "uai") { @@ -38,8 +38,6 @@ main (int argc, const char* argv[]) cerr << "in a UAI or libDAI file" << endl; exit (0); } - fg.print(); - assert (false); processArguments (fg, argc, argv); return 0; } @@ -123,6 +121,14 @@ processArguments (FactorGraph& fg, int argc, const char* argv[]) } } } + runSolver (fg, queryIds); +} + + + +void +runSolver (const FactorGraph& fg, const VarIds& queryIds) +{ Solver* solver = 0; switch (Globals::infAlgorithm) { case InfAlgorithms::VE: @@ -137,23 +143,10 @@ processArguments (FactorGraph& fg, int argc, const char* argv[]) default: assert (false); } - runSolver (solver, queryIds); -} - - - -void -runSolver (Solver* solver, const VarIds& queryIds) -{ if (queryIds.size() == 0) { - solver->runSolver(); solver->printAllPosterioris(); - } else if (queryIds.size() == 1) { - solver->runSolver(); - solver->printPosterioriOf (queryIds[0]); } else { - solver->runSolver(); - solver->printJointDistributionOf (queryIds); + solver->printAnswer (queryIds); } delete solver; } diff --git a/packages/CLPBN/clpbn/bp/HorusYap.cpp b/packages/CLPBN/clpbn/bp/HorusYap.cpp index 833ae2cea..ee326c729 100644 --- a/packages/CLPBN/clpbn/bp/HorusYap.cpp +++ b/packages/CLPBN/clpbn/bp/HorusYap.cpp @@ -379,11 +379,7 @@ void runVeSolver ( mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]); } VarElimSolver solver (*mfg); - if (tasks[i].size() == 1) { - results.push_back (solver.getPosterioriOf (tasks[i][0])); - } else { - results.push_back (solver.getJointDistributionOf (tasks[i])); - } + results.push_back (solver.solveQuery (tasks[i])); if (fg->isFromBayesNetwork()) { delete mfg; } @@ -416,14 +412,9 @@ void runBpSolver ( cerr << "error: unknow solver" << endl; abort(); } - solver->runSolver(); results.reserve (tasks.size()); for (unsigned i = 0; i < tasks.size(); i++) { - if (tasks[i].size() == 1) { - results.push_back (solver->getPosterioriOf (tasks[i][0])); - } else { - results.push_back (solver->getJointDistributionOf (tasks[i])); - } + results.push_back (solver->solveQuery (tasks[i])); } if (fg->isFromBayesNetwork()) { delete mfg; diff --git a/packages/CLPBN/clpbn/bp/Solver.cpp b/packages/CLPBN/clpbn/bp/Solver.cpp index 1fa57e9ee..3e67c1e49 100644 --- a/packages/CLPBN/clpbn/bp/Solver.cpp +++ b/packages/CLPBN/clpbn/bp/Solver.cpp @@ -2,52 +2,36 @@ #include "Util.h" +void +Solver::printAnswer (const VarIds& vids) +{ + Vars unobservedVars; + VarIds unobservedVids; + for (unsigned i = 0; i < vids.size(); i++) { + VarNode* vn = fg_.getVarNode (vids[i]); + if (vn->hasEvidence() == false) { + unobservedVars.push_back (vn); + unobservedVids.push_back (vids[i]); + } + } + Params res = solveQuery (unobservedVids); + vector stateLines = Util::getStateLines (unobservedVars); + for (unsigned i = 0; i < res.size(); i++) { + cout << "P(" << stateLines[i] << ") = " ; + cout << std::setprecision (Constants::PRECISION) << res[i]; + cout << endl; + } + cout << endl; +} + + + void Solver::printAllPosterioris (void) { const VarNodes& vars = fg_.varNodes(); for (unsigned i = 0; i < vars.size(); i++) { - printPosterioriOf (vars[i]->varId()); + printAnswer ({vars[i]->varId()}); } } - - -void -Solver::printPosterioriOf (VarId vid) -{ - VarNode* vn = fg_.getVarNode (vid); - const Params& posterioriDist = getPosterioriOf (vid); - const States& states = vn->states(); - for (unsigned i = 0; i < states.size(); i++) { - cout << "P(" << vn->label() << "=" << states[i] << ") = " ; - cout << setprecision (Constants::PRECISION) << posterioriDist[i]; - cout << endl; - } - cout << endl; -} - - - -void -Solver::printJointDistributionOf (const VarIds& vids) -{ - Vars vars; - VarIds vidsWithoutEvidence; - for (unsigned i = 0; i < vids.size(); i++) { - VarNode* vn = fg_.getVarNode (vids[i]); - if (vn->hasEvidence() == false) { - vars.push_back (vn); - vidsWithoutEvidence.push_back (vids[i]); - } - } - const Params& jointDist = getJointDistributionOf (vidsWithoutEvidence); - vector jointStrings = Util::getJointStateStrings (vars); - for (unsigned i = 0; i < jointDist.size(); i++) { - cout << "P(" << jointStrings[i] << ") = " ; - cout << setprecision (Constants::PRECISION) << jointDist[i]; - cout << endl; - } - cout << endl; -} - diff --git a/packages/CLPBN/clpbn/bp/Solver.h b/packages/CLPBN/clpbn/bp/Solver.h index 826f25e0d..219299673 100644 --- a/packages/CLPBN/clpbn/bp/Solver.h +++ b/packages/CLPBN/clpbn/bp/Solver.h @@ -16,19 +16,13 @@ class Solver virtual ~Solver() { } // ensure that subclass destructor is called - virtual void runSolver (void) = 0; + virtual Params solveQuery (VarIds queryVids) = 0; - virtual Params getPosterioriOf (VarId) = 0; - - virtual Params getJointDistributionOf (const VarIds&) = 0; + void printAnswer (const VarIds& vids); void printAllPosterioris (void); - - void printPosterioriOf (VarId vid); - - void printJointDistributionOf (const VarIds& vids); - private: + protected: const FactorGraph& fg_; }; diff --git a/packages/CLPBN/clpbn/bp/Util.cpp b/packages/CLPBN/clpbn/bp/Util.cpp index 1c3d6c441..4a07af4cb 100644 --- a/packages/CLPBN/clpbn/bp/Util.cpp +++ b/packages/CLPBN/clpbn/bp/Util.cpp @@ -137,7 +137,7 @@ parametersToString (const Params& v, unsigned precision) vector -getJointStateStrings (const Vars& vars) +getStateLines (const Vars& vars) { StatesIndexer idx (vars); vector jointStrings; diff --git a/packages/CLPBN/clpbn/bp/Util.h b/packages/CLPBN/clpbn/bp/Util.h index e7af16311..42ab1d18e 100644 --- a/packages/CLPBN/clpbn/bp/Util.h +++ b/packages/CLPBN/clpbn/bp/Util.h @@ -62,7 +62,7 @@ bool isInteger (const string&); string parametersToString (const Params&, unsigned = Constants::PRECISION); -vector getJointStateStrings (const Vars&); +vector getStateLines (const Vars&); void printHeader (string, std::ostream& os = std::cout); diff --git a/packages/CLPBN/clpbn/bp/VarElimSolver.cpp b/packages/CLPBN/clpbn/bp/VarElimSolver.cpp index 9bd32ecc8..0356ad0f0 100644 --- a/packages/CLPBN/clpbn/bp/VarElimSolver.cpp +++ b/packages/CLPBN/clpbn/bp/VarElimSolver.cpp @@ -6,13 +6,6 @@ #include "Util.h" -VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (fg) -{ - factorGraph_ = &fg; -} - - - VarElimSolver::~VarElimSolver (void) { delete factorList_.back(); @@ -21,30 +14,15 @@ VarElimSolver::~VarElimSolver (void) Params -VarElimSolver::getPosterioriOf (VarId vid) -{ - assert (factorGraph_->getVarNode (vid)); - VarNode* vn = factorGraph_->getVarNode (vid); - if (vn->hasEvidence()) { - Params params (vn->range(), 0.0); - params[vn->getEvidence()] = 1.0; - return params; - } - return getJointDistributionOf (VarIds() = {vid}); -} - - - -Params -VarElimSolver::getJointDistributionOf (const VarIds& vids) +VarElimSolver::solveQuery (VarIds queryVids) { factorList_.clear(); varFactors_.clear(); elimOrder_.clear(); createFactorList(); absorveEvidence(); - findEliminationOrder (vids); - processFactorList (vids); + findEliminationOrder (queryVids); + processFactorList (queryVids); Params params = factorList_.back()->params(); if (Globals::logDomain) { Util::fromLog (params); @@ -57,10 +35,10 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids) void VarElimSolver::createFactorList (void) { - const FacNodes& facNodes = factorGraph_->facNodes(); + const FacNodes& facNodes = fg_.facNodes(); factorList_.reserve (facNodes.size() * 2); for (unsigned i = 0; i < facNodes.size(); i++) { - factorList_.push_back (new Factor (facNodes[i]->factor())); // FIXME + factorList_.push_back (new Factor (facNodes[i]->factor())); const VarNodes& neighs = facNodes[i]->neighbors(); for (unsigned j = 0; j < neighs.size(); j++) { unordered_map >::iterator it @@ -79,7 +57,7 @@ VarElimSolver::createFactorList (void) void VarElimSolver::absorveEvidence (void) { - const VarNodes& varNodes = factorGraph_->varNodes(); + const VarNodes& varNodes = fg_.varNodes(); for (unsigned i = 0; i < varNodes.size(); i++) { if (varNodes[i]->hasEvidence()) { const vector& idxs = @@ -125,7 +103,7 @@ VarElimSolver::processFactorList (const VarIds& vids) VarIds unobservedVids; for (unsigned i = 0; i < vids.size(); i++) { - if (factorGraph_->getVarNode (vids[i])->hasEvidence() == false) { + if (fg_.getVarNode (vids[i])->hasEvidence() == false) { unobservedVids.push_back (vids[i]); } } @@ -146,7 +124,7 @@ VarElimSolver::eliminate (VarId elimVar) unsigned idx = idxs[i]; if (factorList_[idx]) { if (result == 0) { - result = new Factor(*factorList_[idx]); + result = new Factor (*factorList_[idx]); } else { result->multiply (*factorList_[idx]); } diff --git a/packages/CLPBN/clpbn/bp/VarElimSolver.h b/packages/CLPBN/clpbn/bp/VarElimSolver.h index 5014e07f4..effe6838b 100644 --- a/packages/CLPBN/clpbn/bp/VarElimSolver.h +++ b/packages/CLPBN/clpbn/bp/VarElimSolver.h @@ -14,15 +14,11 @@ using namespace std; class VarElimSolver : public Solver { public: - VarElimSolver (const FactorGraph&); + VarElimSolver (const FactorGraph& fg) : Solver (fg) { } ~VarElimSolver (void); - void runSolver (void) { } - - Params getPosterioriOf (VarId); - - Params getJointDistributionOf (const VarIds&); + Params solveQuery (VarIds); private: void createFactorList (void); @@ -37,7 +33,6 @@ class VarElimSolver : public Solver void printActiveFactors (void); - const FactorGraph* factorGraph_; vector factorList_; VarIds elimOrder_; unordered_map> varFactors_;