refactorings

This commit is contained in:
Tiago Gomes 2012-04-10 20:43:08 +01:00
parent 78e86a6330
commit 8697fcd2b4
10 changed files with 251 additions and 266 deletions

View File

@ -14,7 +14,7 @@
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
{
factorGraph_ = &fg;
fg_ = &fg;
runned_ = false;
}
@ -54,8 +54,8 @@ BpSolver::getPosterioriOf (VarId vid)
if (runned_ == false) {
runSolver();
}
assert (factorGraph_->getVarNode (vid));
VarNode* var = factorGraph_->getVarNode (vid);
assert (fg_->getVarNode (vid));
VarNode* var = fg_->getVarNode (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence());
@ -88,7 +88,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
runSolver();
}
int idx = -1;
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
const FacNodes& facNodes = vn->neighbors();
for (unsigned i = 0; i < facNodes.size(); i++) {
if (facNodes[i]->factor().contains (jointVarIds)) {
@ -121,37 +121,64 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
void
BpSolver::initializeSolver (void)
BpSolver::runSolver (void)
{
const VarNodes& varNodes = factorGraph_->varNodes();
for (unsigned i = 0; i < varsI_.size(); i++) {
delete varsI_[i];
clock_t start;
if (Constants::COLLECT_STATS) {
start = clock();
}
varsI_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo());
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (Constants::DEBUG >= 2) {
Util::printHeader (" Iteration " + nIters_);
cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
for (unsigned i = 0; i < links_.size(); i++) {
updateMessage(links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
const FacNodes& facNodes = factorGraph_->facNodes();
for (unsigned i = 0; i < facsI_.size(); i++) {
delete facsI_[i];
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
facsI_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo());
}
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
createLinks();
for (unsigned i = 0; i < links_.size(); i++) {
FacNode* src = links_[i]->getFactor();
VarNode* dst = links_[i]->getVariable();
ninf (dst)->addSpLink (links_[i]);
ninf (src)->addSpLink (links_[i]);
unsigned size = fg_->varNodes().size();
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = fg_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
runned_ = true;
}
@ -159,7 +186,7 @@ BpSolver::initializeSolver (void)
void
BpSolver::createLinks (void)
{
const FacNodes& facNodes = factorGraph_->facNodes();
const FacNodes& facNodes = fg_->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
const VarNodes& neighbors = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighbors.size(); j++) {
@ -342,11 +369,11 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
{
VarNodes jointVars;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
assert (factorGraph_->getVarNode (jointVarIds[i]));
jointVars.push_back (factorGraph_->getVarNode (jointVarIds[i]));
assert (fg_->getVarNode (jointVarIds[i]));
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
}
FactorGraph* fg = new FactorGraph (*factorGraph_);
FactorGraph* fg = new FactorGraph (*fg_);
BpSolver solver (*fg);
solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
@ -390,93 +417,24 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
void
BpSolver::printLinkInformation (void) const
BpSolver::initializeSolver (void)
{
const VarNodes& varNodes = fg_->varNodes();
varsI_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo());
}
const FacNodes& facNodes = fg_->facNodes();
facsI_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo());
}
createLinks();
for (unsigned i = 0; i < links_.size(); i++) {
SpLink* l = links_[i];
cout << l->toString() << ":" << endl;
cout << " curr msg = " ;
cout << l->getMessage() << endl;
cout << " next msg = " ;
cout << l->getNextMessage() << endl;
cout << " residual = " << l->getResidual() << endl;
}
}
void
BpSolver::runSolver (void)
{
clock_t start;
if (Constants::COLLECT_STATS) {
start = clock();
}
runLoopySolver();
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = factorGraph_->varNodes().size();
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = factorGraph_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
runned_ = true;
}
void
BpSolver::runLoopySolver (void)
{
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (Constants::DEBUG >= 2) {
Util::printHeader (" Iteration " + nIters_);
cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
for (unsigned i = 0; i < links_.size(); i++) {
updateMessage(links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
FacNode* src = links_[i]->getFactor();
VarNode* dst = links_[i]->getVariable();
ninf (dst)->addSpLink (links_[i]);
ninf (src)->addSpLink (links_[i]);
}
}
@ -488,7 +446,7 @@ BpSolver::converged (void)
if (links_.size() == 0) {
return true;
}
if (nIters_ == 0 || nIters_ == 1) {
if (nIters_ <= 1) {
return false;
}
bool converged = true;
@ -514,3 +472,19 @@ BpSolver::converged (void)
return converged;
}
void
BpSolver::printLinkInformation (void) const
{
for (unsigned i = 0; i < links_.size(); i++) {
SpLink* l = links_[i];
cout << l->toString() << ":" << endl;
cout << " curr msg = " ;
cout << l->getMessage() << endl;
cout << " next msg = " ;
cout << l->getNextMessage() << endl;
cout << " residual = " << l->getResidual() << endl;
}
}

View File

@ -1,5 +1,5 @@
#ifndef HORUS_BpSolver_H
#define HORUS_BpSolver_H
#ifndef HORUS_BPSOLVER_H
#define HORUS_BPSOLVER_H
#include <set>
#include <vector>
@ -102,7 +102,7 @@ class BpSolver : public Solver
virtual Params getJointDistributionOf (const VarIds&);
protected:
virtual void initializeSolver (void);
void runSolver (void);
virtual void createLinks (void);
@ -114,8 +114,6 @@ class BpSolver : public Solver
virtual Params getJointByConditioning (const VarIds&) const;
virtual void printLinkInformation (void) const;
SPNodeInfo* ninf (const VarNode* var) const
{
return varsI_[var->getIndex()];
@ -170,7 +168,7 @@ class BpSolver : public Solver
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_;
bool runned_;
const FactorGraph* factorGraph_;
const FactorGraph* fg_;
typedef multiset<SpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_;
@ -179,10 +177,12 @@ class BpSolver : public Solver
SpLinkMap linkMap_;
private:
void runSolver (void);
void runLoopySolver (void);
void initializeSolver (void);
bool converged (void);
void printLinkInformation (void) const;
};
#endif // HORUS_BpSolver_H
#endif // HORUS_BPSOLVER_H

View File

@ -18,14 +18,14 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg)
}
const FacNodes& facNodes = fg.facNodes();
factorSignatures_.reserve (facNodes.size());
facSignatures_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned c = facNodes[i]->neighbors().size() + 1;
factorSignatures_.push_back (Signature (c));
facSignatures_.push_back (Signature (c));
}
varColors_.resize (varNodes.size());
factorColors_.resize (facNodes.size());
facColors_.resize (facNodes.size());
setInitialColors();
createGroups();
}
@ -111,7 +111,7 @@ void
CFactorGraph::createGroups (void)
{
VarSignMap varGroups;
FacSignMap factorGroups;
FacSignMap facGroups;
unsigned nIters = 0;
bool groupsHaveChanged = true;
const VarNodes& varNodes = groundFg_->varNodes();
@ -120,19 +120,19 @@ CFactorGraph::createGroups (void)
while (groupsHaveChanged || nIters == 1) {
nIters ++;
unsigned prevFactorGroupsSize = factorGroups.size();
factorGroups.clear();
unsigned prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (unsigned i = 0; i < facNodes.size(); i++) {
const Signature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = factorGroups.find (signature);
if (it == factorGroups.end()) {
it = factorGroups.insert (make_pair (signature, FacNodes())).first;
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
for (FacSignMap::iterator it = factorGroups.begin();
it != factorGroups.end(); it++) {
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); it++) {
Color newColor = getFreeColor();
FacNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) {
@ -161,10 +161,10 @@ CFactorGraph::createGroups (void)
}
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != factorGroups.size();
|| prevFactorGroupsSize != facGroups.size();
}
//printGroups (varGroups, factorGroups);
createClusters (varGroups, factorGroups);
//printGroups (varGroups, facGroups);
createClusters (varGroups, facGroups);
}
@ -172,7 +172,7 @@ CFactorGraph::createGroups (void)
void
CFactorGraph::createClusters (
const VarSignMap& varGroups,
const FacSignMap& factorGroups)
const FacSignMap& facGroups)
{
varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin();
@ -185,12 +185,12 @@ CFactorGraph::createClusters (
varClusters_.push_back (vc);
}
facClusters_.reserve (factorGroups.size());
for (FacSignMap::const_iterator it = factorGroups.begin();
it != factorGroups.end(); it++) {
facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); it++) {
FacNode* groupFactor = it->second[0];
const VarNodes& neighs = groupFactor->neighbors();
VarClusterSet varClusters;
VarClusters varClusters;
varClusters.reserve (neighs.size());
for (unsigned i = 0; i < neighs.size(); i++) {
VarId vid = neighs[i]->varId();
@ -223,7 +223,7 @@ CFactorGraph::getSignature (const VarNode* varNode)
const Signature&
CFactorGraph::getSignature (const FacNode* facNode)
{
Signature& sign = factorSignatures_[facNode->getIndex()];
Signature& sign = facSignatures_[facNode->getIndex()];
vector<Color>::iterator it = sign.colors.begin();
const VarNodes& neighs = facNode->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) {
@ -237,7 +237,7 @@ CFactorGraph::getSignature (const FacNode* facNode)
FactorGraph*
CFactorGraph::getCompressedFactorGraph (void)
CFactorGraph::getGroundFactorGraph (void) const
{
FactorGraph* fg = new FactorGraph();
for (unsigned i = 0; i < varClusters_.size(); i++) {
@ -248,7 +248,7 @@ CFactorGraph::getCompressedFactorGraph (void)
}
for (unsigned i = 0; i < facClusters_.size(); i++) {
const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters();
const VarClusters& myVarClusters = facClusters_[i]->getVarClusters();
Vars myGroundVars;
myGroundVars.reserve (myVarClusters.size());
for (unsigned j = 0; j < myVarClusters.size(); j++) {
@ -300,7 +300,7 @@ CFactorGraph::getGroundEdgeCount (
void
CFactorGraph::printGroups (
const VarSignMap& varGroups,
const FacSignMap& factorGroups) const
const FacSignMap& facGroups) const
{
unsigned count = 1;
cout << "variable groups:" << endl;
@ -319,8 +319,8 @@ CFactorGraph::printGroups (
count = 1;
cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = factorGroups.begin();
it != factorGroups.end(); it++) {
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); it++) {
const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << ++count << ": " ;

View File

@ -22,8 +22,8 @@ typedef unordered_map<unsigned, vector<Color>> VarColorMap;
typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
typedef vector<VarCluster*> VarClusterSet;
typedef vector<FacCluster*> FacClusterSet;
typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters;
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
@ -99,18 +99,20 @@ class VarCluster
facClusters_.push_back (fc);
}
const FacClusterSet& getFacClusters (void) const
const FacClusters& getFacClusters (void) const
{
return facClusters_;
}
VarNode* getRepresentativeVariable (void) const { return representVar_; }
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
private:
VarNodes groundVars_;
FacClusterSet facClusters_;
FacClusters facClusters_;
VarNode* representVar_;
};
@ -118,7 +120,7 @@ class VarCluster
class FacCluster
{
public:
FacCluster (const FacNodes& groundFactors, const VarClusterSet& vcs)
FacCluster (const FacNodes& groundFactors, const VarClusters& vcs)
{
groundFactors_ = groundFactors;
varClusters_ = vcs;
@ -127,7 +129,7 @@ class FacCluster
}
}
const VarClusterSet& getVarClusters (void) const
const VarClusters& getVarClusters (void) const
{
return varClusters_;
}
@ -160,7 +162,7 @@ class FacCluster
private:
FacNodes groundFactors_;
VarClusterSet varClusters_;
VarClusters varClusters_;
FacNode* representFactor_;
};
@ -172,9 +174,9 @@ class CFactorGraph
~CFactorGraph (void);
const VarClusterSet& getVarClusters (void) { return varClusters_; }
const VarClusters& getVarClusters (void) { return varClusters_; }
const FacClusterSet& getFacClusters (void) { return facClusters_; }
const FacClusters& getFacClusters (void) { return facClusters_; }
VarNode* getEquivalentVariable (VarId vid)
{
@ -182,7 +184,7 @@ class CFactorGraph
return vc->getRepresentativeVariable();
}
FactorGraph* getCompressedFactorGraph (void);
FactorGraph* getGroundFactorGraph (void) const;
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
@ -200,7 +202,7 @@ class CFactorGraph
return varColors_[vn->getIndex()];
}
Color getColor (const FacNode* fn) const {
return factorColors_[fn->getIndex()];
return facColors_[fn->getIndex()];
}
void setColor (const VarNode* vn, Color c)
@ -210,7 +212,7 @@ class CFactorGraph
void setColor (const FacNode* fn, Color c)
{
factorColors_[fn->getIndex()] = c;
facColors_[fn->getIndex()] = c;
}
VarCluster* getVariableCluster (VarId vid) const
@ -232,11 +234,11 @@ class CFactorGraph
Color freeColor_;
vector<Color> varColors_;
vector<Color> factorColors_;
vector<Color> facColors_;
vector<Signature> varSignatures_;
vector<Signature> factorSignatures_;
VarClusterSet varClusters_;
FacClusterSet facClusters_;
vector<Signature> facSignatures_;
VarClusters varClusters_;
FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_;
const FactorGraph* groundFg_;
};

View File

@ -1,10 +1,41 @@
#include "CbpSolver.h"
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
{
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
if (Constants::COLLECT_STATS) {
nGroundVars = fg_->varNodes().size();
nGroundFacs = fg_->facNodes().size();
const VarNodes& vars = fg_->varNodes();
nWithoutNeighs = 0;
for (unsigned i = 0; i < vars.size(); i++) {
const FacNodes& factors = vars[i]->neighbors();
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
nWithoutNeighs ++;
}
}
}
cfg_ = new CFactorGraph (fg);
fg_ = cfg_->getGroundFactorGraph();
if (Constants::COLLECT_STATS) {
unsigned nClusterVars = fg_->varNodes().size();
unsigned nClusterFacs = fg_->facNodes().size();
Statistics::updateCompressingStatistics (nGroundVars,
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
}
// Util::printHeader ("Uncompressed Factor Graph");
// fg->print();
// Util::printHeader ("Compressed Factor Graph");
// fg_->print();
}
CbpSolver::~CbpSolver (void)
{
delete lfg_;
delete factorGraph_;
delete cfg_;
delete fg_;
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
@ -16,8 +47,11 @@ CbpSolver::~CbpSolver (void)
Params
CbpSolver::getPosterioriOf (VarId vid)
{
assert (lfg_->getEquivalentVariable (vid));
VarNode* var = lfg_->getEquivalentVariable (vid);
if (runned_ == false) {
runSolver();
}
assert (cfg_->getEquivalentVariable (vid));
VarNode* var = cfg_->getEquivalentVariable (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence());
@ -26,16 +60,16 @@ CbpSolver::getPosterioriOf (VarId vid)
probs.resize (var->range(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (probs, l->getPoweredMessage());
}
LogAware::normalize (probs);
Util::fromLog (probs);
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (probs, l->poweredMessage());
}
LogAware::normalize (probs);
Util::fromLog (probs);
} else {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (probs, l->getPoweredMessage());
Util::multiply (probs, l->poweredMessage());
}
LogAware::normalize (probs);
}
@ -46,67 +80,28 @@ CbpSolver::getPosterioriOf (VarId vid)
Params
CbpSolver::getJointDistributionOf (const VarIds& jointVarIds)
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
{
VarIds eqVarIds;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId());
for (unsigned i = 0; i < jointVids.size(); i++) {
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
eqVarIds.push_back (vn->varId());
}
return BpSolver::getJointDistributionOf (eqVarIds);
}
void
CbpSolver::initializeSolver (void)
{
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
if (Constants::COLLECT_STATS) {
nGroundVars = factorGraph_->varNodes().size();
nGroundFacs = factorGraph_->facNodes().size();
const VarNodes& vars = factorGraph_->varNodes();
nWithoutNeighs = 0;
for (unsigned i = 0; i < vars.size(); i++) {
const FacNodes& factors = vars[i]->neighbors();
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
nWithoutNeighs ++;
}
}
}
lfg_ = new CFactorGraph (*factorGraph_);
// cout << "Uncompressed Factor Graph" << endl;
// factorGraph_->print();
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
factorGraph_ = lfg_->getCompressedFactorGraph();
if (Constants::COLLECT_STATS) {
unsigned nClusterVars = factorGraph_->varNodes().size();
unsigned nClusterFacs = factorGraph_->facNodes().size();
Statistics::updateCompressingStatistics (nGroundVars,
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
}
// cout << "Compressed Factor Graph" << endl;
// factorGraph_->print();
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
// abort();
BpSolver::initializeSolver();
}
void
CbpSolver::createLinks (void)
{
const FacClusterSet fcs = lfg_->getFacClusters();
const FacClusters& fcs = cfg_->getFacClusters();
for (unsigned i = 0; i < fcs.size(); i++) {
const VarClusterSet vcs = fcs[i]->getVarClusters();
const VarClusters& vcs = fcs[i]->getVarClusters();
for (unsigned j = 0; j < vcs.size(); j++) {
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
links_.push_back (new CbpSolverLink (fcs[i]->getRepresentativeFactor(),
unsigned c = cfg_->getGroundEdgeCount (fcs[i], vcs[j]);
links_.push_back (new CbpSolverLink (
fcs[i]->getRepresentativeFactor(),
vcs[j]->getRepresentativeVariable(), c));
}
}
@ -197,10 +192,10 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence());
double value = link->getMessage()[src->getEvidence()];
msg[src->getEvidence()] = LogAware::pow (value, l->getNumberOfEdges() - 1);
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
} else {
msg = link->getMessage();
LogAware::pow (msg, l->getNumberOfEdges() - 1);
LogAware::pow (msg, l->nrEdges() - 1);
}
if (Constants::DEBUG >= 5) {
cout << " " << "init: " << msg << endl;
@ -210,17 +205,17 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (msg, l->getPoweredMessage());
Util::add (msg, l->poweredMessage());
}
}
} else {
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (msg, l->getPoweredMessage());
Util::multiply (msg, l->poweredMessage());
if (Constants::DEBUG >= 5) {
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
cout << l->getPoweredMessage() << endl;
cout << l->poweredMessage() << endl;
}
}
}
@ -242,7 +237,7 @@ CbpSolver::printLinkInformation (void) const
cout << l->toString() << ":" << endl;
cout << " curr msg = " << l->getMessage() << endl;
cout << " next msg = " << l->getNextMessage() << endl;
cout << " powered = " << l->getPoweredMessage() << endl;
cout << " powered = " << l->poweredMessage() << endl;
cout << " residual = " << l->getResidual() << endl;
}
}

