some renamings

This commit is contained in:
Tiago Gomes 2012-04-10 11:51:56 +01:00
parent 6986e8c0d7
commit aa1b2e40ea
11 changed files with 96 additions and 113 deletions

View File

@ -60,7 +60,7 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
void void
BayesBall::constructGraph (FactorGraph* fg) const BayesBall::constructGraph (FactorGraph* fg) const
{ {
const FactorNodes& facNodes = fg_.factorNodes(); const FacNodes& facNodes = fg_.facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
const DAGraphNode* n = dag_.getNode ( const DAGraphNode* n = dag_.getNode (
facNodes[i]->factor().argument (0)); facNodes[i]->factor().argument (0));

View File

@ -99,9 +99,9 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
{ {
int idx = -1; int idx = -1;
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]); VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
const FactorNodes& factorNodes = vn->neighbors(); const FacNodes& facNodes = vn->neighbors();
for (unsigned i = 0; i < factorNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
if (factorNodes[i]->factor().contains (jointVarIds)) { if (facNodes[i]->factor().contains (jointVarIds)) {
idx = i; idx = i;
break; break;
} }
@ -109,8 +109,8 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
if (idx == -1) { if (idx == -1) {
return getJointByConditioning (jointVarIds); return getJointByConditioning (jointVarIds);
} else { } else {
Factor res (factorNodes[idx]->factor()); Factor res (facNodes[idx]->factor());
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks(); const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
Factor msg (links[i]->getVariable()->varId(), Factor msg (links[i]->getVariable()->varId(),
links[i]->getVariable()->range(), links[i]->getVariable()->range(),
@ -188,7 +188,7 @@ BpSolver::initializeSolver (void)
varsI_.push_back (new SPNodeInfo()); varsI_.push_back (new SPNodeInfo());
} }
const FactorNodes& facNodes = factorGraph_->factorNodes(); const FacNodes& facNodes = factorGraph_->facNodes();
for (unsigned i = 0; i < facsI_.size(); i++) { for (unsigned i = 0; i < facsI_.size(); i++) {
delete facsI_[i]; delete facsI_[i];
} }
@ -203,7 +203,7 @@ BpSolver::initializeSolver (void)
createLinks(); createLinks();
for (unsigned i = 0; i < links_.size(); i++) { for (unsigned i = 0; i < links_.size(); i++) {
FactorNode* src = links_[i]->getFactor(); FacNode* src = links_[i]->getFactor();
VarNode* dst = links_[i]->getVariable(); VarNode* dst = links_[i]->getVariable();
ninf (dst)->addSpLink (links_[i]); ninf (dst)->addSpLink (links_[i]);
ninf (src)->addSpLink (links_[i]); ninf (src)->addSpLink (links_[i]);
@ -215,7 +215,7 @@ BpSolver::initializeSolver (void)
void void
BpSolver::createLinks (void) BpSolver::createLinks (void)
{ {
const FactorNodes& facNodes = factorGraph_->factorNodes(); const FacNodes& facNodes = factorGraph_->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
const VarNodes& neighbors = facNodes[i]->neighbors(); const VarNodes& neighbors = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighbors.size(); j++) { for (unsigned j = 0; j < neighbors.size(); j++) {
@ -293,7 +293,7 @@ BpSolver::maxResidualSchedule (void)
linkMap_.find (link)->second = sortedOrder_.insert (link); linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin // update the messages that depend on message source --> destin
const FactorNodes& factorNeighbors = link->getVariable()->neighbors(); const FacNodes& factorNeighbors = link->getVariable()->neighbors();
for (unsigned i = 0; i < factorNeighbors.size(); i++) { for (unsigned i = 0; i < factorNeighbors.size(); i++) {
if (factorNeighbors[i] != link->getFactor()) { if (factorNeighbors[i] != link->getFactor()) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks(); const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
@ -318,7 +318,7 @@ BpSolver::maxResidualSchedule (void)
void void
BpSolver::calculateFactor2VariableMsg (SpLink* link) BpSolver::calculateFactor2VariableMsg (SpLink* link)
{ {
FactorNode* src = link->getFactor(); FacNode* src = link->getFactor();
const VarNode* dst = link->getVariable(); const VarNode* dst = link->getVariable();
const SpLinkSet& links = ninf(src)->getLinks(); const SpLinkSet& links = ninf(src)->getLinks();
// calculate the product of messages that were sent // calculate the product of messages that were sent
@ -388,7 +388,7 @@ Params
BpSolver::getVar2FactorMsg (const SpLink* link) const BpSolver::getVar2FactorMsg (const SpLink* link) const
{ {
const VarNode* src = link->getVariable(); const VarNode* src = link->getVariable();
const FactorNode* dst = link->getFactor(); const FacNode* dst = link->getFactor();
Params msg; Params msg;
if (src->hasEvidence()) { if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence()); msg.resize (src->range(), LogAware::noEvidence());

View File

@ -16,7 +16,7 @@ using namespace std;
class SpLink class SpLink
{ {
public: public:
SpLink (FactorNode* fn, VarNode* vn) SpLink (FacNode* fn, VarNode* vn)
{ {
fac_ = fn; fac_ = fn;
var_ = vn; var_ = vn;
@ -30,7 +30,7 @@ class SpLink
virtual ~SpLink (void) { }; virtual ~SpLink (void) { };
FactorNode* getFactor (void) const { return fac_; } FacNode* getFactor (void) const { return fac_; }
VarNode* getVariable (void) const { return var_; } VarNode* getVariable (void) const { return var_; }
@ -65,14 +65,14 @@ class SpLink
} }
protected: protected:
FactorNode* fac_; FacNode* fac_;
VarNode* var_; VarNode* var_;
Params v1_; Params v1_;
Params v2_; Params v2_;
Params* currMsg_; Params* currMsg_;
Params* nextMsg_; Params* nextMsg_;
bool msgSended_; bool msgSended_;
double residual_; double residual_;
}; };
typedef vector<SpLink*> SpLinkSet; typedef vector<SpLink*> SpLinkSet;
@ -121,7 +121,7 @@ class BpSolver : public Solver
return varsI_[var->getIndex()]; return varsI_[var->getIndex()];
} }
SPNodeInfo* ninf (const FactorNode* fac) const SPNodeInfo* ninf (const FacNode* fac) const
{ {
return facsI_[fac->getIndex()]; return facsI_[fac->getIndex()];
} }

View File

@ -17,7 +17,7 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg)
varSignatures_.push_back (Signature (c)); varSignatures_.push_back (Signature (c));
} }
const FactorNodes& facNodes = fg.factorNodes(); const FacNodes& facNodes = fg.facNodes();
factorSignatures_.reserve (facNodes.size()); factorSignatures_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned c = facNodes[i]->neighbors().size() + 1; unsigned c = facNodes[i]->neighbors().size() + 1;
@ -70,7 +70,7 @@ CFactorGraph::setInitialColors (void)
setColor (varNodes[i], stateColors[idx]); setColor (varNodes[i], stateColors[idx]);
} }
const FactorNodes& facNodes = groundFg_->factorNodes(); const FacNodes& facNodes = groundFg_->facNodes();
if (checkForIdenticalFactors) { if (checkForIdenticalFactors) {
unsigned groupCount = 1; unsigned groupCount = 1;
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
@ -115,7 +115,7 @@ CFactorGraph::createGroups (void)
unsigned nIters = 0; unsigned nIters = 0;
bool groupsHaveChanged = true; bool groupsHaveChanged = true;
const VarNodes& varNodes = groundFg_->varNodes(); const VarNodes& varNodes = groundFg_->varNodes();
const FactorNodes& facNodes = groundFg_->factorNodes(); const FacNodes& facNodes = groundFg_->facNodes();
while (groupsHaveChanged || nIters == 1) { while (groupsHaveChanged || nIters == 1) {
nIters ++; nIters ++;
@ -127,14 +127,14 @@ CFactorGraph::createGroups (void)
const Signature& signature = getSignature (facNodes[i]); const Signature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = factorGroups.find (signature); FacSignMap::iterator it = factorGroups.find (signature);
if (it == factorGroups.end()) { if (it == factorGroups.end()) {
it = factorGroups.insert (make_pair (signature, FactorNodes())).first; it = factorGroups.insert (make_pair (signature, FacNodes())).first;
} }
it->second.push_back (facNodes[i]); it->second.push_back (facNodes[i]);
} }
for (FacSignMap::iterator it = factorGroups.begin(); for (FacSignMap::iterator it = factorGroups.begin();
it != factorGroups.end(); it++) { it != factorGroups.end(); it++) {
Color newColor = getFreeColor(); Color newColor = getFreeColor();
FactorNodes& groupMembers = it->second; FacNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) { for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor); setColor (groupMembers[i], newColor);
} }
@ -188,7 +188,7 @@ CFactorGraph::createClusters (
facClusters_.reserve (factorGroups.size()); facClusters_.reserve (factorGroups.size());
for (FacSignMap::const_iterator it = factorGroups.begin(); for (FacSignMap::const_iterator it = factorGroups.begin();
it != factorGroups.end(); it++) { it != factorGroups.end(); it++) {
FactorNode* groupFactor = it->second[0]; FacNode* groupFactor = it->second[0];
const VarNodes& neighs = groupFactor->neighbors(); const VarNodes& neighs = groupFactor->neighbors();
VarClusterSet varClusters; VarClusterSet varClusters;
varClusters.reserve (neighs.size()); varClusters.reserve (neighs.size());
@ -207,7 +207,7 @@ CFactorGraph::getSignature (const VarNode* varNode)
{ {
Signature& sign = varSignatures_[varNode->getIndex()]; Signature& sign = varSignatures_[varNode->getIndex()];
vector<Color>::iterator it = sign.colors.begin(); vector<Color>::iterator it = sign.colors.begin();
const FactorNodes& neighs = varNode->neighbors(); const FacNodes& neighs = varNode->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) { for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]); *it = getColor (neighs[i]);
it ++; it ++;
@ -221,7 +221,7 @@ CFactorGraph::getSignature (const VarNode* varNode)
const Signature& const Signature&
CFactorGraph::getSignature (const FactorNode* facNode) CFactorGraph::getSignature (const FacNode* facNode)
{ {
Signature& sign = factorSignatures_[facNode->getIndex()]; Signature& sign = factorSignatures_[facNode->getIndex()];
vector<Color>::iterator it = sign.colors.begin(); vector<Color>::iterator it = sign.colors.begin();
@ -255,12 +255,12 @@ CFactorGraph::getCompressedFactorGraph (void)
VarNode* v = myVarClusters[j]->getRepresentativeVariable(); VarNode* v = myVarClusters[j]->getRepresentativeVariable();
myGroundVars.push_back (v); myGroundVars.push_back (v);
} }
FactorNode* fn = new FactorNode (Factor (myGroundVars, FacNode* fn = new FacNode (Factor (myGroundVars,
facClusters_[i]->getGroundFactors()[0]->params())); facClusters_[i]->getGroundFactors()[0]->params()));
facClusters_[i]->setRepresentativeFactor (fn); facClusters_[i]->setRepresentativeFactor (fn);
fg->addFactorNode (fn); fg->addFacNode (fn);
for (unsigned j = 0; j < myGroundVars.size(); j++) { for (unsigned j = 0; j < myGroundVars.size(); j++) {
fg->addEdge (fn, static_cast<VarNode*> (myGroundVars[j])); fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
} }
} }
fg->setIndexes(); fg->setIndexes();
@ -274,7 +274,7 @@ CFactorGraph::getGroundEdgeCount (
const FacCluster* fc, const FacCluster* fc,
const VarCluster* vc) const const VarCluster* vc) const
{ {
const FactorNodes& clusterGroundFactors = fc->getGroundFactors(); const FacNodes& clusterGroundFactors = fc->getGroundFactors();
VarNode* varNode = vc->getGroundVarNodes()[0]; VarNode* varNode = vc->getGroundVarNodes()[0];
unsigned count = 0; unsigned count = 0;
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) { for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
@ -322,7 +322,7 @@ CFactorGraph::printGroups (
cout << endl << "factor groups:" << endl; cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = factorGroups.begin(); for (FacSignMap::const_iterator it = factorGroups.begin();
it != factorGroups.end(); it++) { it != factorGroups.end(); it++) {
const FactorNodes& groupMembers = it->second; const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) { if (groupMembers.size() > 0) {
cout << ++count << ": " ; cout << ++count << ": " ;
for (unsigned i = 0; i < groupMembers.size(); i++) { for (unsigned i = 0; i < groupMembers.size(); i++) {

View File

@ -26,7 +26,7 @@ typedef vector<VarCluster*> VarClusterSet;
typedef vector<FacCluster*> FacClusterSet; typedef vector<FacCluster*> FacClusterSet;
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap; typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
typedef unordered_map<Signature, FactorNodes, SignatureHash> FacSignMap; typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
@ -118,7 +118,7 @@ class VarCluster
class FacCluster class FacCluster
{ {
public: public:
FacCluster (const FactorNodes& groundFactors, const VarClusterSet& vcs) FacCluster (const FacNodes& groundFactors, const VarClusterSet& vcs)
{ {
groundFactors_ = groundFactors; groundFactors_ = groundFactors;
varClusters_ = vcs; varClusters_ = vcs;
@ -132,7 +132,7 @@ class FacCluster
return varClusters_; return varClusters_;
} }
bool containsGround (const FactorNode* fn) bool containsGround (const FacNode* fn)
{ {
for (unsigned i = 0; i < groundFactors_.size(); i++) { for (unsigned i = 0; i < groundFactors_.size(); i++) {
if (groundFactors_[i] == fn) { if (groundFactors_[i] == fn) {
@ -142,26 +142,26 @@ class FacCluster
return false; return false;
} }
FactorNode* getRepresentativeFactor (void) const FacNode* getRepresentativeFactor (void) const
{ {
return representFactor_; return representFactor_;
} }
void setRepresentativeFactor (FactorNode* fn) void setRepresentativeFactor (FacNode* fn)
{ {
representFactor_ = fn; representFactor_ = fn;
} }
const FactorNodes& getGroundFactors (void) const const FacNodes& getGroundFactors (void) const
{ {
return groundFactors_; return groundFactors_;
} }
private: private:
FactorNodes groundFactors_; FacNodes groundFactors_;
VarClusterSet varClusters_; VarClusterSet varClusters_;
FactorNode* representFactor_; FacNode* representFactor_;
}; };
@ -199,7 +199,7 @@ class CFactorGraph
{ {
return varColors_[vn->getIndex()]; return varColors_[vn->getIndex()];
} }
Color getColor (const FactorNode* fn) const { Color getColor (const FacNode* fn) const {
return factorColors_[fn->getIndex()]; return factorColors_[fn->getIndex()];
} }
@ -208,7 +208,7 @@ class CFactorGraph
varColors_[vn->getIndex()] = c; varColors_[vn->getIndex()] = c;
} }
void setColor (const FactorNode* fn, Color c) void setColor (const FacNode* fn, Color c)
{ {
factorColors_[fn->getIndex()] = c; factorColors_[fn->getIndex()] = c;
} }
@ -218,17 +218,17 @@ class CFactorGraph
return vid2VarCluster_.find (vid)->second; return vid2VarCluster_.find (vid)->second;
} }
void setInitialColors (void); void setInitialColors (void);
void createGroups (void); void createGroups (void);
void createClusters (const VarSignMap&, const FacSignMap&); void createClusters (const VarSignMap&, const FacSignMap&);
const Signature& getSignature (const VarNode*); const Signature& getSignature (const VarNode*);
const Signature& getSignature (const FactorNode*); const Signature& getSignature (const FacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const; void printGroups (const VarSignMap&, const FacSignMap&) const;
Color freeColor_; Color freeColor_;
vector<Color> varColors_; vector<Color> varColors_;

View File

@ -64,11 +64,11 @@ CbpSolver::initializeSolver (void)
unsigned nGroundVars, nGroundFacs, nWithoutNeighs; unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
if (Constants::COLLECT_STATS) { if (Constants::COLLECT_STATS) {
nGroundVars = factorGraph_->varNodes().size(); nGroundVars = factorGraph_->varNodes().size();
nGroundFacs = factorGraph_->factorNodes().size(); nGroundFacs = factorGraph_->facNodes().size();
const VarNodes& vars = factorGraph_->varNodes(); const VarNodes& vars = factorGraph_->varNodes();
nWithoutNeighs = 0; nWithoutNeighs = 0;
for (unsigned i = 0; i < vars.size(); i++) { for (unsigned i = 0; i < vars.size(); i++) {
const FactorNodes& factors = vars[i]->neighbors(); const FacNodes& factors = vars[i]->neighbors();
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) { if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
nWithoutNeighs ++; nWithoutNeighs ++;
} }
@ -84,7 +84,7 @@ CbpSolver::initializeSolver (void)
if (Constants::COLLECT_STATS) { if (Constants::COLLECT_STATS) {
unsigned nClusterVars = factorGraph_->varNodes().size(); unsigned nClusterVars = factorGraph_->varNodes().size();
unsigned nClusterFacs = factorGraph_->factorNodes().size(); unsigned nClusterFacs = factorGraph_->facNodes().size();
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs, Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
nClusterVars, nClusterFacs, nClusterVars, nClusterFacs,
nWithoutNeighs); nWithoutNeighs);
@ -154,7 +154,7 @@ CbpSolver::maxResidualSchedule (void)
linkMap_.find (link)->second = sortedOrder_.insert (link); linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin // update the messages that depend on message source --> destin
const FactorNodes& factorNeighbors = link->getVariable()->neighbors(); const FacNodes& factorNeighbors = link->getVariable()->neighbors();
for (unsigned i = 0; i < factorNeighbors.size(); i++) { for (unsigned i = 0; i < factorNeighbors.size(); i++) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks(); const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
for (unsigned j = 0; j < links.size(); j++) { for (unsigned j = 0; j < links.size(); j++) {
@ -193,7 +193,7 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
{ {
Params msg; Params msg;
const VarNode* src = link->getVariable(); const VarNode* src = link->getVariable();
const FactorNode* dst = link->getFactor(); const FacNode* dst = link->getFactor();
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link); const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
if (src->hasEvidence()) { if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence()); msg.resize (src->range(), LogAware::noEvidence());

View File

@ -9,7 +9,7 @@ class Factor;
class CbpSolverLink : public SpLink class CbpSolverLink : public SpLink
{ {
public: public:
CbpSolverLink (FactorNode* fn, VarNode* vn, unsigned c) : SpLink (fn, vn) CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c) : SpLink (fn, vn)
{ {
edgeCount_ = c; edgeCount_ = c;
poweredMsg_.resize (vn->range(), LogAware::one()); poweredMsg_.resize (vn->range(), LogAware::one());

View File

@ -22,13 +22,13 @@ FactorGraph::FactorGraph (const FactorGraph& fg)
for (unsigned i = 0; i < varNodes.size(); i++) { for (unsigned i = 0; i < varNodes.size(); i++) {
addVarNode (new VarNode (varNodes[i])); addVarNode (new VarNode (varNodes[i]));
} }
const FactorNodes& facNodes = fg.factorNodes(); const FacNodes& facNodes = fg.facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
FactorNode* facNode = new FactorNode (facNodes[i]->factor()); FacNode* facNode = new FacNode (facNodes[i]->factor());
addFactorNode (facNode); addFacNode (facNode);
const VarNodes& neighs = facNodes[i]->neighbors(); const VarNodes& neighs = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) { for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (facNode, varNodes_[neighs[j]->getIndex()]); addEdge (varNodes_[neighs[j]->getIndex()], facNode);
} }
} }
setIndexes(); setIndexes();
@ -39,9 +39,9 @@ FactorGraph::FactorGraph (const FactorGraph& fg)
void void
FactorGraph::readFromUaiFormat (const char* fileName) FactorGraph::readFromUaiFormat (const char* fileName)
{ {
ifstream is (fileName); std::ifstream is (fileName);
if (!is.is_open()) { if (!is.is_open()) {
cerr << "error: cannot read from file " + std::string (fileName) << endl; cerr << "error: cannot read from file " << fileName << endl;
abort(); abort();
} }
ignoreLines (is); ignoreLines (is);
@ -113,9 +113,9 @@ FactorGraph::readFromUaiFormat (const char* fileName)
void void
FactorGraph::readFromLibDaiFormat (const char* fileName) FactorGraph::readFromLibDaiFormat (const char* fileName)
{ {
ifstream is (fileName); std::ifstream is (fileName);
if (!is.is_open()) { if (!is.is_open()) {
cerr << "error: cannot read from file " + std::string (fileName) << endl; cerr << "error: cannot read from file " << fileName << endl;
abort(); abort();
} }
ignoreLines (is); ignoreLines (is);
@ -185,8 +185,8 @@ FactorGraph::~FactorGraph (void)
void void
FactorGraph::addFactor (const Factor& factor) FactorGraph::addFactor (const Factor& factor)
{ {
FactorNode* fn = new FactorNode (factor); FacNode* fn = new FacNode (factor);
addFactorNode (fn); addFacNode (fn);
const VarIds& vids = factor.arguments(); const VarIds& vids = factor.arguments();
for (unsigned i = 0; i < vids.size(); i++) { for (unsigned i = 0; i < vids.size(); i++) {
bool found = false; bool found = false;
@ -217,7 +217,7 @@ FactorGraph::addVarNode (VarNode* vn)
void void
FactorGraph::addFactorNode (FactorNode* fn) FactorGraph::addFacNode (FacNode* fn)
{ {
facNodes_.push_back (fn); facNodes_.push_back (fn);
fn->setIndex (facNodes_.size() - 1); fn->setIndex (facNodes_.size() - 1);
@ -226,7 +226,7 @@ FactorGraph::addFactorNode (FactorNode* fn)
void void
FactorGraph::addEdge (VarNode* vn, FactorNode* fn) FactorGraph::addEdge (VarNode* vn, FacNode* fn)
{ {
vn->addNeighbor (fn); vn->addNeighbor (fn);
fn->addNeighbor (vn); fn->addNeighbor (vn);
@ -234,15 +234,6 @@ FactorGraph::addEdge (VarNode* vn, FactorNode* fn)
void
FactorGraph::addEdge (FactorNode* fn, VarNode* vn)
{
fn->addNeighbor (vn);
vn->addNeighbor (fn);
}
bool bool
FactorGraph::isTree (void) const FactorGraph::isTree (void) const
{ {
@ -345,18 +336,15 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
{ {
ofstream out (fileName); ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "error: cannot open file to write at " ; cerr << "error: cannot open file " << fileName << endl;
cerr << "FactorGraph::exportToUaiFormat()" << endl;
abort(); abort();
} }
out << "MARKOV" << endl; out << "MARKOV" << endl;
out << varNodes_.size() << endl; out << varNodes_.size() << endl;
for (unsigned i = 0; i < varNodes_.size(); i++) { for (unsigned i = 0; i < varNodes_.size(); i++) {
out << varNodes_[i]->range() << " " ; out << varNodes_[i]->range() << " " ;
} }
out << endl; out << endl;
out << facNodes_.size() << endl; out << facNodes_.size() << endl;
for (unsigned i = 0; i < facNodes_.size(); i++) { for (unsigned i = 0; i < facNodes_.size(); i++) {
const VarNodes& factorVars = facNodes_[i]->neighbors(); const VarNodes& factorVars = facNodes_[i]->neighbors();
@ -366,7 +354,6 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
} }
out << endl; out << endl;
} }
for (unsigned i = 0; i < facNodes_.size(); i++) { for (unsigned i = 0; i < facNodes_.size(); i++) {
Params params = facNodes_[i]->params(); Params params = facNodes_[i]->params();
if (Globals::logDomain) { if (Globals::logDomain) {
@ -378,7 +365,6 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
} }
out << endl; out << endl;
} }
out.close(); out.close();
} }
@ -389,8 +375,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
{ {
ofstream out (fileName); ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "error: cannot open file to write at " ; cerr << "error: cannot open file " << fileName << endl;
cerr << "FactorGraph::exportToLibDaiFormat()" << endl;
abort(); abort();
} }
out << facNodes_.size() << endl << endl; out << facNodes_.size() << endl << endl;
@ -452,12 +437,12 @@ FactorGraph::containsCycle (void) const
bool bool
FactorGraph::containsCycle ( FactorGraph::containsCycle (
const VarNode* v, const VarNode* v,
const FactorNode* p, const FacNode* p,
vector<bool>& visitedVars, vector<bool>& visitedVars,
vector<bool>& visitedFactors) const vector<bool>& visitedFactors) const
{ {
visitedVars[v->getIndex()] = true; visitedVars[v->getIndex()] = true;
const FactorNodes& adjacencies = v->neighbors(); const FacNodes& adjacencies = v->neighbors();
for (unsigned i = 0; i < adjacencies.size(); i++) { for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i]->getIndex(); int w = adjacencies[i]->getIndex();
if (!visitedFactors[w]) { if (!visitedFactors[w]) {
@ -476,7 +461,7 @@ FactorGraph::containsCycle (
bool bool
FactorGraph::containsCycle ( FactorGraph::containsCycle (
const FactorNode* v, const FacNode* v,
const VarNode* p, const VarNode* p,
vector<bool>& visitedVars, vector<bool>& visitedVars,
vector<bool>& visitedFactors) const vector<bool>& visitedFactors) const

View File

@ -10,7 +10,7 @@
using namespace std; using namespace std;
class FactorNode; class FacNode;
class VarNode : public Var class VarNode : public Var
@ -20,21 +20,21 @@ class VarNode : public Var
VarNode (const Var* v) : Var (v) { } VarNode (const Var* v) : Var (v) { }
void addNeighbor (FactorNode* fn) { neighs_.push_back (fn); } void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
const FactorNodes& neighbors (void) const { return neighs_; } const FacNodes& neighbors (void) const { return neighs_; }
private: private:
DISALLOW_COPY_AND_ASSIGN (VarNode); DISALLOW_COPY_AND_ASSIGN (VarNode);
FactorNodes neighs_; FacNodes neighs_;
}; };
class FactorNode class FacNode
{ {
public: public:
FactorNode (const Factor& f) : factor_(f), index_(-1) { } FacNode (const Factor& f) : factor_(f), index_(-1) { }
const Factor& factor (void) const { return factor_; } const Factor& factor (void) const { return factor_; }
@ -53,7 +53,7 @@ class FactorNode
string getLabel (void) { return factor_.getLabel(); } string getLabel (void) { return factor_.getLabel(); }
private: private:
DISALLOW_COPY_AND_ASSIGN (FactorNode); DISALLOW_COPY_AND_ASSIGN (FacNode);
Factor factor_; Factor factor_;
VarNodes neighs_; VarNodes neighs_;
@ -81,7 +81,7 @@ class FactorGraph
const VarNodes& varNodes (void) const { return varNodes_; } const VarNodes& varNodes (void) const { return varNodes_; }
const FactorNodes& factorNodes (void) const { return facNodes_; } const FacNodes& facNodes (void) const { return facNodes_; }
void setFromBayesNetwork (void) { fromBayesNet_ = true; } void setFromBayesNetwork (void) { fromBayesNet_ = true; }
@ -101,11 +101,9 @@ class FactorGraph
void addVarNode (VarNode*); void addVarNode (VarNode*);
void addFactorNode (FactorNode*); void addFacNode (FacNode*);
void addEdge (VarNode*, FactorNode*); void addEdge (VarNode*, FacNode*);
void addEdge (FactorNode*, VarNode*);
bool isTree (void) const; bool isTree (void) const;
@ -130,14 +128,14 @@ class FactorGraph
bool containsCycle (void) const; bool containsCycle (void) const;
bool containsCycle (const VarNode*, const FactorNode*, bool containsCycle (const VarNode*, const FacNode*,
vector<bool>&, vector<bool>&) const; vector<bool>&, vector<bool>&) const;
bool containsCycle (const FactorNode*, const VarNode*, bool containsCycle (const FacNode*, const VarNode*,
vector<bool>&, vector<bool>&) const; vector<bool>&, vector<bool>&) const;
VarNodes varNodes_; VarNodes varNodes_;
FactorNodes facNodes_; FacNodes facNodes_;
bool fromBayesNet_; bool fromBayesNet_;
DAGraph structure_; DAGraph structure_;

View File

@ -14,14 +14,14 @@ using namespace std;
class Var; class Var;
class Factor; class Factor;
class VarNode; class VarNode;
class FactorNode; class FacNode;
typedef vector<double> Params; typedef vector<double> Params;
typedef unsigned VarId; typedef unsigned VarId;
typedef vector<VarId> VarIds; typedef vector<VarId> VarIds;
typedef vector<Var*> Vars; typedef vector<Var*> Vars;
typedef vector<VarNode*> VarNodes; typedef vector<VarNode*> VarNodes;
typedef vector<FactorNode*> FactorNodes; typedef vector<FacNode*> FacNodes;
typedef vector<Factor*> Factors; typedef vector<Factor*> Factors;
typedef vector<string> States; typedef vector<string> States;
typedef vector<unsigned> Ranges; typedef vector<unsigned> Ranges;

View File

@ -57,11 +57,11 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
void void
VarElimSolver::createFactorList (void) VarElimSolver::createFactorList (void)
{ {
const FactorNodes& factorNodes = factorGraph_->factorNodes(); const FacNodes& facNodes = factorGraph_->facNodes();
factorList_.reserve (factorNodes.size() * 2); factorList_.reserve (facNodes.size() * 2);
for (unsigned i = 0; i < factorNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
factorList_.push_back (new Factor (factorNodes[i]->factor())); // FIXME factorList_.push_back (new Factor (facNodes[i]->factor())); // FIXME
const VarNodes& neighs = factorNodes[i]->neighbors(); const VarNodes& neighs = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) { for (unsigned j = 0; j < neighs.size(); j++) {
unordered_map<VarId,vector<unsigned> >::iterator it unordered_map<VarId,vector<unsigned> >::iterator it
= varFactors_.find (neighs[j]->varId()); = varFactors_.find (neighs[j]->varId());