refactorings

This commit is contained in:
Tiago Gomes 2012-04-10 20:43:08 +01:00
parent 78e86a6330
commit 8697fcd2b4
10 changed files with 251 additions and 266 deletions

View File

@ -14,7 +14,7 @@
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg) BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
{ {
factorGraph_ = &fg; fg_ = &fg;
runned_ = false; runned_ = false;
} }
@ -54,8 +54,8 @@ BpSolver::getPosterioriOf (VarId vid)
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
} }
assert (factorGraph_->getVarNode (vid)); assert (fg_->getVarNode (vid));
VarNode* var = factorGraph_->getVarNode (vid); VarNode* var = fg_->getVarNode (vid);
Params probs; Params probs;
if (var->hasEvidence()) { if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence()); probs.resize (var->range(), LogAware::noEvidence());
@ -88,7 +88,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
runSolver(); runSolver();
} }
int idx = -1; int idx = -1;
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]); VarNode* vn = fg_->getVarNode (jointVarIds[0]);
const FacNodes& facNodes = vn->neighbors(); const FacNodes& facNodes = vn->neighbors();
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
if (facNodes[i]->factor().contains (jointVarIds)) { if (facNodes[i]->factor().contains (jointVarIds)) {
@ -121,37 +121,64 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
void void
BpSolver::initializeSolver (void) BpSolver::runSolver (void)
{ {
const VarNodes& varNodes = factorGraph_->varNodes(); clock_t start;
for (unsigned i = 0; i < varsI_.size(); i++) { if (Constants::COLLECT_STATS) {
delete varsI_[i]; start = clock();
} }
varsI_.reserve (varNodes.size()); initializeSolver();
for (unsigned i = 0; i < varNodes.size(); i++) { nIters_ = 0;
varsI_.push_back (new SPNodeInfo()); while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (Constants::DEBUG >= 2) {
Util::printHeader (" Iteration " + nIters_);
cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
for (unsigned i = 0; i < links_.size(); i++) {
updateMessage(links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
} }
if (Constants::DEBUG >= 2) {
const FacNodes& facNodes = factorGraph_->facNodes(); cout << endl;
for (unsigned i = 0; i < facsI_.size(); i++) { if (nIters_ < BpOptions::maxIter) {
delete facsI_[i]; cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
} }
facsI_.reserve (facNodes.size()); unsigned size = fg_->varNodes().size();
for (unsigned i = 0; i < facNodes.size(); i++) { if (Constants::COLLECT_STATS) {
facsI_.push_back (new SPNodeInfo()); unsigned nIters = 0;
} bool loopy = fg_->isTree() == false;
if (loopy) nIters = nIters_;
for (unsigned i = 0; i < links_.size(); i++) { double time = (double (clock() - start)) / CLOCKS_PER_SEC;
delete links_[i]; Statistics::updateStatistics (size, loopy, nIters, time);
}
createLinks();
for (unsigned i = 0; i < links_.size(); i++) {
FacNode* src = links_[i]->getFactor();
VarNode* dst = links_[i]->getVariable();
ninf (dst)->addSpLink (links_[i]);
ninf (src)->addSpLink (links_[i]);
} }
runned_ = true;
} }
@ -159,7 +186,7 @@ BpSolver::initializeSolver (void)
void void
BpSolver::createLinks (void) BpSolver::createLinks (void)
{ {
const FacNodes& facNodes = factorGraph_->facNodes(); const FacNodes& facNodes = fg_->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
const VarNodes& neighbors = facNodes[i]->neighbors(); const VarNodes& neighbors = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighbors.size(); j++) { for (unsigned j = 0; j < neighbors.size(); j++) {
@ -342,11 +369,11 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
{ {
VarNodes jointVars; VarNodes jointVars;
for (unsigned i = 0; i < jointVarIds.size(); i++) { for (unsigned i = 0; i < jointVarIds.size(); i++) {
assert (factorGraph_->getVarNode (jointVarIds[i])); assert (fg_->getVarNode (jointVarIds[i]));
jointVars.push_back (factorGraph_->getVarNode (jointVarIds[i])); jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
} }
FactorGraph* fg = new FactorGraph (*factorGraph_); FactorGraph* fg = new FactorGraph (*fg_);
BpSolver solver (*fg); BpSolver solver (*fg);
solver.runSolver(); solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]); Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
@ -390,93 +417,24 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
void void
BpSolver::printLinkInformation (void) const BpSolver::initializeSolver (void)
{ {
const VarNodes& varNodes = fg_->varNodes();
varsI_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo());
}
const FacNodes& facNodes = fg_->facNodes();
facsI_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo());
}
createLinks();
for (unsigned i = 0; i < links_.size(); i++) { for (unsigned i = 0; i < links_.size(); i++) {
SpLink* l = links_[i]; FacNode* src = links_[i]->getFactor();
cout << l->toString() << ":" << endl; VarNode* dst = links_[i]->getVariable();
cout << " curr msg = " ; ninf (dst)->addSpLink (links_[i]);
cout << l->getMessage() << endl; ninf (src)->addSpLink (links_[i]);
cout << " next msg = " ;
cout << l->getNextMessage() << endl;
cout << " residual = " << l->getResidual() << endl;
}
}
void
BpSolver::runSolver (void)
{
clock_t start;
if (Constants::COLLECT_STATS) {
start = clock();
}
runLoopySolver();
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = factorGraph_->varNodes().size();
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = factorGraph_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
runned_ = true;
}
void
BpSolver::runLoopySolver (void)
{
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (Constants::DEBUG >= 2) {
Util::printHeader (" Iteration " + nIters_);
cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
for (unsigned i = 0; i < links_.size(); i++) {
updateMessage(links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
} }
} }
@ -488,7 +446,7 @@ BpSolver::converged (void)
if (links_.size() == 0) { if (links_.size() == 0) {
return true; return true;
} }
if (nIters_ == 0 || nIters_ == 1) { if (nIters_ <= 1) {
return false; return false;
} }
bool converged = true; bool converged = true;
@ -514,3 +472,19 @@ BpSolver::converged (void)
return converged; return converged;
} }
void
BpSolver::printLinkInformation (void) const
{
for (unsigned i = 0; i < links_.size(); i++) {
SpLink* l = links_[i];
cout << l->toString() << ":" << endl;
cout << " curr msg = " ;
cout << l->getMessage() << endl;
cout << " next msg = " ;
cout << l->getNextMessage() << endl;
cout << " residual = " << l->getResidual() << endl;
}
}

View File

@ -1,5 +1,5 @@
#ifndef HORUS_BpSolver_H #ifndef HORUS_BPSOLVER_H
#define HORUS_BpSolver_H #define HORUS_BPSOLVER_H
#include <set> #include <set>
#include <vector> #include <vector>
@ -102,7 +102,7 @@ class BpSolver : public Solver
virtual Params getJointDistributionOf (const VarIds&); virtual Params getJointDistributionOf (const VarIds&);
protected: protected:
virtual void initializeSolver (void); void runSolver (void);
virtual void createLinks (void); virtual void createLinks (void);
@ -114,8 +114,6 @@ class BpSolver : public Solver
virtual Params getJointByConditioning (const VarIds&) const; virtual Params getJointByConditioning (const VarIds&) const;
virtual void printLinkInformation (void) const;
SPNodeInfo* ninf (const VarNode* var) const SPNodeInfo* ninf (const VarNode* var) const
{ {
return varsI_[var->getIndex()]; return varsI_[var->getIndex()];
@ -170,7 +168,7 @@ class BpSolver : public Solver
vector<SPNodeInfo*> varsI_; vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_; vector<SPNodeInfo*> facsI_;
bool runned_; bool runned_;
const FactorGraph* factorGraph_; const FactorGraph* fg_;
typedef multiset<SpLink*, CompareResidual> SortedOrder; typedef multiset<SpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_; SortedOrder sortedOrder_;
@ -179,10 +177,12 @@ class BpSolver : public Solver
SpLinkMap linkMap_; SpLinkMap linkMap_;
private: private:
void runSolver (void); void initializeSolver (void);
void runLoopySolver (void);
bool converged (void); bool converged (void);
void printLinkInformation (void) const;
}; };
#endif // HORUS_BpSolver_H #endif // HORUS_BPSOLVER_H

View File

@ -18,14 +18,14 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg)
} }
const FacNodes& facNodes = fg.facNodes(); const FacNodes& facNodes = fg.facNodes();
factorSignatures_.reserve (facNodes.size()); facSignatures_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned c = facNodes[i]->neighbors().size() + 1; unsigned c = facNodes[i]->neighbors().size() + 1;
factorSignatures_.push_back (Signature (c)); facSignatures_.push_back (Signature (c));
} }
varColors_.resize (varNodes.size()); varColors_.resize (varNodes.size());
factorColors_.resize (facNodes.size()); facColors_.resize (facNodes.size());
setInitialColors(); setInitialColors();
createGroups(); createGroups();
} }
@ -111,7 +111,7 @@ void
CFactorGraph::createGroups (void) CFactorGraph::createGroups (void)
{ {
VarSignMap varGroups; VarSignMap varGroups;
FacSignMap factorGroups; FacSignMap facGroups;
unsigned nIters = 0; unsigned nIters = 0;
bool groupsHaveChanged = true; bool groupsHaveChanged = true;
const VarNodes& varNodes = groundFg_->varNodes(); const VarNodes& varNodes = groundFg_->varNodes();
@ -120,19 +120,19 @@ CFactorGraph::createGroups (void)
while (groupsHaveChanged || nIters == 1) { while (groupsHaveChanged || nIters == 1) {
nIters ++; nIters ++;
unsigned prevFactorGroupsSize = factorGroups.size(); unsigned prevFactorGroupsSize = facGroups.size();
factorGroups.clear(); facGroups.clear();
// set a new color to the factors with the same signature // set a new color to the factors with the same signature
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
const Signature& signature = getSignature (facNodes[i]); const Signature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = factorGroups.find (signature); FacSignMap::iterator it = facGroups.find (signature);
if (it == factorGroups.end()) { if (it == facGroups.end()) {
it = factorGroups.insert (make_pair (signature, FacNodes())).first; it = facGroups.insert (make_pair (signature, FacNodes())).first;
} }
it->second.push_back (facNodes[i]); it->second.push_back (facNodes[i]);
} }
for (FacSignMap::iterator it = factorGroups.begin(); for (FacSignMap::iterator it = facGroups.begin();
it != factorGroups.end(); it++) { it != facGroups.end(); it++) {
Color newColor = getFreeColor(); Color newColor = getFreeColor();
FacNodes& groupMembers = it->second; FacNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) { for (unsigned i = 0; i < groupMembers.size(); i++) {
@ -161,10 +161,10 @@ CFactorGraph::createGroups (void)
} }
groupsHaveChanged = prevVarGroupsSize != varGroups.size() groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != factorGroups.size(); || prevFactorGroupsSize != facGroups.size();
} }
//printGroups (varGroups, factorGroups); //printGroups (varGroups, facGroups);
createClusters (varGroups, factorGroups); createClusters (varGroups, facGroups);
} }
@ -172,7 +172,7 @@ CFactorGraph::createGroups (void)
void void
CFactorGraph::createClusters ( CFactorGraph::createClusters (
const VarSignMap& varGroups, const VarSignMap& varGroups,
const FacSignMap& factorGroups) const FacSignMap& facGroups)
{ {
varClusters_.reserve (varGroups.size()); varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin(); for (VarSignMap::const_iterator it = varGroups.begin();
@ -185,12 +185,12 @@ CFactorGraph::createClusters (
varClusters_.push_back (vc); varClusters_.push_back (vc);
} }
facClusters_.reserve (factorGroups.size()); facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = factorGroups.begin(); for (FacSignMap::const_iterator it = facGroups.begin();
it != factorGroups.end(); it++) { it != facGroups.end(); it++) {
FacNode* groupFactor = it->second[0]; FacNode* groupFactor = it->second[0];
const VarNodes& neighs = groupFactor->neighbors(); const VarNodes& neighs = groupFactor->neighbors();
VarClusterSet varClusters; VarClusters varClusters;
varClusters.reserve (neighs.size()); varClusters.reserve (neighs.size());
for (unsigned i = 0; i < neighs.size(); i++) { for (unsigned i = 0; i < neighs.size(); i++) {
VarId vid = neighs[i]->varId(); VarId vid = neighs[i]->varId();
@ -223,7 +223,7 @@ CFactorGraph::getSignature (const VarNode* varNode)
const Signature& const Signature&
CFactorGraph::getSignature (const FacNode* facNode) CFactorGraph::getSignature (const FacNode* facNode)
{ {
Signature& sign = factorSignatures_[facNode->getIndex()]; Signature& sign = facSignatures_[facNode->getIndex()];
vector<Color>::iterator it = sign.colors.begin(); vector<Color>::iterator it = sign.colors.begin();
const VarNodes& neighs = facNode->neighbors(); const VarNodes& neighs = facNode->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) { for (unsigned i = 0; i < neighs.size(); i++) {
@ -237,7 +237,7 @@ CFactorGraph::getSignature (const FacNode* facNode)
FactorGraph* FactorGraph*
CFactorGraph::getCompressedFactorGraph (void) CFactorGraph::getGroundFactorGraph (void) const
{ {
FactorGraph* fg = new FactorGraph(); FactorGraph* fg = new FactorGraph();
for (unsigned i = 0; i < varClusters_.size(); i++) { for (unsigned i = 0; i < varClusters_.size(); i++) {
@ -248,7 +248,7 @@ CFactorGraph::getCompressedFactorGraph (void)
} }
for (unsigned i = 0; i < facClusters_.size(); i++) { for (unsigned i = 0; i < facClusters_.size(); i++) {
const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters(); const VarClusters& myVarClusters = facClusters_[i]->getVarClusters();
Vars myGroundVars; Vars myGroundVars;
myGroundVars.reserve (myVarClusters.size()); myGroundVars.reserve (myVarClusters.size());
for (unsigned j = 0; j < myVarClusters.size(); j++) { for (unsigned j = 0; j < myVarClusters.size(); j++) {
@ -300,7 +300,7 @@ CFactorGraph::getGroundEdgeCount (
void void
CFactorGraph::printGroups ( CFactorGraph::printGroups (
const VarSignMap& varGroups, const VarSignMap& varGroups,
const FacSignMap& factorGroups) const const FacSignMap& facGroups) const
{ {
unsigned count = 1; unsigned count = 1;
cout << "variable groups:" << endl; cout << "variable groups:" << endl;
@ -319,8 +319,8 @@ CFactorGraph::printGroups (
count = 1; count = 1;
cout << endl << "factor groups:" << endl; cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = factorGroups.begin(); for (FacSignMap::const_iterator it = facGroups.begin();
it != factorGroups.end(); it++) { it != facGroups.end(); it++) {
const FacNodes& groupMembers = it->second; const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) { if (groupMembers.size() > 0) {
cout << ++count << ": " ; cout << ++count << ": " ;

View File

@ -22,8 +22,8 @@ typedef unordered_map<unsigned, vector<Color>> VarColorMap;
typedef unordered_map<unsigned, Color> DistColorMap; typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster; typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
typedef vector<VarCluster*> VarClusterSet; typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusterSet; typedef vector<FacCluster*> FacClusters;
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap; typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap; typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
@ -99,18 +99,20 @@ class VarCluster
facClusters_.push_back (fc); facClusters_.push_back (fc);
} }
const FacClusterSet& getFacClusters (void) const const FacClusters& getFacClusters (void) const
{ {
return facClusters_; return facClusters_;
} }
VarNode* getRepresentativeVariable (void) const { return representVar_; } VarNode* getRepresentativeVariable (void) const { return representVar_; }
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
const VarNodes& getGroundVarNodes (void) const { return groundVars_; } void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
private: private:
VarNodes groundVars_; VarNodes groundVars_;
FacClusterSet facClusters_; FacClusters facClusters_;
VarNode* representVar_; VarNode* representVar_;
}; };
@ -118,7 +120,7 @@ class VarCluster
class FacCluster class FacCluster
{ {
public: public:
FacCluster (const FacNodes& groundFactors, const VarClusterSet& vcs) FacCluster (const FacNodes& groundFactors, const VarClusters& vcs)
{ {
groundFactors_ = groundFactors; groundFactors_ = groundFactors;
varClusters_ = vcs; varClusters_ = vcs;
@ -127,7 +129,7 @@ class FacCluster
} }
} }
const VarClusterSet& getVarClusters (void) const const VarClusters& getVarClusters (void) const
{ {
return varClusters_; return varClusters_;
} }
@ -160,7 +162,7 @@ class FacCluster
private: private:
FacNodes groundFactors_; FacNodes groundFactors_;
VarClusterSet varClusters_; VarClusters varClusters_;
FacNode* representFactor_; FacNode* representFactor_;
}; };
@ -172,9 +174,9 @@ class CFactorGraph
~CFactorGraph (void); ~CFactorGraph (void);
const VarClusterSet& getVarClusters (void) { return varClusters_; } const VarClusters& getVarClusters (void) { return varClusters_; }
const FacClusterSet& getFacClusters (void) { return facClusters_; } const FacClusters& getFacClusters (void) { return facClusters_; }
VarNode* getEquivalentVariable (VarId vid) VarNode* getEquivalentVariable (VarId vid)
{ {
@ -182,7 +184,7 @@ class CFactorGraph
return vc->getRepresentativeVariable(); return vc->getRepresentativeVariable();
} }
FactorGraph* getCompressedFactorGraph (void); FactorGraph* getGroundFactorGraph (void) const;
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const; unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
@ -200,7 +202,7 @@ class CFactorGraph
return varColors_[vn->getIndex()]; return varColors_[vn->getIndex()];
} }
Color getColor (const FacNode* fn) const { Color getColor (const FacNode* fn) const {
return factorColors_[fn->getIndex()]; return facColors_[fn->getIndex()];
} }
void setColor (const VarNode* vn, Color c) void setColor (const VarNode* vn, Color c)
@ -210,7 +212,7 @@ class CFactorGraph
void setColor (const FacNode* fn, Color c) void setColor (const FacNode* fn, Color c)
{ {
factorColors_[fn->getIndex()] = c; facColors_[fn->getIndex()] = c;
} }
VarCluster* getVariableCluster (VarId vid) const VarCluster* getVariableCluster (VarId vid) const
@ -232,11 +234,11 @@ class CFactorGraph
Color freeColor_; Color freeColor_;
vector<Color> varColors_; vector<Color> varColors_;
vector<Color> factorColors_; vector<Color> facColors_;
vector<Signature> varSignatures_; vector<Signature> varSignatures_;
vector<Signature> factorSignatures_; vector<Signature> facSignatures_;
VarClusterSet varClusters_; VarClusters varClusters_;
FacClusterSet facClusters_; FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_; VarId2VarCluster vid2VarCluster_;
const FactorGraph* groundFg_; const FactorGraph* groundFg_;
}; };

View File

@ -1,10 +1,41 @@
#include "CbpSolver.h" #include "CbpSolver.h"
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
{
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
if (Constants::COLLECT_STATS) {
nGroundVars = fg_->varNodes().size();
nGroundFacs = fg_->facNodes().size();
const VarNodes& vars = fg_->varNodes();
nWithoutNeighs = 0;
for (unsigned i = 0; i < vars.size(); i++) {
const FacNodes& factors = vars[i]->neighbors();
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
nWithoutNeighs ++;
}
}
}
cfg_ = new CFactorGraph (fg);
fg_ = cfg_->getGroundFactorGraph();
if (Constants::COLLECT_STATS) {
unsigned nClusterVars = fg_->varNodes().size();
unsigned nClusterFacs = fg_->facNodes().size();
Statistics::updateCompressingStatistics (nGroundVars,
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
}
// Util::printHeader ("Uncompressed Factor Graph");
// fg->print();
// Util::printHeader ("Compressed Factor Graph");
// fg_->print();
}
CbpSolver::~CbpSolver (void) CbpSolver::~CbpSolver (void)
{ {
delete lfg_; delete cfg_;
delete factorGraph_; delete fg_;
for (unsigned i = 0; i < links_.size(); i++) { for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i]; delete links_[i];
} }
@ -16,8 +47,11 @@ CbpSolver::~CbpSolver (void)
Params Params
CbpSolver::getPosterioriOf (VarId vid) CbpSolver::getPosterioriOf (VarId vid)
{ {
assert (lfg_->getEquivalentVariable (vid)); if (runned_ == false) {
VarNode* var = lfg_->getEquivalentVariable (vid); runSolver();
}
assert (cfg_->getEquivalentVariable (vid));
VarNode* var = cfg_->getEquivalentVariable (vid);
Params probs; Params probs;
if (var->hasEvidence()) { if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence()); probs.resize (var->range(), LogAware::noEvidence());
@ -26,16 +60,16 @@ CbpSolver::getPosterioriOf (VarId vid)
probs.resize (var->range(), LogAware::multIdenty()); probs.resize (var->range(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks(); const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) { if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (probs, l->getPoweredMessage()); Util::add (probs, l->poweredMessage());
} }
LogAware::normalize (probs); LogAware::normalize (probs);
Util::fromLog (probs); Util::fromLog (probs);
} else { } else {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (probs, l->getPoweredMessage()); Util::multiply (probs, l->poweredMessage());
} }
LogAware::normalize (probs); LogAware::normalize (probs);
} }
@ -46,67 +80,28 @@ CbpSolver::getPosterioriOf (VarId vid)
Params Params
CbpSolver::getJointDistributionOf (const VarIds& jointVarIds) CbpSolver::getJointDistributionOf (const VarIds& jointVids)
{ {
VarIds eqVarIds; VarIds eqVarIds;
for (unsigned i = 0; i < jointVarIds.size(); i++) { for (unsigned i = 0; i < jointVids.size(); i++) {
eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId()); VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
eqVarIds.push_back (vn->varId());
} }
return BpSolver::getJointDistributionOf (eqVarIds); return BpSolver::getJointDistributionOf (eqVarIds);
} }
void
CbpSolver::initializeSolver (void)
{
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
if (Constants::COLLECT_STATS) {
nGroundVars = factorGraph_->varNodes().size();
nGroundFacs = factorGraph_->facNodes().size();
const VarNodes& vars = factorGraph_->varNodes();
nWithoutNeighs = 0;
for (unsigned i = 0; i < vars.size(); i++) {
const FacNodes& factors = vars[i]->neighbors();
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
nWithoutNeighs ++;
}
}
}
lfg_ = new CFactorGraph (*factorGraph_);
// cout << "Uncompressed Factor Graph" << endl;
// factorGraph_->print();
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
factorGraph_ = lfg_->getCompressedFactorGraph();
if (Constants::COLLECT_STATS) {
unsigned nClusterVars = factorGraph_->varNodes().size();
unsigned nClusterFacs = factorGraph_->facNodes().size();
Statistics::updateCompressingStatistics (nGroundVars,
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
}
// cout << "Compressed Factor Graph" << endl;
// factorGraph_->print();
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
// abort();
BpSolver::initializeSolver();
}
void void
CbpSolver::createLinks (void) CbpSolver::createLinks (void)
{ {
const FacClusterSet fcs = lfg_->getFacClusters(); const FacClusters& fcs = cfg_->getFacClusters();
for (unsigned i = 0; i < fcs.size(); i++) { for (unsigned i = 0; i < fcs.size(); i++) {
const VarClusterSet vcs = fcs[i]->getVarClusters(); const VarClusters& vcs = fcs[i]->getVarClusters();
for (unsigned j = 0; j < vcs.size(); j++) { for (unsigned j = 0; j < vcs.size(); j++) {
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]); unsigned c = cfg_->getGroundEdgeCount (fcs[i], vcs[j]);
links_.push_back (new CbpSolverLink (fcs[i]->getRepresentativeFactor(), links_.push_back (new CbpSolverLink (
fcs[i]->getRepresentativeFactor(),
vcs[j]->getRepresentativeVariable(), c)); vcs[j]->getRepresentativeVariable(), c));
} }
} }
@ -197,10 +192,10 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
if (src->hasEvidence()) { if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence()); msg.resize (src->range(), LogAware::noEvidence());
double value = link->getMessage()[src->getEvidence()]; double value = link->getMessage()[src->getEvidence()];
msg[src->getEvidence()] = LogAware::pow (value, l->getNumberOfEdges() - 1); msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
} else { } else {
msg = link->getMessage(); msg = link->getMessage();
LogAware::pow (msg, l->getNumberOfEdges() - 1); LogAware::pow (msg, l->nrEdges() - 1);
} }
if (Constants::DEBUG >= 5) { if (Constants::DEBUG >= 5) {
cout << " " << "init: " << msg << endl; cout << " " << "init: " << msg << endl;
@ -210,17 +205,17 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) { if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (msg, l->getPoweredMessage()); Util::add (msg, l->poweredMessage());
} }
} }
} else { } else {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) { if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (msg, l->getPoweredMessage()); Util::multiply (msg, l->poweredMessage());
if (Constants::DEBUG >= 5) { if (Constants::DEBUG >= 5) {
cout << " msg from " << l->getFactor()->getLabel() << ": " ; cout << " msg from " << l->getFactor()->getLabel() << ": " ;
cout << l->getPoweredMessage() << endl; cout << l->poweredMessage() << endl;
} }
} }
} }
@ -242,7 +237,7 @@ CbpSolver::printLinkInformation (void) const
cout << l->toString() << ":" << endl; cout << l->toString() << ":" << endl;
cout << " curr msg = " << l->getMessage() << endl; cout << " curr msg = " << l->getMessage() << endl;
cout << " next msg = " << l->getNextMessage() << endl; cout << " next msg = " << l->getNextMessage() << endl;
cout << " powered = " << l->getPoweredMessage() << endl; cout << " powered = " << l->poweredMessage() << endl;
cout << " residual = " << l->getResidual() << endl; cout << " residual = " << l->getResidual() << endl;
} }
} }