View File

@ -9,27 +9,25 @@ class Factor;
class CbpSolverLink : public SpLink
{
public:
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c) : SpLink (fn, vn)
{
edgeCount_ = c;
poweredMsg_.resize (vn->range(), LogAware::one());
}
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
: SpLink (fn, vn), nrEdges_(c),
pwdMsg_(vn->range(), LogAware::one()) { }
unsigned getNumberOfEdges (void) const { return edgeCount_; }
unsigned nrEdges (void) const { return nrEdges_; }
const Params& getPoweredMessage (void) const { return poweredMsg_; }
const Params& poweredMessage (void) const { return pwdMsg_; }
void updateMessage (void)
{
poweredMsg_ = *nextMsg_;
pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_);
msgSended_ = true;
LogAware::pow (poweredMsg_, edgeCount_);
LogAware::pow (pwdMsg_, nrEdges_);
}
private:
Params poweredMsg_;
unsigned edgeCount_;
unsigned nrEdges_;
Params pwdMsg_;
};
@ -37,16 +35,15 @@ class CbpSolverLink : public SpLink
class CbpSolver : public BpSolver
{
public:
CbpSolver (const FactorGraph& fg) : BpSolver (fg) { }
CbpSolver (const FactorGraph& fg);
~CbpSolver (void);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
private:
void initializeSolver (void);
void createLinks (void);
@ -56,8 +53,7 @@ class CbpSolver : public BpSolver
void printLinkInformation (void) const;
CFactorGraph* lfg_;
FactorGraph* factorGraph_;
CFactorGraph* cfg_;
};
#endif // HORUS_CBP_H

