renamings and delete bn_bp stuff
This commit is contained in:
parent
b28ee8fb3a
commit
d1b25f0864
@ -60,7 +60,7 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||
void
|
||||
BayesBall::constructGraph (FactorGraph* fg) const
|
||||
{
|
||||
const FgFacSet& facNodes = fg_.getFactorNodes();
|
||||
const FactorNodes& facNodes = fg_.factorNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const DAGraphNode* n = dag_.getNode (
|
||||
facNodes[i]->factor()->argument (0));
|
||||
|
@ -6,11 +6,9 @@
|
||||
#include <list>
|
||||
#include <map>
|
||||
|
||||
#include "GraphicalModel.h"
|
||||
#include "Horus.h"
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "BayesNet.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -6,19 +6,19 @@
|
||||
#include <list>
|
||||
#include <map>
|
||||
|
||||
#include "GraphicalModel.h"
|
||||
#include "Var.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class VarNode;
|
||||
class Var;
|
||||
|
||||
class DAGraphNode : public VarNode
|
||||
class DAGraphNode : public Var
|
||||
{
|
||||
public:
|
||||
DAGraphNode (VarNode* vn) : VarNode (vn) , visited_(false),
|
||||
DAGraphNode (Var* v) : Var (v) , visited_(false),
|
||||
markedOnTop_(false), markedOnBottom_(false) { }
|
||||
|
||||
const vector<DAGraphNode*>& childs (void) const { return childs_; }
|
||||
|
@ -10,14 +10,14 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
||||
groundFg_ = &fg;
|
||||
freeColor_ = 0;
|
||||
|
||||
const FgVarSet& varNodes = fg.getVarNodes();
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
varSignatures_.reserve (varNodes.size());
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
|
||||
varSignatures_.push_back (Signature (c));
|
||||
}
|
||||
|
||||
const FgFacSet& facNodes = fg.getFactorNodes();
|
||||
const FactorNodes& facNodes = fg.factorNodes();
|
||||
factorSignatures_.reserve (facNodes.size());
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
unsigned c = facNodes[i]->neighbors().size() + 1;
|
||||
@ -49,7 +49,7 @@ CFactorGraph::setInitialColors (void)
|
||||
{
|
||||
// create the initial variable colors
|
||||
VarColorMap colorMap;
|
||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
||||
const VarNodes& varNodes = groundFg_->varNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
unsigned dsize = varNodes[i]->range();
|
||||
VarColorMap::iterator it = colorMap.find (dsize);
|
||||
@ -70,7 +70,7 @@ CFactorGraph::setInitialColors (void)
|
||||
setColor (varNodes[i], stateColors[idx]);
|
||||
}
|
||||
|
||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
||||
const FactorNodes& facNodes = groundFg_->factorNodes();
|
||||
if (checkForIdenticalFactors) {
|
||||
unsigned groupCount = 1;
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
@ -114,8 +114,8 @@ CFactorGraph::createGroups (void)
|
||||
FacSignMap factorGroups;
|
||||
unsigned nIters = 0;
|
||||
bool groupsHaveChanged = true;
|
||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
||||
const VarNodes& varNodes = groundFg_->varNodes();
|
||||
const FactorNodes& facNodes = groundFg_->factorNodes();
|
||||
|
||||
while (groupsHaveChanged || nIters == 1) {
|
||||
nIters ++;
|
||||
@ -127,14 +127,14 @@ CFactorGraph::createGroups (void)
|
||||
const Signature& signature = getSignature (facNodes[i]);
|
||||
FacSignMap::iterator it = factorGroups.find (signature);
|
||||
if (it == factorGroups.end()) {
|
||||
it = factorGroups.insert (make_pair (signature, FgFacSet())).first;
|
||||
it = factorGroups.insert (make_pair (signature, FactorNodes())).first;
|
||||
}
|
||||
it->second.push_back (facNodes[i]);
|
||||
}
|
||||
for (FacSignMap::iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FgFacSet& groupMembers = it->second;
|
||||
FactorNodes& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
@ -147,14 +147,14 @@ CFactorGraph::createGroups (void)
|
||||
const Signature& signature = getSignature (varNodes[i]);
|
||||
VarSignMap::iterator it = varGroups.find (signature);
|
||||
if (it == varGroups.end()) {
|
||||
it = varGroups.insert (make_pair (signature, FgVarSet())).first;
|
||||
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||
}
|
||||
it->second.push_back (varNodes[i]);
|
||||
}
|
||||
for (VarSignMap::iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FgVarSet& groupMembers = it->second;
|
||||
VarNodes& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
@ -177,7 +177,7 @@ CFactorGraph::createClusters (
|
||||
varClusters_.reserve (varGroups.size());
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
const FgVarSet& groupVars = it->second;
|
||||
const VarNodes& groupVars = it->second;
|
||||
VarCluster* vc = new VarCluster (groupVars);
|
||||
for (unsigned i = 0; i < groupVars.size(); i++) {
|
||||
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
||||
@ -188,8 +188,8 @@ CFactorGraph::createClusters (
|
||||
facClusters_.reserve (factorGroups.size());
|
||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
FgFacNode* groupFactor = it->second[0];
|
||||
const FgVarSet& neighs = groupFactor->neighbors();
|
||||
FactorNode* groupFactor = it->second[0];
|
||||
const VarNodes& neighs = groupFactor->neighbors();
|
||||
VarClusterSet varClusters;
|
||||
varClusters.reserve (neighs.size());
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
@ -203,11 +203,11 @@ CFactorGraph::createClusters (
|
||||
|
||||
|
||||
const Signature&
|
||||
CFactorGraph::getSignature (const FgVarNode* varNode)
|
||||
CFactorGraph::getSignature (const VarNode* varNode)
|
||||
{
|
||||
Signature& sign = varSignatures_[varNode->getIndex()];
|
||||
vector<Color>::iterator it = sign.colors.begin();
|
||||
const FgFacSet& neighs = varNode->neighbors();
|
||||
const FactorNodes& neighs = varNode->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
*it = getColor (neighs[i]);
|
||||
it ++;
|
||||
@ -221,11 +221,11 @@ CFactorGraph::getSignature (const FgVarNode* varNode)
|
||||
|
||||
|
||||
const Signature&
|
||||
CFactorGraph::getSignature (const FgFacNode* facNode)
|
||||
CFactorGraph::getSignature (const FactorNode* facNode)
|
||||
{
|
||||
Signature& sign = factorSignatures_[facNode->getIndex()];
|
||||
vector<Color>::iterator it = sign.colors.begin();
|
||||
const FgVarSet& neighs = facNode->neighbors();
|
||||
const VarNodes& neighs = facNode->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
*it = getColor (neighs[i]);
|
||||
it ++;
|
||||
@ -241,27 +241,27 @@ CFactorGraph::getCompressedFactorGraph (void)
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0];
|
||||
FgVarNode* newVar = new FgVarNode (var);
|
||||
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
|
||||
VarNode* newVar = new VarNode (var);
|
||||
varClusters_[i]->setRepresentativeVariable (newVar);
|
||||
fg->addVariable (newVar);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
||||
const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters();
|
||||
VarNodes myGroundVars;
|
||||
Vars myGroundVars;
|
||||
myGroundVars.reserve (myVarClusters.size());
|
||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||
FgVarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
||||
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
||||
myGroundVars.push_back (v);
|
||||
}
|
||||
Factor* newFactor = new Factor (myGroundVars,
|
||||
facClusters_[i]->getGroundFactors()[0]->params());
|
||||
FgFacNode* fn = new FgFacNode (newFactor);
|
||||
FactorNode* fn = new FactorNode (newFactor);
|
||||
facClusters_[i]->setRepresentativeFactor (fn);
|
||||
fg->addFactor (fn);
|
||||
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
||||
fg->addEdge (fn, static_cast<FgVarNode*> (myGroundVars[j]));
|
||||
fg->addEdge (fn, static_cast<VarNode*> (myGroundVars[j]));
|
||||
}
|
||||
}
|
||||
fg->setIndexes();
|
||||
@ -275,17 +275,17 @@ CFactorGraph::getGroundEdgeCount (
|
||||
const FacCluster* fc,
|
||||
const VarCluster* vc) const
|
||||
{
|
||||
const FgFacSet& clusterGroundFactors = fc->getGroundFactors();
|
||||
FgVarNode* varNode = vc->getGroundFgVarNodes()[0];
|
||||
const FactorNodes& clusterGroundFactors = fc->getGroundFactors();
|
||||
VarNode* varNode = vc->getGroundVarNodes()[0];
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||
if (clusterGroundFactors[i]->factor()->indexOf (varNode->varId()) != -1) {
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
// CFgVarSet vars = vc->getGroundFgVarNodes();
|
||||
// CVarNodes vars = vc->getGroundVarNodes();
|
||||
// for (unsigned i = 1; i < vars.size(); i++) {
|
||||
// FgVarNode* var = vc->getGroundFgVarNodes()[i];
|
||||
// VarNode* var = vc->getGroundVarNodes()[i];
|
||||
// unsigned count2 = 0;
|
||||
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||
// if (clusterGroundFactors[i]->getPosition (var) != -1) {
|
||||
@ -308,7 +308,7 @@ CFactorGraph::printGroups (
|
||||
cout << "variable groups:" << endl;
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
const FgVarSet& groupMembers = it->second;
|
||||
const VarNodes& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << count << ": " ;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
@ -323,7 +323,7 @@ CFactorGraph::printGroups (
|
||||
cout << endl << "factor groups:" << endl;
|
||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
const FgFacSet& groupMembers = it->second;
|
||||
const FactorNodes& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << ++count << ": " ;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
|
@ -25,8 +25,8 @@ typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||
typedef vector<VarCluster*> VarClusterSet;
|
||||
typedef vector<FacCluster*> FacClusterSet;
|
||||
|
||||
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
|
||||
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
|
||||
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
|
||||
typedef unordered_map<Signature, FactorNodes, SignatureHash> FacSignMap;
|
||||
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ struct SignatureHash
|
||||
class VarCluster
|
||||
{
|
||||
public:
|
||||
VarCluster (const FgVarSet& vs)
|
||||
VarCluster (const VarNodes& vs)
|
||||
{
|
||||
for (unsigned i = 0; i < vs.size(); i++) {
|
||||
groundVars_.push_back (vs[i]);
|
||||
@ -104,21 +104,21 @@ class VarCluster
|
||||
return facClusters_;
|
||||
}
|
||||
|
||||
FgVarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
|
||||
const FgVarSet& getGroundFgVarNodes (void) const { return groundVars_; }
|
||||
VarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
|
||||
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
|
||||
|
||||
private:
|
||||
FgVarSet groundVars_;
|
||||
VarNodes groundVars_;
|
||||
FacClusterSet facClusters_;
|
||||
FgVarNode* representVar_;
|
||||
VarNode* representVar_;
|
||||
};
|
||||
|
||||
|
||||
class FacCluster
|
||||
{
|
||||
public:
|
||||
FacCluster (const FgFacSet& groundFactors, const VarClusterSet& vcs)
|
||||
FacCluster (const FactorNodes& groundFactors, const VarClusterSet& vcs)
|
||||
{
|
||||
groundFactors_ = groundFactors;
|
||||
varClusters_ = vcs;
|
||||
@ -132,7 +132,7 @@ class FacCluster
|
||||
return varClusters_;
|
||||
}
|
||||
|
||||
bool containsGround (const FgFacNode* fn)
|
||||
bool containsGround (const FactorNode* fn)
|
||||
{
|
||||
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
||||
if (groundFactors_[i] == fn) {
|
||||
@ -142,26 +142,26 @@ class FacCluster
|
||||
return false;
|
||||
}
|
||||
|
||||
FgFacNode* getRepresentativeFactor (void) const
|
||||
FactorNode* getRepresentativeFactor (void) const
|
||||
{
|
||||
return representFactor_;
|
||||
}
|
||||
|
||||
void setRepresentativeFactor (FgFacNode* fn)
|
||||
void setRepresentativeFactor (FactorNode* fn)
|
||||
{
|
||||
representFactor_ = fn;
|
||||
}
|
||||
|
||||
const FgFacSet& getGroundFactors (void) const
|
||||
const FactorNodes& getGroundFactors (void) const
|
||||
{
|
||||
return groundFactors_;
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
FgFacSet groundFactors_;
|
||||
FactorNodes groundFactors_;
|
||||
VarClusterSet varClusters_;
|
||||
FgFacNode* representFactor_;
|
||||
FactorNode* representFactor_;
|
||||
};
|
||||
|
||||
|
||||
@ -176,7 +176,7 @@ class CFactorGraph
|
||||
|
||||
const FacClusterSet& getFacClusters (void) { return facClusters_; }
|
||||
|
||||
FgVarNode* getEquivalentVariable (VarId vid)
|
||||
VarNode* getEquivalentVariable (VarId vid)
|
||||
{
|
||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||
return vc->getRepresentativeVariable();
|
||||
@ -195,20 +195,20 @@ class CFactorGraph
|
||||
return freeColor_ - 1;
|
||||
}
|
||||
|
||||
Color getColor (const FgVarNode* vn) const
|
||||
Color getColor (const VarNode* vn) const
|
||||
{
|
||||
return varColors_[vn->getIndex()];
|
||||
}
|
||||
Color getColor (const FgFacNode* fn) const {
|
||||
Color getColor (const FactorNode* fn) const {
|
||||
return factorColors_[fn->getIndex()];
|
||||
}
|
||||
|
||||
void setColor (const FgVarNode* vn, Color c)
|
||||
void setColor (const VarNode* vn, Color c)
|
||||
{
|
||||
varColors_[vn->getIndex()] = c;
|
||||
}
|
||||
|
||||
void setColor (const FgFacNode* fn, Color c)
|
||||
void setColor (const FactorNode* fn, Color c)
|
||||
{
|
||||
factorColors_[fn->getIndex()] = c;
|
||||
}
|
||||
@ -224,9 +224,9 @@ class CFactorGraph
|
||||
|
||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||
|
||||
const Signature& getSignature (const FgVarNode*);
|
||||
const Signature& getSignature (const VarNode*);
|
||||
|
||||
const Signature& getSignature (const FgFacNode*);
|
||||
const Signature& getSignature (const FactorNode*);
|
||||
|
||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||
|
||||
|
@ -17,7 +17,7 @@ Params
|
||||
CbpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
assert (lfg_->getEquivalentVariable (vid));
|
||||
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
||||
VarNode* var = lfg_->getEquivalentVariable (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->range(), LogAware::noEvidence());
|
||||
@ -52,7 +52,7 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId());
|
||||
}
|
||||
return FgBpSolver::getJointDistributionOf (eqVarIds);
|
||||
return BpSolver::getJointDistributionOf (eqVarIds);
|
||||
}
|
||||
|
||||
|
||||
@ -63,12 +63,12 @@ CbpSolver::initializeSolver (void)
|
||||
{
|
||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
||||
if (Constants::COLLECT_STATS) {
|
||||
nGroundVars = factorGraph_->getVarNodes().size();
|
||||
nGroundFacs = factorGraph_->getFactorNodes().size();
|
||||
const FgVarSet& vars = factorGraph_->getVarNodes();
|
||||
nGroundVars = factorGraph_->varNodes().size();
|
||||
nGroundFacs = factorGraph_->factorNodes().size();
|
||||
const VarNodes& vars = factorGraph_->varNodes();
|
||||
nWithoutNeighs = 0;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
const FgFacSet& factors = vars[i]->neighbors();
|
||||
const FactorNodes& factors = vars[i]->neighbors();
|
||||
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
||||
nWithoutNeighs ++;
|
||||
}
|
||||
@ -78,23 +78,23 @@ CbpSolver::initializeSolver (void)
|
||||
lfg_ = new CFactorGraph (*factorGraph_);
|
||||
|
||||
// cout << "Uncompressed Factor Graph" << endl;
|
||||
// factorGraph_->printGraphicalModel();
|
||||
// factorGraph_->print();
|
||||
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
|
||||
factorGraph_ = lfg_->getCompressedFactorGraph();
|
||||
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nClusterVars = factorGraph_->getVarNodes().size();
|
||||
unsigned nClusterFacs = factorGraph_->getFactorNodes().size();
|
||||
unsigned nClusterVars = factorGraph_->varNodes().size();
|
||||
unsigned nClusterFacs = factorGraph_->factorNodes().size();
|
||||
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
|
||||
nClusterVars, nClusterFacs,
|
||||
nWithoutNeighs);
|
||||
}
|
||||
|
||||
// cout << "Compressed Factor Graph" << endl;
|
||||
// factorGraph_->printGraphicalModel();
|
||||
// factorGraph_->print();
|
||||
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
|
||||
// abort();
|
||||
FgBpSolver::initializeSolver();
|
||||
BpSolver::initializeSolver();
|
||||
}
|
||||
|
||||
|
||||
@ -154,7 +154,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
|
||||
const FactorNodes& factorNeighbors = link->getVariable()->neighbors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
@ -192,8 +192,8 @@ Params
|
||||
CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
{
|
||||
Params msg;
|
||||
const FgVarNode* src = link->getVariable();
|
||||
const FgFacNode* dst = link->getFactor();
|
||||
const VarNode* src = link->getVariable();
|
||||
const FactorNode* dst = link->getFactor();
|
||||
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
|
@ -1,7 +1,7 @@
|
||||
#ifndef HORUS_CBP_H
|
||||
#define HORUS_CBP_H
|
||||
|
||||
#include "FgBpSolver.h"
|
||||
#include "BpSolver.h"
|
||||
#include "CFactorGraph.h"
|
||||
|
||||
class Factor;
|
||||
@ -9,7 +9,7 @@ class Factor;
|
||||
class CbpSolverLink : public SpLink
|
||||
{
|
||||
public:
|
||||
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn)
|
||||
CbpSolverLink (FactorNode* fn, VarNode* vn, unsigned c) : SpLink (fn, vn)
|
||||
{
|
||||
edgeCount_ = c;
|
||||
poweredMsg_.resize (vn->range(), LogAware::one());
|
||||
@ -34,10 +34,10 @@ class CbpSolverLink : public SpLink
|
||||
|
||||
|
||||
|
||||
class CbpSolver : public FgBpSolver
|
||||
class CbpSolver : public BpSolver
|
||||
{
|
||||
public:
|
||||
CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { }
|
||||
CbpSolver (FactorGraph& fg) : BpSolver (fg) { }
|
||||
|
||||
~CbpSolver (void);
|
||||
|
||||
|
@ -17,10 +17,10 @@ enum ElimHeuristic
|
||||
};
|
||||
|
||||
|
||||
class EgNode : public VarNode
|
||||
class EgNode : public Var
|
||||
{
|
||||
public:
|
||||
EgNode (VarId vid, unsigned range) : VarNode (vid, range) { }
|
||||
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
|
||||
|
||||
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
|
||||
|
||||
|
@ -29,7 +29,7 @@ Factor::Factor (VarId vid, unsigned range)
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const VarNodes& vars)
|
||||
Factor::Factor (const Vars& vars)
|
||||
{
|
||||
int nrParams = 1;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
@ -60,7 +60,7 @@ Factor::Factor (
|
||||
|
||||
|
||||
Factor::Factor (
|
||||
const VarNodes& vars,
|
||||
const Vars& vars,
|
||||
const Params& params,
|
||||
unsigned distId)
|
||||
{
|
||||
@ -267,7 +267,7 @@ Factor::getLabel (void) const
|
||||
ss << "f(" ;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
ss << VarNode (args_[i], ranges_[i]).label();
|
||||
ss << Var (args_[i], ranges_[i]).label();
|
||||
}
|
||||
ss << ")" ;
|
||||
return ss.str();
|
||||
@ -278,9 +278,9 @@ Factor::getLabel (void) const
|
||||
void
|
||||
Factor::print (void) const
|
||||
{
|
||||
VarNodes vars;
|
||||
Vars vars;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
vars.push_back (new VarNode (args_[i], ranges_[i]));
|
||||
vars.push_back (new Var (args_[i], ranges_[i]));
|
||||
}
|
||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "VarNode.h"
|
||||
#include "Var.h"
|
||||
#include "Indexer.h"
|
||||
#include "Util.h"
|
||||
|
||||
@ -260,11 +260,11 @@ class Factor : public TFactor<VarId>
|
||||
|
||||
Factor (VarId, unsigned);
|
||||
|
||||
Factor (const VarNodes&);
|
||||
Factor (const Vars&);
|
||||
|
||||
Factor (VarId, unsigned, const Params&);
|
||||
|
||||
Factor (const VarNodes&, const Params&,
|
||||
Factor (const Vars&, const Params&,
|
||||
unsigned = Util::maxUnsigned());
|
||||
|
||||
Factor (const VarIds&, const Ranges&, const Params&,
|
||||
|
@ -18,17 +18,17 @@ bool FactorGraph::orderFactorVariables = false;
|
||||
|
||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||
{
|
||||
const FgVarSet& vars = fg.getVarNodes();
|
||||
const VarNodes& vars = fg.varNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
FgVarNode* varNode = new FgVarNode (vars[i]);
|
||||
VarNode* varNode = new VarNode (vars[i]);
|
||||
addVariable (varNode);
|
||||
}
|
||||
|
||||
const FgFacSet& facs = fg.getFactorNodes();
|
||||
const FactorNodes& facs = fg.factorNodes();
|
||||
for (unsigned i = 0; i < facs.size(); i++) {
|
||||
FgFacNode* facNode = new FgFacNode (facs[i]);
|
||||
FactorNode* facNode = new FactorNode (facs[i]);
|
||||
addFactor (facNode);
|
||||
const FgVarSet& neighs = facs[i]->neighbors();
|
||||
const VarNodes& neighs = facs[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
|
||||
}
|
||||
@ -68,7 +68,7 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
for (unsigned i = 0; i < nVars; i++) {
|
||||
addVariable (new FgVarNode (i, domainSizes[i]));
|
||||
addVariable (new VarNode (i, domainSizes[i]));
|
||||
}
|
||||
|
||||
unsigned nFactors;
|
||||
@ -77,21 +77,21 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
unsigned nFactorVars;
|
||||
is >> nFactorVars;
|
||||
VarNodes neighs;
|
||||
Vars neighs;
|
||||
for (unsigned j = 0; j < nFactorVars; j++) {
|
||||
unsigned vid;
|
||||
is >> vid;
|
||||
FgVarNode* neigh = getFgVarNode (vid);
|
||||
VarNode* neigh = getVarNode (vid);
|
||||
if (!neigh) {
|
||||
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
|
||||
abort();
|
||||
}
|
||||
neighs.push_back (neigh);
|
||||
}
|
||||
FgFacNode* fn = new FgFacNode (new Factor (neighs));
|
||||
FactorNode* fn = new FactorNode (new Factor (neighs));
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
||||
addEdge (fn, static_cast<VarNode*> (neighs[j]));
|
||||
}
|
||||
}
|
||||
|
||||
@ -162,15 +162,15 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
vids.push_back (vid);
|
||||
}
|
||||
|
||||
VarNodes neighs;
|
||||
Vars neighs;
|
||||
unsigned nParams = 1;
|
||||
for (unsigned j = 0; j < nVars; j++) {
|
||||
unsigned dsize;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> dsize;
|
||||
FgVarNode* var = getFgVarNode (vids[j]);
|
||||
VarNode* var = getVarNode (vids[j]);
|
||||
if (var == 0) {
|
||||
var = new FgVarNode (vids[j], dsize);
|
||||
var = new VarNode (vids[j], dsize);
|
||||
addVariable (var);
|
||||
} else {
|
||||
if (var->range() != dsize) {
|
||||
@ -199,10 +199,10 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
FgFacNode* fn = new FgFacNode (new Factor (neighs, params));
|
||||
FactorNode* fn = new FactorNode (new Factor (neighs, params));
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
||||
addEdge (fn, static_cast<VarNode*> (neighs[j]));
|
||||
}
|
||||
}
|
||||
is.close();
|
||||
@ -224,7 +224,7 @@ FactorGraph::~FactorGraph (void)
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addVariable (FgVarNode* vn)
|
||||
FactorGraph::addVariable (VarNode* vn)
|
||||
{
|
||||
varNodes_.push_back (vn);
|
||||
vn->setIndex (varNodes_.size() - 1);
|
||||
@ -234,7 +234,7 @@ FactorGraph::addVariable (FgVarNode* vn)
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (FgFacNode* fn)
|
||||
FactorGraph::addFactor (FactorNode* fn)
|
||||
{
|
||||
facNodes_.push_back (fn);
|
||||
fn->setIndex (facNodes_.size() - 1);
|
||||
@ -245,7 +245,7 @@ FactorGraph::addFactor (FgFacNode* fn)
|
||||
void
|
||||
FactorGraph::addFactor (const Factor& factor)
|
||||
{
|
||||
FgFacNode* fn = new FgFacNode (factor);
|
||||
FactorNode* fn = new FactorNode (factor);
|
||||
addFactor (fn);
|
||||
const VarIds& vids = factor.arguments();
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
@ -257,7 +257,7 @@ FactorGraph::addFactor (const Factor& factor)
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
FgVarNode* vn = new FgVarNode (vids[i], factor.range (i));
|
||||
VarNode* vn = new VarNode (vids[i], factor.range (i));
|
||||
addVariable (vn);
|
||||
addEdge (vn, fn);
|
||||
}
|
||||
@ -267,7 +267,7 @@ FactorGraph::addFactor (const Factor& factor)
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
||||
FactorGraph::addEdge (VarNode* vn, FactorNode* fn)
|
||||
{
|
||||
vn->addNeighbor (fn);
|
||||
fn->addNeighbor (vn);
|
||||
@ -276,7 +276,7 @@ FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addEdge (FgFacNode* fn, FgVarNode* vn)
|
||||
FactorGraph::addEdge (FactorNode* fn, VarNode* vn)
|
||||
{
|
||||
fn->addNeighbor (vn);
|
||||
vn->addNeighbor (fn);
|
||||
@ -284,28 +284,6 @@ FactorGraph::addEdge (FgFacNode* fn, FgVarNode* vn)
|
||||
|
||||
|
||||
|
||||
VarNode*
|
||||
FactorGraph::getVariableNode (VarId vid) const
|
||||
{
|
||||
FgVarNode* vn = getFgVarNode (vid);
|
||||
assert (vn);
|
||||
return vn;
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarNodes
|
||||
FactorGraph::getVariableNodes (void) const
|
||||
{
|
||||
VarNodes vars;
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
vars.push_back (varNodes_[i]);
|
||||
}
|
||||
return vars;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::isTree (void) const
|
||||
{
|
||||
@ -348,7 +326,7 @@ FactorGraph::setIndexes (void)
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::printGraphicalModel (void) const
|
||||
FactorGraph::print (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
cout << "var id = " << varNodes_[i]->varId() << endl;
|
||||
@ -390,7 +368,7 @@ FactorGraph::exportToGraphViz (const char* fileName) const
|
||||
out << "\"" << ", shape=box]" << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const FgVarSet& myVars = facNodes_[i]->neighbors();
|
||||
const VarNodes& myVars = facNodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < myVars.size(); j++) {
|
||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||
out << " -- " ;
|
||||
@ -422,7 +400,7 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
|
||||
out << facNodes_.size() << endl;
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const FgVarSet& factorVars = facNodes_[i]->neighbors();
|
||||
const VarNodes& factorVars = facNodes_[i]->neighbors();
|
||||
out << factorVars.size();
|
||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||
out << " " << factorVars[j]->getIndex();
|
||||
@ -458,7 +436,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
}
|
||||
out << facNodes_.size() << endl << endl;
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const FgVarSet& factorVars = facNodes_[i]->neighbors();
|
||||
const VarNodes& factorVars = facNodes_[i]->neighbors();
|
||||
out << factorVars.size() << endl;
|
||||
for (int j = factorVars.size() - 1; j >= 0; j--) {
|
||||
out << factorVars[j]->varId() << " " ;
|
||||
@ -503,13 +481,13 @@ FactorGraph::containsCycle (void) const
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (
|
||||
const FgVarNode* v,
|
||||
const FgFacNode* p,
|
||||
const VarNode* v,
|
||||
const FactorNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedVars[v->getIndex()] = true;
|
||||
const FgFacSet& adjacencies = v->neighbors();
|
||||
const FactorNodes& adjacencies = v->neighbors();
|
||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||
int w = adjacencies[i]->getIndex();
|
||||
if (!visitedFactors[w]) {
|
||||
@ -528,13 +506,13 @@ FactorGraph::containsCycle (
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (
|
||||
const FgFacNode* v,
|
||||
const FgVarNode* p,
|
||||
const FactorNode* v,
|
||||
const VarNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedFactors[v->getIndex()] = true;
|
||||
const FgVarSet& adjacencies = v->neighbors();
|
||||
const VarNodes& adjacencies = v->neighbors();
|
||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||
int w = adjacencies[i]->getIndex();
|
||||
if (!visitedVars[w]) {
|
||||
|
@ -4,52 +4,51 @@
|
||||
#include <vector>
|
||||
|
||||
#include "Factor.h"
|
||||
#include "GraphicalModel.h"
|
||||
#include "BayesNet.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class FgFacNode;
|
||||
class FactorNode;
|
||||
|
||||
|
||||
class FgVarNode : public VarNode
|
||||
class VarNode : public Var
|
||||
{
|
||||
public:
|
||||
FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { }
|
||||
VarNode (VarId varId, unsigned nrStates) : Var (varId, nrStates) { }
|
||||
|
||||
FgVarNode (const VarNode* v) : VarNode (v) { }
|
||||
VarNode (const Var* v) : Var (v) { }
|
||||
|
||||
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); }
|
||||
void addNeighbor (FactorNode* fn) { neighs_.push_back (fn); }
|
||||
|
||||
const FgFacSet& neighbors (void) const { return neighs_; }
|
||||
const FactorNodes& neighbors (void) const { return neighs_; }
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
||||
DISALLOW_COPY_AND_ASSIGN (VarNode);
|
||||
|
||||
FgFacSet neighs_;
|
||||
FactorNodes neighs_;
|
||||
};
|
||||
|
||||
|
||||
class FgFacNode
|
||||
class FactorNode
|
||||
{
|
||||
public:
|
||||
FgFacNode (const FgFacNode* fn)
|
||||
FactorNode (const FactorNode* fn)
|
||||
{
|
||||
factor_ = new Factor (*fn->factor());
|
||||
index_ = -1;
|
||||
}
|
||||
|
||||
FgFacNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
|
||||
FactorNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
|
||||
|
||||
FgFacNode (const Factor& f) : factor_(new Factor (f)), index_(-1) { }
|
||||
FactorNode (const Factor& f) : factor_(new Factor (f)), index_(-1) { }
|
||||
|
||||
Factor* factor() const { return factor_; }
|
||||
|
||||
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
|
||||
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
|
||||
|
||||
const FgVarSet& neighbors (void) const { return neighs_; }
|
||||
const VarNodes& neighbors (void) const { return neighs_; }
|
||||
|
||||
int getIndex (void) const
|
||||
{
|
||||
@ -73,24 +72,24 @@ class FgFacNode
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FgFacNode);
|
||||
DISALLOW_COPY_AND_ASSIGN (FactorNode);
|
||||
|
||||
Factor* factor_;
|
||||
FgVarSet neighs_;
|
||||
VarNodes neighs_;
|
||||
int index_;
|
||||
};
|
||||
|
||||
|
||||
struct CompVarId
|
||||
{
|
||||
bool operator() (const VarNode* vn1, const VarNode* vn2) const
|
||||
bool operator() (const Var* v1, const Var* v2) const
|
||||
{
|
||||
return vn1->varId() < vn2->varId();
|
||||
return v1->varId() < v2->varId();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class FactorGraph : public GraphicalModel
|
||||
class FactorGraph
|
||||
{
|
||||
public:
|
||||
FactorGraph (void) { }
|
||||
@ -99,15 +98,15 @@ class FactorGraph : public GraphicalModel
|
||||
|
||||
~FactorGraph (void);
|
||||
|
||||
const FgVarSet& getVarNodes (void) const { return varNodes_; }
|
||||
const VarNodes& varNodes (void) const { return varNodes_; }
|
||||
|
||||
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
|
||||
const FactorNodes& factorNodes (void) const { return facNodes_; }
|
||||
|
||||
void setFromBayesNetwork (void) { fromBayesNet_ = true; }
|
||||
|
||||
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
|
||||
|
||||
FgVarNode* getFgVarNode (VarId vid) const
|
||||
VarNode* getVarNode (VarId vid) const
|
||||
{
|
||||
IndexMap::const_iterator it = varMap_.find (vid);
|
||||
return (it != varMap_.end()) ? varNodes_[it->second] : 0;
|
||||
@ -117,19 +116,15 @@ class FactorGraph : public GraphicalModel
|
||||
|
||||
void readFromLibDaiFormat (const char*);
|
||||
|
||||
void addVariable (FgVarNode*);
|
||||
void addVariable (VarNode*);
|
||||
|
||||
void addFactor (FgFacNode*);
|
||||
void addFactor (FactorNode*);
|
||||
|
||||
void addFactor (const Factor& factor);
|
||||
|
||||
void addEdge (FgVarNode*, FgFacNode*);
|
||||
void addEdge (VarNode*, FactorNode*);
|
||||
|
||||
void addEdge (FgFacNode*, FgVarNode*);
|
||||
|
||||
VarNode* getVariableNode (unsigned) const;
|
||||
|
||||
VarNodes getVariableNodes (void) const;
|
||||
void addEdge (FactorNode*, VarNode*);
|
||||
|
||||
bool isTree (void) const;
|
||||
|
||||
@ -137,7 +132,7 @@ class FactorGraph : public GraphicalModel
|
||||
|
||||
void setIndexes (void);
|
||||
|
||||
void printGraphicalModel (void) const;
|
||||
void print (void) const;
|
||||
|
||||
void exportToGraphViz (const char*) const;
|
||||
|
||||
@ -152,14 +147,14 @@ class FactorGraph : public GraphicalModel
|
||||
|
||||
bool containsCycle (void) const;
|
||||
|
||||
bool containsCycle (const FgVarNode*, const FgFacNode*,
|
||||
bool containsCycle (const VarNode*, const FactorNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
|
||||
bool containsCycle (const FgFacNode*, const FgVarNode*,
|
||||
bool containsCycle (const FactorNode*, const VarNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
|
||||
FgVarSet varNodes_;
|
||||
FgFacSet facNodes_;
|
||||
VarNodes varNodes_;
|
||||
FactorNodes facNodes_;
|
||||
|
||||
bool fromBayesNet_;
|
||||
DAGraph structure_;
|
||||
|
@ -1,496 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "FgBpSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Indexer.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
FgBpSolver::FgBpSolver (const FactorGraph& fg) : Solver (&fg)
|
||||
{
|
||||
factorGraph_ = &fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FgBpSolver::~FgBpSolver (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
||||
delete facsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::runSolver (void)
|
||||
{
|
||||
clock_t start;
|
||||
if (Constants::COLLECT_STATS) {
|
||||
start = clock();
|
||||
}
|
||||
runLoopySolver();
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
unsigned size = factorGraph_->getVarNodes().size();
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = factorGraph_->isTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
FgBpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
assert (factorGraph_->getFgVarNode (vid));
|
||||
FgVarNode* var = factorGraph_->getFgVarNode (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->range(), LogAware::noEvidence());
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::add (probs, links[i]->getMessage());
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::multiply (probs, links[i]->getMessage());
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
}
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
{
|
||||
int idx = -1;
|
||||
FgVarNode* vn = factorGraph_->getFgVarNode (jointVarIds[0]);
|
||||
const FgFacSet& factorNodes = vn->neighbors();
|
||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
||||
if (factorNodes[i]->factor()->contains (jointVarIds)) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (idx == -1) {
|
||||
return getJointByConditioning (jointVarIds);
|
||||
} else {
|
||||
Factor res (*factorNodes[idx]->factor());
|
||||
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Factor msg (links[i]->getVariable()->varId(),
|
||||
links[i]->getVariable()->range(),
|
||||
getVar2FactorMsg (links[i]));
|
||||
res.multiply (msg);
|
||||
}
|
||||
res.sumOutAllExcept (jointVarIds);
|
||||
res.reorderArguments (jointVarIds);
|
||||
res.normalize();
|
||||
Params jointDist = res.params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (jointDist);
|
||||
}
|
||||
return jointDist;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::runLoopySolver (void)
|
||||
{
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
|
||||
nIters_ ++;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader (" Iteration " + nIters_);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
updateMessage(links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::initializeSolver (void)
|
||||
{
|
||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
varsI_.reserve (varNodes.size());
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
varsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
|
||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
||||
delete facsI_[i];
|
||||
}
|
||||
facsI_.reserve (facNodes.size());
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
facsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
createLinks();
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
FgFacNode* src = links_[i]->getFactor();
|
||||
FgVarNode* dst = links_[i]->getVariable();
|
||||
ninf (dst)->addSpLink (links_[i]);
|
||||
ninf (src)->addSpLink (links_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::createLinks (void)
|
||||
{
|
||||
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const FgVarSet& neighbors = facNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FgBpSolver::converged (void)
|
||||
{
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIters_ == 0 || nIters_ == 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
} else {
|
||||
converged = true;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
if (Constants::DEBUG == 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
SpLink* link = *it;
|
||||
if (link->getResidual() < BpOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
updateMessage (link);
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
if (factorNeighbors[i] != link->getFactor()) {
|
||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) {
|
||||
calculateMessage (links[j]);
|
||||
SpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printDashedLine();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
||||
{
|
||||
const FgFacNode* src = link->getFactor();
|
||||
const FgVarNode* dst = link->getVariable();
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dst'
|
||||
unsigned msgSize = 1;
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
msgSize *= links[i]->getVariable()->range();
|
||||
}
|
||||
unsigned repetitions = 1;
|
||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||
if (Globals::logDomain) {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
} else {
|
||||
unsigned ds = links[i]->getVariable()->range();
|
||||
Util::add (msgProduct, Params (ds, 1.0), repetitions);
|
||||
repetitions *= ds;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message from " << links[i]->getVariable()->label();
|
||||
cout << ": " << endl;
|
||||
}
|
||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
} else {
|
||||
unsigned ds = links[i]->getVariable()->range();
|
||||
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
|
||||
repetitions *= ds;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Factor result (src->factor()->arguments(),
|
||||
src->factor()->ranges(),
|
||||
msgProduct);
|
||||
result.multiply (*(src->factor()));
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->params() << endl;
|
||||
cout << " factor product: " << result.params() << endl;
|
||||
}
|
||||
result.sumOutAllExcept (dst->varId());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " marginalized: " ;
|
||||
cout << result.params() << endl;
|
||||
}
|
||||
const Params& resultParams = result.params();
|
||||
Params& message = link->getNextMessage();
|
||||
for (unsigned i = 0; i < resultParams.size(); i++) {
|
||||
message[i] = resultParams[i];
|
||||
}
|
||||
LogAware::normalize (message);
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " curr msg: " << link->getMessage() << endl;
|
||||
cout << " next msg: " << message << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
{
|
||||
const FgVarNode* src = link->getVariable();
|
||||
const FgFacNode* dst = link->getFactor();
|
||||
Params msg;
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
msg[src->getEvidence()] = LogAware::withEvidence();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << msg;
|
||||
}
|
||||
} else {
|
||||
msg.resize (src->range(), LogAware::one());
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << msg;
|
||||
}
|
||||
const SpLinkSet& links = ninf (src)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
Util::add (msg, links[i]->getMessage());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
Util::multiply (msg, links[i]->getMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " x " << links[i]->getMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
FgVarSet jointVars;
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
assert (factorGraph_->getFgVarNode (jointVarIds[i]));
|
||||
jointVars.push_back (factorGraph_->getFgVarNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
FactorGraph* fg = new FactorGraph (*factorGraph_);
|
||||
FgBpSolver solver (*fg);
|
||||
solver.runSolver();
|
||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
|
||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
||||
assert (jointVars[i]->hasEvidence() == false);
|
||||
Params newBeliefs;
|
||||
VarNodes observedVars;
|
||||
for (unsigned j = 0; j < observedVids.size(); j++) {
|
||||
observedVars.push_back (fg->getFgVarNode (observedVids[j]));
|
||||
}
|
||||
StatesIndexer idx (observedVars, false);
|
||||
while (idx.valid()) {
|
||||
for (unsigned j = 0; j < observedVars.size(); j++) {
|
||||
observedVars[j]->setEvidence (idx[j]);
|
||||
}
|
||||
++ idx;
|
||||
FgBpSolver solver (*fg);
|
||||
solver.runSolver();
|
||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
}
|
||||
|
||||
int count = -1;
|
||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->range() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
observedVids.push_back (jointVars[i]->varId());
|
||||
}
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::printLinkInformation (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
SpLink* l = links_[i];
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " ;
|
||||
cout << l->getMessage() << endl;
|
||||
cout << " next msg = " ;
|
||||
cout << l->getNextMessage() << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
@ -1,186 +0,0 @@
|
||||
#ifndef HORUS_FGBPSOLVER_H
|
||||
#define HORUS_FGBPSOLVER_H
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "Factor.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Util.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class SpLink
|
||||
{
|
||||
public:
|
||||
SpLink (FgFacNode* fn, FgVarNode* vn)
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
msgSended_ = false;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
virtual ~SpLink (void) { };
|
||||
|
||||
FgFacNode* getFactor (void) const { return fac_; }
|
||||
|
||||
FgVarNode* getVariable (void) const { return var_; }
|
||||
|
||||
const Params& getMessage (void) const { return *currMsg_; }
|
||||
|
||||
Params& getNextMessage (void) { return *nextMsg_; }
|
||||
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
|
||||
double getResidual (void) const { return residual_; }
|
||||
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = LogAware::getMaxNorm (v1_,v2_);
|
||||
}
|
||||
|
||||
virtual void updateMessage (void)
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
msgSended_ = true;
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << fac_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->label();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
FgFacNode* fac_;
|
||||
FgVarNode* var_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
|
||||
typedef vector<SpLink*> SpLinkSet;
|
||||
|
||||
|
||||
class SPNodeInfo
|
||||
{
|
||||
public:
|
||||
void addSpLink (SpLink* link) { links_.push_back (link); }
|
||||
const SpLinkSet& getLinks (void) { return links_; }
|
||||
private:
|
||||
SpLinkSet links_;
|
||||
};
|
||||
|
||||
|
||||
class FgBpSolver : public Solver
|
||||
{
|
||||
public:
|
||||
FgBpSolver (const FactorGraph&);
|
||||
|
||||
virtual ~FgBpSolver (void);
|
||||
|
||||
void runSolver (void);
|
||||
|
||||
virtual Params getPosterioriOf (VarId);
|
||||
|
||||
virtual Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
protected:
|
||||
virtual void initializeSolver (void);
|
||||
|
||||
virtual void createLinks (void);
|
||||
|
||||
virtual void maxResidualSchedule (void);
|
||||
|
||||
virtual void calculateFactor2VariableMsg (SpLink*) const;
|
||||
|
||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
||||
|
||||
virtual Params getJointByConditioning (const VarIds&) const;
|
||||
|
||||
virtual void printLinkInformation (void) const;
|
||||
|
||||
SPNodeInfo* ninf (const FgVarNode* var) const
|
||||
{
|
||||
return varsI_[var->getIndex()];
|
||||
}
|
||||
|
||||
SPNodeInfo* ninf (const FgFacNode* fac) const
|
||||
{
|
||||
return facsI_[fac->getIndex()];
|
||||
}
|
||||
|
||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
|
||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
|
||||
void updateMessage (SpLink* link)
|
||||
{
|
||||
link->updateMessage();
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
struct CompareResidual
|
||||
{
|
||||
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
||||
{
|
||||
return link1->getResidual() > link2->getResidual();
|
||||
}
|
||||
};
|
||||
|
||||
SpLinkSet links_;
|
||||
unsigned nIters_;
|
||||
vector<SPNodeInfo*> varsI_;
|
||||
vector<SPNodeInfo*> facsI_;
|
||||
const FactorGraph* factorGraph_;
|
||||
|
||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
|
||||
SpLinkMap linkMap_;
|
||||
|
||||
private:
|
||||
void runLoopySolver (void);
|
||||
bool converged (void);
|
||||
};
|
||||
|
||||
#endif // HORUS_FGBPSOLVER_H
|
||||
|
@ -1,64 +0,0 @@
|
||||
#ifndef HORUS_GRAPHICALMODEL_H
|
||||
#define HORUS_GRAPHICALMODEL_H
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "VarNode.h"
|
||||
#include "Util.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
struct VarInfo
|
||||
{
|
||||
VarInfo (string l, const States& sts) : label(l), states(sts) { }
|
||||
string label;
|
||||
States states;
|
||||
};
|
||||
|
||||
|
||||
class GraphicalModel
|
||||
{
|
||||
public:
|
||||
virtual ~GraphicalModel (void) { };
|
||||
|
||||
virtual VarNode* getVariableNode (VarId) const = 0;
|
||||
|
||||
virtual VarNodes getVariableNodes (void) const = 0;
|
||||
|
||||
virtual void printGraphicalModel (void) const = 0;
|
||||
|
||||
static void addVariableInformation (
|
||||
VarId vid, string label, const States& states)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid) == false);
|
||||
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
||||
}
|
||||
|
||||
static VarInfo getVarInformation (VarId vid)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid));
|
||||
return varsInfo_.find (vid)->second;
|
||||
}
|
||||
|
||||
static bool variablesHaveInformation (void)
|
||||
{
|
||||
return varsInfo_.size() != 0;
|
||||
}
|
||||
|
||||
static void clearVariablesInformation (void)
|
||||
{
|
||||
varsInfo_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
static unordered_map<VarId,VarInfo> varsInfo_;
|
||||
};
|
||||
|
||||
#endif // HORUS_GRAPHICALMODEL_H
|
||||
|
@ -11,30 +11,27 @@
|
||||
|
||||
using namespace std;
|
||||
|
||||
class VarNode;
|
||||
class BayesNode;
|
||||
class FgVarNode;
|
||||
class FgFacNode;
|
||||
class Var;
|
||||
class Factor;
|
||||
class VarNode;
|
||||
class FactorNode;
|
||||
|
||||
typedef vector<double> Params;
|
||||
typedef unsigned VarId;
|
||||
typedef vector<VarId> VarIds;
|
||||
typedef vector<Var*> Vars;
|
||||
typedef vector<VarNode*> VarNodes;
|
||||
typedef vector<BayesNode*> BnNodeSet;
|
||||
typedef vector<FgVarNode*> FgVarSet;
|
||||
typedef vector<FgFacNode*> FgFacSet;
|
||||
typedef vector<Factor*> FactorSet;
|
||||
typedef vector<FactorNode*> FactorNodes;
|
||||
typedef vector<Factor*> Factors;
|
||||
typedef vector<string> States;
|
||||
typedef vector<unsigned> Ranges;
|
||||
|
||||
|
||||
enum InfAlgorithms
|
||||
{
|
||||
VE, // variable elimination
|
||||
BN_BP, // bayesian network belief propagation
|
||||
FG_BP, // factor graph belief propagation
|
||||
CBP // counting bp solver
|
||||
VE, // variable elimination
|
||||
BP, // belief propagation
|
||||
CBP // counting belief propagation
|
||||
};
|
||||
|
||||
|
||||
|
@ -5,13 +5,13 @@
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "VarElimSolver.h"
|
||||
#include "FgBpSolver.h"
|
||||
#include "BpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
void processArguments (FactorGraph&, int, const char* []);
|
||||
void runSolver (Solver*, const VarNodes&);
|
||||
void runSolver (Solver*, const VarIds&);
|
||||
|
||||
const string USAGE = "usage: \
|
||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||
@ -48,7 +48,7 @@ main (int argc, const char* argv[])
|
||||
void
|
||||
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
{
|
||||
VarNodes queryVars;
|
||||
VarIds queryIds;
|
||||
for (int i = 2; i < argc; i++) {
|
||||
const string& arg = argv[i];
|
||||
if (arg.find ('=') == std::string::npos) {
|
||||
@ -62,9 +62,9 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
stringstream ss;
|
||||
ss << arg;
|
||||
ss >> vid;
|
||||
VarNode* queryVar = fg.getFgVarNode (vid);
|
||||
VarNode* queryVar = fg.getVarNode (vid);
|
||||
if (queryVar) {
|
||||
queryVars.push_back (queryVar);
|
||||
queryIds.push_back (vid);
|
||||
} else {
|
||||
cerr << "error: there isn't a variable with " ;
|
||||
cerr << "`" << vid << "' as id" ;
|
||||
@ -93,7 +93,7 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
stringstream ss;
|
||||
ss << arg.substr (0, pos);
|
||||
ss >> vid;
|
||||
VarNode* var = fg.getFgVarNode (vid);
|
||||
VarNode* var = fg.getVarNode (vid);
|
||||
if (var) {
|
||||
if (!Util::isInteger (arg.substr (pos + 1))) {
|
||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
||||
@ -127,8 +127,8 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
case InfAlgorithms::VE:
|
||||
solver = new VarElimSolver (fg);
|
||||
break;
|
||||
case InfAlgorithms::FG_BP:
|
||||
solver = new FgBpSolver (fg);
|
||||
case InfAlgorithms::BP:
|
||||
solver = new BpSolver (fg);
|
||||
break;
|
||||
case InfAlgorithms::CBP:
|
||||
solver = new CbpSolver (fg);
|
||||
@ -136,27 +136,23 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
runSolver (solver, queryVars);
|
||||
runSolver (solver, queryIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
runSolver (Solver* solver, const VarNodes& queryVars)
|
||||
runSolver (Solver* solver, const VarIds& queryIds)
|
||||
{
|
||||
VarIds vids;
|
||||
for (unsigned i = 0; i < queryVars.size(); i++) {
|
||||
vids.push_back (queryVars[i]->varId());
|
||||
}
|
||||
if (queryVars.size() == 0) {
|
||||
if (queryIds.size() == 0) {
|
||||
solver->runSolver();
|
||||
solver->printAllPosterioris();
|
||||
} else if (queryVars.size() == 1) {
|
||||
} else if (queryIds.size() == 1) {
|
||||
solver->runSolver();
|
||||
solver->printPosterioriOf (vids[0]);
|
||||
solver->printPosterioriOf (queryIds[0]);
|
||||
} else {
|
||||
solver->runSolver();
|
||||
solver->printJointDistributionOf (vids);
|
||||
solver->printJointDistributionOf (queryIds);
|
||||
}
|
||||
delete solver;
|
||||
}
|
||||
|
@ -11,7 +11,7 @@
|
||||
#include "FactorGraph.h"
|
||||
#include "FoveSolver.h"
|
||||
#include "VarElimSolver.h"
|
||||
#include "FgBpSolver.h"
|
||||
#include "BpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "BayesBall.h"
|
||||
@ -241,8 +241,8 @@ createGroundNetwork (void)
|
||||
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
|
||||
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
|
||||
cout << vid << " == " << ev << endl;
|
||||
assert (fg->getFgVarNode (vid));
|
||||
fg->getFgVarNode (vid)->setEvidence (ev);
|
||||
assert (fg->getVarNode (vid));
|
||||
fg->getVarNode (vid)->setEvidence (ev);
|
||||
evidenceList = YAP_TailOfTerm (evidenceList);
|
||||
}
|
||||
|
||||
@ -348,7 +348,6 @@ runGroundSolver (void)
|
||||
taskList = YAP_TailOfTerm (taskList);
|
||||
}
|
||||
|
||||
fg->printGraphicalModel();
|
||||
vector<Params> results;
|
||||
if (Globals::infAlgorithm == InfAlgorithms::VE) {
|
||||
runVeSolver (fg, tasks, results);
|
||||
@ -414,8 +413,8 @@ void runBpSolver (
|
||||
mfg = BayesBall::getMinimalFactorGraph (
|
||||
*fg, VarIds (vids.begin(),vids.end()));
|
||||
}
|
||||
if (Globals::infAlgorithm == InfAlgorithms::FG_BP) {
|
||||
solver = new FgBpSolver (*mfg);
|
||||
if (Globals::infAlgorithm == InfAlgorithms::BP) {
|
||||
solver = new BpSolver (*mfg);
|
||||
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
||||
CFactorGraph::checkForIdenticalFactors = false;
|
||||
solver = new CbpSolver (*mfg);
|
||||
@ -494,7 +493,7 @@ setBayesNetParams (void)
|
||||
int
|
||||
setExtraVarsInfo (void)
|
||||
{
|
||||
GraphicalModel::clearVariablesInformation();
|
||||
Var::clearVariablesInformation();
|
||||
YAP_Term varsInfoL = YAP_ARG2;
|
||||
while (varsInfoL != YAP_TermNil()) {
|
||||
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
||||
@ -507,7 +506,7 @@ setExtraVarsInfo (void)
|
||||
states.push_back ((char*) YAP_AtomName (atom));
|
||||
statesL = YAP_TailOfTerm (statesL);
|
||||
}
|
||||
GraphicalModel::addVariableInformation (vid,
|
||||
Var::addVariableInformation (vid,
|
||||
(char*) YAP_AtomName (label), states);
|
||||
varsInfoL = YAP_TailOfTerm (varsInfoL);
|
||||
}
|
||||
@ -524,10 +523,8 @@ setHorusFlag (void)
|
||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||
if ( value == "ve") {
|
||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||
} else if (value == "bn_bp") {
|
||||
Globals::infAlgorithm = InfAlgorithms::BN_BP;
|
||||
} else if (value == "fg_bp") {
|
||||
Globals::infAlgorithm = InfAlgorithms::FG_BP;
|
||||
} else if (value == "bp") {
|
||||
Globals::infAlgorithm = InfAlgorithms::BP;
|
||||
} else if (value == "cbp") {
|
||||
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||
} else {
|
||||
|
@ -8,7 +8,7 @@
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "VarNode.h"
|
||||
#include "Var.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ class StatesIndexer
|
||||
}
|
||||
}
|
||||
|
||||
StatesIndexer (const VarNodes& vars, bool calcOffsets = true)
|
||||
StatesIndexer (const Vars& vars, bool calcOffsets = true)
|
||||
{
|
||||
size_ = 1;
|
||||
indices_.resize (vars.size(), 0);
|
||||
|
@ -45,7 +45,6 @@ CWD=$(PWD)
|
||||
|
||||
|
||||
HEADERS = \
|
||||
$(srcdir)/GraphicalModel.h \
|
||||
$(srcdir)/BayesNet.h \
|
||||
$(srcdir)/BayesBall.h \
|
||||
$(srcdir)/ElimGraph.h \
|
||||
@ -55,10 +54,10 @@ HEADERS = \
|
||||
$(srcdir)/ConstraintTree.h \
|
||||
$(srcdir)/Solver.h \
|
||||
$(srcdir)/VarElimSolver.h \
|
||||
$(srcdir)/FgBpSolver.h \
|
||||
$(srcdir)/BpSolver.h \
|
||||
$(srcdir)/CbpSolver.h \
|
||||
$(srcdir)/FoveSolver.h \
|
||||
$(srcdir)/VarNode.h \
|
||||
$(srcdir)/Var.h \
|
||||
$(srcdir)/Indexer.h \
|
||||
$(srcdir)/Parfactor.h \
|
||||
$(srcdir)/ProbFormula.h \
|
||||
@ -77,10 +76,10 @@ CPP_SOURCES = \
|
||||
$(srcdir)/Factor.cpp \
|
||||
$(srcdir)/CFactorGraph.cpp \
|
||||
$(srcdir)/ConstraintTree.cpp \
|
||||
$(srcdir)/VarNode.cpp \
|
||||
$(srcdir)/Var.cpp \
|
||||
$(srcdir)/Solver.cpp \
|
||||
$(srcdir)/VarElimSolver.cpp \
|
||||
$(srcdir)/FgBpSolver.cpp \
|
||||
$(srcdir)/BpSolver.cpp \
|
||||
$(srcdir)/CbpSolver.cpp \
|
||||
$(srcdir)/FoveSolver.cpp \
|
||||
$(srcdir)/Parfactor.cpp \
|
||||
@ -100,10 +99,10 @@ OBJS = \
|
||||
Factor.o \
|
||||
CFactorGraph.o \
|
||||
ConstraintTree.o \
|
||||
VarNode.o \
|
||||
Var.o \
|
||||
Solver.o \
|
||||
VarElimSolver.o \
|
||||
FgBpSolver.o \
|
||||
BpSolver.o \
|
||||
CbpSolver.o \
|
||||
FoveSolver.o \
|
||||
Parfactor.o \
|
||||
@ -122,10 +121,10 @@ HCLI_OBJS = \
|
||||
Factor.o \
|
||||
CFactorGraph.o \
|
||||
ConstraintTree.o \
|
||||
VarNode.o \
|
||||
Var.o \
|
||||
Solver.o \
|
||||
VarElimSolver.o \
|
||||
FgBpSolver.o \
|
||||
BpSolver.o \
|
||||
CbpSolver.o \
|
||||
FoveSolver.o \
|
||||
Parfactor.o \
|
||||
|
@ -5,7 +5,7 @@
|
||||
void
|
||||
Solver::printAllPosterioris (void)
|
||||
{
|
||||
const VarNodes& vars = gm_->getVariableNodes();
|
||||
const VarNodes& vars = fg_.varNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
printPosterioriOf (vars[i]->varId());
|
||||
}
|
||||
@ -16,11 +16,11 @@ Solver::printAllPosterioris (void)
|
||||
void
|
||||
Solver::printPosterioriOf (VarId vid)
|
||||
{
|
||||
VarNode* var = gm_->getVariableNode (vid);
|
||||
VarNode* vn = fg_.getVarNode (vid);
|
||||
const Params& posterioriDist = getPosterioriOf (vid);
|
||||
const States& states = var->states();
|
||||
const States& states = vn->states();
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
cout << "P(" << var->label() << "=" << states[i] << ") = " ;
|
||||
cout << "P(" << vn->label() << "=" << states[i] << ") = " ;
|
||||
cout << setprecision (Constants::PRECISION) << posterioriDist[i];
|
||||
cout << endl;
|
||||
}
|
||||
@ -32,12 +32,12 @@ Solver::printPosterioriOf (VarId vid)
|
||||
void
|
||||
Solver::printJointDistributionOf (const VarIds& vids)
|
||||
{
|
||||
VarNodes vars;
|
||||
Vars vars;
|
||||
VarIds vidsWithoutEvidence;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
VarNode* var = gm_->getVariableNode (vids[i]);
|
||||
if (var->hasEvidence() == false) {
|
||||
vars.push_back (var);
|
||||
VarNode* vn = fg_.getVarNode (vids[i]);
|
||||
if (vn->hasEvidence() == false) {
|
||||
vars.push_back (vn);
|
||||
vidsWithoutEvidence.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
|
@ -3,15 +3,16 @@
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
#include "GraphicalModel.h"
|
||||
#include "VarNode.h"
|
||||
#include "Var.h"
|
||||
#include "FactorGraph.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
class Solver
|
||||
{
|
||||
public:
|
||||
Solver (const GraphicalModel* gm) : gm_(gm) { }
|
||||
Solver (const FactorGraph& fg) : fg_(fg) { }
|
||||
|
||||
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||
|
||||
@ -28,7 +29,7 @@ class Solver
|
||||
void printJointDistributionOf (const VarIds& vids);
|
||||
|
||||
private:
|
||||
const GraphicalModel* gm_;
|
||||
const FactorGraph& fg_;
|
||||
};
|
||||
|
||||
#endif // HORUS_SOLVER_H
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
#include "Util.h"
|
||||
#include "Indexer.h"
|
||||
#include "GraphicalModel.h"
|
||||
|
||||
|
||||
namespace Globals {
|
||||
@ -25,12 +24,11 @@ Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
|
||||
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
|
||||
//Schedule schedule = BpOptions::Schedule::PARALLEL;
|
||||
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
||||
double accuracy = 0.0001;
|
||||
unsigned maxIter = 1000;
|
||||
double accuracy = 0.0001;
|
||||
unsigned maxIter = 1000;
|
||||
}
|
||||
|
||||
|
||||
unordered_map<VarId, VarInfo> GraphicalModel::varsInfo_;
|
||||
|
||||
vector<NetInfo> Statistics::netInfo_;
|
||||
vector<CompressInfo> Statistics::compressInfo_;
|
||||
@ -139,7 +137,7 @@ parametersToString (const Params& v, unsigned precision)
|
||||
|
||||
|
||||
vector<string>
|
||||
getJointStateStrings (const VarNodes& vars)
|
||||
getJointStateStrings (const Vars& vars)
|
||||
{
|
||||
StatesIndexer idx (vars);
|
||||
vector<string> jointStrings;
|
||||
@ -401,10 +399,9 @@ Statistics::getStatisticString (void)
|
||||
stringstream ss2, ss3, ss4, ss1;
|
||||
ss1 << "running mode: " ;
|
||||
switch (Globals::infAlgorithm) {
|
||||
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
||||
case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break;
|
||||
case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break;
|
||||
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
|
||||
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
||||
case InfAlgorithms::BP: ss1 << "bp" << endl; break;
|
||||
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
|
||||
}
|
||||
ss1 << "message schedule: " ;
|
||||
switch (BpOptions::schedule) {
|
||||
|
@ -62,7 +62,7 @@ bool isInteger (const string&);
|
||||
|
||||
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
||||
|
||||
vector<string> getJointStateStrings (const VarNodes&);
|
||||
vector<string> getJointStateStrings (const Vars&);
|
||||
|
||||
void printHeader (string, std::ostream& os = std::cout);
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (&fg)
|
||||
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (fg)
|
||||
{
|
||||
factorGraph_ = &fg;
|
||||
}
|
||||
@ -23,8 +23,8 @@ VarElimSolver::~VarElimSolver (void)
|
||||
Params
|
||||
VarElimSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
assert (factorGraph_->getFgVarNode (vid));
|
||||
FgVarNode* vn = factorGraph_->getFgVarNode (vid);
|
||||
assert (factorGraph_->getVarNode (vid));
|
||||
VarNode* vn = factorGraph_->getVarNode (vid);
|
||||
if (vn->hasEvidence()) {
|
||||
Params params (vn->range(), 0.0);
|
||||
params[vn->getEvidence()] = 1.0;
|
||||
@ -57,11 +57,11 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
||||
void
|
||||
VarElimSolver::createFactorList (void)
|
||||
{
|
||||
const FgFacSet& factorNodes = factorGraph_->getFactorNodes();
|
||||
const FactorNodes& factorNodes = factorGraph_->factorNodes();
|
||||
factorList_.reserve (factorNodes.size() * 2);
|
||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
||||
factorList_.push_back (new Factor (*factorNodes[i]->factor()));
|
||||
const FgVarSet& neighs = factorNodes[i]->neighbors();
|
||||
const VarNodes& neighs = factorNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
unordered_map<VarId,vector<unsigned> >::iterator it
|
||||
= varFactors_.find (neighs[j]->varId());
|
||||
@ -79,7 +79,7 @@ VarElimSolver::createFactorList (void)
|
||||
void
|
||||
VarElimSolver::absorveEvidence (void)
|
||||
{
|
||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
||||
const VarNodes& varNodes = factorGraph_->varNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
if (varNodes[i]->hasEvidence()) {
|
||||
const vector<unsigned>& idxs =
|
||||
@ -126,7 +126,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
||||
|
||||
VarIds unobservedVids;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
if (factorGraph_->getFgVarNode (vids[i])->hasEvidence() == false) {
|
||||
if (factorGraph_->getVarNode (vids[i])->hasEvidence() == false) {
|
||||
unobservedVids.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
|
@ -1,100 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
|
||||
#include "VarNode.h"
|
||||
#include "GraphicalModel.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
VarNode::VarNode (const VarNode* v)
|
||||
{
|
||||
varId_ = v->varId();
|
||||
range_ = v->range();
|
||||
evidence_ = v->getEvidence();
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarNode::VarNode (VarId varId, unsigned range, int evidence)
|
||||
{
|
||||
assert (range != 0);
|
||||
assert (evidence < (int) range);
|
||||
varId_ = varId;
|
||||
range_ = range;
|
||||
evidence_ = evidence;
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
VarNode::isValidState (int stateIndex)
|
||||
{
|
||||
return stateIndex >= 0 && stateIndex < (int) range_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
VarNode::isValidState (const string& stateName)
|
||||
{
|
||||
States states = GraphicalModel::getVarInformation (varId_).states;
|
||||
return Util::contains (states, stateName);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarNode::setEvidence (int ev)
|
||||
{
|
||||
assert (ev < (int) range_);
|
||||
evidence_ = ev;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarNode::setEvidence (const string& ev)
|
||||
{
|
||||
States states = GraphicalModel::getVarInformation (varId_).states;
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
if (states[i] == ev) {
|
||||
evidence_ = i;
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert (false);
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
VarNode::label (void) const
|
||||
{
|
||||
if (GraphicalModel::variablesHaveInformation()) {
|
||||
return GraphicalModel::getVarInformation (varId_).label;
|
||||
}
|
||||
stringstream ss;
|
||||
ss << "x" << varId_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
States
|
||||
VarNode::states (void) const
|
||||
{
|
||||
if (GraphicalModel::variablesHaveInformation()) {
|
||||
return GraphicalModel::getVarInformation (varId_).states;
|
||||
}
|
||||
States states;
|
||||
for (unsigned i = 0; i < range_; i++) {
|
||||
stringstream ss;
|
||||
ss << i ;
|
||||
states.push_back (ss.str());
|
||||
}
|
||||
return states;
|
||||
}
|
||||
|
@ -1,71 +0,0 @@
|
||||
#ifndef HORUS_VARNODE_H
|
||||
#define HORUS_VARNODE_H
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class VarNode
|
||||
{
|
||||
public:
|
||||
VarNode (const VarNode*);
|
||||
|
||||
VarNode (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
||||
|
||||
virtual ~VarNode (void) { };
|
||||
|
||||
unsigned varId (void) const { return varId_; }
|
||||
|
||||
unsigned range (void) const { return range_; }
|
||||
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
|
||||
unsigned getIndex (void) const { return index_; }
|
||||
|
||||
void setIndex (unsigned idx) { index_ = idx; }
|
||||
|
||||
operator unsigned () const { return index_; }
|
||||
|
||||
bool hasEvidence (void) const
|
||||
{
|
||||
return evidence_ != Constants::NO_EVIDENCE;
|
||||
}
|
||||
|
||||
bool operator== (const VarNode& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ == var.varId();
|
||||
}
|
||||
|
||||
bool operator!= (const VarNode& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ != var.varId();
|
||||
}
|
||||
|
||||
bool isValidState (int);
|
||||
|
||||
bool isValidState (const string&);
|
||||
|
||||
void setEvidence (int);
|
||||
|
||||
void setEvidence (const string&);
|
||||
|
||||
string label (void) const;
|
||||
|
||||
States states (void) const;
|
||||
|
||||
private:
|
||||
VarId varId_;
|
||||
unsigned range_;
|
||||
int evidence_;
|
||||
unsigned index_;
|
||||
|
||||
};
|
||||
|
||||
#endif // BP_VARNODE_H
|
||||
|
@ -1,50 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
cp ~/bin/yap ~/bin/town_fgbp
|
||||
YAP=~/bin/town_fgbp
|
||||
|
||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
||||
OUT_FILE_NAME=fb_bp.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
|
||||
function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_horus:set_horus_flag(use_logarithms, true),
|
||||
$extra_flag1, $extra_flag2,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
format(S, '$3: ~15+ ',[]),
|
||||
close(S).
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||
#run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||
#run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||
#run_solver town_50000 $1 town_50000 $3 $4 $5
|
||||
#run_solver town_100000 $1 town_100000 $3 $4 $5
|
||||
#run_solver town_500000 $1 town_500000 $3 $4 $5
|
||||
#run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
||||
}
|
||||
|
||||
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
||||
|
@ -9,7 +9,7 @@ OUT_FILE_NAME=results.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
# yap -g "['../../../../examples/School/sch32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_horus:set_horus_flag(inf_alg,fg_bp), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
|
||||
# yap -g "['../../../../examples/School/sch32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_horus:set_horus_flag(inf_alg, bp), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
|
||||
|
||||
function run_solver
|
||||
{
|
||||
@ -58,24 +58,21 @@ function run_all_graphs
|
||||
}
|
||||
|
||||
|
||||
#run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
||||
#run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
|
||||
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
||||
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||
#run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
||||
#run_all_graphs bp "bp(seq_fixed) " bp seq_fixed
|
||||
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||
exit
|
||||
|
||||
|
||||
run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
||||
run_all_graphs bp "hve(min_weight) " ve min_weight
|
||||
run_all_graphs bp "hve(min_fill) " ve min_fill
|
||||
run_all_graphs bp "hve(w_min_fill) " ve weighted_min_fill
|
||||
run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
|
||||
run_all_graphs bp "bn_bp(max_residual) " bn_bp max_residual
|
||||
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
||||
run_all_graphs bp "fg_bp(max_residual) " fg_bp max_residual
|
||||
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||
run_all_graphs bp "cbp(max_residual) " cbp max_residual
|
||||
run_all_graphs gibbs "gibbs "
|
||||
run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
||||
run_all_graphs bp "hve(min_weight) " ve min_weight
|
||||
run_all_graphs bp "hve(min_fill) " ve min_fill
|
||||
run_all_graphs bp "hve(w_min_fill) " ve weighted_min_fill
|
||||
run_all_graphs bp "bp(seq_fixed) " bp seq_fixed
|
||||
run_all_graphs bp "bp(max_residual) " bp max_residual
|
||||
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||
run_all_graphs bp "cbp(max_residual) " cbp max_residual
|
||||
run_all_graphs gibbs "gibbs "
|
||||
echo "************************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver ve" >> "$OUT_FILE_NAME"
|
||||
echo "************************************************************************" >> "$OUT_FILE_NAME"
|
||||
|
@ -1,81 +0,0 @@
|
||||
<?xml version="1.0" encoding="US-ASCII"?>
|
||||
|
||||
<!--
|
||||
|
||||
B E
|
||||
\ /
|
||||
\ /
|
||||
A
|
||||
/ \
|
||||
/ \
|
||||
J M
|
||||
|
||||
-->
|
||||
|
||||
|
||||
<BIF VERSION="0.3">
|
||||
<NETWORK>
|
||||
<NAME>Simple Loop</NAME>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>B</NAME>
|
||||
<OUTCOME>b1</OUTCOME>
|
||||
<OUTCOME>b2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>E</NAME>
|
||||
<OUTCOME>e1</OUTCOME>
|
||||
<OUTCOME>e2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>A</NAME>
|
||||
<OUTCOME>a1</OUTCOME>
|
||||
<OUTCOME>a2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>J</NAME>
|
||||
<OUTCOME>j1</OUTCOME>
|
||||
<OUTCOME>j2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>M</NAME>
|
||||
<OUTCOME>m1</OUTCOME>
|
||||
<OUTCOME>m2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>B</FOR>
|
||||
<TABLE> .001 .999 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>E</FOR>
|
||||
<TABLE> .002 .998 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>A</FOR>
|
||||
<GIVEN>B</GIVEN>
|
||||
<GIVEN>E</GIVEN>
|
||||
<TABLE> .95 .05 .94 .06 .29 .71 .001 .999 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>J</FOR>
|
||||
<GIVEN>A</GIVEN>
|
||||
<TABLE> .9 .1 .05 .95 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>M</FOR>
|
||||
<GIVEN>A</GIVEN>
|
||||
<TABLE> .7 .3 .01 .99 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
</NETWORK>
|
||||
</BIF>
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
%:- set_pfl_flag(solver,ve).
|
||||
:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
|
||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,fg_bp).
|
||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
|
||||
%:- set_pfl_flag(solver,fove).
|
||||
|
||||
% :- yap_flag(write_strings, off).
|
||||
|
Reference in New Issue
Block a user