refactorings
This commit is contained in:
parent
78e86a6330
commit
8697fcd2b4
@ -14,7 +14,7 @@
|
||||
|
||||
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
||||
{
|
||||
factorGraph_ = &fg;
|
||||
fg_ = &fg;
|
||||
runned_ = false;
|
||||
}
|
||||
|
||||
@ -54,8 +54,8 @@ BpSolver::getPosterioriOf (VarId vid)
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
assert (factorGraph_->getVarNode (vid));
|
||||
VarNode* var = factorGraph_->getVarNode (vid);
|
||||
assert (fg_->getVarNode (vid));
|
||||
VarNode* var = fg_->getVarNode (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->range(), LogAware::noEvidence());
|
||||
@ -88,7 +88,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
runSolver();
|
||||
}
|
||||
int idx = -1;
|
||||
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
|
||||
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
|
||||
const FacNodes& facNodes = vn->neighbors();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
if (facNodes[i]->factor().contains (jointVarIds)) {
|
||||
@ -121,37 +121,64 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
|
||||
|
||||
void
|
||||
BpSolver::initializeSolver (void)
|
||||
BpSolver::runSolver (void)
|
||||
{
|
||||
const VarNodes& varNodes = factorGraph_->varNodes();
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
clock_t start;
|
||||
if (Constants::COLLECT_STATS) {
|
||||
start = clock();
|
||||
}
|
||||
varsI_.reserve (varNodes.size());
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
varsI_.push_back (new SPNodeInfo());
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
nIters_ ++;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader (" Iteration " + nIters_);
|
||||
cout << endl;
|
||||
}
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
}
|
||||
break;
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
updateMessage(links_[i]);
|
||||
}
|
||||
break;
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
const FacNodes& facNodes = factorGraph_->facNodes();
|
||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
||||
delete facsI_[i];
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
facsI_.reserve (facNodes.size());
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
facsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
createLinks();
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
FacNode* src = links_[i]->getFactor();
|
||||
VarNode* dst = links_[i]->getVariable();
|
||||
ninf (dst)->addSpLink (links_[i]);
|
||||
ninf (src)->addSpLink (links_[i]);
|
||||
unsigned size = fg_->varNodes().size();
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = fg_->isTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
}
|
||||
runned_ = true;
|
||||
}
|
||||
|
||||
|
||||
@ -159,7 +186,7 @@ BpSolver::initializeSolver (void)
|
||||
void
|
||||
BpSolver::createLinks (void)
|
||||
{
|
||||
const FacNodes& facNodes = factorGraph_->facNodes();
|
||||
const FacNodes& facNodes = fg_->facNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const VarNodes& neighbors = facNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||
@ -342,11 +369,11 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
VarNodes jointVars;
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
assert (factorGraph_->getVarNode (jointVarIds[i]));
|
||||
jointVars.push_back (factorGraph_->getVarNode (jointVarIds[i]));
|
||||
assert (fg_->getVarNode (jointVarIds[i]));
|
||||
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
FactorGraph* fg = new FactorGraph (*factorGraph_);
|
||||
FactorGraph* fg = new FactorGraph (*fg_);
|
||||
BpSolver solver (*fg);
|
||||
solver.runSolver();
|
||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||
@ -390,93 +417,24 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
|
||||
|
||||
void
|
||||
BpSolver::printLinkInformation (void) const
|
||||
BpSolver::initializeSolver (void)
|
||||
{
|
||||
const VarNodes& varNodes = fg_->varNodes();
|
||||
varsI_.reserve (varNodes.size());
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
varsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
const FacNodes& facNodes = fg_->facNodes();
|
||||
facsI_.reserve (facNodes.size());
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
facsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
createLinks();
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
SpLink* l = links_[i];
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " ;
|
||||
cout << l->getMessage() << endl;
|
||||
cout << " next msg = " ;
|
||||
cout << l->getNextMessage() << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::runSolver (void)
|
||||
{
|
||||
clock_t start;
|
||||
if (Constants::COLLECT_STATS) {
|
||||
start = clock();
|
||||
}
|
||||
runLoopySolver();
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
unsigned size = factorGraph_->varNodes().size();
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = factorGraph_->isTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
}
|
||||
runned_ = true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::runLoopySolver (void)
|
||||
{
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
|
||||
nIters_ ++;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader (" Iteration " + nIters_);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
updateMessage(links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
FacNode* src = links_[i]->getFactor();
|
||||
VarNode* dst = links_[i]->getVariable();
|
||||
ninf (dst)->addSpLink (links_[i]);
|
||||
ninf (src)->addSpLink (links_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -488,7 +446,7 @@ BpSolver::converged (void)
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIters_ == 0 || nIters_ == 1) {
|
||||
if (nIters_ <= 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
@ -514,3 +472,19 @@ BpSolver::converged (void)
|
||||
return converged;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::printLinkInformation (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
SpLink* l = links_[i];
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " ;
|
||||
cout << l->getMessage() << endl;
|
||||
cout << " next msg = " ;
|
||||
cout << l->getNextMessage() << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_BpSolver_H
|
||||
#define HORUS_BpSolver_H
|
||||
#ifndef HORUS_BPSOLVER_H
|
||||
#define HORUS_BPSOLVER_H
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
@ -102,7 +102,7 @@ class BpSolver : public Solver
|
||||
virtual Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
protected:
|
||||
virtual void initializeSolver (void);
|
||||
void runSolver (void);
|
||||
|
||||
virtual void createLinks (void);
|
||||
|
||||
@ -114,8 +114,6 @@ class BpSolver : public Solver
|
||||
|
||||
virtual Params getJointByConditioning (const VarIds&) const;
|
||||
|
||||
virtual void printLinkInformation (void) const;
|
||||
|
||||
SPNodeInfo* ninf (const VarNode* var) const
|
||||
{
|
||||
return varsI_[var->getIndex()];
|
||||
@ -170,7 +168,7 @@ class BpSolver : public Solver
|
||||
vector<SPNodeInfo*> varsI_;
|
||||
vector<SPNodeInfo*> facsI_;
|
||||
bool runned_;
|
||||
const FactorGraph* factorGraph_;
|
||||
const FactorGraph* fg_;
|
||||
|
||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
@ -179,10 +177,12 @@ class BpSolver : public Solver
|
||||
SpLinkMap linkMap_;
|
||||
|
||||
private:
|
||||
void runSolver (void);
|
||||
void runLoopySolver (void);
|
||||
void initializeSolver (void);
|
||||
|
||||
bool converged (void);
|
||||
|
||||
void printLinkInformation (void) const;
|
||||
};
|
||||
|
||||
#endif // HORUS_BpSolver_H
|
||||
#endif // HORUS_BPSOLVER_H
|
||||
|
||||
|
@ -18,14 +18,14 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
||||
}
|
||||
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
factorSignatures_.reserve (facNodes.size());
|
||||
facSignatures_.reserve (facNodes.size());
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
unsigned c = facNodes[i]->neighbors().size() + 1;
|
||||
factorSignatures_.push_back (Signature (c));
|
||||
facSignatures_.push_back (Signature (c));
|
||||
}
|
||||
|
||||
varColors_.resize (varNodes.size());
|
||||
factorColors_.resize (facNodes.size());
|
||||
facColors_.resize (facNodes.size());
|
||||
setInitialColors();
|
||||
createGroups();
|
||||
}
|
||||
@ -111,7 +111,7 @@ void
|
||||
CFactorGraph::createGroups (void)
|
||||
{
|
||||
VarSignMap varGroups;
|
||||
FacSignMap factorGroups;
|
||||
FacSignMap facGroups;
|
||||
unsigned nIters = 0;
|
||||
bool groupsHaveChanged = true;
|
||||
const VarNodes& varNodes = groundFg_->varNodes();
|
||||
@ -120,19 +120,19 @@ CFactorGraph::createGroups (void)
|
||||
while (groupsHaveChanged || nIters == 1) {
|
||||
nIters ++;
|
||||
|
||||
unsigned prevFactorGroupsSize = factorGroups.size();
|
||||
factorGroups.clear();
|
||||
unsigned prevFactorGroupsSize = facGroups.size();
|
||||
facGroups.clear();
|
||||
// set a new color to the factors with the same signature
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const Signature& signature = getSignature (facNodes[i]);
|
||||
FacSignMap::iterator it = factorGroups.find (signature);
|
||||
if (it == factorGroups.end()) {
|
||||
it = factorGroups.insert (make_pair (signature, FacNodes())).first;
|
||||
FacSignMap::iterator it = facGroups.find (signature);
|
||||
if (it == facGroups.end()) {
|
||||
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||
}
|
||||
it->second.push_back (facNodes[i]);
|
||||
}
|
||||
for (FacSignMap::iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
for (FacSignMap::iterator it = facGroups.begin();
|
||||
it != facGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FacNodes& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
@ -161,10 +161,10 @@ CFactorGraph::createGroups (void)
|
||||
}
|
||||
|
||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||
|| prevFactorGroupsSize != factorGroups.size();
|
||||
|| prevFactorGroupsSize != facGroups.size();
|
||||
}
|
||||
//printGroups (varGroups, factorGroups);
|
||||
createClusters (varGroups, factorGroups);
|
||||
//printGroups (varGroups, facGroups);
|
||||
createClusters (varGroups, facGroups);
|
||||
}
|
||||
|
||||
|
||||
@ -172,7 +172,7 @@ CFactorGraph::createGroups (void)
|
||||
void
|
||||
CFactorGraph::createClusters (
|
||||
const VarSignMap& varGroups,
|
||||
const FacSignMap& factorGroups)
|
||||
const FacSignMap& facGroups)
|
||||
{
|
||||
varClusters_.reserve (varGroups.size());
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
@ -185,12 +185,12 @@ CFactorGraph::createClusters (
|
||||
varClusters_.push_back (vc);
|
||||
}
|
||||
|
||||
facClusters_.reserve (factorGroups.size());
|
||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
facClusters_.reserve (facGroups.size());
|
||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||
it != facGroups.end(); it++) {
|
||||
FacNode* groupFactor = it->second[0];
|
||||
const VarNodes& neighs = groupFactor->neighbors();
|
||||
VarClusterSet varClusters;
|
||||
VarClusters varClusters;
|
||||
varClusters.reserve (neighs.size());
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
VarId vid = neighs[i]->varId();
|
||||
@ -223,7 +223,7 @@ CFactorGraph::getSignature (const VarNode* varNode)
|
||||
const Signature&
|
||||
CFactorGraph::getSignature (const FacNode* facNode)
|
||||
{
|
||||
Signature& sign = factorSignatures_[facNode->getIndex()];
|
||||
Signature& sign = facSignatures_[facNode->getIndex()];
|
||||
vector<Color>::iterator it = sign.colors.begin();
|
||||
const VarNodes& neighs = facNode->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
@ -237,7 +237,7 @@ CFactorGraph::getSignature (const FacNode* facNode)
|
||||
|
||||
|
||||
FactorGraph*
|
||||
CFactorGraph::getCompressedFactorGraph (void)
|
||||
CFactorGraph::getGroundFactorGraph (void) const
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
@ -248,7 +248,7 @@ CFactorGraph::getCompressedFactorGraph (void)
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
||||
const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters();
|
||||
const VarClusters& myVarClusters = facClusters_[i]->getVarClusters();
|
||||
Vars myGroundVars;
|
||||
myGroundVars.reserve (myVarClusters.size());
|
||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||
@ -300,7 +300,7 @@ CFactorGraph::getGroundEdgeCount (
|
||||
void
|
||||
CFactorGraph::printGroups (
|
||||
const VarSignMap& varGroups,
|
||||
const FacSignMap& factorGroups) const
|
||||
const FacSignMap& facGroups) const
|
||||
{
|
||||
unsigned count = 1;
|
||||
cout << "variable groups:" << endl;
|
||||
@ -319,8 +319,8 @@ CFactorGraph::printGroups (
|
||||
|
||||
count = 1;
|
||||
cout << endl << "factor groups:" << endl;
|
||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||
it != facGroups.end(); it++) {
|
||||
const FacNodes& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << ++count << ": " ;
|
||||
|
@ -22,8 +22,8 @@ typedef unordered_map<unsigned, vector<Color>> VarColorMap;
|
||||
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||
|
||||
typedef vector<VarCluster*> VarClusterSet;
|
||||
typedef vector<FacCluster*> FacClusterSet;
|
||||
typedef vector<VarCluster*> VarClusters;
|
||||
typedef vector<FacCluster*> FacClusters;
|
||||
|
||||
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
|
||||
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
|
||||
@ -99,18 +99,20 @@ class VarCluster
|
||||
facClusters_.push_back (fc);
|
||||
}
|
||||
|
||||
const FacClusterSet& getFacClusters (void) const
|
||||
const FacClusters& getFacClusters (void) const
|
||||
{
|
||||
return facClusters_;
|
||||
}
|
||||
|
||||
VarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
|
||||
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
|
||||
|
||||
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
|
||||
|
||||
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
|
||||
|
||||
private:
|
||||
VarNodes groundVars_;
|
||||
FacClusterSet facClusters_;
|
||||
FacClusters facClusters_;
|
||||
VarNode* representVar_;
|
||||
};
|
||||
|
||||
@ -118,7 +120,7 @@ class VarCluster
|
||||
class FacCluster
|
||||
{
|
||||
public:
|
||||
FacCluster (const FacNodes& groundFactors, const VarClusterSet& vcs)
|
||||
FacCluster (const FacNodes& groundFactors, const VarClusters& vcs)
|
||||
{
|
||||
groundFactors_ = groundFactors;
|
||||
varClusters_ = vcs;
|
||||
@ -127,7 +129,7 @@ class FacCluster
|
||||
}
|
||||
}
|
||||
|
||||
const VarClusterSet& getVarClusters (void) const
|
||||
const VarClusters& getVarClusters (void) const
|
||||
{
|
||||
return varClusters_;
|
||||
}
|
||||
@ -160,7 +162,7 @@ class FacCluster
|
||||
|
||||
private:
|
||||
FacNodes groundFactors_;
|
||||
VarClusterSet varClusters_;
|
||||
VarClusters varClusters_;
|
||||
FacNode* representFactor_;
|
||||
};
|
||||
|
||||
@ -172,9 +174,9 @@ class CFactorGraph
|
||||
|
||||
~CFactorGraph (void);
|
||||
|
||||
const VarClusterSet& getVarClusters (void) { return varClusters_; }
|
||||
const VarClusters& getVarClusters (void) { return varClusters_; }
|
||||
|
||||
const FacClusterSet& getFacClusters (void) { return facClusters_; }
|
||||
const FacClusters& getFacClusters (void) { return facClusters_; }
|
||||
|
||||
VarNode* getEquivalentVariable (VarId vid)
|
||||
{
|
||||
@ -182,7 +184,7 @@ class CFactorGraph
|
||||
return vc->getRepresentativeVariable();
|
||||
}
|
||||
|
||||
FactorGraph* getCompressedFactorGraph (void);
|
||||
FactorGraph* getGroundFactorGraph (void) const;
|
||||
|
||||
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
|
||||
|
||||
@ -200,7 +202,7 @@ class CFactorGraph
|
||||
return varColors_[vn->getIndex()];
|
||||
}
|
||||
Color getColor (const FacNode* fn) const {
|
||||
return factorColors_[fn->getIndex()];
|
||||
return facColors_[fn->getIndex()];
|
||||
}
|
||||
|
||||
void setColor (const VarNode* vn, Color c)
|
||||
@ -210,7 +212,7 @@ class CFactorGraph
|
||||
|
||||
void setColor (const FacNode* fn, Color c)
|
||||
{
|
||||
factorColors_[fn->getIndex()] = c;
|
||||
facColors_[fn->getIndex()] = c;
|
||||
}
|
||||
|
||||
VarCluster* getVariableCluster (VarId vid) const
|
||||
@ -232,11 +234,11 @@ class CFactorGraph
|
||||
|
||||
Color freeColor_;
|
||||
vector<Color> varColors_;
|
||||
vector<Color> factorColors_;
|
||||
vector<Color> facColors_;
|
||||
vector<Signature> varSignatures_;
|
||||
vector<Signature> factorSignatures_;
|
||||
VarClusterSet varClusters_;
|
||||
FacClusterSet facClusters_;
|
||||
vector<Signature> facSignatures_;
|
||||
VarClusters varClusters_;
|
||||
FacClusters facClusters_;
|
||||
VarId2VarCluster vid2VarCluster_;
|
||||
const FactorGraph* groundFg_;
|
||||
};
|
||||
|
@ -1,10 +1,41 @@
|
||||
#include "CbpSolver.h"
|
||||
|
||||
|
||||
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
||||
{
|
||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
||||
if (Constants::COLLECT_STATS) {
|
||||
nGroundVars = fg_->varNodes().size();
|
||||
nGroundFacs = fg_->facNodes().size();
|
||||
const VarNodes& vars = fg_->varNodes();
|
||||
nWithoutNeighs = 0;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
const FacNodes& factors = vars[i]->neighbors();
|
||||
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
||||
nWithoutNeighs ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
cfg_ = new CFactorGraph (fg);
|
||||
fg_ = cfg_->getGroundFactorGraph();
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nClusterVars = fg_->varNodes().size();
|
||||
unsigned nClusterFacs = fg_->facNodes().size();
|
||||
Statistics::updateCompressingStatistics (nGroundVars,
|
||||
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
|
||||
}
|
||||
// Util::printHeader ("Uncompressed Factor Graph");
|
||||
// fg->print();
|
||||
// Util::printHeader ("Compressed Factor Graph");
|
||||
// fg_->print();
|
||||
}
|
||||
|
||||
|
||||
|
||||
CbpSolver::~CbpSolver (void)
|
||||
{
|
||||
delete lfg_;
|
||||
delete factorGraph_;
|
||||
delete cfg_;
|
||||
delete fg_;
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
@ -16,8 +47,11 @@ CbpSolver::~CbpSolver (void)
|
||||
Params
|
||||
CbpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
assert (lfg_->getEquivalentVariable (vid));
|
||||
VarNode* var = lfg_->getEquivalentVariable (vid);
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
assert (cfg_->getEquivalentVariable (vid));
|
||||
VarNode* var = cfg_->getEquivalentVariable (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->range(), LogAware::noEvidence());
|
||||
@ -26,16 +60,16 @@ CbpSolver::getPosterioriOf (VarId vid)
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (probs, l->getPoweredMessage());
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (probs, l->poweredMessage());
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::multiply (probs, l->getPoweredMessage());
|
||||
Util::multiply (probs, l->poweredMessage());
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
}
|
||||
@ -46,67 +80,28 @@ CbpSolver::getPosterioriOf (VarId vid)
|
||||
|
||||
|
||||
Params
|
||||
CbpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
||||
{
|
||||
VarIds eqVarIds;
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId());
|
||||
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
|
||||
eqVarIds.push_back (vn->varId());
|
||||
}
|
||||
return BpSolver::getJointDistributionOf (eqVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::initializeSolver (void)
|
||||
{
|
||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
||||
if (Constants::COLLECT_STATS) {
|
||||
nGroundVars = factorGraph_->varNodes().size();
|
||||
nGroundFacs = factorGraph_->facNodes().size();
|
||||
const VarNodes& vars = factorGraph_->varNodes();
|
||||
nWithoutNeighs = 0;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
const FacNodes& factors = vars[i]->neighbors();
|
||||
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
||||
nWithoutNeighs ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lfg_ = new CFactorGraph (*factorGraph_);
|
||||
|
||||
// cout << "Uncompressed Factor Graph" << endl;
|
||||
// factorGraph_->print();
|
||||
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
|
||||
factorGraph_ = lfg_->getCompressedFactorGraph();
|
||||
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nClusterVars = factorGraph_->varNodes().size();
|
||||
unsigned nClusterFacs = factorGraph_->facNodes().size();
|
||||
Statistics::updateCompressingStatistics (nGroundVars,
|
||||
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
|
||||
}
|
||||
|
||||
// cout << "Compressed Factor Graph" << endl;
|
||||
// factorGraph_->print();
|
||||
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
|
||||
// abort();
|
||||
BpSolver::initializeSolver();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::createLinks (void)
|
||||
{
|
||||
const FacClusterSet fcs = lfg_->getFacClusters();
|
||||
const FacClusters& fcs = cfg_->getFacClusters();
|
||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||
const VarClusterSet vcs = fcs[i]->getVarClusters();
|
||||
const VarClusters& vcs = fcs[i]->getVarClusters();
|
||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
|
||||
links_.push_back (new CbpSolverLink (fcs[i]->getRepresentativeFactor(),
|
||||
unsigned c = cfg_->getGroundEdgeCount (fcs[i], vcs[j]);
|
||||
links_.push_back (new CbpSolverLink (
|
||||
fcs[i]->getRepresentativeFactor(),
|
||||
vcs[j]->getRepresentativeVariable(), c));
|
||||
}
|
||||
}
|
||||
@ -197,10 +192,10 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
double value = link->getMessage()[src->getEvidence()];
|
||||
msg[src->getEvidence()] = LogAware::pow (value, l->getNumberOfEdges() - 1);
|
||||
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
|
||||
} else {
|
||||
msg = link->getMessage();
|
||||
LogAware::pow (msg, l->getNumberOfEdges() - 1);
|
||||
LogAware::pow (msg, l->nrEdges() - 1);
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " " << "init: " << msg << endl;
|
||||
@ -210,17 +205,17 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (msg, l->getPoweredMessage());
|
||||
Util::add (msg, l->poweredMessage());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::multiply (msg, l->getPoweredMessage());
|
||||
Util::multiply (msg, l->poweredMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
||||
cout << l->getPoweredMessage() << endl;
|
||||
cout << l->poweredMessage() << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -242,7 +237,7 @@ CbpSolver::printLinkInformation (void) const
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " << l->getMessage() << endl;
|
||||
cout << " next msg = " << l->getNextMessage() << endl;
|
||||
cout << " powered = " << l->getPoweredMessage() << endl;
|
||||
cout << " powered = " << l->poweredMessage() << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
@ -9,27 +9,25 @@ class Factor;
|
||||
class CbpSolverLink : public SpLink
|
||||
{
|
||||
public:
|
||||
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c) : SpLink (fn, vn)
|
||||
{
|
||||
edgeCount_ = c;
|
||||
poweredMsg_.resize (vn->range(), LogAware::one());
|
||||
}
|
||||
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
|
||||
: SpLink (fn, vn), nrEdges_(c),
|
||||
pwdMsg_(vn->range(), LogAware::one()) { }
|
||||
|
||||
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
||||
unsigned nrEdges (void) const { return nrEdges_; }
|
||||
|
||||
const Params& getPoweredMessage (void) const { return poweredMsg_; }
|
||||
const Params& poweredMessage (void) const { return pwdMsg_; }
|
||||
|
||||
void updateMessage (void)
|
||||
{
|
||||
poweredMsg_ = *nextMsg_;
|
||||
pwdMsg_ = *nextMsg_;
|
||||
swap (currMsg_, nextMsg_);
|
||||
msgSended_ = true;
|
||||
LogAware::pow (poweredMsg_, edgeCount_);
|
||||
LogAware::pow (pwdMsg_, nrEdges_);
|
||||
}
|
||||
|
||||
private:
|
||||
Params poweredMsg_;
|
||||
unsigned edgeCount_;
|
||||
unsigned nrEdges_;
|
||||
Params pwdMsg_;
|
||||
};
|
||||
|
||||
|
||||
@ -37,16 +35,15 @@ class CbpSolverLink : public SpLink
|
||||
class CbpSolver : public BpSolver
|
||||
{
|
||||
public:
|
||||
CbpSolver (const FactorGraph& fg) : BpSolver (fg) { }
|
||||
CbpSolver (const FactorGraph& fg);
|
||||
|
||||
~CbpSolver (void);
|
||||
|
||||
|
||||
Params getPosterioriOf (VarId);
|
||||
|
||||
Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
private:
|
||||
void initializeSolver (void);
|
||||
|
||||
void createLinks (void);
|
||||
|
||||
@ -56,8 +53,7 @@ class CbpSolver : public BpSolver
|
||||
|
||||
void printLinkInformation (void) const;
|
||||
|
||||
CFactorGraph* lfg_;
|
||||
FactorGraph* factorGraph_;
|
||||
CFactorGraph* cfg_;
|
||||
};
|
||||
|
||||
#endif // HORUS_CBP_H
|
||||
|
@ -14,25 +14,43 @@ void processArguments (FactorGraph&, int, const char* []);
|
||||
void runSolver (const FactorGraph&, const VarIds&);
|
||||
|
||||
const string USAGE = "usage: \
|
||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||
./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||
|
||||
|
||||
int
|
||||
main (int argc, const char* argv[])
|
||||
{
|
||||
if (!argv[1]) {
|
||||
if (argc <= 1) {
|
||||
cerr << "error: no solver specified" << endl;
|
||||
cerr << "error: no graphical model specified" << endl;
|
||||
cerr << USAGE << endl;
|
||||
exit (0);
|
||||
}
|
||||
string fileName = argv[1];
|
||||
if (argc <= 2) {
|
||||
cerr << "error: no graphical model specified" << endl;
|
||||
cerr << USAGE << endl;
|
||||
exit (0);
|
||||
}
|
||||
string solver (argv[1]);
|
||||
if (solver == "ve") {
|
||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||
} else if (solver == "bp") {
|
||||
Globals::infAlgorithm = InfAlgorithms::BP;
|
||||
} else if (solver == "cbp") {
|
||||
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||
} else {
|
||||
cerr << "error: unknow solver `" << solver << "'" << endl ;
|
||||
cerr << USAGE << endl;
|
||||
exit(0);
|
||||
}
|
||||
string fileName (argv[2]);
|
||||
string extension = fileName.substr (
|
||||
fileName.find_last_of ('.') + 1);
|
||||
FactorGraph fg;
|
||||
if (extension == "uai") {
|
||||
fg.readFromUaiFormat (argv[1]);
|
||||
fg.readFromUaiFormat (fileName.c_str());
|
||||
} else if (extension == "fg") {
|
||||
fg.readFromLibDaiFormat (argv[1]);
|
||||
fg.readFromLibDaiFormat (fileName.c_str());
|
||||
} else {
|
||||
cerr << "error: the graphical model must be defined either " ;
|
||||
cerr << "in a UAI or libDAI file" << endl;
|
||||
@ -48,7 +66,7 @@ void
|
||||
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
{
|
||||
VarIds queryIds;
|
||||
for (int i = 2; i < argc; i++) {
|
||||
for (int i = 3; i < argc; i++) {
|
||||
const string& arg = argv[i];
|
||||
if (arg.find ('=') == std::string::npos) {
|
||||
if (!Util::isInteger (arg)) {
|
||||
|
@ -8,7 +8,7 @@ Solver::printAnswer (const VarIds& vids)
|
||||
Vars unobservedVars;
|
||||
VarIds unobservedVids;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
VarNode* vn = fg_.getVarNode (vids[i]);
|
||||
VarNode* vn = fg.getVarNode (vids[i]);
|
||||
if (vn->hasEvidence() == false) {
|
||||
unobservedVars.push_back (vn);
|
||||
unobservedVids.push_back (vids[i]);
|
||||
@ -29,7 +29,7 @@ Solver::printAnswer (const VarIds& vids)
|
||||
void
|
||||
Solver::printAllPosterioris (void)
|
||||
{
|
||||
const VarNodes& vars = fg_.varNodes();
|
||||
const VarNodes& vars = fg.varNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
printAnswer ({vars[i]->varId()});
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ using namespace std;
|
||||
class Solver
|
||||
{
|
||||
public:
|
||||
Solver (const FactorGraph& fg) : fg_(fg) { }
|
||||
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
|
||||
|
||||
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||
|
||||
@ -23,7 +23,7 @@ class Solver
|
||||
void printAllPosterioris (void);
|
||||
|
||||
protected:
|
||||
const FactorGraph& fg_;
|
||||
const FactorGraph& fg;
|
||||
};
|
||||
|
||||
#endif // HORUS_SOLVER_H
|
||||
|
@ -35,7 +35,7 @@ VarElimSolver::solveQuery (VarIds queryVids)
|
||||
void
|
||||
VarElimSolver::createFactorList (void)
|
||||
{
|
||||
const FacNodes& facNodes = fg_.facNodes();
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
factorList_.reserve (facNodes.size() * 2);
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
||||
@ -57,7 +57,7 @@ VarElimSolver::createFactorList (void)
|
||||
void
|
||||
VarElimSolver::absorveEvidence (void)
|
||||
{
|
||||
const VarNodes& varNodes = fg_.varNodes();
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
if (varNodes[i]->hasEvidence()) {
|
||||
const vector<unsigned>& idxs =
|
||||
@ -103,7 +103,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
||||
|
||||
VarIds unobservedVids;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
if (fg_.getVarNode (vids[i])->hasEvidence() == false) {
|
||||
if (fg.getVarNode (vids[i])->hasEvidence() == false) {
|
||||
unobservedVids.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user