View File

@ -14,25 +14,43 @@ void processArguments (FactorGraph&, int, const char* []);
void runSolver (const FactorGraph&, const VarIds&);
const string USAGE = "usage: \
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
int
main (int argc, const char* argv[])
{
if (!argv[1]) {
if (argc <= 1) {
cerr << "error: no solver specified" << endl;
cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl;
exit (0);
}
string fileName = argv[1];
if (argc <= 2) {
cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl;
exit (0);
}
string solver (argv[1]);
if (solver == "ve") {
Globals::infAlgorithm = InfAlgorithms::VE;
} else if (solver == "bp") {
Globals::infAlgorithm = InfAlgorithms::BP;
} else if (solver == "cbp") {
Globals::infAlgorithm = InfAlgorithms::CBP;
} else {
cerr << "error: unknow solver `" << solver << "'" << endl ;
cerr << USAGE << endl;
exit(0);
}
string fileName (argv[2]);
string extension = fileName.substr (
fileName.find_last_of ('.') + 1);
FactorGraph fg;
if (extension == "uai") {
fg.readFromUaiFormat (argv[1]);
fg.readFromUaiFormat (fileName.c_str());
} else if (extension == "fg") {
fg.readFromLibDaiFormat (argv[1]);
fg.readFromLibDaiFormat (fileName.c_str());
} else {
cerr << "error: the graphical model must be defined either " ;
cerr << "in a UAI or libDAI file" << endl;
@ -48,7 +66,7 @@ void
processArguments (FactorGraph& fg, int argc, const char* argv[])
{
VarIds queryIds;
for (int i = 2; i < argc; i++) {
for (int i = 3; i < argc; i++) {
const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) {
if (!Util::isInteger (arg)) {

View File

@ -8,7 +8,7 @@ Solver::printAnswer (const VarIds& vids)
Vars unobservedVars;
VarIds unobservedVids;
for (unsigned i = 0; i < vids.size(); i++) {
VarNode* vn = fg_.getVarNode (vids[i]);
VarNode* vn = fg.getVarNode (vids[i]);
if (vn->hasEvidence() == false) {
unobservedVars.push_back (vn);
unobservedVids.push_back (vids[i]);
@ -29,7 +29,7 @@ Solver::printAnswer (const VarIds& vids)
void
Solver::printAllPosterioris (void)
{
const VarNodes& vars = fg_.varNodes();
const VarNodes& vars = fg.varNodes();
for (unsigned i = 0; i < vars.size(); i++) {
printAnswer ({vars[i]->varId()});
}

View File

@ -12,7 +12,7 @@ using namespace std;
class Solver
{
public:
Solver (const FactorGraph& fg) : fg_(fg) { }
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
virtual ~Solver() { } // ensure that subclass destructor is called
@ -23,7 +23,7 @@ class Solver
void printAllPosterioris (void);
protected:
const FactorGraph& fg_;
const FactorGraph& fg;
};
#endif // HORUS_SOLVER_H

View File

@ -35,7 +35,7 @@ VarElimSolver::solveQuery (VarIds queryVids)
void
VarElimSolver::createFactorList (void)
{
const FacNodes& facNodes = fg_.facNodes();
const FacNodes& facNodes = fg.facNodes();
factorList_.reserve (facNodes.size() * 2);
for (unsigned i = 0; i < facNodes.size(); i++) {
factorList_.push_back (new Factor (facNodes[i]->factor()));
@ -57,7 +57,7 @@ VarElimSolver::createFactorList (void)
void
VarElimSolver::absorveEvidence (void)
{
const VarNodes& varNodes = fg_.varNodes();
const VarNodes& varNodes = fg.varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) {
const vector<unsigned>& idxs =
@ -103,7 +103,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
VarIds unobservedVids;
for (unsigned i = 0; i < vids.size(); i++) {
if (fg_.getVarNode (vids[i])->hasEvidence() == false) {
if (fg.getVarNode (vids[i])->hasEvidence() == false) {
unobservedVids.push_back (vids[i]);
}
}