View File

@ -9,27 +9,25 @@ class Factor;
class CbpSolverLink : public SpLink class CbpSolverLink : public SpLink
{ {
public: public:
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c) : SpLink (fn, vn) CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
{ : SpLink (fn, vn), nrEdges_(c),
edgeCount_ = c; pwdMsg_(vn->range(), LogAware::one()) { }
poweredMsg_.resize (vn->range(), LogAware::one());
}
unsigned getNumberOfEdges (void) const { return edgeCount_; } unsigned nrEdges (void) const { return nrEdges_; }
const Params& getPoweredMessage (void) const { return poweredMsg_; } const Params& poweredMessage (void) const { return pwdMsg_; }
void updateMessage (void) void updateMessage (void)
{ {
poweredMsg_ = *nextMsg_; pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_); swap (currMsg_, nextMsg_);
msgSended_ = true; msgSended_ = true;
LogAware::pow (poweredMsg_, edgeCount_); LogAware::pow (pwdMsg_, nrEdges_);
} }
private: private:
Params poweredMsg_; unsigned nrEdges_;
unsigned edgeCount_; Params pwdMsg_;
}; };
@ -37,16 +35,15 @@ class CbpSolverLink : public SpLink
class CbpSolver : public BpSolver class CbpSolver : public BpSolver
{ {
public: public:
CbpSolver (const FactorGraph& fg) : BpSolver (fg) { } CbpSolver (const FactorGraph& fg);
~CbpSolver (void); ~CbpSolver (void);
Params getPosterioriOf (VarId); Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&); Params getJointDistributionOf (const VarIds&);
private: private:
void initializeSolver (void);
void createLinks (void); void createLinks (void);
@ -56,8 +53,7 @@ class CbpSolver : public BpSolver
void printLinkInformation (void) const; void printLinkInformation (void) const;
CFactorGraph* lfg_; CFactorGraph* cfg_;
FactorGraph* factorGraph_;
}; };
#endif // HORUS_CBP_H #endif // HORUS_CBP_H

