renamings and delete bn_bp stuff
This commit is contained in:
parent
b28ee8fb3a
commit
d1b25f0864
@ -60,7 +60,7 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
|||||||
void
|
void
|
||||||
BayesBall::constructGraph (FactorGraph* fg) const
|
BayesBall::constructGraph (FactorGraph* fg) const
|
||||||
{
|
{
|
||||||
const FgFacSet& facNodes = fg_.getFactorNodes();
|
const FactorNodes& facNodes = fg_.factorNodes();
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
const DAGraphNode* n = dag_.getNode (
|
const DAGraphNode* n = dag_.getNode (
|
||||||
facNodes[i]->factor()->argument (0));
|
facNodes[i]->factor()->argument (0));
|
||||||
|
@ -6,11 +6,9 @@
|
|||||||
#include <list>
|
#include <list>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
@ -6,19 +6,19 @@
|
|||||||
#include <list>
|
#include <list>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
#include "Var.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
class VarNode;
|
class Var;
|
||||||
|
|
||||||
class DAGraphNode : public VarNode
|
class DAGraphNode : public Var
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
DAGraphNode (VarNode* vn) : VarNode (vn) , visited_(false),
|
DAGraphNode (Var* v) : Var (v) , visited_(false),
|
||||||
markedOnTop_(false), markedOnBottom_(false) { }
|
markedOnTop_(false), markedOnBottom_(false) { }
|
||||||
|
|
||||||
const vector<DAGraphNode*>& childs (void) const { return childs_; }
|
const vector<DAGraphNode*>& childs (void) const { return childs_; }
|
||||||
|
@ -10,14 +10,14 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
|||||||
groundFg_ = &fg;
|
groundFg_ = &fg;
|
||||||
freeColor_ = 0;
|
freeColor_ = 0;
|
||||||
|
|
||||||
const FgVarSet& varNodes = fg.getVarNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
varSignatures_.reserve (varNodes.size());
|
varSignatures_.reserve (varNodes.size());
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
|
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
|
||||||
varSignatures_.push_back (Signature (c));
|
varSignatures_.push_back (Signature (c));
|
||||||
}
|
}
|
||||||
|
|
||||||
const FgFacSet& facNodes = fg.getFactorNodes();
|
const FactorNodes& facNodes = fg.factorNodes();
|
||||||
factorSignatures_.reserve (facNodes.size());
|
factorSignatures_.reserve (facNodes.size());
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
unsigned c = facNodes[i]->neighbors().size() + 1;
|
unsigned c = facNodes[i]->neighbors().size() + 1;
|
||||||
@ -49,7 +49,7 @@ CFactorGraph::setInitialColors (void)
|
|||||||
{
|
{
|
||||||
// create the initial variable colors
|
// create the initial variable colors
|
||||||
VarColorMap colorMap;
|
VarColorMap colorMap;
|
||||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
const VarNodes& varNodes = groundFg_->varNodes();
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
unsigned dsize = varNodes[i]->range();
|
unsigned dsize = varNodes[i]->range();
|
||||||
VarColorMap::iterator it = colorMap.find (dsize);
|
VarColorMap::iterator it = colorMap.find (dsize);
|
||||||
@ -70,7 +70,7 @@ CFactorGraph::setInitialColors (void)
|
|||||||
setColor (varNodes[i], stateColors[idx]);
|
setColor (varNodes[i], stateColors[idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
const FactorNodes& facNodes = groundFg_->factorNodes();
|
||||||
if (checkForIdenticalFactors) {
|
if (checkForIdenticalFactors) {
|
||||||
unsigned groupCount = 1;
|
unsigned groupCount = 1;
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
@ -114,8 +114,8 @@ CFactorGraph::createGroups (void)
|
|||||||
FacSignMap factorGroups;
|
FacSignMap factorGroups;
|
||||||
unsigned nIters = 0;
|
unsigned nIters = 0;
|
||||||
bool groupsHaveChanged = true;
|
bool groupsHaveChanged = true;
|
||||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
const VarNodes& varNodes = groundFg_->varNodes();
|
||||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
const FactorNodes& facNodes = groundFg_->factorNodes();
|
||||||
|
|
||||||
while (groupsHaveChanged || nIters == 1) {
|
while (groupsHaveChanged || nIters == 1) {
|
||||||
nIters ++;
|
nIters ++;
|
||||||
@ -127,14 +127,14 @@ CFactorGraph::createGroups (void)
|
|||||||
const Signature& signature = getSignature (facNodes[i]);
|
const Signature& signature = getSignature (facNodes[i]);
|
||||||
FacSignMap::iterator it = factorGroups.find (signature);
|
FacSignMap::iterator it = factorGroups.find (signature);
|
||||||
if (it == factorGroups.end()) {
|
if (it == factorGroups.end()) {
|
||||||
it = factorGroups.insert (make_pair (signature, FgFacSet())).first;
|
it = factorGroups.insert (make_pair (signature, FactorNodes())).first;
|
||||||
}
|
}
|
||||||
it->second.push_back (facNodes[i]);
|
it->second.push_back (facNodes[i]);
|
||||||
}
|
}
|
||||||
for (FacSignMap::iterator it = factorGroups.begin();
|
for (FacSignMap::iterator it = factorGroups.begin();
|
||||||
it != factorGroups.end(); it++) {
|
it != factorGroups.end(); it++) {
|
||||||
Color newColor = getFreeColor();
|
Color newColor = getFreeColor();
|
||||||
FgFacSet& groupMembers = it->second;
|
FactorNodes& groupMembers = it->second;
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
setColor (groupMembers[i], newColor);
|
setColor (groupMembers[i], newColor);
|
||||||
}
|
}
|
||||||
@ -147,14 +147,14 @@ CFactorGraph::createGroups (void)
|
|||||||
const Signature& signature = getSignature (varNodes[i]);
|
const Signature& signature = getSignature (varNodes[i]);
|
||||||
VarSignMap::iterator it = varGroups.find (signature);
|
VarSignMap::iterator it = varGroups.find (signature);
|
||||||
if (it == varGroups.end()) {
|
if (it == varGroups.end()) {
|
||||||
it = varGroups.insert (make_pair (signature, FgVarSet())).first;
|
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||||
}
|
}
|
||||||
it->second.push_back (varNodes[i]);
|
it->second.push_back (varNodes[i]);
|
||||||
}
|
}
|
||||||
for (VarSignMap::iterator it = varGroups.begin();
|
for (VarSignMap::iterator it = varGroups.begin();
|
||||||
it != varGroups.end(); it++) {
|
it != varGroups.end(); it++) {
|
||||||
Color newColor = getFreeColor();
|
Color newColor = getFreeColor();
|
||||||
FgVarSet& groupMembers = it->second;
|
VarNodes& groupMembers = it->second;
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
setColor (groupMembers[i], newColor);
|
setColor (groupMembers[i], newColor);
|
||||||
}
|
}
|
||||||
@ -177,7 +177,7 @@ CFactorGraph::createClusters (
|
|||||||
varClusters_.reserve (varGroups.size());
|
varClusters_.reserve (varGroups.size());
|
||||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
it != varGroups.end(); it++) {
|
it != varGroups.end(); it++) {
|
||||||
const FgVarSet& groupVars = it->second;
|
const VarNodes& groupVars = it->second;
|
||||||
VarCluster* vc = new VarCluster (groupVars);
|
VarCluster* vc = new VarCluster (groupVars);
|
||||||
for (unsigned i = 0; i < groupVars.size(); i++) {
|
for (unsigned i = 0; i < groupVars.size(); i++) {
|
||||||
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
||||||
@ -188,8 +188,8 @@ CFactorGraph::createClusters (
|
|||||||
facClusters_.reserve (factorGroups.size());
|
facClusters_.reserve (factorGroups.size());
|
||||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
for (FacSignMap::const_iterator it = factorGroups.begin();
|
||||||
it != factorGroups.end(); it++) {
|
it != factorGroups.end(); it++) {
|
||||||
FgFacNode* groupFactor = it->second[0];
|
FactorNode* groupFactor = it->second[0];
|
||||||
const FgVarSet& neighs = groupFactor->neighbors();
|
const VarNodes& neighs = groupFactor->neighbors();
|
||||||
VarClusterSet varClusters;
|
VarClusterSet varClusters;
|
||||||
varClusters.reserve (neighs.size());
|
varClusters.reserve (neighs.size());
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
@ -203,11 +203,11 @@ CFactorGraph::createClusters (
|
|||||||
|
|
||||||
|
|
||||||
const Signature&
|
const Signature&
|
||||||
CFactorGraph::getSignature (const FgVarNode* varNode)
|
CFactorGraph::getSignature (const VarNode* varNode)
|
||||||
{
|
{
|
||||||
Signature& sign = varSignatures_[varNode->getIndex()];
|
Signature& sign = varSignatures_[varNode->getIndex()];
|
||||||
vector<Color>::iterator it = sign.colors.begin();
|
vector<Color>::iterator it = sign.colors.begin();
|
||||||
const FgFacSet& neighs = varNode->neighbors();
|
const FactorNodes& neighs = varNode->neighbors();
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
*it = getColor (neighs[i]);
|
*it = getColor (neighs[i]);
|
||||||
it ++;
|
it ++;
|
||||||
@ -221,11 +221,11 @@ CFactorGraph::getSignature (const FgVarNode* varNode)
|
|||||||
|
|
||||||
|
|
||||||
const Signature&
|
const Signature&
|
||||||
CFactorGraph::getSignature (const FgFacNode* facNode)
|
CFactorGraph::getSignature (const FactorNode* facNode)
|
||||||
{
|
{
|
||||||
Signature& sign = factorSignatures_[facNode->getIndex()];
|
Signature& sign = factorSignatures_[facNode->getIndex()];
|
||||||
vector<Color>::iterator it = sign.colors.begin();
|
vector<Color>::iterator it = sign.colors.begin();
|
||||||
const FgVarSet& neighs = facNode->neighbors();
|
const VarNodes& neighs = facNode->neighbors();
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
*it = getColor (neighs[i]);
|
*it = getColor (neighs[i]);
|
||||||
it ++;
|
it ++;
|
||||||
@ -241,27 +241,27 @@ CFactorGraph::getCompressedFactorGraph (void)
|
|||||||
{
|
{
|
||||||
FactorGraph* fg = new FactorGraph();
|
FactorGraph* fg = new FactorGraph();
|
||||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||||
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0];
|
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
|
||||||
FgVarNode* newVar = new FgVarNode (var);
|
VarNode* newVar = new VarNode (var);
|
||||||
varClusters_[i]->setRepresentativeVariable (newVar);
|
varClusters_[i]->setRepresentativeVariable (newVar);
|
||||||
fg->addVariable (newVar);
|
fg->addVariable (newVar);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
||||||
const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters();
|
const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters();
|
||||||
VarNodes myGroundVars;
|
Vars myGroundVars;
|
||||||
myGroundVars.reserve (myVarClusters.size());
|
myGroundVars.reserve (myVarClusters.size());
|
||||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||||
FgVarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
||||||
myGroundVars.push_back (v);
|
myGroundVars.push_back (v);
|
||||||
}
|
}
|
||||||
Factor* newFactor = new Factor (myGroundVars,
|
Factor* newFactor = new Factor (myGroundVars,
|
||||||
facClusters_[i]->getGroundFactors()[0]->params());
|
facClusters_[i]->getGroundFactors()[0]->params());
|
||||||
FgFacNode* fn = new FgFacNode (newFactor);
|
FactorNode* fn = new FactorNode (newFactor);
|
||||||
facClusters_[i]->setRepresentativeFactor (fn);
|
facClusters_[i]->setRepresentativeFactor (fn);
|
||||||
fg->addFactor (fn);
|
fg->addFactor (fn);
|
||||||
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
||||||
fg->addEdge (fn, static_cast<FgVarNode*> (myGroundVars[j]));
|
fg->addEdge (fn, static_cast<VarNode*> (myGroundVars[j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fg->setIndexes();
|
fg->setIndexes();
|
||||||
@ -275,17 +275,17 @@ CFactorGraph::getGroundEdgeCount (
|
|||||||
const FacCluster* fc,
|
const FacCluster* fc,
|
||||||
const VarCluster* vc) const
|
const VarCluster* vc) const
|
||||||
{
|
{
|
||||||
const FgFacSet& clusterGroundFactors = fc->getGroundFactors();
|
const FactorNodes& clusterGroundFactors = fc->getGroundFactors();
|
||||||
FgVarNode* varNode = vc->getGroundFgVarNodes()[0];
|
VarNode* varNode = vc->getGroundVarNodes()[0];
|
||||||
unsigned count = 0;
|
unsigned count = 0;
|
||||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||||
if (clusterGroundFactors[i]->factor()->indexOf (varNode->varId()) != -1) {
|
if (clusterGroundFactors[i]->factor()->indexOf (varNode->varId()) != -1) {
|
||||||
count ++;
|
count ++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// CFgVarSet vars = vc->getGroundFgVarNodes();
|
// CVarNodes vars = vc->getGroundVarNodes();
|
||||||
// for (unsigned i = 1; i < vars.size(); i++) {
|
// for (unsigned i = 1; i < vars.size(); i++) {
|
||||||
// FgVarNode* var = vc->getGroundFgVarNodes()[i];
|
// VarNode* var = vc->getGroundVarNodes()[i];
|
||||||
// unsigned count2 = 0;
|
// unsigned count2 = 0;
|
||||||
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||||
// if (clusterGroundFactors[i]->getPosition (var) != -1) {
|
// if (clusterGroundFactors[i]->getPosition (var) != -1) {
|
||||||
@ -308,7 +308,7 @@ CFactorGraph::printGroups (
|
|||||||
cout << "variable groups:" << endl;
|
cout << "variable groups:" << endl;
|
||||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
it != varGroups.end(); it++) {
|
it != varGroups.end(); it++) {
|
||||||
const FgVarSet& groupMembers = it->second;
|
const VarNodes& groupMembers = it->second;
|
||||||
if (groupMembers.size() > 0) {
|
if (groupMembers.size() > 0) {
|
||||||
cout << count << ": " ;
|
cout << count << ": " ;
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
@ -323,7 +323,7 @@ CFactorGraph::printGroups (
|
|||||||
cout << endl << "factor groups:" << endl;
|
cout << endl << "factor groups:" << endl;
|
||||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
for (FacSignMap::const_iterator it = factorGroups.begin();
|
||||||
it != factorGroups.end(); it++) {
|
it != factorGroups.end(); it++) {
|
||||||
const FgFacSet& groupMembers = it->second;
|
const FactorNodes& groupMembers = it->second;
|
||||||
if (groupMembers.size() > 0) {
|
if (groupMembers.size() > 0) {
|
||||||
cout << ++count << ": " ;
|
cout << ++count << ": " ;
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
|
@ -25,8 +25,8 @@ typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
|||||||
typedef vector<VarCluster*> VarClusterSet;
|
typedef vector<VarCluster*> VarClusterSet;
|
||||||
typedef vector<FacCluster*> FacClusterSet;
|
typedef vector<FacCluster*> FacClusterSet;
|
||||||
|
|
||||||
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
|
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
|
||||||
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
|
typedef unordered_map<Signature, FactorNodes, SignatureHash> FacSignMap;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ struct SignatureHash
|
|||||||
class VarCluster
|
class VarCluster
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
VarCluster (const FgVarSet& vs)
|
VarCluster (const VarNodes& vs)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < vs.size(); i++) {
|
for (unsigned i = 0; i < vs.size(); i++) {
|
||||||
groundVars_.push_back (vs[i]);
|
groundVars_.push_back (vs[i]);
|
||||||
@ -104,21 +104,21 @@ class VarCluster
|
|||||||
return facClusters_;
|
return facClusters_;
|
||||||
}
|
}
|
||||||
|
|
||||||
FgVarNode* getRepresentativeVariable (void) const { return representVar_; }
|
VarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||||
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
|
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
|
||||||
const FgVarSet& getGroundFgVarNodes (void) const { return groundVars_; }
|
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FgVarSet groundVars_;
|
VarNodes groundVars_;
|
||||||
FacClusterSet facClusters_;
|
FacClusterSet facClusters_;
|
||||||
FgVarNode* representVar_;
|
VarNode* representVar_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class FacCluster
|
class FacCluster
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FacCluster (const FgFacSet& groundFactors, const VarClusterSet& vcs)
|
FacCluster (const FactorNodes& groundFactors, const VarClusterSet& vcs)
|
||||||
{
|
{
|
||||||
groundFactors_ = groundFactors;
|
groundFactors_ = groundFactors;
|
||||||
varClusters_ = vcs;
|
varClusters_ = vcs;
|
||||||
@ -132,7 +132,7 @@ class FacCluster
|
|||||||
return varClusters_;
|
return varClusters_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool containsGround (const FgFacNode* fn)
|
bool containsGround (const FactorNode* fn)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
||||||
if (groundFactors_[i] == fn) {
|
if (groundFactors_[i] == fn) {
|
||||||
@ -142,26 +142,26 @@ class FacCluster
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
FgFacNode* getRepresentativeFactor (void) const
|
FactorNode* getRepresentativeFactor (void) const
|
||||||
{
|
{
|
||||||
return representFactor_;
|
return representFactor_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setRepresentativeFactor (FgFacNode* fn)
|
void setRepresentativeFactor (FactorNode* fn)
|
||||||
{
|
{
|
||||||
representFactor_ = fn;
|
representFactor_ = fn;
|
||||||
}
|
}
|
||||||
|
|
||||||
const FgFacSet& getGroundFactors (void) const
|
const FactorNodes& getGroundFactors (void) const
|
||||||
{
|
{
|
||||||
return groundFactors_;
|
return groundFactors_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FgFacSet groundFactors_;
|
FactorNodes groundFactors_;
|
||||||
VarClusterSet varClusters_;
|
VarClusterSet varClusters_;
|
||||||
FgFacNode* representFactor_;
|
FactorNode* representFactor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -176,7 +176,7 @@ class CFactorGraph
|
|||||||
|
|
||||||
const FacClusterSet& getFacClusters (void) { return facClusters_; }
|
const FacClusterSet& getFacClusters (void) { return facClusters_; }
|
||||||
|
|
||||||
FgVarNode* getEquivalentVariable (VarId vid)
|
VarNode* getEquivalentVariable (VarId vid)
|
||||||
{
|
{
|
||||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||||
return vc->getRepresentativeVariable();
|
return vc->getRepresentativeVariable();
|
||||||
@ -195,20 +195,20 @@ class CFactorGraph
|
|||||||
return freeColor_ - 1;
|
return freeColor_ - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Color getColor (const FgVarNode* vn) const
|
Color getColor (const VarNode* vn) const
|
||||||
{
|
{
|
||||||
return varColors_[vn->getIndex()];
|
return varColors_[vn->getIndex()];
|
||||||
}
|
}
|
||||||
Color getColor (const FgFacNode* fn) const {
|
Color getColor (const FactorNode* fn) const {
|
||||||
return factorColors_[fn->getIndex()];
|
return factorColors_[fn->getIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
void setColor (const FgVarNode* vn, Color c)
|
void setColor (const VarNode* vn, Color c)
|
||||||
{
|
{
|
||||||
varColors_[vn->getIndex()] = c;
|
varColors_[vn->getIndex()] = c;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setColor (const FgFacNode* fn, Color c)
|
void setColor (const FactorNode* fn, Color c)
|
||||||
{
|
{
|
||||||
factorColors_[fn->getIndex()] = c;
|
factorColors_[fn->getIndex()] = c;
|
||||||
}
|
}
|
||||||
@ -224,9 +224,9 @@ class CFactorGraph
|
|||||||
|
|
||||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||||
|
|
||||||
const Signature& getSignature (const FgVarNode*);
|
const Signature& getSignature (const VarNode*);
|
||||||
|
|
||||||
const Signature& getSignature (const FgFacNode*);
|
const Signature& getSignature (const FactorNode*);
|
||||||
|
|
||||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ Params
|
|||||||
CbpSolver::getPosterioriOf (VarId vid)
|
CbpSolver::getPosterioriOf (VarId vid)
|
||||||
{
|
{
|
||||||
assert (lfg_->getEquivalentVariable (vid));
|
assert (lfg_->getEquivalentVariable (vid));
|
||||||
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
VarNode* var = lfg_->getEquivalentVariable (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->range(), LogAware::noEvidence());
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
@ -52,7 +52,7 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||||
eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId());
|
eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId());
|
||||||
}
|
}
|
||||||
return FgBpSolver::getJointDistributionOf (eqVarIds);
|
return BpSolver::getJointDistributionOf (eqVarIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -63,12 +63,12 @@ CbpSolver::initializeSolver (void)
|
|||||||
{
|
{
|
||||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
||||||
if (Constants::COLLECT_STATS) {
|
if (Constants::COLLECT_STATS) {
|
||||||
nGroundVars = factorGraph_->getVarNodes().size();
|
nGroundVars = factorGraph_->varNodes().size();
|
||||||
nGroundFacs = factorGraph_->getFactorNodes().size();
|
nGroundFacs = factorGraph_->factorNodes().size();
|
||||||
const FgVarSet& vars = factorGraph_->getVarNodes();
|
const VarNodes& vars = factorGraph_->varNodes();
|
||||||
nWithoutNeighs = 0;
|
nWithoutNeighs = 0;
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
const FgFacSet& factors = vars[i]->neighbors();
|
const FactorNodes& factors = vars[i]->neighbors();
|
||||||
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
||||||
nWithoutNeighs ++;
|
nWithoutNeighs ++;
|
||||||
}
|
}
|
||||||
@ -78,23 +78,23 @@ CbpSolver::initializeSolver (void)
|
|||||||
lfg_ = new CFactorGraph (*factorGraph_);
|
lfg_ = new CFactorGraph (*factorGraph_);
|
||||||
|
|
||||||
// cout << "Uncompressed Factor Graph" << endl;
|
// cout << "Uncompressed Factor Graph" << endl;
|
||||||
// factorGraph_->printGraphicalModel();
|
// factorGraph_->print();
|
||||||
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
|
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
|
||||||
factorGraph_ = lfg_->getCompressedFactorGraph();
|
factorGraph_ = lfg_->getCompressedFactorGraph();
|
||||||
|
|
||||||
if (Constants::COLLECT_STATS) {
|
if (Constants::COLLECT_STATS) {
|
||||||
unsigned nClusterVars = factorGraph_->getVarNodes().size();
|
unsigned nClusterVars = factorGraph_->varNodes().size();
|
||||||
unsigned nClusterFacs = factorGraph_->getFactorNodes().size();
|
unsigned nClusterFacs = factorGraph_->factorNodes().size();
|
||||||
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
|
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
|
||||||
nClusterVars, nClusterFacs,
|
nClusterVars, nClusterFacs,
|
||||||
nWithoutNeighs);
|
nWithoutNeighs);
|
||||||
}
|
}
|
||||||
|
|
||||||
// cout << "Compressed Factor Graph" << endl;
|
// cout << "Compressed Factor Graph" << endl;
|
||||||
// factorGraph_->printGraphicalModel();
|
// factorGraph_->print();
|
||||||
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
|
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
|
||||||
// abort();
|
// abort();
|
||||||
FgBpSolver::initializeSolver();
|
BpSolver::initializeSolver();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -154,7 +154,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
// update the messages that depend on message source --> destin
|
// update the messages that depend on message source --> destin
|
||||||
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
|
const FactorNodes& factorNeighbors = link->getVariable()->neighbors();
|
||||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||||
for (unsigned j = 0; j < links.size(); j++) {
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
@ -192,8 +192,8 @@ Params
|
|||||||
CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||||
{
|
{
|
||||||
Params msg;
|
Params msg;
|
||||||
const FgVarNode* src = link->getVariable();
|
const VarNode* src = link->getVariable();
|
||||||
const FgFacNode* dst = link->getFactor();
|
const FactorNode* dst = link->getFactor();
|
||||||
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->range(), LogAware::noEvidence());
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#ifndef HORUS_CBP_H
|
#ifndef HORUS_CBP_H
|
||||||
#define HORUS_CBP_H
|
#define HORUS_CBP_H
|
||||||
|
|
||||||
#include "FgBpSolver.h"
|
#include "BpSolver.h"
|
||||||
#include "CFactorGraph.h"
|
#include "CFactorGraph.h"
|
||||||
|
|
||||||
class Factor;
|
class Factor;
|
||||||
@ -9,7 +9,7 @@ class Factor;
|
|||||||
class CbpSolverLink : public SpLink
|
class CbpSolverLink : public SpLink
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn)
|
CbpSolverLink (FactorNode* fn, VarNode* vn, unsigned c) : SpLink (fn, vn)
|
||||||
{
|
{
|
||||||
edgeCount_ = c;
|
edgeCount_ = c;
|
||||||
poweredMsg_.resize (vn->range(), LogAware::one());
|
poweredMsg_.resize (vn->range(), LogAware::one());
|
||||||
@ -34,10 +34,10 @@ class CbpSolverLink : public SpLink
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CbpSolver : public FgBpSolver
|
class CbpSolver : public BpSolver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { }
|
CbpSolver (FactorGraph& fg) : BpSolver (fg) { }
|
||||||
|
|
||||||
~CbpSolver (void);
|
~CbpSolver (void);
|
||||||
|
|
||||||
|
@ -17,10 +17,10 @@ enum ElimHeuristic
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class EgNode : public VarNode
|
class EgNode : public Var
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
EgNode (VarId vid, unsigned range) : VarNode (vid, range) { }
|
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
|
||||||
|
|
||||||
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
|
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ Factor::Factor (VarId vid, unsigned range)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (const VarNodes& vars)
|
Factor::Factor (const Vars& vars)
|
||||||
{
|
{
|
||||||
int nrParams = 1;
|
int nrParams = 1;
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
@ -60,7 +60,7 @@ Factor::Factor (
|
|||||||
|
|
||||||
|
|
||||||
Factor::Factor (
|
Factor::Factor (
|
||||||
const VarNodes& vars,
|
const Vars& vars,
|
||||||
const Params& params,
|
const Params& params,
|
||||||
unsigned distId)
|
unsigned distId)
|
||||||
{
|
{
|
||||||
@ -267,7 +267,7 @@ Factor::getLabel (void) const
|
|||||||
ss << "f(" ;
|
ss << "f(" ;
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (i != 0) ss << "," ;
|
if (i != 0) ss << "," ;
|
||||||
ss << VarNode (args_[i], ranges_[i]).label();
|
ss << Var (args_[i], ranges_[i]).label();
|
||||||
}
|
}
|
||||||
ss << ")" ;
|
ss << ")" ;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
@ -278,9 +278,9 @@ Factor::getLabel (void) const
|
|||||||
void
|
void
|
||||||
Factor::print (void) const
|
Factor::print (void) const
|
||||||
{
|
{
|
||||||
VarNodes vars;
|
Vars vars;
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
vars.push_back (new VarNode (args_[i], ranges_[i]));
|
vars.push_back (new Var (args_[i], ranges_[i]));
|
||||||
}
|
}
|
||||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
||||||
for (unsigned i = 0; i < params_.size(); i++) {
|
for (unsigned i = 0; i < params_.size(); i++) {
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "VarNode.h"
|
#include "Var.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
@ -260,11 +260,11 @@ class Factor : public TFactor<VarId>
|
|||||||
|
|
||||||
Factor (VarId, unsigned);
|
Factor (VarId, unsigned);
|
||||||
|
|
||||||
Factor (const VarNodes&);
|
Factor (const Vars&);
|
||||||
|
|
||||||
Factor (VarId, unsigned, const Params&);
|
Factor (VarId, unsigned, const Params&);
|
||||||
|
|
||||||
Factor (const VarNodes&, const Params&,
|
Factor (const Vars&, const Params&,
|
||||||
unsigned = Util::maxUnsigned());
|
unsigned = Util::maxUnsigned());
|
||||||
|
|
||||||
Factor (const VarIds&, const Ranges&, const Params&,
|
Factor (const VarIds&, const Ranges&, const Params&,
|
||||||
|
@ -18,17 +18,17 @@ bool FactorGraph::orderFactorVariables = false;
|
|||||||
|
|
||||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||||
{
|
{
|
||||||
const FgVarSet& vars = fg.getVarNodes();
|
const VarNodes& vars = fg.varNodes();
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
FgVarNode* varNode = new FgVarNode (vars[i]);
|
VarNode* varNode = new VarNode (vars[i]);
|
||||||
addVariable (varNode);
|
addVariable (varNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
const FgFacSet& facs = fg.getFactorNodes();
|
const FactorNodes& facs = fg.factorNodes();
|
||||||
for (unsigned i = 0; i < facs.size(); i++) {
|
for (unsigned i = 0; i < facs.size(); i++) {
|
||||||
FgFacNode* facNode = new FgFacNode (facs[i]);
|
FactorNode* facNode = new FactorNode (facs[i]);
|
||||||
addFactor (facNode);
|
addFactor (facNode);
|
||||||
const FgVarSet& neighs = facs[i]->neighbors();
|
const VarNodes& neighs = facs[i]->neighbors();
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
|
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
|
||||||
}
|
}
|
||||||
@ -68,7 +68,7 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
|||||||
|
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||||
for (unsigned i = 0; i < nVars; i++) {
|
for (unsigned i = 0; i < nVars; i++) {
|
||||||
addVariable (new FgVarNode (i, domainSizes[i]));
|
addVariable (new VarNode (i, domainSizes[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned nFactors;
|
unsigned nFactors;
|
||||||
@ -77,21 +77,21 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
|||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||||
unsigned nFactorVars;
|
unsigned nFactorVars;
|
||||||
is >> nFactorVars;
|
is >> nFactorVars;
|
||||||
VarNodes neighs;
|
Vars neighs;
|
||||||
for (unsigned j = 0; j < nFactorVars; j++) {
|
for (unsigned j = 0; j < nFactorVars; j++) {
|
||||||
unsigned vid;
|
unsigned vid;
|
||||||
is >> vid;
|
is >> vid;
|
||||||
FgVarNode* neigh = getFgVarNode (vid);
|
VarNode* neigh = getVarNode (vid);
|
||||||
if (!neigh) {
|
if (!neigh) {
|
||||||
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
|
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
neighs.push_back (neigh);
|
neighs.push_back (neigh);
|
||||||
}
|
}
|
||||||
FgFacNode* fn = new FgFacNode (new Factor (neighs));
|
FactorNode* fn = new FactorNode (new Factor (neighs));
|
||||||
addFactor (fn);
|
addFactor (fn);
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
addEdge (fn, static_cast<VarNode*> (neighs[j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,15 +162,15 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
|||||||
vids.push_back (vid);
|
vids.push_back (vid);
|
||||||
}
|
}
|
||||||
|
|
||||||
VarNodes neighs;
|
Vars neighs;
|
||||||
unsigned nParams = 1;
|
unsigned nParams = 1;
|
||||||
for (unsigned j = 0; j < nVars; j++) {
|
for (unsigned j = 0; j < nVars; j++) {
|
||||||
unsigned dsize;
|
unsigned dsize;
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
while ((is.peek()) == '#') getline (is, line);
|
||||||
is >> dsize;
|
is >> dsize;
|
||||||
FgVarNode* var = getFgVarNode (vids[j]);
|
VarNode* var = getVarNode (vids[j]);
|
||||||
if (var == 0) {
|
if (var == 0) {
|
||||||
var = new FgVarNode (vids[j], dsize);
|
var = new VarNode (vids[j], dsize);
|
||||||
addVariable (var);
|
addVariable (var);
|
||||||
} else {
|
} else {
|
||||||
if (var->range() != dsize) {
|
if (var->range() != dsize) {
|
||||||
@ -199,10 +199,10 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
|||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::toLog (params);
|
Util::toLog (params);
|
||||||
}
|
}
|
||||||
FgFacNode* fn = new FgFacNode (new Factor (neighs, params));
|
FactorNode* fn = new FactorNode (new Factor (neighs, params));
|
||||||
addFactor (fn);
|
addFactor (fn);
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
addEdge (fn, static_cast<VarNode*> (neighs[j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
is.close();
|
is.close();
|
||||||
@ -224,7 +224,7 @@ FactorGraph::~FactorGraph (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addVariable (FgVarNode* vn)
|
FactorGraph::addVariable (VarNode* vn)
|
||||||
{
|
{
|
||||||
varNodes_.push_back (vn);
|
varNodes_.push_back (vn);
|
||||||
vn->setIndex (varNodes_.size() - 1);
|
vn->setIndex (varNodes_.size() - 1);
|
||||||
@ -234,7 +234,7 @@ FactorGraph::addVariable (FgVarNode* vn)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addFactor (FgFacNode* fn)
|
FactorGraph::addFactor (FactorNode* fn)
|
||||||
{
|
{
|
||||||
facNodes_.push_back (fn);
|
facNodes_.push_back (fn);
|
||||||
fn->setIndex (facNodes_.size() - 1);
|
fn->setIndex (facNodes_.size() - 1);
|
||||||
@ -245,7 +245,7 @@ FactorGraph::addFactor (FgFacNode* fn)
|
|||||||
void
|
void
|
||||||
FactorGraph::addFactor (const Factor& factor)
|
FactorGraph::addFactor (const Factor& factor)
|
||||||
{
|
{
|
||||||
FgFacNode* fn = new FgFacNode (factor);
|
FactorNode* fn = new FactorNode (factor);
|
||||||
addFactor (fn);
|
addFactor (fn);
|
||||||
const VarIds& vids = factor.arguments();
|
const VarIds& vids = factor.arguments();
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
@ -257,7 +257,7 @@ FactorGraph::addFactor (const Factor& factor)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (found == false) {
|
if (found == false) {
|
||||||
FgVarNode* vn = new FgVarNode (vids[i], factor.range (i));
|
VarNode* vn = new VarNode (vids[i], factor.range (i));
|
||||||
addVariable (vn);
|
addVariable (vn);
|
||||||
addEdge (vn, fn);
|
addEdge (vn, fn);
|
||||||
}
|
}
|
||||||
@ -267,7 +267,7 @@ FactorGraph::addFactor (const Factor& factor)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
FactorGraph::addEdge (VarNode* vn, FactorNode* fn)
|
||||||
{
|
{
|
||||||
vn->addNeighbor (fn);
|
vn->addNeighbor (fn);
|
||||||
fn->addNeighbor (vn);
|
fn->addNeighbor (vn);
|
||||||
@ -276,7 +276,7 @@ FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addEdge (FgFacNode* fn, FgVarNode* vn)
|
FactorGraph::addEdge (FactorNode* fn, VarNode* vn)
|
||||||
{
|
{
|
||||||
fn->addNeighbor (vn);
|
fn->addNeighbor (vn);
|
||||||
vn->addNeighbor (fn);
|
vn->addNeighbor (fn);
|
||||||
@ -284,28 +284,6 @@ FactorGraph::addEdge (FgFacNode* fn, FgVarNode* vn)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarNode*
|
|
||||||
FactorGraph::getVariableNode (VarId vid) const
|
|
||||||
{
|
|
||||||
FgVarNode* vn = getFgVarNode (vid);
|
|
||||||
assert (vn);
|
|
||||||
return vn;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarNodes
|
|
||||||
FactorGraph::getVariableNodes (void) const
|
|
||||||
{
|
|
||||||
VarNodes vars;
|
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
|
||||||
vars.push_back (varNodes_[i]);
|
|
||||||
}
|
|
||||||
return vars;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
FactorGraph::isTree (void) const
|
FactorGraph::isTree (void) const
|
||||||
{
|
{
|
||||||
@ -348,7 +326,7 @@ FactorGraph::setIndexes (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::printGraphicalModel (void) const
|
FactorGraph::print (void) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
cout << "var id = " << varNodes_[i]->varId() << endl;
|
cout << "var id = " << varNodes_[i]->varId() << endl;
|
||||||
@ -390,7 +368,7 @@ FactorGraph::exportToGraphViz (const char* fileName) const
|
|||||||
out << "\"" << ", shape=box]" << endl;
|
out << "\"" << ", shape=box]" << endl;
|
||||||
}
|
}
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
const FgVarSet& myVars = facNodes_[i]->neighbors();
|
const VarNodes& myVars = facNodes_[i]->neighbors();
|
||||||
for (unsigned j = 0; j < myVars.size(); j++) {
|
for (unsigned j = 0; j < myVars.size(); j++) {
|
||||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||||
out << " -- " ;
|
out << " -- " ;
|
||||||
@ -422,7 +400,7 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
|||||||
|
|
||||||
out << facNodes_.size() << endl;
|
out << facNodes_.size() << endl;
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
const FgVarSet& factorVars = facNodes_[i]->neighbors();
|
const VarNodes& factorVars = facNodes_[i]->neighbors();
|
||||||
out << factorVars.size();
|
out << factorVars.size();
|
||||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||||
out << " " << factorVars[j]->getIndex();
|
out << " " << factorVars[j]->getIndex();
|
||||||
@ -458,7 +436,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
|||||||
}
|
}
|
||||||
out << facNodes_.size() << endl << endl;
|
out << facNodes_.size() << endl << endl;
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
const FgVarSet& factorVars = facNodes_[i]->neighbors();
|
const VarNodes& factorVars = facNodes_[i]->neighbors();
|
||||||
out << factorVars.size() << endl;
|
out << factorVars.size() << endl;
|
||||||
for (int j = factorVars.size() - 1; j >= 0; j--) {
|
for (int j = factorVars.size() - 1; j >= 0; j--) {
|
||||||
out << factorVars[j]->varId() << " " ;
|
out << factorVars[j]->varId() << " " ;
|
||||||
@ -503,13 +481,13 @@ FactorGraph::containsCycle (void) const
|
|||||||
|
|
||||||
bool
|
bool
|
||||||
FactorGraph::containsCycle (
|
FactorGraph::containsCycle (
|
||||||
const FgVarNode* v,
|
const VarNode* v,
|
||||||
const FgFacNode* p,
|
const FactorNode* p,
|
||||||
vector<bool>& visitedVars,
|
vector<bool>& visitedVars,
|
||||||
vector<bool>& visitedFactors) const
|
vector<bool>& visitedFactors) const
|
||||||
{
|
{
|
||||||
visitedVars[v->getIndex()] = true;
|
visitedVars[v->getIndex()] = true;
|
||||||
const FgFacSet& adjacencies = v->neighbors();
|
const FactorNodes& adjacencies = v->neighbors();
|
||||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||||
int w = adjacencies[i]->getIndex();
|
int w = adjacencies[i]->getIndex();
|
||||||
if (!visitedFactors[w]) {
|
if (!visitedFactors[w]) {
|
||||||
@ -528,13 +506,13 @@ FactorGraph::containsCycle (
|
|||||||
|
|
||||||
bool
|
bool
|
||||||
FactorGraph::containsCycle (
|
FactorGraph::containsCycle (
|
||||||
const FgFacNode* v,
|
const FactorNode* v,
|
||||||
const FgVarNode* p,
|
const VarNode* p,
|
||||||
vector<bool>& visitedVars,
|
vector<bool>& visitedVars,
|
||||||
vector<bool>& visitedFactors) const
|
vector<bool>& visitedFactors) const
|
||||||
{
|
{
|
||||||
visitedFactors[v->getIndex()] = true;
|
visitedFactors[v->getIndex()] = true;
|
||||||
const FgVarSet& adjacencies = v->neighbors();
|
const VarNodes& adjacencies = v->neighbors();
|
||||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||||
int w = adjacencies[i]->getIndex();
|
int w = adjacencies[i]->getIndex();
|
||||||
if (!visitedVars[w]) {
|
if (!visitedVars[w]) {
|
||||||
|
@ -4,52 +4,51 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "GraphicalModel.h"
|
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
class FgFacNode;
|
class FactorNode;
|
||||||
|
|
||||||
|
|
||||||
class FgVarNode : public VarNode
|
class VarNode : public Var
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { }
|
VarNode (VarId varId, unsigned nrStates) : Var (varId, nrStates) { }
|
||||||
|
|
||||||
FgVarNode (const VarNode* v) : VarNode (v) { }
|
VarNode (const Var* v) : Var (v) { }
|
||||||
|
|
||||||
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); }
|
void addNeighbor (FactorNode* fn) { neighs_.push_back (fn); }
|
||||||
|
|
||||||
const FgFacSet& neighbors (void) const { return neighs_; }
|
const FactorNodes& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
DISALLOW_COPY_AND_ASSIGN (VarNode);
|
||||||
|
|
||||||
FgFacSet neighs_;
|
FactorNodes neighs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class FgFacNode
|
class FactorNode
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FgFacNode (const FgFacNode* fn)
|
FactorNode (const FactorNode* fn)
|
||||||
{
|
{
|
||||||
factor_ = new Factor (*fn->factor());
|
factor_ = new Factor (*fn->factor());
|
||||||
index_ = -1;
|
index_ = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
FgFacNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
|
FactorNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
|
||||||
|
|
||||||
FgFacNode (const Factor& f) : factor_(new Factor (f)), index_(-1) { }
|
FactorNode (const Factor& f) : factor_(new Factor (f)), index_(-1) { }
|
||||||
|
|
||||||
Factor* factor() const { return factor_; }
|
Factor* factor() const { return factor_; }
|
||||||
|
|
||||||
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
|
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
|
||||||
|
|
||||||
const FgVarSet& neighbors (void) const { return neighs_; }
|
const VarNodes& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
int getIndex (void) const
|
int getIndex (void) const
|
||||||
{
|
{
|
||||||
@ -73,24 +72,24 @@ class FgFacNode
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (FgFacNode);
|
DISALLOW_COPY_AND_ASSIGN (FactorNode);
|
||||||
|
|
||||||
Factor* factor_;
|
Factor* factor_;
|
||||||
FgVarSet neighs_;
|
VarNodes neighs_;
|
||||||
int index_;
|
int index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
struct CompVarId
|
struct CompVarId
|
||||||
{
|
{
|
||||||
bool operator() (const VarNode* vn1, const VarNode* vn2) const
|
bool operator() (const Var* v1, const Var* v2) const
|
||||||
{
|
{
|
||||||
return vn1->varId() < vn2->varId();
|
return v1->varId() < v2->varId();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class FactorGraph : public GraphicalModel
|
class FactorGraph
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FactorGraph (void) { }
|
FactorGraph (void) { }
|
||||||
@ -99,15 +98,15 @@ class FactorGraph : public GraphicalModel
|
|||||||
|
|
||||||
~FactorGraph (void);
|
~FactorGraph (void);
|
||||||
|
|
||||||
const FgVarSet& getVarNodes (void) const { return varNodes_; }
|
const VarNodes& varNodes (void) const { return varNodes_; }
|
||||||
|
|
||||||
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
|
const FactorNodes& factorNodes (void) const { return facNodes_; }
|
||||||
|
|
||||||
void setFromBayesNetwork (void) { fromBayesNet_ = true; }
|
void setFromBayesNetwork (void) { fromBayesNet_ = true; }
|
||||||
|
|
||||||
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
|
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
|
||||||
|
|
||||||
FgVarNode* getFgVarNode (VarId vid) const
|
VarNode* getVarNode (VarId vid) const
|
||||||
{
|
{
|
||||||
IndexMap::const_iterator it = varMap_.find (vid);
|
IndexMap::const_iterator it = varMap_.find (vid);
|
||||||
return (it != varMap_.end()) ? varNodes_[it->second] : 0;
|
return (it != varMap_.end()) ? varNodes_[it->second] : 0;
|
||||||
@ -117,19 +116,15 @@ class FactorGraph : public GraphicalModel
|
|||||||
|
|
||||||
void readFromLibDaiFormat (const char*);
|
void readFromLibDaiFormat (const char*);
|
||||||
|
|
||||||
void addVariable (FgVarNode*);
|
void addVariable (VarNode*);
|
||||||
|
|
||||||
void addFactor (FgFacNode*);
|
void addFactor (FactorNode*);
|
||||||
|
|
||||||
void addFactor (const Factor& factor);
|
void addFactor (const Factor& factor);
|
||||||
|
|
||||||
void addEdge (FgVarNode*, FgFacNode*);
|
void addEdge (VarNode*, FactorNode*);
|
||||||
|
|
||||||
void addEdge (FgFacNode*, FgVarNode*);
|
void addEdge (FactorNode*, VarNode*);
|
||||||
|
|
||||||
VarNode* getVariableNode (unsigned) const;
|
|
||||||
|
|
||||||
VarNodes getVariableNodes (void) const;
|
|
||||||
|
|
||||||
bool isTree (void) const;
|
bool isTree (void) const;
|
||||||
|
|
||||||
@ -137,7 +132,7 @@ class FactorGraph : public GraphicalModel
|
|||||||
|
|
||||||
void setIndexes (void);
|
void setIndexes (void);
|
||||||
|
|
||||||
void printGraphicalModel (void) const;
|
void print (void) const;
|
||||||
|
|
||||||
void exportToGraphViz (const char*) const;
|
void exportToGraphViz (const char*) const;
|
||||||
|
|
||||||
@ -152,14 +147,14 @@ class FactorGraph : public GraphicalModel
|
|||||||
|
|
||||||
bool containsCycle (void) const;
|
bool containsCycle (void) const;
|
||||||
|
|
||||||
bool containsCycle (const FgVarNode*, const FgFacNode*,
|
bool containsCycle (const VarNode*, const FactorNode*,
|
||||||
vector<bool>&, vector<bool>&) const;
|
vector<bool>&, vector<bool>&) const;
|
||||||
|
|
||||||
bool containsCycle (const FgFacNode*, const FgVarNode*,
|
bool containsCycle (const FactorNode*, const VarNode*,
|
||||||
vector<bool>&, vector<bool>&) const;
|
vector<bool>&, vector<bool>&) const;
|
||||||
|
|
||||||
FgVarSet varNodes_;
|
VarNodes varNodes_;
|
||||||
FgFacSet facNodes_;
|
FactorNodes facNodes_;
|
||||||
|
|
||||||
bool fromBayesNet_;
|
bool fromBayesNet_;
|
||||||
DAGraph structure_;
|
DAGraph structure_;
|
||||||
|
@ -1,496 +0,0 @@
|
|||||||
#include <cassert>
|
|
||||||
#include <limits>
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "FgBpSolver.h"
|
|
||||||
#include "FactorGraph.h"
|
|
||||||
#include "Factor.h"
|
|
||||||
#include "Indexer.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
|
|
||||||
FgBpSolver::FgBpSolver (const FactorGraph& fg) : Solver (&fg)
|
|
||||||
{
|
|
||||||
factorGraph_ = &fg;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FgBpSolver::~FgBpSolver (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
|
||||||
delete varsI_[i];
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
|
||||||
delete facsI_[i];
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
delete links_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FgBpSolver::runSolver (void)
|
|
||||||
{
|
|
||||||
clock_t start;
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
start = clock();
|
|
||||||
}
|
|
||||||
runLoopySolver();
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
|
||||||
cout << "Sum-Product converged in " ;
|
|
||||||
cout << nIters_ << " iterations" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
unsigned size = factorGraph_->getVarNodes().size();
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
unsigned nIters = 0;
|
|
||||||
bool loopy = factorGraph_->isTree() == false;
|
|
||||||
if (loopy) nIters = nIters_;
|
|
||||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
|
||||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
FgBpSolver::getPosterioriOf (VarId vid)
|
|
||||||
{
|
|
||||||
assert (factorGraph_->getFgVarNode (vid));
|
|
||||||
FgVarNode* var = factorGraph_->getFgVarNode (vid);
|
|
||||||
Params probs;
|
|
||||||
if (var->hasEvidence()) {
|
|
||||||
probs.resize (var->range(), LogAware::noEvidence());
|
|
||||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
|
||||||
} else {
|
|
||||||
probs.resize (var->range(), LogAware::multIdenty());
|
|
||||||
const SpLinkSet& links = ninf(var)->getLinks();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
Util::add (probs, links[i]->getMessage());
|
|
||||||
}
|
|
||||||
LogAware::normalize (probs);
|
|
||||||
Util::fromLog (probs);
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
Util::multiply (probs, links[i]->getMessage());
|
|
||||||
}
|
|
||||||
LogAware::normalize (probs);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return probs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|
||||||
{
|
|
||||||
int idx = -1;
|
|
||||||
FgVarNode* vn = factorGraph_->getFgVarNode (jointVarIds[0]);
|
|
||||||
const FgFacSet& factorNodes = vn->neighbors();
|
|
||||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
|
||||||
if (factorNodes[i]->factor()->contains (jointVarIds)) {
|
|
||||||
idx = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (idx == -1) {
|
|
||||||
return getJointByConditioning (jointVarIds);
|
|
||||||
} else {
|
|
||||||
Factor res (*factorNodes[idx]->factor());
|
|
||||||
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
Factor msg (links[i]->getVariable()->varId(),
|
|
||||||
links[i]->getVariable()->range(),
|
|
||||||
getVar2FactorMsg (links[i]));
|
|
||||||
res.multiply (msg);
|
|
||||||
}
|
|
||||||
res.sumOutAllExcept (jointVarIds);
|
|
||||||
res.reorderArguments (jointVarIds);
|
|
||||||
res.normalize();
|
|
||||||
Params jointDist = res.params();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
Util::fromLog (jointDist);
|
|
||||||
}
|
|
||||||
return jointDist;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FgBpSolver::runLoopySolver (void)
|
|
||||||
{
|
|
||||||
initializeSolver();
|
|
||||||
nIters_ = 0;
|
|
||||||
|
|
||||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
|
||||||
|
|
||||||
nIters_ ++;
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
Util::printHeader (" Iteration " + nIters_);
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (BpOptions::schedule) {
|
|
||||||
case BpOptions::Schedule::SEQ_RANDOM:
|
|
||||||
random_shuffle (links_.begin(), links_.end());
|
|
||||||
// no break
|
|
||||||
|
|
||||||
case BpOptions::Schedule::SEQ_FIXED:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateAndUpdateMessage (links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case BpOptions::Schedule::PARALLEL:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
updateMessage(links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
|
||||||
maxResidualSchedule();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FgBpSolver::initializeSolver (void)
|
|
||||||
{
|
|
||||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
|
||||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
|
||||||
delete varsI_[i];
|
|
||||||
}
|
|
||||||
varsI_.reserve (varNodes.size());
|
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
|
||||||
varsI_.push_back (new SPNodeInfo());
|
|
||||||
}
|
|
||||||
|
|
||||||
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
|
|
||||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
|
||||||
delete facsI_[i];
|
|
||||||
}
|
|
||||||
facsI_.reserve (facNodes.size());
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
facsI_.push_back (new SPNodeInfo());
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
delete links_[i];
|
|
||||||
}
|
|
||||||
createLinks();
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
FgFacNode* src = links_[i]->getFactor();
|
|
||||||
FgVarNode* dst = links_[i]->getVariable();
|
|
||||||
ninf (dst)->addSpLink (links_[i]);
|
|
||||||
ninf (src)->addSpLink (links_[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FgBpSolver::createLinks (void)
|
|
||||||
{
|
|
||||||
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
const FgVarSet& neighbors = facNodes[i]->neighbors();
|
|
||||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
|
||||||
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
FgBpSolver::converged (void)
|
|
||||||
{
|
|
||||||
if (links_.size() == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (nIters_ == 0 || nIters_ == 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool converged = true;
|
|
||||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
|
||||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
|
||||||
if (maxResidual > BpOptions::accuracy) {
|
|
||||||
converged = false;
|
|
||||||
} else {
|
|
||||||
converged = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
double residual = links_[i]->getResidual();
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
|
||||||
}
|
|
||||||
if (residual > BpOptions::accuracy) {
|
|
||||||
converged = false;
|
|
||||||
if (Constants::DEBUG == 0) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return converged;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FgBpSolver::maxResidualSchedule (void)
|
|
||||||
{
|
|
||||||
if (nIters_ == 1) {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
|
||||||
linkMap_.insert (make_pair (links_[i], it));
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned c = 0; c < links_.size(); c++) {
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << "current residuals:" << endl;
|
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
|
||||||
it != sortedOrder_.end(); it ++) {
|
|
||||||
cout << " " << setw (30) << left << (*it)->toString();
|
|
||||||
cout << "residual = " << (*it)->getResidual() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
|
||||||
SpLink* link = *it;
|
|
||||||
if (link->getResidual() < BpOptions::accuracy) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
updateMessage (link);
|
|
||||||
link->clearResidual();
|
|
||||||
sortedOrder_.erase (it);
|
|
||||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
|
||||||
|
|
||||||
// update the messages that depend on message source --> destin
|
|
||||||
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
|
|
||||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
|
||||||
if (factorNeighbors[i] != link->getFactor()) {
|
|
||||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
|
||||||
for (unsigned j = 0; j < links.size(); j++) {
|
|
||||||
if (links[j]->getVariable() != link->getVariable()) {
|
|
||||||
calculateMessage (links[j]);
|
|
||||||
SpLinkMap::iterator iter = linkMap_.find (links[j]);
|
|
||||||
sortedOrder_.erase (iter->second);
|
|
||||||
iter->second = sortedOrder_.insert (links[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
Util::printDashedLine();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
|
||||||
{
|
|
||||||
const FgFacNode* src = link->getFactor();
|
|
||||||
const FgVarNode* dst = link->getVariable();
|
|
||||||
const SpLinkSet& links = ninf(src)->getLinks();
|
|
||||||
// calculate the product of messages that were sent
|
|
||||||
// to factor `src', except from var `dst'
|
|
||||||
unsigned msgSize = 1;
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
msgSize *= links[i]->getVariable()->range();
|
|
||||||
}
|
|
||||||
unsigned repetitions = 1;
|
|
||||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (int i = links.size() - 1; i >= 0; i--) {
|
|
||||||
if (links[i]->getVariable() != dst) {
|
|
||||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
|
||||||
repetitions *= links[i]->getVariable()->range();
|
|
||||||
} else {
|
|
||||||
unsigned ds = links[i]->getVariable()->range();
|
|
||||||
Util::add (msgProduct, Params (ds, 1.0), repetitions);
|
|
||||||
repetitions *= ds;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = links.size() - 1; i >= 0; i--) {
|
|
||||||
if (links[i]->getVariable() != dst) {
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " message from " << links[i]->getVariable()->label();
|
|
||||||
cout << ": " << endl;
|
|
||||||
}
|
|
||||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
|
||||||
repetitions *= links[i]->getVariable()->range();
|
|
||||||
} else {
|
|
||||||
unsigned ds = links[i]->getVariable()->range();
|
|
||||||
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
|
|
||||||
repetitions *= ds;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Factor result (src->factor()->arguments(),
|
|
||||||
src->factor()->ranges(),
|
|
||||||
msgProduct);
|
|
||||||
result.multiply (*(src->factor()));
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " message product: " << msgProduct << endl;
|
|
||||||
cout << " original factor: " << src->params() << endl;
|
|
||||||
cout << " factor product: " << result.params() << endl;
|
|
||||||
}
|
|
||||||
result.sumOutAllExcept (dst->varId());
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " marginalized: " ;
|
|
||||||
cout << result.params() << endl;
|
|
||||||
}
|
|
||||||
const Params& resultParams = result.params();
|
|
||||||
Params& message = link->getNextMessage();
|
|
||||||
for (unsigned i = 0; i < resultParams.size(); i++) {
|
|
||||||
message[i] = resultParams[i];
|
|
||||||
}
|
|
||||||
LogAware::normalize (message);
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " curr msg: " << link->getMessage() << endl;
|
|
||||||
cout << " next msg: " << message << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
|
||||||
{
|
|
||||||
const FgVarNode* src = link->getVariable();
|
|
||||||
const FgFacNode* dst = link->getFactor();
|
|
||||||
Params msg;
|
|
||||||
if (src->hasEvidence()) {
|
|
||||||
msg.resize (src->range(), LogAware::noEvidence());
|
|
||||||
msg[src->getEvidence()] = LogAware::withEvidence();
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << msg;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg.resize (src->range(), LogAware::one());
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << msg;
|
|
||||||
}
|
|
||||||
const SpLinkSet& links = ninf (src)->getLinks();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
if (links[i]->getFactor() != dst) {
|
|
||||||
Util::add (msg, links[i]->getMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
if (links[i]->getFactor() != dst) {
|
|
||||||
Util::multiply (msg, links[i]->getMessage());
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " x " << links[i]->getMessage();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " = " << msg;
|
|
||||||
}
|
|
||||||
return msg;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|
||||||
{
|
|
||||||
FgVarSet jointVars;
|
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
|
||||||
assert (factorGraph_->getFgVarNode (jointVarIds[i]));
|
|
||||||
jointVars.push_back (factorGraph_->getFgVarNode (jointVarIds[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
FactorGraph* fg = new FactorGraph (*factorGraph_);
|
|
||||||
FgBpSolver solver (*fg);
|
|
||||||
solver.runSolver();
|
|
||||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
|
||||||
|
|
||||||
VarIds observedVids = {jointVars[0]->varId()};
|
|
||||||
|
|
||||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
|
||||||
assert (jointVars[i]->hasEvidence() == false);
|
|
||||||
Params newBeliefs;
|
|
||||||
VarNodes observedVars;
|
|
||||||
for (unsigned j = 0; j < observedVids.size(); j++) {
|
|
||||||
observedVars.push_back (fg->getFgVarNode (observedVids[j]));
|
|
||||||
}
|
|
||||||
StatesIndexer idx (observedVars, false);
|
|
||||||
while (idx.valid()) {
|
|
||||||
for (unsigned j = 0; j < observedVars.size(); j++) {
|
|
||||||
observedVars[j]->setEvidence (idx[j]);
|
|
||||||
}
|
|
||||||
++ idx;
|
|
||||||
FgBpSolver solver (*fg);
|
|
||||||
solver.runSolver();
|
|
||||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
|
||||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
|
||||||
newBeliefs.push_back (beliefs[k]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int count = -1;
|
|
||||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
|
||||||
if (j % jointVars[i]->range() == 0) {
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
newBeliefs[j] *= prevBeliefs[count];
|
|
||||||
}
|
|
||||||
prevBeliefs = newBeliefs;
|
|
||||||
observedVids.push_back (jointVars[i]->varId());
|
|
||||||
}
|
|
||||||
return prevBeliefs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FgBpSolver::printLinkInformation (void) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
SpLink* l = links_[i];
|
|
||||||
cout << l->toString() << ":" << endl;
|
|
||||||
cout << " curr msg = " ;
|
|
||||||
cout << l->getMessage() << endl;
|
|
||||||
cout << " next msg = " ;
|
|
||||||
cout << l->getNextMessage() << endl;
|
|
||||||
cout << " residual = " << l->getResidual() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,186 +0,0 @@
|
|||||||
#ifndef HORUS_FGBPSOLVER_H
|
|
||||||
#define HORUS_FGBPSOLVER_H
|
|
||||||
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "Solver.h"
|
|
||||||
#include "Factor.h"
|
|
||||||
#include "FactorGraph.h"
|
|
||||||
#include "Util.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
class SpLink
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
SpLink (FgFacNode* fn, FgVarNode* vn)
|
|
||||||
{
|
|
||||||
fac_ = fn;
|
|
||||||
var_ = vn;
|
|
||||||
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
|
||||||
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
|
||||||
currMsg_ = &v1_;
|
|
||||||
nextMsg_ = &v2_;
|
|
||||||
msgSended_ = false;
|
|
||||||
residual_ = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual ~SpLink (void) { };
|
|
||||||
|
|
||||||
FgFacNode* getFactor (void) const { return fac_; }
|
|
||||||
|
|
||||||
FgVarNode* getVariable (void) const { return var_; }
|
|
||||||
|
|
||||||
const Params& getMessage (void) const { return *currMsg_; }
|
|
||||||
|
|
||||||
Params& getNextMessage (void) { return *nextMsg_; }
|
|
||||||
|
|
||||||
bool messageWasSended (void) const { return msgSended_; }
|
|
||||||
|
|
||||||
double getResidual (void) const { return residual_; }
|
|
||||||
|
|
||||||
void clearResidual (void) { residual_ = 0.0; }
|
|
||||||
|
|
||||||
void updateResidual (void)
|
|
||||||
{
|
|
||||||
residual_ = LogAware::getMaxNorm (v1_,v2_);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual void updateMessage (void)
|
|
||||||
{
|
|
||||||
swap (currMsg_, nextMsg_);
|
|
||||||
msgSended_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
string toString (void) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss << fac_->getLabel();
|
|
||||||
ss << " -- " ;
|
|
||||||
ss << var_->label();
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
FgFacNode* fac_;
|
|
||||||
FgVarNode* var_;
|
|
||||||
Params v1_;
|
|
||||||
Params v2_;
|
|
||||||
Params* currMsg_;
|
|
||||||
Params* nextMsg_;
|
|
||||||
bool msgSended_;
|
|
||||||
double residual_;
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef vector<SpLink*> SpLinkSet;
|
|
||||||
|
|
||||||
|
|
||||||
class SPNodeInfo
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
void addSpLink (SpLink* link) { links_.push_back (link); }
|
|
||||||
const SpLinkSet& getLinks (void) { return links_; }
|
|
||||||
private:
|
|
||||||
SpLinkSet links_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class FgBpSolver : public Solver
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
FgBpSolver (const FactorGraph&);
|
|
||||||
|
|
||||||
virtual ~FgBpSolver (void);
|
|
||||||
|
|
||||||
void runSolver (void);
|
|
||||||
|
|
||||||
virtual Params getPosterioriOf (VarId);
|
|
||||||
|
|
||||||
virtual Params getJointDistributionOf (const VarIds&);
|
|
||||||
|
|
||||||
protected:
|
|
||||||
virtual void initializeSolver (void);
|
|
||||||
|
|
||||||
virtual void createLinks (void);
|
|
||||||
|
|
||||||
virtual void maxResidualSchedule (void);
|
|
||||||
|
|
||||||
virtual void calculateFactor2VariableMsg (SpLink*) const;
|
|
||||||
|
|
||||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
|
||||||
|
|
||||||
virtual Params getJointByConditioning (const VarIds&) const;
|
|
||||||
|
|
||||||
virtual void printLinkInformation (void) const;
|
|
||||||
|
|
||||||
SPNodeInfo* ninf (const FgVarNode* var) const
|
|
||||||
{
|
|
||||||
return varsI_[var->getIndex()];
|
|
||||||
}
|
|
||||||
|
|
||||||
SPNodeInfo* ninf (const FgFacNode* fac) const
|
|
||||||
{
|
|
||||||
return facsI_[fac->getIndex()];
|
|
||||||
}
|
|
||||||
|
|
||||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
|
||||||
{
|
|
||||||
if (Constants::DEBUG >= 3) {
|
|
||||||
cout << "calculating & updating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
calculateFactor2VariableMsg (link);
|
|
||||||
if (calcResidual) {
|
|
||||||
link->updateResidual();
|
|
||||||
}
|
|
||||||
link->updateMessage();
|
|
||||||
}
|
|
||||||
|
|
||||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
|
||||||
{
|
|
||||||
if (Constants::DEBUG >= 3) {
|
|
||||||
cout << "calculating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
calculateFactor2VariableMsg (link);
|
|
||||||
if (calcResidual) {
|
|
||||||
link->updateResidual();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateMessage (SpLink* link)
|
|
||||||
{
|
|
||||||
link->updateMessage();
|
|
||||||
if (Constants::DEBUG >= 3) {
|
|
||||||
cout << "updating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CompareResidual
|
|
||||||
{
|
|
||||||
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
|
||||||
{
|
|
||||||
return link1->getResidual() > link2->getResidual();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
SpLinkSet links_;
|
|
||||||
unsigned nIters_;
|
|
||||||
vector<SPNodeInfo*> varsI_;
|
|
||||||
vector<SPNodeInfo*> facsI_;
|
|
||||||
const FactorGraph* factorGraph_;
|
|
||||||
|
|
||||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
|
||||||
SortedOrder sortedOrder_;
|
|
||||||
|
|
||||||
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
|
|
||||||
SpLinkMap linkMap_;
|
|
||||||
|
|
||||||
private:
|
|
||||||
void runLoopySolver (void);
|
|
||||||
bool converged (void);
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_FGBPSOLVER_H
|
|
||||||
|
|
@ -1,64 +0,0 @@
|
|||||||
#ifndef HORUS_GRAPHICALMODEL_H
|
|
||||||
#define HORUS_GRAPHICALMODEL_H
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "VarNode.h"
|
|
||||||
#include "Util.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
struct VarInfo
|
|
||||||
{
|
|
||||||
VarInfo (string l, const States& sts) : label(l), states(sts) { }
|
|
||||||
string label;
|
|
||||||
States states;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class GraphicalModel
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
virtual ~GraphicalModel (void) { };
|
|
||||||
|
|
||||||
virtual VarNode* getVariableNode (VarId) const = 0;
|
|
||||||
|
|
||||||
virtual VarNodes getVariableNodes (void) const = 0;
|
|
||||||
|
|
||||||
virtual void printGraphicalModel (void) const = 0;
|
|
||||||
|
|
||||||
static void addVariableInformation (
|
|
||||||
VarId vid, string label, const States& states)
|
|
||||||
{
|
|
||||||
assert (Util::contains (varsInfo_, vid) == false);
|
|
||||||
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
|
||||||
}
|
|
||||||
|
|
||||||
static VarInfo getVarInformation (VarId vid)
|
|
||||||
{
|
|
||||||
assert (Util::contains (varsInfo_, vid));
|
|
||||||
return varsInfo_.find (vid)->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool variablesHaveInformation (void)
|
|
||||||
{
|
|
||||||
return varsInfo_.size() != 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void clearVariablesInformation (void)
|
|
||||||
{
|
|
||||||
varsInfo_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
static unordered_map<VarId,VarInfo> varsInfo_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_GRAPHICALMODEL_H
|
|
||||||
|
|
@ -11,20 +11,18 @@
|
|||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class VarNode;
|
class Var;
|
||||||
class BayesNode;
|
|
||||||
class FgVarNode;
|
|
||||||
class FgFacNode;
|
|
||||||
class Factor;
|
class Factor;
|
||||||
|
class VarNode;
|
||||||
|
class FactorNode;
|
||||||
|
|
||||||
typedef vector<double> Params;
|
typedef vector<double> Params;
|
||||||
typedef unsigned VarId;
|
typedef unsigned VarId;
|
||||||
typedef vector<VarId> VarIds;
|
typedef vector<VarId> VarIds;
|
||||||
|
typedef vector<Var*> Vars;
|
||||||
typedef vector<VarNode*> VarNodes;
|
typedef vector<VarNode*> VarNodes;
|
||||||
typedef vector<BayesNode*> BnNodeSet;
|
typedef vector<FactorNode*> FactorNodes;
|
||||||
typedef vector<FgVarNode*> FgVarSet;
|
typedef vector<Factor*> Factors;
|
||||||
typedef vector<FgFacNode*> FgFacSet;
|
|
||||||
typedef vector<Factor*> FactorSet;
|
|
||||||
typedef vector<string> States;
|
typedef vector<string> States;
|
||||||
typedef vector<unsigned> Ranges;
|
typedef vector<unsigned> Ranges;
|
||||||
|
|
||||||
@ -32,9 +30,8 @@ typedef vector<unsigned> Ranges;
|
|||||||
enum InfAlgorithms
|
enum InfAlgorithms
|
||||||
{
|
{
|
||||||
VE, // variable elimination
|
VE, // variable elimination
|
||||||
BN_BP, // bayesian network belief propagation
|
BP, // belief propagation
|
||||||
FG_BP, // factor graph belief propagation
|
CBP // counting belief propagation
|
||||||
CBP // counting bp solver
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,13 +5,13 @@
|
|||||||
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "VarElimSolver.h"
|
#include "VarElimSolver.h"
|
||||||
#include "FgBpSolver.h"
|
#include "BpSolver.h"
|
||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
void processArguments (FactorGraph&, int, const char* []);
|
void processArguments (FactorGraph&, int, const char* []);
|
||||||
void runSolver (Solver*, const VarNodes&);
|
void runSolver (Solver*, const VarIds&);
|
||||||
|
|
||||||
const string USAGE = "usage: \
|
const string USAGE = "usage: \
|
||||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||||
@ -48,7 +48,7 @@ main (int argc, const char* argv[])
|
|||||||
void
|
void
|
||||||
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
VarNodes queryVars;
|
VarIds queryIds;
|
||||||
for (int i = 2; i < argc; i++) {
|
for (int i = 2; i < argc; i++) {
|
||||||
const string& arg = argv[i];
|
const string& arg = argv[i];
|
||||||
if (arg.find ('=') == std::string::npos) {
|
if (arg.find ('=') == std::string::npos) {
|
||||||
@ -62,9 +62,9 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << arg;
|
ss << arg;
|
||||||
ss >> vid;
|
ss >> vid;
|
||||||
VarNode* queryVar = fg.getFgVarNode (vid);
|
VarNode* queryVar = fg.getVarNode (vid);
|
||||||
if (queryVar) {
|
if (queryVar) {
|
||||||
queryVars.push_back (queryVar);
|
queryIds.push_back (vid);
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: there isn't a variable with " ;
|
cerr << "error: there isn't a variable with " ;
|
||||||
cerr << "`" << vid << "' as id" ;
|
cerr << "`" << vid << "' as id" ;
|
||||||
@ -93,7 +93,7 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << arg.substr (0, pos);
|
ss << arg.substr (0, pos);
|
||||||
ss >> vid;
|
ss >> vid;
|
||||||
VarNode* var = fg.getFgVarNode (vid);
|
VarNode* var = fg.getVarNode (vid);
|
||||||
if (var) {
|
if (var) {
|
||||||
if (!Util::isInteger (arg.substr (pos + 1))) {
|
if (!Util::isInteger (arg.substr (pos + 1))) {
|
||||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
||||||
@ -127,8 +127,8 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
case InfAlgorithms::VE:
|
case InfAlgorithms::VE:
|
||||||
solver = new VarElimSolver (fg);
|
solver = new VarElimSolver (fg);
|
||||||
break;
|
break;
|
||||||
case InfAlgorithms::FG_BP:
|
case InfAlgorithms::BP:
|
||||||
solver = new FgBpSolver (fg);
|
solver = new BpSolver (fg);
|
||||||
break;
|
break;
|
||||||
case InfAlgorithms::CBP:
|
case InfAlgorithms::CBP:
|
||||||
solver = new CbpSolver (fg);
|
solver = new CbpSolver (fg);
|
||||||
@ -136,27 +136,23 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
default:
|
default:
|
||||||
assert (false);
|
assert (false);
|
||||||
}
|
}
|
||||||
runSolver (solver, queryVars);
|
runSolver (solver, queryIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
runSolver (Solver* solver, const VarNodes& queryVars)
|
runSolver (Solver* solver, const VarIds& queryIds)
|
||||||
{
|
{
|
||||||
VarIds vids;
|
if (queryIds.size() == 0) {
|
||||||
for (unsigned i = 0; i < queryVars.size(); i++) {
|
|
||||||
vids.push_back (queryVars[i]->varId());
|
|
||||||
}
|
|
||||||
if (queryVars.size() == 0) {
|
|
||||||
solver->runSolver();
|
solver->runSolver();
|
||||||
solver->printAllPosterioris();
|
solver->printAllPosterioris();
|
||||||
} else if (queryVars.size() == 1) {
|
} else if (queryIds.size() == 1) {
|
||||||
solver->runSolver();
|
solver->runSolver();
|
||||||
solver->printPosterioriOf (vids[0]);
|
solver->printPosterioriOf (queryIds[0]);
|
||||||
} else {
|
} else {
|
||||||
solver->runSolver();
|
solver->runSolver();
|
||||||
solver->printJointDistributionOf (vids);
|
solver->printJointDistributionOf (queryIds);
|
||||||
}
|
}
|
||||||
delete solver;
|
delete solver;
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "FoveSolver.h"
|
#include "FoveSolver.h"
|
||||||
#include "VarElimSolver.h"
|
#include "VarElimSolver.h"
|
||||||
#include "FgBpSolver.h"
|
#include "BpSolver.h"
|
||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
#include "ElimGraph.h"
|
#include "ElimGraph.h"
|
||||||
#include "BayesBall.h"
|
#include "BayesBall.h"
|
||||||
@ -241,8 +241,8 @@ createGroundNetwork (void)
|
|||||||
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
|
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
|
||||||
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
|
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
|
||||||
cout << vid << " == " << ev << endl;
|
cout << vid << " == " << ev << endl;
|
||||||
assert (fg->getFgVarNode (vid));
|
assert (fg->getVarNode (vid));
|
||||||
fg->getFgVarNode (vid)->setEvidence (ev);
|
fg->getVarNode (vid)->setEvidence (ev);
|
||||||
evidenceList = YAP_TailOfTerm (evidenceList);
|
evidenceList = YAP_TailOfTerm (evidenceList);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -348,7 +348,6 @@ runGroundSolver (void)
|
|||||||
taskList = YAP_TailOfTerm (taskList);
|
taskList = YAP_TailOfTerm (taskList);
|
||||||
}
|
}
|
||||||
|
|
||||||
fg->printGraphicalModel();
|
|
||||||
vector<Params> results;
|
vector<Params> results;
|
||||||
if (Globals::infAlgorithm == InfAlgorithms::VE) {
|
if (Globals::infAlgorithm == InfAlgorithms::VE) {
|
||||||
runVeSolver (fg, tasks, results);
|
runVeSolver (fg, tasks, results);
|
||||||
@ -414,8 +413,8 @@ void runBpSolver (
|
|||||||
mfg = BayesBall::getMinimalFactorGraph (
|
mfg = BayesBall::getMinimalFactorGraph (
|
||||||
*fg, VarIds (vids.begin(),vids.end()));
|
*fg, VarIds (vids.begin(),vids.end()));
|
||||||
}
|
}
|
||||||
if (Globals::infAlgorithm == InfAlgorithms::FG_BP) {
|
if (Globals::infAlgorithm == InfAlgorithms::BP) {
|
||||||
solver = new FgBpSolver (*mfg);
|
solver = new BpSolver (*mfg);
|
||||||
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
||||||
CFactorGraph::checkForIdenticalFactors = false;
|
CFactorGraph::checkForIdenticalFactors = false;
|
||||||
solver = new CbpSolver (*mfg);
|
solver = new CbpSolver (*mfg);
|
||||||
@ -494,7 +493,7 @@ setBayesNetParams (void)
|
|||||||
int
|
int
|
||||||
setExtraVarsInfo (void)
|
setExtraVarsInfo (void)
|
||||||
{
|
{
|
||||||
GraphicalModel::clearVariablesInformation();
|
Var::clearVariablesInformation();
|
||||||
YAP_Term varsInfoL = YAP_ARG2;
|
YAP_Term varsInfoL = YAP_ARG2;
|
||||||
while (varsInfoL != YAP_TermNil()) {
|
while (varsInfoL != YAP_TermNil()) {
|
||||||
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
||||||
@ -507,7 +506,7 @@ setExtraVarsInfo (void)
|
|||||||
states.push_back ((char*) YAP_AtomName (atom));
|
states.push_back ((char*) YAP_AtomName (atom));
|
||||||
statesL = YAP_TailOfTerm (statesL);
|
statesL = YAP_TailOfTerm (statesL);
|
||||||
}
|
}
|
||||||
GraphicalModel::addVariableInformation (vid,
|
Var::addVariableInformation (vid,
|
||||||
(char*) YAP_AtomName (label), states);
|
(char*) YAP_AtomName (label), states);
|
||||||
varsInfoL = YAP_TailOfTerm (varsInfoL);
|
varsInfoL = YAP_TailOfTerm (varsInfoL);
|
||||||
}
|
}
|
||||||
@ -524,10 +523,8 @@ setHorusFlag (void)
|
|||||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||||
if ( value == "ve") {
|
if ( value == "ve") {
|
||||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||||
} else if (value == "bn_bp") {
|
} else if (value == "bp") {
|
||||||
Globals::infAlgorithm = InfAlgorithms::BN_BP;
|
Globals::infAlgorithm = InfAlgorithms::BP;
|
||||||
} else if (value == "fg_bp") {
|
|
||||||
Globals::infAlgorithm = InfAlgorithms::FG_BP;
|
|
||||||
} else if (value == "cbp") {
|
} else if (value == "cbp") {
|
||||||
Globals::infAlgorithm = InfAlgorithms::CBP;
|
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||||
} else {
|
} else {
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
||||||
#include "VarNode.h"
|
#include "Var.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ class StatesIndexer
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StatesIndexer (const VarNodes& vars, bool calcOffsets = true)
|
StatesIndexer (const Vars& vars, bool calcOffsets = true)
|
||||||
{
|
{
|
||||||
size_ = 1;
|
size_ = 1;
|
||||||
indices_.resize (vars.size(), 0);
|
indices_.resize (vars.size(), 0);
|
||||||
|
@ -45,7 +45,6 @@ CWD=$(PWD)
|
|||||||
|
|
||||||
|
|
||||||
HEADERS = \
|
HEADERS = \
|
||||||
$(srcdir)/GraphicalModel.h \
|
|
||||||
$(srcdir)/BayesNet.h \
|
$(srcdir)/BayesNet.h \
|
||||||
$(srcdir)/BayesBall.h \
|
$(srcdir)/BayesBall.h \
|
||||||
$(srcdir)/ElimGraph.h \
|
$(srcdir)/ElimGraph.h \
|
||||||
@ -55,10 +54,10 @@ HEADERS = \
|
|||||||
$(srcdir)/ConstraintTree.h \
|
$(srcdir)/ConstraintTree.h \
|
||||||
$(srcdir)/Solver.h \
|
$(srcdir)/Solver.h \
|
||||||
$(srcdir)/VarElimSolver.h \
|
$(srcdir)/VarElimSolver.h \
|
||||||
$(srcdir)/FgBpSolver.h \
|
$(srcdir)/BpSolver.h \
|
||||||
$(srcdir)/CbpSolver.h \
|
$(srcdir)/CbpSolver.h \
|
||||||
$(srcdir)/FoveSolver.h \
|
$(srcdir)/FoveSolver.h \
|
||||||
$(srcdir)/VarNode.h \
|
$(srcdir)/Var.h \
|
||||||
$(srcdir)/Indexer.h \
|
$(srcdir)/Indexer.h \
|
||||||
$(srcdir)/Parfactor.h \
|
$(srcdir)/Parfactor.h \
|
||||||
$(srcdir)/ProbFormula.h \
|
$(srcdir)/ProbFormula.h \
|
||||||
@ -77,10 +76,10 @@ CPP_SOURCES = \
|
|||||||
$(srcdir)/Factor.cpp \
|
$(srcdir)/Factor.cpp \
|
||||||
$(srcdir)/CFactorGraph.cpp \
|
$(srcdir)/CFactorGraph.cpp \
|
||||||
$(srcdir)/ConstraintTree.cpp \
|
$(srcdir)/ConstraintTree.cpp \
|
||||||
$(srcdir)/VarNode.cpp \
|
$(srcdir)/Var.cpp \
|
||||||
$(srcdir)/Solver.cpp \
|
$(srcdir)/Solver.cpp \
|
||||||
$(srcdir)/VarElimSolver.cpp \
|
$(srcdir)/VarElimSolver.cpp \
|
||||||
$(srcdir)/FgBpSolver.cpp \
|
$(srcdir)/BpSolver.cpp \
|
||||||
$(srcdir)/CbpSolver.cpp \
|
$(srcdir)/CbpSolver.cpp \
|
||||||
$(srcdir)/FoveSolver.cpp \
|
$(srcdir)/FoveSolver.cpp \
|
||||||
$(srcdir)/Parfactor.cpp \
|
$(srcdir)/Parfactor.cpp \
|
||||||
@ -100,10 +99,10 @@ OBJS = \
|
|||||||
Factor.o \
|
Factor.o \
|
||||||
CFactorGraph.o \
|
CFactorGraph.o \
|
||||||
ConstraintTree.o \
|
ConstraintTree.o \
|
||||||
VarNode.o \
|
Var.o \
|
||||||
Solver.o \
|
Solver.o \
|
||||||
VarElimSolver.o \
|
VarElimSolver.o \
|
||||||
FgBpSolver.o \
|
BpSolver.o \
|
||||||
CbpSolver.o \
|
CbpSolver.o \
|
||||||
FoveSolver.o \
|
FoveSolver.o \
|
||||||
Parfactor.o \
|
Parfactor.o \
|
||||||
@ -122,10 +121,10 @@ HCLI_OBJS = \
|
|||||||
Factor.o \
|
Factor.o \
|
||||||
CFactorGraph.o \
|
CFactorGraph.o \
|
||||||
ConstraintTree.o \
|
ConstraintTree.o \
|
||||||
VarNode.o \
|
Var.o \
|
||||||
Solver.o \
|
Solver.o \
|
||||||
VarElimSolver.o \
|
VarElimSolver.o \
|
||||||
FgBpSolver.o \
|
BpSolver.o \
|
||||||
CbpSolver.o \
|
CbpSolver.o \
|
||||||
FoveSolver.o \
|
FoveSolver.o \
|
||||||
Parfactor.o \
|
Parfactor.o \
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
void
|
void
|
||||||
Solver::printAllPosterioris (void)
|
Solver::printAllPosterioris (void)
|
||||||
{
|
{
|
||||||
const VarNodes& vars = gm_->getVariableNodes();
|
const VarNodes& vars = fg_.varNodes();
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
printPosterioriOf (vars[i]->varId());
|
printPosterioriOf (vars[i]->varId());
|
||||||
}
|
}
|
||||||
@ -16,11 +16,11 @@ Solver::printAllPosterioris (void)
|
|||||||
void
|
void
|
||||||
Solver::printPosterioriOf (VarId vid)
|
Solver::printPosterioriOf (VarId vid)
|
||||||
{
|
{
|
||||||
VarNode* var = gm_->getVariableNode (vid);
|
VarNode* vn = fg_.getVarNode (vid);
|
||||||
const Params& posterioriDist = getPosterioriOf (vid);
|
const Params& posterioriDist = getPosterioriOf (vid);
|
||||||
const States& states = var->states();
|
const States& states = vn->states();
|
||||||
for (unsigned i = 0; i < states.size(); i++) {
|
for (unsigned i = 0; i < states.size(); i++) {
|
||||||
cout << "P(" << var->label() << "=" << states[i] << ") = " ;
|
cout << "P(" << vn->label() << "=" << states[i] << ") = " ;
|
||||||
cout << setprecision (Constants::PRECISION) << posterioriDist[i];
|
cout << setprecision (Constants::PRECISION) << posterioriDist[i];
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
@ -32,12 +32,12 @@ Solver::printPosterioriOf (VarId vid)
|
|||||||
void
|
void
|
||||||
Solver::printJointDistributionOf (const VarIds& vids)
|
Solver::printJointDistributionOf (const VarIds& vids)
|
||||||
{
|
{
|
||||||
VarNodes vars;
|
Vars vars;
|
||||||
VarIds vidsWithoutEvidence;
|
VarIds vidsWithoutEvidence;
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
VarNode* var = gm_->getVariableNode (vids[i]);
|
VarNode* vn = fg_.getVarNode (vids[i]);
|
||||||
if (var->hasEvidence() == false) {
|
if (vn->hasEvidence() == false) {
|
||||||
vars.push_back (var);
|
vars.push_back (vn);
|
||||||
vidsWithoutEvidence.push_back (vids[i]);
|
vidsWithoutEvidence.push_back (vids[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,15 +3,16 @@
|
|||||||
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
#include "Var.h"
|
||||||
#include "VarNode.h"
|
#include "FactorGraph.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class Solver
|
class Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Solver (const GraphicalModel* gm) : gm_(gm) { }
|
Solver (const FactorGraph& fg) : fg_(fg) { }
|
||||||
|
|
||||||
virtual ~Solver() { } // ensure that subclass destructor is called
|
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||||
|
|
||||||
@ -28,7 +29,7 @@ class Solver
|
|||||||
void printJointDistributionOf (const VarIds& vids);
|
void printJointDistributionOf (const VarIds& vids);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const GraphicalModel* gm_;
|
const FactorGraph& fg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_SOLVER_H
|
#endif // HORUS_SOLVER_H
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
#include "GraphicalModel.h"
|
|
||||||
|
|
||||||
|
|
||||||
namespace Globals {
|
namespace Globals {
|
||||||
@ -30,7 +29,6 @@ unsigned maxIter = 1000;
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
unordered_map<VarId, VarInfo> GraphicalModel::varsInfo_;
|
|
||||||
|
|
||||||
vector<NetInfo> Statistics::netInfo_;
|
vector<NetInfo> Statistics::netInfo_;
|
||||||
vector<CompressInfo> Statistics::compressInfo_;
|
vector<CompressInfo> Statistics::compressInfo_;
|
||||||
@ -139,7 +137,7 @@ parametersToString (const Params& v, unsigned precision)
|
|||||||
|
|
||||||
|
|
||||||
vector<string>
|
vector<string>
|
||||||
getJointStateStrings (const VarNodes& vars)
|
getJointStateStrings (const Vars& vars)
|
||||||
{
|
{
|
||||||
StatesIndexer idx (vars);
|
StatesIndexer idx (vars);
|
||||||
vector<string> jointStrings;
|
vector<string> jointStrings;
|
||||||
@ -402,8 +400,7 @@ Statistics::getStatisticString (void)
|
|||||||
ss1 << "running mode: " ;
|
ss1 << "running mode: " ;
|
||||||
switch (Globals::infAlgorithm) {
|
switch (Globals::infAlgorithm) {
|
||||||
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
||||||
case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break;
|
case InfAlgorithms::BP: ss1 << "bp" << endl; break;
|
||||||
case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break;
|
|
||||||
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
|
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
|
||||||
}
|
}
|
||||||
ss1 << "message schedule: " ;
|
ss1 << "message schedule: " ;
|
||||||
|
@ -62,7 +62,7 @@ bool isInteger (const string&);
|
|||||||
|
|
||||||
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
||||||
|
|
||||||
vector<string> getJointStateStrings (const VarNodes&);
|
vector<string> getJointStateStrings (const Vars&);
|
||||||
|
|
||||||
void printHeader (string, std::ostream& os = std::cout);
|
void printHeader (string, std::ostream& os = std::cout);
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (&fg)
|
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (fg)
|
||||||
{
|
{
|
||||||
factorGraph_ = &fg;
|
factorGraph_ = &fg;
|
||||||
}
|
}
|
||||||
@ -23,8 +23,8 @@ VarElimSolver::~VarElimSolver (void)
|
|||||||
Params
|
Params
|
||||||
VarElimSolver::getPosterioriOf (VarId vid)
|
VarElimSolver::getPosterioriOf (VarId vid)
|
||||||
{
|
{
|
||||||
assert (factorGraph_->getFgVarNode (vid));
|
assert (factorGraph_->getVarNode (vid));
|
||||||
FgVarNode* vn = factorGraph_->getFgVarNode (vid);
|
VarNode* vn = factorGraph_->getVarNode (vid);
|
||||||
if (vn->hasEvidence()) {
|
if (vn->hasEvidence()) {
|
||||||
Params params (vn->range(), 0.0);
|
Params params (vn->range(), 0.0);
|
||||||
params[vn->getEvidence()] = 1.0;
|
params[vn->getEvidence()] = 1.0;
|
||||||
@ -57,11 +57,11 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
|||||||
void
|
void
|
||||||
VarElimSolver::createFactorList (void)
|
VarElimSolver::createFactorList (void)
|
||||||
{
|
{
|
||||||
const FgFacSet& factorNodes = factorGraph_->getFactorNodes();
|
const FactorNodes& factorNodes = factorGraph_->factorNodes();
|
||||||
factorList_.reserve (factorNodes.size() * 2);
|
factorList_.reserve (factorNodes.size() * 2);
|
||||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
||||||
factorList_.push_back (new Factor (*factorNodes[i]->factor()));
|
factorList_.push_back (new Factor (*factorNodes[i]->factor()));
|
||||||
const FgVarSet& neighs = factorNodes[i]->neighbors();
|
const VarNodes& neighs = factorNodes[i]->neighbors();
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
unordered_map<VarId,vector<unsigned> >::iterator it
|
unordered_map<VarId,vector<unsigned> >::iterator it
|
||||||
= varFactors_.find (neighs[j]->varId());
|
= varFactors_.find (neighs[j]->varId());
|
||||||
@ -79,7 +79,7 @@ VarElimSolver::createFactorList (void)
|
|||||||
void
|
void
|
||||||
VarElimSolver::absorveEvidence (void)
|
VarElimSolver::absorveEvidence (void)
|
||||||
{
|
{
|
||||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
const VarNodes& varNodes = factorGraph_->varNodes();
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
if (varNodes[i]->hasEvidence()) {
|
if (varNodes[i]->hasEvidence()) {
|
||||||
const vector<unsigned>& idxs =
|
const vector<unsigned>& idxs =
|
||||||
@ -126,7 +126,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
|||||||
|
|
||||||
VarIds unobservedVids;
|
VarIds unobservedVids;
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
if (factorGraph_->getFgVarNode (vids[i])->hasEvidence() == false) {
|
if (factorGraph_->getVarNode (vids[i])->hasEvidence() == false) {
|
||||||
unobservedVids.push_back (vids[i]);
|
unobservedVids.push_back (vids[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,100 +0,0 @@
|
|||||||
#include <algorithm>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "VarNode.h"
|
|
||||||
#include "GraphicalModel.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
VarNode::VarNode (const VarNode* v)
|
|
||||||
{
|
|
||||||
varId_ = v->varId();
|
|
||||||
range_ = v->range();
|
|
||||||
evidence_ = v->getEvidence();
|
|
||||||
index_ = std::numeric_limits<unsigned>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarNode::VarNode (VarId varId, unsigned range, int evidence)
|
|
||||||
{
|
|
||||||
assert (range != 0);
|
|
||||||
assert (evidence < (int) range);
|
|
||||||
varId_ = varId;
|
|
||||||
range_ = range;
|
|
||||||
evidence_ = evidence;
|
|
||||||
index_ = std::numeric_limits<unsigned>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
VarNode::isValidState (int stateIndex)
|
|
||||||
{
|
|
||||||
return stateIndex >= 0 && stateIndex < (int) range_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
VarNode::isValidState (const string& stateName)
|
|
||||||
{
|
|
||||||
States states = GraphicalModel::getVarInformation (varId_).states;
|
|
||||||
return Util::contains (states, stateName);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
VarNode::setEvidence (int ev)
|
|
||||||
{
|
|
||||||
assert (ev < (int) range_);
|
|
||||||
evidence_ = ev;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
VarNode::setEvidence (const string& ev)
|
|
||||||
{
|
|
||||||
States states = GraphicalModel::getVarInformation (varId_).states;
|
|
||||||
for (unsigned i = 0; i < states.size(); i++) {
|
|
||||||
if (states[i] == ev) {
|
|
||||||
evidence_ = i;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert (false);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
|
||||||
VarNode::label (void) const
|
|
||||||
{
|
|
||||||
if (GraphicalModel::variablesHaveInformation()) {
|
|
||||||
return GraphicalModel::getVarInformation (varId_).label;
|
|
||||||
}
|
|
||||||
stringstream ss;
|
|
||||||
ss << "x" << varId_;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
States
|
|
||||||
VarNode::states (void) const
|
|
||||||
{
|
|
||||||
if (GraphicalModel::variablesHaveInformation()) {
|
|
||||||
return GraphicalModel::getVarInformation (varId_).states;
|
|
||||||
}
|
|
||||||
States states;
|
|
||||||
for (unsigned i = 0; i < range_; i++) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << i ;
|
|
||||||
states.push_back (ss.str());
|
|
||||||
}
|
|
||||||
return states;
|
|
||||||
}
|
|
||||||
|
|
@ -1,71 +0,0 @@
|
|||||||
#ifndef HORUS_VARNODE_H
|
|
||||||
#define HORUS_VARNODE_H
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
class VarNode
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
VarNode (const VarNode*);
|
|
||||||
|
|
||||||
VarNode (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
|
||||||
|
|
||||||
virtual ~VarNode (void) { };
|
|
||||||
|
|
||||||
unsigned varId (void) const { return varId_; }
|
|
||||||
|
|
||||||
unsigned range (void) const { return range_; }
|
|
||||||
|
|
||||||
int getEvidence (void) const { return evidence_; }
|
|
||||||
|
|
||||||
unsigned getIndex (void) const { return index_; }
|
|
||||||
|
|
||||||
void setIndex (unsigned idx) { index_ = idx; }
|
|
||||||
|
|
||||||
operator unsigned () const { return index_; }
|
|
||||||
|
|
||||||
bool hasEvidence (void) const
|
|
||||||
{
|
|
||||||
return evidence_ != Constants::NO_EVIDENCE;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator== (const VarNode& var) const
|
|
||||||
{
|
|
||||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
|
||||||
return varId_ == var.varId();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator!= (const VarNode& var) const
|
|
||||||
{
|
|
||||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
|
||||||
return varId_ != var.varId();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isValidState (int);
|
|
||||||
|
|
||||||
bool isValidState (const string&);
|
|
||||||
|
|
||||||
void setEvidence (int);
|
|
||||||
|
|
||||||
void setEvidence (const string&);
|
|
||||||
|
|
||||||
string label (void) const;
|
|
||||||
|
|
||||||
States states (void) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
VarId varId_;
|
|
||||||
unsigned range_;
|
|
||||||
int evidence_;
|
|
||||||
unsigned index_;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // BP_VARNODE_H
|
|
||||||
|
|
@ -1,50 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
cp ~/bin/yap ~/bin/town_fgbp
|
|
||||||
YAP=~/bin/town_fgbp
|
|
||||||
|
|
||||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
|
||||||
OUT_FILE_NAME=fb_bp.log
|
|
||||||
rm -f $OUT_FILE_NAME
|
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
|
||||||
|
|
||||||
|
|
||||||
function run_solver
|
|
||||||
{
|
|
||||||
if [ $2 = bp ]
|
|
||||||
then
|
|
||||||
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
|
||||||
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
|
||||||
else
|
|
||||||
extra_flag1=true
|
|
||||||
extra_flag2=true
|
|
||||||
fi
|
|
||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
|
||||||
[$1].
|
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
|
||||||
clpbn_horus:set_horus_flag(use_logarithms, true),
|
|
||||||
$extra_flag1, $extra_flag2,
|
|
||||||
run_query(_R),
|
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
|
||||||
format(S, '$3: ~15+ ',[]),
|
|
||||||
close(S).
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function run_all_graphs
|
|
||||||
{
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
|
||||||
#run_solver town_5000 $1 town_5000 $3 $4 $5
|
|
||||||
#run_solver town_10000 $1 town_10000 $3 $4 $5
|
|
||||||
#run_solver town_50000 $1 town_50000 $3 $4 $5
|
|
||||||
#run_solver town_100000 $1 town_100000 $3 $4 $5
|
|
||||||
#run_solver town_500000 $1 town_500000 $3 $4 $5
|
|
||||||
#run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
|
||||||
}
|
|
||||||
|
|
||||||
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
|
||||||
|
|
@ -9,7 +9,7 @@ OUT_FILE_NAME=results.log
|
|||||||
rm -f $OUT_FILE_NAME
|
rm -f $OUT_FILE_NAME
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
rm -f ignore.$OUT_FILE_NAME
|
||||||
|
|
||||||
# yap -g "['../../../../examples/School/sch32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_horus:set_horus_flag(inf_alg,fg_bp), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
|
# yap -g "['../../../../examples/School/sch32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_horus:set_horus_flag(inf_alg, bp), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
|
||||||
|
|
||||||
function run_solver
|
function run_solver
|
||||||
{
|
{
|
||||||
@ -59,8 +59,7 @@ function run_all_graphs
|
|||||||
|
|
||||||
|
|
||||||
#run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
#run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
||||||
#run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
|
#run_all_graphs bp "bp(seq_fixed) " bp seq_fixed
|
||||||
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
|
||||||
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||||
exit
|
exit
|
||||||
|
|
||||||
@ -69,10 +68,8 @@ run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
|||||||
run_all_graphs bp "hve(min_weight) " ve min_weight
|
run_all_graphs bp "hve(min_weight) " ve min_weight
|
||||||
run_all_graphs bp "hve(min_fill) " ve min_fill
|
run_all_graphs bp "hve(min_fill) " ve min_fill
|
||||||
run_all_graphs bp "hve(w_min_fill) " ve weighted_min_fill
|
run_all_graphs bp "hve(w_min_fill) " ve weighted_min_fill
|
||||||
run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
|
run_all_graphs bp "bp(seq_fixed) " bp seq_fixed
|
||||||
run_all_graphs bp "bn_bp(max_residual) " bn_bp max_residual
|
run_all_graphs bp "bp(max_residual) " bp max_residual
|
||||||
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
|
||||||
run_all_graphs bp "fg_bp(max_residual) " fg_bp max_residual
|
|
||||||
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||||
run_all_graphs bp "cbp(max_residual) " cbp max_residual
|
run_all_graphs bp "cbp(max_residual) " cbp max_residual
|
||||||
run_all_graphs gibbs "gibbs "
|
run_all_graphs gibbs "gibbs "
|
||||||
|
@ -1,81 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="US-ASCII"?>
|
|
||||||
|
|
||||||
<!--
|
|
||||||
|
|
||||||
B E
|
|
||||||
\ /
|
|
||||||
\ /
|
|
||||||
A
|
|
||||||
/ \
|
|
||||||
/ \
|
|
||||||
J M
|
|
||||||
|
|
||||||
-->
|
|
||||||
|
|
||||||
|
|
||||||
<BIF VERSION="0.3">
|
|
||||||
<NETWORK>
|
|
||||||
<NAME>Simple Loop</NAME>
|
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
|
||||||
<NAME>B</NAME>
|
|
||||||
<OUTCOME>b1</OUTCOME>
|
|
||||||
<OUTCOME>b2</OUTCOME>
|
|
||||||
</VARIABLE>
|
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
|
||||||
<NAME>E</NAME>
|
|
||||||
<OUTCOME>e1</OUTCOME>
|
|
||||||
<OUTCOME>e2</OUTCOME>
|
|
||||||
</VARIABLE>
|
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
|
||||||
<NAME>A</NAME>
|
|
||||||
<OUTCOME>a1</OUTCOME>
|
|
||||||
<OUTCOME>a2</OUTCOME>
|
|
||||||
</VARIABLE>
|
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
|
||||||
<NAME>J</NAME>
|
|
||||||
<OUTCOME>j1</OUTCOME>
|
|
||||||
<OUTCOME>j2</OUTCOME>
|
|
||||||
</VARIABLE>
|
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
|
||||||
<NAME>M</NAME>
|
|
||||||
<OUTCOME>m1</OUTCOME>
|
|
||||||
<OUTCOME>m2</OUTCOME>
|
|
||||||
</VARIABLE>
|
|
||||||
|
|
||||||
<DEFINITION>
|
|
||||||
<FOR>B</FOR>
|
|
||||||
<TABLE> .001 .999 </TABLE>
|
|
||||||
</DEFINITION>
|
|
||||||
|
|
||||||
<DEFINITION>
|
|
||||||
<FOR>E</FOR>
|
|
||||||
<TABLE> .002 .998 </TABLE>
|
|
||||||
</DEFINITION>
|
|
||||||
|
|
||||||
<DEFINITION>
|
|
||||||
<FOR>A</FOR>
|
|
||||||
<GIVEN>B</GIVEN>
|
|
||||||
<GIVEN>E</GIVEN>
|
|
||||||
<TABLE> .95 .05 .94 .06 .29 .71 .001 .999 </TABLE>
|
|
||||||
</DEFINITION>
|
|
||||||
|
|
||||||
<DEFINITION>
|
|
||||||
<FOR>J</FOR>
|
|
||||||
<GIVEN>A</GIVEN>
|
|
||||||
<TABLE> .9 .1 .05 .95 </TABLE>
|
|
||||||
</DEFINITION>
|
|
||||||
|
|
||||||
<DEFINITION>
|
|
||||||
<FOR>M</FOR>
|
|
||||||
<GIVEN>A</GIVEN>
|
|
||||||
<TABLE> .7 .3 .01 .99 </TABLE>
|
|
||||||
</DEFINITION>
|
|
||||||
|
|
||||||
</NETWORK>
|
|
||||||
</BIF>
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
%:- set_pfl_flag(solver,ve).
|
%:- set_pfl_flag(solver,ve).
|
||||||
:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
|
:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
|
||||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,fg_bp).
|
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
|
||||||
%:- set_pfl_flag(solver,fove).
|
%:- set_pfl_flag(solver,fove).
|
||||||
|
|
||||||
% :- yap_flag(write_strings, off).
|
% :- yap_flag(write_strings, off).
|
||||||
|
Reference in New Issue
Block a user