some refactorings

This commit is contained in:
Tiago Gomes 2012-04-16 21:42:14 +01:00
parent 0d5888e238
commit 66353e67ec
2 changed files with 93 additions and 113 deletions

View File

@ -1,4 +1,3 @@
#include "CFactorGraph.h" #include "CFactorGraph.h"
#include "Factor.h" #include "Factor.h"
@ -6,26 +5,23 @@
bool CFactorGraph::checkForIdenticalFactors = true; bool CFactorGraph::checkForIdenticalFactors = true;
CFactorGraph::CFactorGraph (const FactorGraph& fg) CFactorGraph::CFactorGraph (const FactorGraph& fg)
: freeColor_(0), groundFg_(&fg)
{ {
groundFg_ = &fg;
freeColor_ = 0;
const VarNodes& varNodes = fg.varNodes(); const VarNodes& varNodes = fg.varNodes();
varSignatures_.reserve (varNodes.size()); varSignatures_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) { for (unsigned i = 0; i < varNodes.size(); i++) {
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1; unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
varSignatures_.push_back (Signature (c)); varSignatures_.push_back (Signature (c));
} }
const FacNodes& facNodes = fg.facNodes(); const FacNodes& facNodes = fg.facNodes();
facSignatures_.reserve (facNodes.size()); facSignatures_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned c = facNodes[i]->neighbors().size() + 1; unsigned c = facNodes[i]->neighbors().size() + 1;
facSignatures_.push_back (Signature (c)); facSignatures_.push_back (Signature (c));
} }
varColors_.resize (varNodes.size()); varColors_.resize (varNodes.size());
facColors_.resize (facNodes.size()); facColors_.resize (facNodes.size());
findIdenticalFactors();
setInitialColors(); setInitialColors();
createGroups(); createGroups();
} }
@ -44,6 +40,40 @@ CFactorGraph::~CFactorGraph (void)
void
CFactorGraph::findIdenticalFactors()
{
if (checkForIdenticalFactors == false) {
return;
}
const FacNodes& facNodes = groundFg_->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
facNodes[i]->factor().setDistId (Util::maxUnsigned());
}
unsigned groupCount = 1;
for (unsigned i = 0; i < facNodes.size(); i++) {
Factor& f1 = facNodes[i]->factor();
if (f1.distId() != Util::maxUnsigned()) {
continue;
}
f1.setDistId (groupCount);
for (unsigned j = i + 1; j < facNodes.size(); j++) {
Factor& f2 = facNodes[j]->factor();
if (f2.distId() != Util::maxUnsigned()) {
continue;
}
if (f1.size() == f2.size() &&
f1.ranges() == f2.ranges() &&
f1.params() == f2.params()) {
f2.setDistId (groupCount);
}
}
groupCount ++;
}
}
void void
CFactorGraph::setInitialColors (void) CFactorGraph::setInitialColors (void)
{ {
@ -69,34 +99,7 @@ CFactorGraph::setInitialColors (void)
} }
setColor (varNodes[i], stateColors[idx]); setColor (varNodes[i], stateColors[idx]);
} }
const FacNodes& facNodes = groundFg_->facNodes(); const FacNodes& facNodes = groundFg_->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
facNodes[i]->factor().setDistId (Util::maxUnsigned());
}
// FIXME FIXME FIXME : pfl should give correct dist ids.
if (checkForIdenticalFactors || true) {
unsigned groupCount = 1;
for (unsigned i = 0; i < facNodes.size(); i++) {
Factor& f1 = facNodes[i]->factor();
if (f1.distId() != Util::maxUnsigned()) {
continue;
}
f1.setDistId (groupCount);
for (unsigned j = i + 1; j < facNodes.size(); j++) {
Factor& f2 = facNodes[j]->factor();
if (f2.distId() != Util::maxUnsigned()) {
continue;
}
if (f1.size() == f2.size() &&
f1.ranges() == f2.ranges() &&
f1.params() == f2.params()) {
f2.setDistId (groupCount);
}
}
groupCount ++;
}
}
// create the initial factor colors // create the initial factor colors
DistColorMap distColors; DistColorMap distColors;
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
@ -245,23 +248,24 @@ CFactorGraph::getGroundFactorGraph (void) const
{ {
FactorGraph* fg = new FactorGraph(); FactorGraph* fg = new FactorGraph();
for (unsigned i = 0; i < varClusters_.size(); i++) { for (unsigned i = 0; i < varClusters_.size(); i++) {
VarNode* var = varClusters_[i]->getGroundVarNodes()[0]; VarNode* newVar = new VarNode (varClusters_[i]->members()[0]);
VarNode* newVar = new VarNode (var); varClusters_[i]->setRepresentative (newVar);
varClusters_[i]->setRepresentativeVariable (newVar);
fg->addVarNode (newVar); fg->addVarNode (newVar);
} }
for (unsigned i = 0; i < facClusters_.size(); i++) { for (unsigned i = 0; i < facClusters_.size(); i++) {
const VarClusters& myVarClusters = facClusters_[i]->getVarClusters(); const VarClusters& myVarClusters = facClusters_[i]->varClusters();
Vars myGroundVars; Vars myGroundVars;
myGroundVars.reserve (myVarClusters.size()); myGroundVars.reserve (myVarClusters.size());
for (unsigned j = 0; j < myVarClusters.size(); j++) { for (unsigned j = 0; j < myVarClusters.size(); j++) {
VarNode* v = myVarClusters[j]->getRepresentativeVariable(); VarNode* v = myVarClusters[j]->getRepresentative();
myGroundVars.push_back (v); myGroundVars.push_back (v);
} }
FacNode* fn = new FacNode (Factor (myGroundVars, FacNode* fn = new FacNode (Factor (
facClusters_[i]->getGroundFactors()[0]->factor().params())); myGroundVars,
facClusters_[i]->setRepresentativeFactor (fn); facClusters_[i]->members()[0]->factor().params(),
facClusters_[i]->members()[0]->factor().distId()));
facClusters_[i]->setRepresentative (fn);
fg->addFacNode (fn); fg->addFacNode (fn);
for (unsigned j = 0; j < myGroundVars.size(); j++) { for (unsigned j = 0; j < myGroundVars.size(); j++) {
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn); fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
@ -278,24 +282,26 @@ CFactorGraph::getEdgeCount (
const VarCluster* vc) const const VarCluster* vc) const
{ {
unsigned count = 0; unsigned count = 0;
VarId vid = vc->getGroundVarNodes().front()->varId(); VarId vid = vc->members().front()->varId();
const FacNodes& clusterGroundFactors = fc->getGroundFactors(); const FacNodes& members = fc->members();
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) { for (unsigned i = 0; i < members.size(); i++) {
if (clusterGroundFactors[i]->factor().contains (vid)) { if (members[i]->factor().contains (vid)) {
count ++; count ++;
} }
} }
// CVarNodes vars = vc->getGroundVarNodes(); if (Constants::DEBUG > 0) {
// for (unsigned i = 1; i < vars.size(); i++) { const VarNodes& vars = vc->members();
// VarNode* var = vc->getGroundVarNodes()[i]; for (unsigned i = 1; i < vars.size(); i++) {
// unsigned count2 = 0; VarId vid = vars[i]->varId();
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) { unsigned count2 = 0;
// if (clusterGroundFactors[i]->getPosition (var) != -1) { for (unsigned i = 0; i < members.size(); i++) {
// count2 ++; if (members[i]->factor().contains (vid)) {
// } count2 ++;
// } }
// if (count != count2) { cout << "oops!" << endl; abort(); } }
// } assert (count == count2);
}
}
return count; return count;
} }

View File

@ -5,6 +5,7 @@
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Factor.h" #include "Factor.h"
#include "Util.h"
#include "Horus.h" #include "Horus.h"
class VarCluster; class VarCluster;
@ -84,83 +85,54 @@ struct SignatureHash
class VarCluster class VarCluster
{ {
public: public:
VarCluster (const VarNodes& vs) VarCluster (const VarNodes& vs) : members_(vs) { }
{
for (unsigned i = 0; i < vs.size(); i++) {
groundVars_.push_back (vs[i]);
}
}
void addFacCluster (FacCluster* fc) const VarNodes& members (void) const { return members_; }
{
facClusters_.push_back (fc);
}
const FacClusters& getFacClusters (void) const const FacClusters& facClusters (void) const { return facClusters_; }
{
return facClusters_;
}
VarNode* getRepresentativeVariable (void) const { return representVar_; } void addFacCluster (FacCluster* fc) { facClusters_.push_back (fc); }
void setRepresentativeVariable (VarNode* v) { representVar_ = v; } VarNode* getRepresentative (void) const { return repr_; }
const VarNodes& getGroundVarNodes (void) const { return groundVars_; } void setRepresentative (VarNode* vn) { repr_ = vn; }
private: private:
VarNodes groundVars_; VarNodes members_;
FacClusters facClusters_; FacClusters facClusters_;
VarNode* representVar_; VarNode* repr_;
}; };
class FacCluster class FacCluster
{ {
public: public:
FacCluster (const FacNodes& groundFactors, const VarClusters& vcs) FacCluster (const FacNodes& fcs, const VarClusters& vcs)
: members_(fcs), varClusters_(vcs)
{ {
groundFactors_ = groundFactors;
varClusters_ = vcs;
for (unsigned i = 0; i < varClusters_.size(); i++) { for (unsigned i = 0; i < varClusters_.size(); i++) {
varClusters_[i]->addFacCluster (this); varClusters_[i]->addFacCluster (this);
} }
} }
const VarClusters& getVarClusters (void) const const FacNodes& members (void) const { return members_; }
{
return varClusters_;
}
bool containsGround (const FacNode* fn) const VarClusters& varClusters (void) const { return varClusters_; }
{
for (unsigned i = 0; i < groundFactors_.size(); i++) {
if (groundFactors_[i] == fn) {
return true;
}
}
return false;
}
FacNode* getRepresentativeFactor (void) const FacNode* getRepresentative (void) const { return repr_; }
{
return representFactor_;
}
void setRepresentativeFactor (FacNode* fn) void setRepresentative (FacNode* fn) { repr_ = fn; }
{
representFactor_ = fn;
}
const FacNodes& getGroundFactors (void) const bool containsGround (const FacNode* fn) const
{ {
return groundFactors_; return std::find (members_.begin(), members_.end(), fn)
!= members_.end();
} }
private: private:
FacNodes groundFactors_; FacNodes members_;
VarClusters varClusters_; VarClusters varClusters_;
FacNode* representFactor_; FacNode* repr_;
}; };
@ -171,14 +143,14 @@ class CFactorGraph
~CFactorGraph (void); ~CFactorGraph (void);
const VarClusters& getVarClusters (void) { return varClusters_; } const VarClusters& varClusters (void) { return varClusters_; }
const FacClusters& getFacClusters (void) { return facClusters_; } const FacClusters& facClusters (void) { return facClusters_; }
VarNode* getEquivalentVariable (VarId vid) VarNode* getEquivalentVariable (VarId vid)
{ {
VarCluster* vc = vid2VarCluster_.find (vid)->second; VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->getRepresentativeVariable(); return vc->getRepresentative();
} }
FactorGraph* getGroundFactorGraph (void) const; FactorGraph* getGroundFactorGraph (void) const;
@ -217,6 +189,8 @@ class CFactorGraph
return vid2VarCluster_.find (vid)->second; return vid2VarCluster_.find (vid)->second;
} }
void findIdenticalFactors (void);
void setInitialColors (void); void setInitialColors (void);
void createGroups (void); void createGroups (void);