View File

@ -14,25 +14,43 @@ void processArguments (FactorGraph&, int, const char* []);
void runSolver (const FactorGraph&, const VarIds&); void runSolver (const FactorGraph&, const VarIds&);
const string USAGE = "usage: \ const string USAGE = "usage: \
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ; ./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
int int
main (int argc, const char* argv[]) main (int argc, const char* argv[])
{ {
if (!argv[1]) { if (argc <= 1) {
cerr << "error: no solver specified" << endl;
cerr << "error: no graphical model specified" << endl; cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl; cerr << USAGE << endl;
exit (0); exit (0);
} }
string fileName = argv[1]; if (argc <= 2) {
cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl;
exit (0);
}
string solver (argv[1]);
if (solver == "ve") {
Globals::infAlgorithm = InfAlgorithms::VE;
} else if (solver == "bp") {
Globals::infAlgorithm = InfAlgorithms::BP;
} else if (solver == "cbp") {
Globals::infAlgorithm = InfAlgorithms::CBP;
} else {
cerr << "error: unknow solver `" << solver << "'" << endl ;
cerr << USAGE << endl;
exit(0);
}
string fileName (argv[2]);
string extension = fileName.substr ( string extension = fileName.substr (
fileName.find_last_of ('.') + 1); fileName.find_last_of ('.') + 1);
FactorGraph fg; FactorGraph fg;
if (extension == "uai") { if (extension == "uai") {
fg.readFromUaiFormat (argv[1]); fg.readFromUaiFormat (fileName.c_str());
} else if (extension == "fg") { } else if (extension == "fg") {
fg.readFromLibDaiFormat (argv[1]); fg.readFromLibDaiFormat (fileName.c_str());
} else { } else {
cerr << "error: the graphical model must be defined either " ; cerr << "error: the graphical model must be defined either " ;
cerr << "in a UAI or libDAI file" << endl; cerr << "in a UAI or libDAI file" << endl;
@ -48,7 +66,7 @@ void
processArguments (FactorGraph& fg, int argc, const char* argv[]) processArguments (FactorGraph& fg, int argc, const char* argv[])
{ {
VarIds queryIds; VarIds queryIds;
for (int i = 2; i < argc; i++) { for (int i = 3; i < argc; i++) {
const string& arg = argv[i]; const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) { if (arg.find ('=') == std::string::npos) {
if (!Util::isInteger (arg)) { if (!Util::isInteger (arg)) {

View File

@ -8,7 +8,7 @@ Solver::printAnswer (const VarIds& vids)
Vars unobservedVars; Vars unobservedVars;
VarIds unobservedVids; VarIds unobservedVids;
for (unsigned i = 0; i < vids.size(); i++) { for (unsigned i = 0; i < vids.size(); i++) {
VarNode* vn = fg_.getVarNode (vids[i]); VarNode* vn = fg.getVarNode (vids[i]);
if (vn->hasEvidence() == false) { if (vn->hasEvidence() == false) {
unobservedVars.push_back (vn); unobservedVars.push_back (vn);
unobservedVids.push_back (vids[i]); unobservedVids.push_back (vids[i]);
@ -29,7 +29,7 @@ Solver::printAnswer (const VarIds& vids)
void void
Solver::printAllPosterioris (void) Solver::printAllPosterioris (void)
{ {
const VarNodes& vars = fg_.varNodes(); const VarNodes& vars = fg.varNodes();
for (unsigned i = 0; i < vars.size(); i++) { for (unsigned i = 0; i < vars.size(); i++) {
printAnswer ({vars[i]->varId()}); printAnswer ({vars[i]->varId()});
} }

View File

@ -12,7 +12,7 @@ using namespace std;
class Solver class Solver
{ {
public: public:
Solver (const FactorGraph& fg) : fg_(fg) { } Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
virtual ~Solver() { } // ensure that subclass destructor is called virtual ~Solver() { } // ensure that subclass destructor is called
@ -23,7 +23,7 @@ class Solver
void printAllPosterioris (void); void printAllPosterioris (void);
protected: protected:
const FactorGraph& fg_; const FactorGraph& fg;
}; };
#endif // HORUS_SOLVER_H #endif // HORUS_SOLVER_H

View File

@ -35,7 +35,7 @@ VarElimSolver::solveQuery (VarIds queryVids)
void void
VarElimSolver::createFactorList (void) VarElimSolver::createFactorList (void)
{ {
const FacNodes& facNodes = fg_.facNodes(); const FacNodes& facNodes = fg.facNodes();
factorList_.reserve (facNodes.size() * 2); factorList_.reserve (facNodes.size() * 2);
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
factorList_.push_back (new Factor (facNodes[i]->factor())); factorList_.push_back (new Factor (facNodes[i]->factor()));
@ -57,7 +57,7 @@ VarElimSolver::createFactorList (void)
void void
VarElimSolver::absorveEvidence (void) VarElimSolver::absorveEvidence (void)
{ {
const VarNodes& varNodes = fg_.varNodes(); const VarNodes& varNodes = fg.varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) { for (unsigned i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) { if (varNodes[i]->hasEvidence()) {
const vector<unsigned>& idxs = const vector<unsigned>& idxs =
@ -103,7 +103,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
VarIds unobservedVids; VarIds unobservedVids;
for (unsigned i = 0; i < vids.size(); i++) { for (unsigned i = 0; i < vids.size(); i++) {
if (fg_.getVarNode (vids[i])->hasEvidence() == false) { if (fg.getVarNode (vids[i])->hasEvidence() == false) {
unobservedVids.push_back (vids[i]); unobservedVids.push_back (vids[i]);
} }
} }