From 8697fcd2b4f4ab95b912597e4acb307c8ab016ab Mon Sep 17 00:00:00 2001 From: Tiago Gomes Date: Tue, 10 Apr 2012 20:43:08 +0100 Subject: [PATCH] refactorings --- packages/CLPBN/clpbn/bp/BpSolver.cpp | 214 ++++++++++------------ packages/CLPBN/clpbn/bp/BpSolver.h | 18 +- packages/CLPBN/clpbn/bp/CFactorGraph.cpp | 50 ++--- packages/CLPBN/clpbn/bp/CFactorGraph.h | 38 ++-- packages/CLPBN/clpbn/bp/CbpSolver.cpp | 125 ++++++------- packages/CLPBN/clpbn/bp/CbpSolver.h | 28 ++- packages/CLPBN/clpbn/bp/HorusCli.cpp | 30 ++- packages/CLPBN/clpbn/bp/Solver.cpp | 4 +- packages/CLPBN/clpbn/bp/Solver.h | 4 +- packages/CLPBN/clpbn/bp/VarElimSolver.cpp | 6 +- 10 files changed, 251 insertions(+), 266 deletions(-) diff --git a/packages/CLPBN/clpbn/bp/BpSolver.cpp b/packages/CLPBN/clpbn/bp/BpSolver.cpp index d645dbce3..e97608776 100644 --- a/packages/CLPBN/clpbn/bp/BpSolver.cpp +++ b/packages/CLPBN/clpbn/bp/BpSolver.cpp @@ -14,7 +14,7 @@ BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg) { - factorGraph_ = &fg; + fg_ = &fg; runned_ = false; } @@ -54,8 +54,8 @@ BpSolver::getPosterioriOf (VarId vid) if (runned_ == false) { runSolver(); } - assert (factorGraph_->getVarNode (vid)); - VarNode* var = factorGraph_->getVarNode (vid); + assert (fg_->getVarNode (vid)); + VarNode* var = fg_->getVarNode (vid); Params probs; if (var->hasEvidence()) { probs.resize (var->range(), LogAware::noEvidence()); @@ -88,7 +88,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds) runSolver(); } int idx = -1; - VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]); + VarNode* vn = fg_->getVarNode (jointVarIds[0]); const FacNodes& facNodes = vn->neighbors(); for (unsigned i = 0; i < facNodes.size(); i++) { if (facNodes[i]->factor().contains (jointVarIds)) { @@ -121,37 +121,64 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds) void -BpSolver::initializeSolver (void) +BpSolver::runSolver (void) { - const VarNodes& varNodes = factorGraph_->varNodes(); - for (unsigned i = 0; i < varsI_.size(); i++) { - delete varsI_[i]; + clock_t start; + if (Constants::COLLECT_STATS) { + start = clock(); } - varsI_.reserve (varNodes.size()); - for (unsigned i = 0; i < varNodes.size(); i++) { - varsI_.push_back (new SPNodeInfo()); + initializeSolver(); + nIters_ = 0; + while (!converged() && nIters_ < BpOptions::maxIter) { + nIters_ ++; + if (Constants::DEBUG >= 2) { + Util::printHeader (" Iteration " + nIters_); + cout << endl; + } + switch (BpOptions::schedule) { + case BpOptions::Schedule::SEQ_RANDOM: + random_shuffle (links_.begin(), links_.end()); + // no break + case BpOptions::Schedule::SEQ_FIXED: + for (unsigned i = 0; i < links_.size(); i++) { + calculateAndUpdateMessage (links_[i]); + } + break; + case BpOptions::Schedule::PARALLEL: + for (unsigned i = 0; i < links_.size(); i++) { + calculateMessage (links_[i]); + } + for (unsigned i = 0; i < links_.size(); i++) { + updateMessage(links_[i]); + } + break; + case BpOptions::Schedule::MAX_RESIDUAL: + maxResidualSchedule(); + break; + } + if (Constants::DEBUG >= 2) { + cout << endl; + } } - - const FacNodes& facNodes = factorGraph_->facNodes(); - for (unsigned i = 0; i < facsI_.size(); i++) { - delete facsI_[i]; + if (Constants::DEBUG >= 2) { + cout << endl; + if (nIters_ < BpOptions::maxIter) { + cout << "Sum-Product converged in " ; + cout << nIters_ << " iterations" << endl; + } else { + cout << "The maximum number of iterations was hit, terminating..." ; + cout << endl; + } } - facsI_.reserve (facNodes.size()); - for (unsigned i = 0; i < facNodes.size(); i++) { - facsI_.push_back (new SPNodeInfo()); - } - - for (unsigned i = 0; i < links_.size(); i++) { - delete links_[i]; - } - createLinks(); - - for (unsigned i = 0; i < links_.size(); i++) { - FacNode* src = links_[i]->getFactor(); - VarNode* dst = links_[i]->getVariable(); - ninf (dst)->addSpLink (links_[i]); - ninf (src)->addSpLink (links_[i]); + unsigned size = fg_->varNodes().size(); + if (Constants::COLLECT_STATS) { + unsigned nIters = 0; + bool loopy = fg_->isTree() == false; + if (loopy) nIters = nIters_; + double time = (double (clock() - start)) / CLOCKS_PER_SEC; + Statistics::updateStatistics (size, loopy, nIters, time); } + runned_ = true; } @@ -159,7 +186,7 @@ BpSolver::initializeSolver (void) void BpSolver::createLinks (void) { - const FacNodes& facNodes = factorGraph_->facNodes(); + const FacNodes& facNodes = fg_->facNodes(); for (unsigned i = 0; i < facNodes.size(); i++) { const VarNodes& neighbors = facNodes[i]->neighbors(); for (unsigned j = 0; j < neighbors.size(); j++) { @@ -342,11 +369,11 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const { VarNodes jointVars; for (unsigned i = 0; i < jointVarIds.size(); i++) { - assert (factorGraph_->getVarNode (jointVarIds[i])); - jointVars.push_back (factorGraph_->getVarNode (jointVarIds[i])); + assert (fg_->getVarNode (jointVarIds[i])); + jointVars.push_back (fg_->getVarNode (jointVarIds[i])); } - FactorGraph* fg = new FactorGraph (*factorGraph_); + FactorGraph* fg = new FactorGraph (*fg_); BpSolver solver (*fg); solver.runSolver(); Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]); @@ -390,93 +417,24 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const void -BpSolver::printLinkInformation (void) const +BpSolver::initializeSolver (void) { + const VarNodes& varNodes = fg_->varNodes(); + varsI_.reserve (varNodes.size()); + for (unsigned i = 0; i < varNodes.size(); i++) { + varsI_.push_back (new SPNodeInfo()); + } + const FacNodes& facNodes = fg_->facNodes(); + facsI_.reserve (facNodes.size()); + for (unsigned i = 0; i < facNodes.size(); i++) { + facsI_.push_back (new SPNodeInfo()); + } + createLinks(); for (unsigned i = 0; i < links_.size(); i++) { - SpLink* l = links_[i]; - cout << l->toString() << ":" << endl; - cout << " curr msg = " ; - cout << l->getMessage() << endl; - cout << " next msg = " ; - cout << l->getNextMessage() << endl; - cout << " residual = " << l->getResidual() << endl; - } -} - - - -void -BpSolver::runSolver (void) -{ - clock_t start; - if (Constants::COLLECT_STATS) { - start = clock(); - } - runLoopySolver(); - if (Constants::DEBUG >= 2) { - cout << endl; - if (nIters_ < BpOptions::maxIter) { - cout << "Sum-Product converged in " ; - cout << nIters_ << " iterations" << endl; - } else { - cout << "The maximum number of iterations was hit, terminating..." ; - cout << endl; - } - } - unsigned size = factorGraph_->varNodes().size(); - if (Constants::COLLECT_STATS) { - unsigned nIters = 0; - bool loopy = factorGraph_->isTree() == false; - if (loopy) nIters = nIters_; - double time = (double (clock() - start)) / CLOCKS_PER_SEC; - Statistics::updateStatistics (size, loopy, nIters, time); - } - runned_ = true; -} - - - -void -BpSolver::runLoopySolver (void) -{ - initializeSolver(); - nIters_ = 0; - - while (!converged() && nIters_ < BpOptions::maxIter) { - - nIters_ ++; - if (Constants::DEBUG >= 2) { - Util::printHeader (" Iteration " + nIters_); - cout << endl; - } - - switch (BpOptions::schedule) { - case BpOptions::Schedule::SEQ_RANDOM: - random_shuffle (links_.begin(), links_.end()); - // no break - - case BpOptions::Schedule::SEQ_FIXED: - for (unsigned i = 0; i < links_.size(); i++) { - calculateAndUpdateMessage (links_[i]); - } - break; - - case BpOptions::Schedule::PARALLEL: - for (unsigned i = 0; i < links_.size(); i++) { - calculateMessage (links_[i]); - } - for (unsigned i = 0; i < links_.size(); i++) { - updateMessage(links_[i]); - } - break; - - case BpOptions::Schedule::MAX_RESIDUAL: - maxResidualSchedule(); - break; - } - if (Constants::DEBUG >= 2) { - cout << endl; - } + FacNode* src = links_[i]->getFactor(); + VarNode* dst = links_[i]->getVariable(); + ninf (dst)->addSpLink (links_[i]); + ninf (src)->addSpLink (links_[i]); } } @@ -488,7 +446,7 @@ BpSolver::converged (void) if (links_.size() == 0) { return true; } - if (nIters_ == 0 || nIters_ == 1) { + if (nIters_ <= 1) { return false; } bool converged = true; @@ -514,3 +472,19 @@ BpSolver::converged (void) return converged; } + + +void +BpSolver::printLinkInformation (void) const +{ + for (unsigned i = 0; i < links_.size(); i++) { + SpLink* l = links_[i]; + cout << l->toString() << ":" << endl; + cout << " curr msg = " ; + cout << l->getMessage() << endl; + cout << " next msg = " ; + cout << l->getNextMessage() << endl; + cout << " residual = " << l->getResidual() << endl; + } +} + diff --git a/packages/CLPBN/clpbn/bp/BpSolver.h b/packages/CLPBN/clpbn/bp/BpSolver.h index ebd5c0c93..c9260b013 100644 --- a/packages/CLPBN/clpbn/bp/BpSolver.h +++ b/packages/CLPBN/clpbn/bp/BpSolver.h @@ -1,5 +1,5 @@ -#ifndef HORUS_BpSolver_H -#define HORUS_BpSolver_H +#ifndef HORUS_BPSOLVER_H +#define HORUS_BPSOLVER_H #include #include @@ -102,7 +102,7 @@ class BpSolver : public Solver virtual Params getJointDistributionOf (const VarIds&); protected: - virtual void initializeSolver (void); + void runSolver (void); virtual void createLinks (void); @@ -114,8 +114,6 @@ class BpSolver : public Solver virtual Params getJointByConditioning (const VarIds&) const; - virtual void printLinkInformation (void) const; - SPNodeInfo* ninf (const VarNode* var) const { return varsI_[var->getIndex()]; @@ -170,7 +168,7 @@ class BpSolver : public Solver vector varsI_; vector facsI_; bool runned_; - const FactorGraph* factorGraph_; + const FactorGraph* fg_; typedef multiset SortedOrder; SortedOrder sortedOrder_; @@ -179,10 +177,12 @@ class BpSolver : public Solver SpLinkMap linkMap_; private: - void runSolver (void); - void runLoopySolver (void); + void initializeSolver (void); + bool converged (void); + + void printLinkInformation (void) const; }; -#endif // HORUS_BpSolver_H +#endif // HORUS_BPSOLVER_H diff --git a/packages/CLPBN/clpbn/bp/CFactorGraph.cpp b/packages/CLPBN/clpbn/bp/CFactorGraph.cpp index 84b58e60f..975b14225 100644 --- a/packages/CLPBN/clpbn/bp/CFactorGraph.cpp +++ b/packages/CLPBN/clpbn/bp/CFactorGraph.cpp @@ -18,14 +18,14 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg) } const FacNodes& facNodes = fg.facNodes(); - factorSignatures_.reserve (facNodes.size()); + facSignatures_.reserve (facNodes.size()); for (unsigned i = 0; i < facNodes.size(); i++) { unsigned c = facNodes[i]->neighbors().size() + 1; - factorSignatures_.push_back (Signature (c)); + facSignatures_.push_back (Signature (c)); } varColors_.resize (varNodes.size()); - factorColors_.resize (facNodes.size()); + facColors_.resize (facNodes.size()); setInitialColors(); createGroups(); } @@ -111,7 +111,7 @@ void CFactorGraph::createGroups (void) { VarSignMap varGroups; - FacSignMap factorGroups; + FacSignMap facGroups; unsigned nIters = 0; bool groupsHaveChanged = true; const VarNodes& varNodes = groundFg_->varNodes(); @@ -120,19 +120,19 @@ CFactorGraph::createGroups (void) while (groupsHaveChanged || nIters == 1) { nIters ++; - unsigned prevFactorGroupsSize = factorGroups.size(); - factorGroups.clear(); + unsigned prevFactorGroupsSize = facGroups.size(); + facGroups.clear(); // set a new color to the factors with the same signature for (unsigned i = 0; i < facNodes.size(); i++) { const Signature& signature = getSignature (facNodes[i]); - FacSignMap::iterator it = factorGroups.find (signature); - if (it == factorGroups.end()) { - it = factorGroups.insert (make_pair (signature, FacNodes())).first; + FacSignMap::iterator it = facGroups.find (signature); + if (it == facGroups.end()) { + it = facGroups.insert (make_pair (signature, FacNodes())).first; } it->second.push_back (facNodes[i]); } - for (FacSignMap::iterator it = factorGroups.begin(); - it != factorGroups.end(); it++) { + for (FacSignMap::iterator it = facGroups.begin(); + it != facGroups.end(); it++) { Color newColor = getFreeColor(); FacNodes& groupMembers = it->second; for (unsigned i = 0; i < groupMembers.size(); i++) { @@ -161,10 +161,10 @@ CFactorGraph::createGroups (void) } groupsHaveChanged = prevVarGroupsSize != varGroups.size() - || prevFactorGroupsSize != factorGroups.size(); + || prevFactorGroupsSize != facGroups.size(); } - //printGroups (varGroups, factorGroups); - createClusters (varGroups, factorGroups); + //printGroups (varGroups, facGroups); + createClusters (varGroups, facGroups); } @@ -172,7 +172,7 @@ CFactorGraph::createGroups (void) void CFactorGraph::createClusters ( const VarSignMap& varGroups, - const FacSignMap& factorGroups) + const FacSignMap& facGroups) { varClusters_.reserve (varGroups.size()); for (VarSignMap::const_iterator it = varGroups.begin(); @@ -185,12 +185,12 @@ CFactorGraph::createClusters ( varClusters_.push_back (vc); } - facClusters_.reserve (factorGroups.size()); - for (FacSignMap::const_iterator it = factorGroups.begin(); - it != factorGroups.end(); it++) { + facClusters_.reserve (facGroups.size()); + for (FacSignMap::const_iterator it = facGroups.begin(); + it != facGroups.end(); it++) { FacNode* groupFactor = it->second[0]; const VarNodes& neighs = groupFactor->neighbors(); - VarClusterSet varClusters; + VarClusters varClusters; varClusters.reserve (neighs.size()); for (unsigned i = 0; i < neighs.size(); i++) { VarId vid = neighs[i]->varId(); @@ -223,7 +223,7 @@ CFactorGraph::getSignature (const VarNode* varNode) const Signature& CFactorGraph::getSignature (const FacNode* facNode) { - Signature& sign = factorSignatures_[facNode->getIndex()]; + Signature& sign = facSignatures_[facNode->getIndex()]; vector::iterator it = sign.colors.begin(); const VarNodes& neighs = facNode->neighbors(); for (unsigned i = 0; i < neighs.size(); i++) { @@ -237,7 +237,7 @@ CFactorGraph::getSignature (const FacNode* facNode) FactorGraph* -CFactorGraph::getCompressedFactorGraph (void) +CFactorGraph::getGroundFactorGraph (void) const { FactorGraph* fg = new FactorGraph(); for (unsigned i = 0; i < varClusters_.size(); i++) { @@ -248,7 +248,7 @@ CFactorGraph::getCompressedFactorGraph (void) } for (unsigned i = 0; i < facClusters_.size(); i++) { - const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters(); + const VarClusters& myVarClusters = facClusters_[i]->getVarClusters(); Vars myGroundVars; myGroundVars.reserve (myVarClusters.size()); for (unsigned j = 0; j < myVarClusters.size(); j++) { @@ -300,7 +300,7 @@ CFactorGraph::getGroundEdgeCount ( void CFactorGraph::printGroups ( const VarSignMap& varGroups, - const FacSignMap& factorGroups) const + const FacSignMap& facGroups) const { unsigned count = 1; cout << "variable groups:" << endl; @@ -319,8 +319,8 @@ CFactorGraph::printGroups ( count = 1; cout << endl << "factor groups:" << endl; - for (FacSignMap::const_iterator it = factorGroups.begin(); - it != factorGroups.end(); it++) { + for (FacSignMap::const_iterator it = facGroups.begin(); + it != facGroups.end(); it++) { const FacNodes& groupMembers = it->second; if (groupMembers.size() > 0) { cout << ++count << ": " ; diff --git a/packages/CLPBN/clpbn/bp/CFactorGraph.h b/packages/CLPBN/clpbn/bp/CFactorGraph.h index 68e29720d..80b93d5c3 100644 --- a/packages/CLPBN/clpbn/bp/CFactorGraph.h +++ b/packages/CLPBN/clpbn/bp/CFactorGraph.h @@ -22,8 +22,8 @@ typedef unordered_map> VarColorMap; typedef unordered_map DistColorMap; typedef unordered_map VarId2VarCluster; -typedef vector VarClusterSet; -typedef vector FacClusterSet; +typedef vector VarClusters; +typedef vector FacClusters; typedef unordered_map VarSignMap; typedef unordered_map FacSignMap; @@ -99,18 +99,20 @@ class VarCluster facClusters_.push_back (fc); } - const FacClusterSet& getFacClusters (void) const + const FacClusters& getFacClusters (void) const { return facClusters_; } VarNode* getRepresentativeVariable (void) const { return representVar_; } - void setRepresentativeVariable (VarNode* v) { representVar_ = v; } - const VarNodes& getGroundVarNodes (void) const { return groundVars_; } + + void setRepresentativeVariable (VarNode* v) { representVar_ = v; } + + const VarNodes& getGroundVarNodes (void) const { return groundVars_; } private: VarNodes groundVars_; - FacClusterSet facClusters_; + FacClusters facClusters_; VarNode* representVar_; }; @@ -118,7 +120,7 @@ class VarCluster class FacCluster { public: - FacCluster (const FacNodes& groundFactors, const VarClusterSet& vcs) + FacCluster (const FacNodes& groundFactors, const VarClusters& vcs) { groundFactors_ = groundFactors; varClusters_ = vcs; @@ -127,7 +129,7 @@ class FacCluster } } - const VarClusterSet& getVarClusters (void) const + const VarClusters& getVarClusters (void) const { return varClusters_; } @@ -160,7 +162,7 @@ class FacCluster private: FacNodes groundFactors_; - VarClusterSet varClusters_; + VarClusters varClusters_; FacNode* representFactor_; }; @@ -172,9 +174,9 @@ class CFactorGraph ~CFactorGraph (void); - const VarClusterSet& getVarClusters (void) { return varClusters_; } + const VarClusters& getVarClusters (void) { return varClusters_; } - const FacClusterSet& getFacClusters (void) { return facClusters_; } + const FacClusters& getFacClusters (void) { return facClusters_; } VarNode* getEquivalentVariable (VarId vid) { @@ -182,7 +184,7 @@ class CFactorGraph return vc->getRepresentativeVariable(); } - FactorGraph* getCompressedFactorGraph (void); + FactorGraph* getGroundFactorGraph (void) const; unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const; @@ -200,7 +202,7 @@ class CFactorGraph return varColors_[vn->getIndex()]; } Color getColor (const FacNode* fn) const { - return factorColors_[fn->getIndex()]; + return facColors_[fn->getIndex()]; } void setColor (const VarNode* vn, Color c) @@ -210,7 +212,7 @@ class CFactorGraph void setColor (const FacNode* fn, Color c) { - factorColors_[fn->getIndex()] = c; + facColors_[fn->getIndex()] = c; } VarCluster* getVariableCluster (VarId vid) const @@ -232,11 +234,11 @@ class CFactorGraph Color freeColor_; vector varColors_; - vector factorColors_; + vector facColors_; vector varSignatures_; - vector factorSignatures_; - VarClusterSet varClusters_; - FacClusterSet facClusters_; + vector facSignatures_; + VarClusters varClusters_; + FacClusters facClusters_; VarId2VarCluster vid2VarCluster_; const FactorGraph* groundFg_; }; diff --git a/packages/CLPBN/clpbn/bp/CbpSolver.cpp b/packages/CLPBN/clpbn/bp/CbpSolver.cpp index 1437638f8..c9ca34683 100644 --- a/packages/CLPBN/clpbn/bp/CbpSolver.cpp +++ b/packages/CLPBN/clpbn/bp/CbpSolver.cpp @@ -1,10 +1,41 @@ #include "CbpSolver.h" +CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg) +{ + unsigned nGroundVars, nGroundFacs, nWithoutNeighs; + if (Constants::COLLECT_STATS) { + nGroundVars = fg_->varNodes().size(); + nGroundFacs = fg_->facNodes().size(); + const VarNodes& vars = fg_->varNodes(); + nWithoutNeighs = 0; + for (unsigned i = 0; i < vars.size(); i++) { + const FacNodes& factors = vars[i]->neighbors(); + if (factors.size() == 1 && factors[0]->neighbors().size() == 1) { + nWithoutNeighs ++; + } + } + } + cfg_ = new CFactorGraph (fg); + fg_ = cfg_->getGroundFactorGraph(); + if (Constants::COLLECT_STATS) { + unsigned nClusterVars = fg_->varNodes().size(); + unsigned nClusterFacs = fg_->facNodes().size(); + Statistics::updateCompressingStatistics (nGroundVars, + nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs); + } + // Util::printHeader ("Uncompressed Factor Graph"); + // fg->print(); + // Util::printHeader ("Compressed Factor Graph"); + // fg_->print(); +} + + + CbpSolver::~CbpSolver (void) { - delete lfg_; - delete factorGraph_; + delete cfg_; + delete fg_; for (unsigned i = 0; i < links_.size(); i++) { delete links_[i]; } @@ -16,8 +47,11 @@ CbpSolver::~CbpSolver (void) Params CbpSolver::getPosterioriOf (VarId vid) { - assert (lfg_->getEquivalentVariable (vid)); - VarNode* var = lfg_->getEquivalentVariable (vid); + if (runned_ == false) { + runSolver(); + } + assert (cfg_->getEquivalentVariable (vid)); + VarNode* var = cfg_->getEquivalentVariable (vid); Params probs; if (var->hasEvidence()) { probs.resize (var->range(), LogAware::noEvidence()); @@ -26,16 +60,16 @@ CbpSolver::getPosterioriOf (VarId vid) probs.resize (var->range(), LogAware::multIdenty()); const SpLinkSet& links = ninf(var)->getLinks(); if (Globals::logDomain) { - for (unsigned i = 0; i < links.size(); i++) { - CbpSolverLink* l = static_cast (links[i]); - Util::add (probs, l->getPoweredMessage()); - } - LogAware::normalize (probs); - Util::fromLog (probs); + for (unsigned i = 0; i < links.size(); i++) { + CbpSolverLink* l = static_cast (links[i]); + Util::add (probs, l->poweredMessage()); + } + LogAware::normalize (probs); + Util::fromLog (probs); } else { for (unsigned i = 0; i < links.size(); i++) { CbpSolverLink* l = static_cast (links[i]); - Util::multiply (probs, l->getPoweredMessage()); + Util::multiply (probs, l->poweredMessage()); } LogAware::normalize (probs); } @@ -46,67 +80,28 @@ CbpSolver::getPosterioriOf (VarId vid) Params -CbpSolver::getJointDistributionOf (const VarIds& jointVarIds) +CbpSolver::getJointDistributionOf (const VarIds& jointVids) { VarIds eqVarIds; - for (unsigned i = 0; i < jointVarIds.size(); i++) { - eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId()); + for (unsigned i = 0; i < jointVids.size(); i++) { + VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]); + eqVarIds.push_back (vn->varId()); } return BpSolver::getJointDistributionOf (eqVarIds); } - -void -CbpSolver::initializeSolver (void) -{ - unsigned nGroundVars, nGroundFacs, nWithoutNeighs; - if (Constants::COLLECT_STATS) { - nGroundVars = factorGraph_->varNodes().size(); - nGroundFacs = factorGraph_->facNodes().size(); - const VarNodes& vars = factorGraph_->varNodes(); - nWithoutNeighs = 0; - for (unsigned i = 0; i < vars.size(); i++) { - const FacNodes& factors = vars[i]->neighbors(); - if (factors.size() == 1 && factors[0]->neighbors().size() == 1) { - nWithoutNeighs ++; - } - } - } - - lfg_ = new CFactorGraph (*factorGraph_); - - // cout << "Uncompressed Factor Graph" << endl; - // factorGraph_->print(); - // factorGraph_->exportToGraphViz ("uncompressed_fg.dot"); - factorGraph_ = lfg_->getCompressedFactorGraph(); - - if (Constants::COLLECT_STATS) { - unsigned nClusterVars = factorGraph_->varNodes().size(); - unsigned nClusterFacs = factorGraph_->facNodes().size(); - Statistics::updateCompressingStatistics (nGroundVars, - nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs); - } - - // cout << "Compressed Factor Graph" << endl; - // factorGraph_->print(); - // factorGraph_->exportToGraphViz ("compressed_fg.dot"); - // abort(); - BpSolver::initializeSolver(); -} - - - void CbpSolver::createLinks (void) { - const FacClusterSet fcs = lfg_->getFacClusters(); + const FacClusters& fcs = cfg_->getFacClusters(); for (unsigned i = 0; i < fcs.size(); i++) { - const VarClusterSet vcs = fcs[i]->getVarClusters(); + const VarClusters& vcs = fcs[i]->getVarClusters(); for (unsigned j = 0; j < vcs.size(); j++) { - unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]); - links_.push_back (new CbpSolverLink (fcs[i]->getRepresentativeFactor(), + unsigned c = cfg_->getGroundEdgeCount (fcs[i], vcs[j]); + links_.push_back (new CbpSolverLink ( + fcs[i]->getRepresentativeFactor(), vcs[j]->getRepresentativeVariable(), c)); } } @@ -197,10 +192,10 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const if (src->hasEvidence()) { msg.resize (src->range(), LogAware::noEvidence()); double value = link->getMessage()[src->getEvidence()]; - msg[src->getEvidence()] = LogAware::pow (value, l->getNumberOfEdges() - 1); + msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1); } else { msg = link->getMessage(); - LogAware::pow (msg, l->getNumberOfEdges() - 1); + LogAware::pow (msg, l->nrEdges() - 1); } if (Constants::DEBUG >= 5) { cout << " " << "init: " << msg << endl; @@ -210,17 +205,17 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const for (unsigned i = 0; i < links.size(); i++) { if (links[i]->getFactor() != dst) { CbpSolverLink* l = static_cast (links[i]); - Util::add (msg, l->getPoweredMessage()); + Util::add (msg, l->poweredMessage()); } } } else { for (unsigned i = 0; i < links.size(); i++) { if (links[i]->getFactor() != dst) { CbpSolverLink* l = static_cast (links[i]); - Util::multiply (msg, l->getPoweredMessage()); + Util::multiply (msg, l->poweredMessage()); if (Constants::DEBUG >= 5) { cout << " msg from " << l->getFactor()->getLabel() << ": " ; - cout << l->getPoweredMessage() << endl; + cout << l->poweredMessage() << endl; } } } @@ -242,7 +237,7 @@ CbpSolver::printLinkInformation (void) const cout << l->toString() << ":" << endl; cout << " curr msg = " << l->getMessage() << endl; cout << " next msg = " << l->getNextMessage() << endl; - cout << " powered = " << l->getPoweredMessage() << endl; + cout << " powered = " << l->poweredMessage() << endl; cout << " residual = " << l->getResidual() << endl; } } diff --git a/packages/CLPBN/clpbn/bp/CbpSolver.h b/packages/CLPBN/clpbn/bp/CbpSolver.h index c270e9d4b..638e68fb9 100644 --- a/packages/CLPBN/clpbn/bp/CbpSolver.h +++ b/packages/CLPBN/clpbn/bp/CbpSolver.h @@ -9,27 +9,25 @@ class Factor; class CbpSolverLink : public SpLink { public: - CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c) : SpLink (fn, vn) - { - edgeCount_ = c; - poweredMsg_.resize (vn->range(), LogAware::one()); - } + CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c) + : SpLink (fn, vn), nrEdges_(c), + pwdMsg_(vn->range(), LogAware::one()) { } - unsigned getNumberOfEdges (void) const { return edgeCount_; } + unsigned nrEdges (void) const { return nrEdges_; } - const Params& getPoweredMessage (void) const { return poweredMsg_; } + const Params& poweredMessage (void) const { return pwdMsg_; } void updateMessage (void) { - poweredMsg_ = *nextMsg_; + pwdMsg_ = *nextMsg_; swap (currMsg_, nextMsg_); msgSended_ = true; - LogAware::pow (poweredMsg_, edgeCount_); + LogAware::pow (pwdMsg_, nrEdges_); } private: - Params poweredMsg_; - unsigned edgeCount_; + unsigned nrEdges_; + Params pwdMsg_; }; @@ -37,16 +35,15 @@ class CbpSolverLink : public SpLink class CbpSolver : public BpSolver { public: - CbpSolver (const FactorGraph& fg) : BpSolver (fg) { } + CbpSolver (const FactorGraph& fg); ~CbpSolver (void); - + Params getPosterioriOf (VarId); Params getJointDistributionOf (const VarIds&); private: - void initializeSolver (void); void createLinks (void); @@ -56,8 +53,7 @@ class CbpSolver : public BpSolver void printLinkInformation (void) const; - CFactorGraph* lfg_; - FactorGraph* factorGraph_; + CFactorGraph* cfg_; }; #endif // HORUS_CBP_H diff --git a/packages/CLPBN/clpbn/bp/HorusCli.cpp b/packages/CLPBN/clpbn/bp/HorusCli.cpp index 4b6153a77..cf3333d16 100644 --- a/packages/CLPBN/clpbn/bp/HorusCli.cpp +++ b/packages/CLPBN/clpbn/bp/HorusCli.cpp @@ -14,25 +14,43 @@ void processArguments (FactorGraph&, int, const char* []); void runSolver (const FactorGraph&, const VarIds&); const string USAGE = "usage: \ -./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ; +./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ; int main (int argc, const char* argv[]) { - if (!argv[1]) { + if (argc <= 1) { + cerr << "error: no solver specified" << endl; cerr << "error: no graphical model specified" << endl; cerr << USAGE << endl; exit (0); } - string fileName = argv[1]; + if (argc <= 2) { + cerr << "error: no graphical model specified" << endl; + cerr << USAGE << endl; + exit (0); + } + string solver (argv[1]); + if (solver == "ve") { + Globals::infAlgorithm = InfAlgorithms::VE; + } else if (solver == "bp") { + Globals::infAlgorithm = InfAlgorithms::BP; + } else if (solver == "cbp") { + Globals::infAlgorithm = InfAlgorithms::CBP; + } else { + cerr << "error: unknow solver `" << solver << "'" << endl ; + cerr << USAGE << endl; + exit(0); + } + string fileName (argv[2]); string extension = fileName.substr ( fileName.find_last_of ('.') + 1); FactorGraph fg; if (extension == "uai") { - fg.readFromUaiFormat (argv[1]); + fg.readFromUaiFormat (fileName.c_str()); } else if (extension == "fg") { - fg.readFromLibDaiFormat (argv[1]); + fg.readFromLibDaiFormat (fileName.c_str()); } else { cerr << "error: the graphical model must be defined either " ; cerr << "in a UAI or libDAI file" << endl; @@ -48,7 +66,7 @@ void processArguments (FactorGraph& fg, int argc, const char* argv[]) { VarIds queryIds; - for (int i = 2; i < argc; i++) { + for (int i = 3; i < argc; i++) { const string& arg = argv[i]; if (arg.find ('=') == std::string::npos) { if (!Util::isInteger (arg)) { diff --git a/packages/CLPBN/clpbn/bp/Solver.cpp b/packages/CLPBN/clpbn/bp/Solver.cpp index 3e67c1e49..44d61db2e 100644 --- a/packages/CLPBN/clpbn/bp/Solver.cpp +++ b/packages/CLPBN/clpbn/bp/Solver.cpp @@ -8,7 +8,7 @@ Solver::printAnswer (const VarIds& vids) Vars unobservedVars; VarIds unobservedVids; for (unsigned i = 0; i < vids.size(); i++) { - VarNode* vn = fg_.getVarNode (vids[i]); + VarNode* vn = fg.getVarNode (vids[i]); if (vn->hasEvidence() == false) { unobservedVars.push_back (vn); unobservedVids.push_back (vids[i]); @@ -29,7 +29,7 @@ Solver::printAnswer (const VarIds& vids) void Solver::printAllPosterioris (void) { - const VarNodes& vars = fg_.varNodes(); + const VarNodes& vars = fg.varNodes(); for (unsigned i = 0; i < vars.size(); i++) { printAnswer ({vars[i]->varId()}); } diff --git a/packages/CLPBN/clpbn/bp/Solver.h b/packages/CLPBN/clpbn/bp/Solver.h index 219299673..00d584128 100644 --- a/packages/CLPBN/clpbn/bp/Solver.h +++ b/packages/CLPBN/clpbn/bp/Solver.h @@ -12,7 +12,7 @@ using namespace std; class Solver { public: - Solver (const FactorGraph& fg) : fg_(fg) { } + Solver (const FactorGraph& factorGraph) : fg(factorGraph) { } virtual ~Solver() { } // ensure that subclass destructor is called @@ -23,7 +23,7 @@ class Solver void printAllPosterioris (void); protected: - const FactorGraph& fg_; + const FactorGraph& fg; }; #endif // HORUS_SOLVER_H diff --git a/packages/CLPBN/clpbn/bp/VarElimSolver.cpp b/packages/CLPBN/clpbn/bp/VarElimSolver.cpp index 0356ad0f0..75af57868 100644 --- a/packages/CLPBN/clpbn/bp/VarElimSolver.cpp +++ b/packages/CLPBN/clpbn/bp/VarElimSolver.cpp @@ -35,7 +35,7 @@ VarElimSolver::solveQuery (VarIds queryVids) void VarElimSolver::createFactorList (void) { - const FacNodes& facNodes = fg_.facNodes(); + const FacNodes& facNodes = fg.facNodes(); factorList_.reserve (facNodes.size() * 2); for (unsigned i = 0; i < facNodes.size(); i++) { factorList_.push_back (new Factor (facNodes[i]->factor())); @@ -57,7 +57,7 @@ VarElimSolver::createFactorList (void) void VarElimSolver::absorveEvidence (void) { - const VarNodes& varNodes = fg_.varNodes(); + const VarNodes& varNodes = fg.varNodes(); for (unsigned i = 0; i < varNodes.size(); i++) { if (varNodes[i]->hasEvidence()) { const vector& idxs = @@ -103,7 +103,7 @@ VarElimSolver::processFactorList (const VarIds& vids) VarIds unobservedVids; for (unsigned i = 0; i < vids.size(); i++) { - if (fg_.getVarNode (vids[i])->hasEvidence() == false) { + if (fg.getVarNode (vids[i])->hasEvidence() == false) { unobservedVids.push_back (vids[i]); } }