improve cbp by supporting factors connected to a single var two or more times

This commit is contained in:
Tiago Gomes 2012-04-26 00:54:06 +01:00
parent ad24a360ce
commit 689244a0d8
8 changed files with 252 additions and 226 deletions

View File

@ -325,7 +325,6 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
} }
} }
} }
Factor result (src->factor().arguments(), Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct); src->factor().ranges(), msgProduct);
result.multiply (src->factor()); result.multiply (src->factor());
@ -336,18 +335,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
} }
result.sumOutAllExcept (dst->varId()); result.sumOutAllExcept (dst->varId());
if (Constants::DEBUG >= 5) { if (Constants::DEBUG >= 5) {
cout << " marginalized: " ; cout << " marginalized: " << result.params() << endl;
cout << result.params() << endl;
} }
const Params& resultParams = result.params(); link->getNextMessage() = result.params();
Params& message = link->getNextMessage(); LogAware::normalize (link->getNextMessage());
for (unsigned i = 0; i < resultParams.size(); i++) {
message[i] = resultParams[i];
}
LogAware::normalize (message);
if (Constants::DEBUG >= 5) { if (Constants::DEBUG >= 5) {
cout << " curr msg: " << link->getMessage() << endl; cout << " curr msg: " << link->getMessage() << endl;
cout << " next msg: " << message << endl; cout << " next msg: " << link->getNextMessage() << endl;
} }
} }

View File

@ -183,7 +183,7 @@ class BpSolver : public Solver
bool converged (void); bool converged (void);
void printLinkInformation (void) const; virtual void printLinkInformation (void) const;
}; };
#endif // HORUS_BPSOLVER_H #endif // HORUS_BPSOLVER_H

View File

@ -7,20 +7,6 @@ bool CFactorGraph::checkForIdenticalFactors = true;
CFactorGraph::CFactorGraph (const FactorGraph& fg) CFactorGraph::CFactorGraph (const FactorGraph& fg)
: freeColor_(0), groundFg_(&fg) : freeColor_(0), groundFg_(&fg)
{ {
const VarNodes& varNodes = fg.varNodes();
varSignatures_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
varSignatures_.push_back (Signature (c));
}
const FacNodes& facNodes = fg.facNodes();
facSignatures_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned c = facNodes[i]->neighbors().size() + 1;
facSignatures_.push_back (Signature (c));
}
varColors_.resize (varNodes.size());
facColors_.resize (facNodes.size());
findIdenticalFactors(); findIdenticalFactors();
setInitialColors(); setInitialColors();
createGroups(); createGroups();
@ -77,6 +63,9 @@ CFactorGraph::findIdenticalFactors()
void void
CFactorGraph::setInitialColors (void) CFactorGraph::setInitialColors (void)
{ {
varColors_.resize (groundFg_->nrVarNodes());
facColors_.resize (groundFg_->nrFacNodes());
// create the initial variable colors // create the initial variable colors
VarColorMap colorMap; VarColorMap colorMap;
const VarNodes& varNodes = groundFg_->varNodes(); const VarNodes& varNodes = groundFg_->varNodes();
@ -127,31 +116,11 @@ CFactorGraph::createGroups (void)
while (groupsHaveChanged || nIters == 1) { while (groupsHaveChanged || nIters == 1) {
nIters ++; nIters ++;
unsigned prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (unsigned i = 0; i < facNodes.size(); i++) {
const Signature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); it++) {
Color newColor = getFreeColor();
FacNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
// set a new color to the variables with the same signature // set a new color to the variables with the same signature
unsigned prevVarGroupsSize = varGroups.size(); unsigned prevVarGroupsSize = varGroups.size();
varGroups.clear(); varGroups.clear();
for (unsigned i = 0; i < varNodes.size(); i++) { for (unsigned i = 0; i < varNodes.size(); i++) {
const Signature& signature = getSignature (varNodes[i]); const VarSignature& signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature); VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) { if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signature, VarNodes())).first; it = varGroups.insert (make_pair (signature, VarNodes())).first;
@ -167,6 +136,26 @@ CFactorGraph::createGroups (void)
} }
} }
unsigned prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (unsigned i = 0; i < facNodes.size(); i++) {
const FacSignature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); it++) {
Color newColor = getFreeColor();
FacNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
groupsHaveChanged = prevVarGroupsSize != varGroups.size() groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != facGroups.size(); || prevFactorGroupsSize != facGroups.size();
} }
@ -209,67 +198,61 @@ CFactorGraph::createClusters (
const Signature& VarSignature
CFactorGraph::getSignature (const VarNode* varNode) CFactorGraph::getSignature (const VarNode* varNode)
{ {
Signature& sign = varSignatures_[varNode->getIndex()];
Colors::iterator it = sign.colors.begin();
const FacNodes& neighs = varNode->neighbors(); const FacNodes& neighs = varNode->neighbors();
VarSignature sign;
sign.reserve (neighs.size() + 1);
for (unsigned i = 0; i < neighs.size(); i++) { for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]); sign.push_back (make_pair (
it ++; getColor (neighs[i]),
*it = neighs[i]->factor().indexOf (varNode->varId()); neighs[i]->factor().indexOf (varNode->varId())));
it ++;
} }
*it = getColor (varNode); std::sort (sign.begin(), sign.end());
sign.push_back (make_pair (getColor (varNode), 0));
return sign; return sign;
} }
const Signature& FacSignature
CFactorGraph::getSignature (const FacNode* facNode) CFactorGraph::getSignature (const FacNode* facNode)
{ {
Signature& sign = facSignatures_[facNode->getIndex()];
Colors::iterator it = sign.colors.begin();
const VarNodes& neighs = facNode->neighbors(); const VarNodes& neighs = facNode->neighbors();
FacSignature sign;
sign.reserve (neighs.size() + 1);
for (unsigned i = 0; i < neighs.size(); i++) { for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]); sign.push_back (getColor (neighs[i]));
it ++;
} }
std::sort (sign.colors.begin(), -- sign.colors.end()); sign.push_back (getColor (facNode));
*it = getColor (facNode);
return sign; return sign;
} }
FactorGraph* FactorGraph*
CFactorGraph::getGroundFactorGraph (void) const CFactorGraph::getGroundFactorGraph (void)
{ {
FactorGraph* fg = new FactorGraph(); FactorGraph* fg = new FactorGraph();
for (unsigned i = 0; i < varClusters_.size(); i++) { for (unsigned i = 0; i < varClusters_.size(); i++) {
VarNode* newVar = new VarNode (varClusters_[i]->members()[0]); VarNode* newVar = new VarNode (varClusters_[i]->first());
varClusters_[i]->setRepresentative (newVar); varClusters_[i]->setRepresentative (newVar);
fg->addVarNode (newVar); fg->addVarNode (newVar);
} }
for (unsigned i = 0; i < facClusters_.size(); i++) { for (unsigned i = 0; i < facClusters_.size(); i++) {
const VarClusters& myVarClusters = facClusters_[i]->varClusters(); Vars vars;
Vars myGroundVars; const VarClusters& clusters = facClusters_[i]->varClusters();
myGroundVars.reserve (myVarClusters.size()); for (unsigned j = 0; j < clusters.size(); j++) {
for (unsigned j = 0; j < myVarClusters.size(); j++) { vars.push_back (clusters[j]->representative());
VarNode* v = myVarClusters[j]->getRepresentative();
myGroundVars.push_back (v);
} }
const Factor& groundFac = facClusters_[i]->first()->factor();
FacNode* fn = new FacNode (Factor ( FacNode* fn = new FacNode (Factor (
myGroundVars, vars, groundFac.params(), groundFac.distId()));
facClusters_[i]->members()[0]->factor().params(),
facClusters_[i]->members()[0]->factor().distId()));
facClusters_[i]->setRepresentative (fn); facClusters_[i]->setRepresentative (fn);
fg->addFacNode (fn); fg->addFacNode (fn);
for (unsigned j = 0; j < myGroundVars.size(); j++) { for (unsigned j = 0; j < vars.size(); j++) {
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn); fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
} }
} }
return fg; return fg;
@ -280,29 +263,21 @@ CFactorGraph::getGroundFactorGraph (void) const
unsigned unsigned
CFactorGraph::getEdgeCount ( CFactorGraph::getEdgeCount (
const FacCluster* fc, const FacCluster* fc,
const VarCluster* vc) const const VarCluster* vc,
unsigned index) const
{ {
unsigned count = 0; unsigned count = 0;
VarId vid = vc->members().front()->varId(); VarId reprVid = vc->representative()->varId();
const FacNodes& members = fc->members(); VarNode* groundVar = groundFg_->getVarNode (reprVid);
for (unsigned i = 0; i < members.size(); i++) { const FacNodes& neighs = groundVar->neighbors();
if (members[i]->factor().contains (vid)) { for (unsigned i = 0; i < neighs.size(); i++) {
FacNodes::const_iterator it;
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
if (it != fc->members().end() &&
(*it)->factor().indexOf (reprVid) == (int)index) {
count ++; count ++;
} }
} }
if (Constants::DEBUG > 0) {
const VarNodes& vars = vc->members();
for (unsigned i = 1; i < vars.size(); i++) {
VarId vid = vars[i]->varId();
unsigned count2 = 0;
for (unsigned i = 0; i < members.size(); i++) {
if (members[i]->factor().contains (vid)) {
count2 ++;
}
}
assert (count == count2);
}
}
return count; return count;
} }
@ -327,7 +302,6 @@ CFactorGraph::printGroups (
cout << endl; cout << endl;
} }
} }
count = 1; count = 1;
cout << endl << "factor groups:" << endl; cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = facGroups.begin(); for (FacSignMap::const_iterator it = facGroups.begin();

View File

@ -10,17 +10,19 @@
class VarCluster; class VarCluster;
class FacCluster; class FacCluster;
class Signature; class VarSignatureHash;
class SignatureHash; class FacSignatureHash;
typedef long Color; typedef long Color;
typedef vector<Color> Colors; typedef vector<Color> Colors;
typedef vector<std::pair<Color,unsigned>> VarSignature;
typedef vector<Color> FacSignature;
typedef unordered_map<unsigned, Color> DistColorMap; typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<unsigned, Colors> VarColorMap; typedef unordered_map<unsigned, Colors> VarColorMap;
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap; typedef unordered_map<VarSignature, VarNodes, VarSignatureHash> VarSignMap;
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap; typedef unordered_map<FacSignature, FacNodes, FacSignatureHash> FacSignMap;
typedef vector<VarCluster*> VarClusters; typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters; typedef vector<FacCluster*> FacClusters;
@ -28,53 +30,27 @@ typedef vector<FacCluster*> FacClusters;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster; typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
struct Signature struct VarSignatureHash
{ {
Signature (unsigned size) : colors(size) { } size_t operator() (const VarSignature &sig) const
bool operator< (const Signature& sig) const
{ {
if (colors.size() < sig.colors.size()) { size_t val = hash<size_t>()(sig.size());
return true; for (unsigned i = 0; i < sig.size(); i++) {
} else if (colors.size() > sig.colors.size()) { val ^= hash<size_t>()(sig[i].first);
return false; val ^= hash<size_t>()(sig[i].second);
} else {
for (unsigned i = 0; i < colors.size(); i++) {
if (colors[i] < sig.colors[i]) {
return true;
} else if (colors[i] > sig.colors[i]) {
return false;
} }
return val;
} }
}
return false;
}
bool operator== (const Signature& sig) const
{
if (colors.size() != sig.colors.size()) {
return false;
}
for (unsigned i = 0; i < colors.size(); i++) {
if (colors[i] != sig.colors[i]) {
return false;
}
}
return true;
}
Colors colors;
}; };
struct FacSignatureHash
struct SignatureHash
{ {
size_t operator() (const Signature &sig) const size_t operator() (const FacSignature &sig) const
{ {
size_t val = hash<size_t>()(sig.colors.size()); size_t val = hash<size_t>()(sig.size());
for (unsigned i = 0; i < sig.colors.size(); i++) { for (unsigned i = 0; i < sig.size(); i++) {
val ^= hash<size_t>()(sig.colors[i]); val ^= hash<size_t>()(sig[i]);
} }
return val; return val;
} }
@ -87,19 +63,16 @@ class VarCluster
public: public:
VarCluster (const VarNodes& vs) : members_(vs) { } VarCluster (const VarNodes& vs) : members_(vs) { }
const VarNode* first (void) const { return members_.front(); }
const VarNodes& members (void) const { return members_; } const VarNodes& members (void) const { return members_; }
const FacClusters& facClusters (void) const { return facClusters_; } VarNode* representative (void) const { return repr_; }
void addFacCluster (FacCluster* fc) { facClusters_.push_back (fc); }
VarNode* getRepresentative (void) const { return repr_; }
void setRepresentative (VarNode* vn) { repr_ = vn; } void setRepresentative (VarNode* vn) { repr_ = vn; }
private: private:
VarNodes members_; VarNodes members_;
FacClusters facClusters_;
VarNode* repr_; VarNode* repr_;
}; };
@ -108,27 +81,18 @@ class FacCluster
{ {
public: public:
FacCluster (const FacNodes& fcs, const VarClusters& vcs) FacCluster (const FacNodes& fcs, const VarClusters& vcs)
: members_(fcs), varClusters_(vcs) : members_(fcs), varClusters_(vcs) { }
{
for (unsigned i = 0; i < varClusters_.size(); i++) { const FacNode* first (void) const { return members_.front(); }
varClusters_[i]->addFacCluster (this);
}
}
const FacNodes& members (void) const { return members_; } const FacNodes& members (void) const { return members_; }
const VarClusters& varClusters (void) const { return varClusters_; } VarClusters& varClusters (void) { return varClusters_; }
FacNode* getRepresentative (void) const { return repr_; } FacNode* representative (void) const { return repr_; }
void setRepresentative (FacNode* fn) { repr_ = fn; } void setRepresentative (FacNode* fn) { repr_ = fn; }
bool containsGround (const FacNode* fn) const
{
return std::find (members_.begin(), members_.end(), fn)
!= members_.end();
}
private: private:
FacNodes members_; FacNodes members_;
VarClusters varClusters_; VarClusters varClusters_;
@ -147,15 +111,16 @@ class CFactorGraph
const FacClusters& facClusters (void) { return facClusters_; } const FacClusters& facClusters (void) { return facClusters_; }
VarNode* getEquivalentVariable (VarId vid) VarNode* getEquivalent (VarId vid)
{ {
VarCluster* vc = vid2VarCluster_.find (vid)->second; VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->getRepresentative(); return vc->representative();
} }
FactorGraph* getGroundFactorGraph (void) const; FactorGraph* getGroundFactorGraph (void);
unsigned getEdgeCount (const FacCluster*, const VarCluster*) const; unsigned getEdgeCount (const FacCluster*,
const VarCluster*, unsigned index) const;
static bool checkForIdenticalFactors; static bool checkForIdenticalFactors;
@ -184,11 +149,6 @@ class CFactorGraph
facColors_[fn->getIndex()] = c; facColors_[fn->getIndex()] = c;
} }
VarCluster* getVariableCluster (VarId vid) const
{
return vid2VarCluster_.find (vid)->second;
}
void findIdenticalFactors (void); void findIdenticalFactors (void);
void setInitialColors (void); void setInitialColors (void);
@ -197,17 +157,15 @@ class CFactorGraph
void createClusters (const VarSignMap&, const FacSignMap&); void createClusters (const VarSignMap&, const FacSignMap&);
const Signature& getSignature (const VarNode*); VarSignature getSignature (const VarNode*);
const Signature& getSignature (const FacNode*); FacSignature getSignature (const FacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const; void printGroups (const VarSignMap&, const FacSignMap&) const;
Color freeColor_; Color freeColor_;
Colors varColors_; Colors varColors_;
Colors facColors_; Colors facColors_;
vector<Signature> varSignatures_;
vector<Signature> facSignatures_;
VarClusters varClusters_; VarClusters varClusters_;
FacClusters facClusters_; FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_; VarId2VarCluster vid2VarCluster_;

View File

@ -1,5 +1,6 @@
#include "CbpSolver.h" #include "CbpSolver.h"
vector<int> CbpSolver::counts;
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg) CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
{ {
@ -24,16 +25,14 @@ CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
Statistics::updateCompressingStatistics (nrGroundVars, Statistics::updateCompressingStatistics (nrGroundVars,
nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless); nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless);
} }
// cout << "uncompressed factor graph:" << endl; if (Constants::DEBUG >= 5) {
// cout << " " << fg.nrVarNodes() << " variables " << endl; cout << "uncompressed factor graph:" << endl;
// cout << " " << fg.nrFacNodes() << " factors " << endl; cout << " " << fg.nrVarNodes() << " variables " << endl;
// cout << "compressed factor graph:" << endl; cout << " " << fg.nrFacNodes() << " factors " << endl;
// cout << " " << fg_->nrVarNodes() << " variables " << endl; cout << "compressed factor graph:" << endl;
// cout << " " << fg_->nrFacNodes() << " factors " << endl; cout << " " << fg_->nrVarNodes() << " variables " << endl;
// Util::printHeader ("Compressed Factor Graph"); cout << " " << fg_->nrFacNodes() << " factors " << endl;
// fg_->print(); }
// Util::printHeader ("Uncompressed Factor Graph");
// fg.print();
} }
@ -81,8 +80,8 @@ CbpSolver::getPosterioriOf (VarId vid)
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
} }
assert (cfg_->getEquivalentVariable (vid)); assert (cfg_->getEquivalent (vid));
VarNode* var = cfg_->getEquivalentVariable (vid); VarNode* var = cfg_->getEquivalent (vid);
Params probs; Params probs;
if (var->hasEvidence()) { if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence()); probs.resize (var->range(), LogAware::noEvidence());
@ -115,7 +114,7 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVids)
{ {
VarIds eqVarIds; VarIds eqVarIds;
for (unsigned i = 0; i < jointVids.size(); i++) { for (unsigned i = 0; i < jointVids.size(); i++) {
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]); VarNode* vn = cfg_->getEquivalent (jointVids[i]);
eqVarIds.push_back (vn->varId()); eqVarIds.push_back (vn->varId());
} }
return BpSolver::getJointDistributionOf (eqVarIds); return BpSolver::getJointDistributionOf (eqVarIds);
@ -130,10 +129,15 @@ CbpSolver::createLinks (void)
for (unsigned i = 0; i < fcs.size(); i++) { for (unsigned i = 0; i < fcs.size(); i++) {
const VarClusters& vcs = fcs[i]->varClusters(); const VarClusters& vcs = fcs[i]->varClusters();
for (unsigned j = 0; j < vcs.size(); j++) { for (unsigned j = 0; j < vcs.size(); j++) {
unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]); unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j);
if (Constants::DEBUG >= 5) {
cout << "creating edge " ;
cout << fcs[i]->representative()->getLabel() << " -> " ;
cout << vcs[j]->representative()->label();
cout << " idx=" << j << ", count=" << count << endl;
}
links_.push_back (new CbpSolverLink ( links_.push_back (new CbpSolverLink (
fcs[i]->getRepresentative(), fcs[i]->representative(), vcs[j]->representative(), j, count));
vcs[j]->getRepresentative(), c));
} }
} }
} }
@ -213,47 +217,135 @@ CbpSolver::maxResidualSchedule (void)
Params void
CbpSolver::getVar2FactorMsg (const SpLink* link) const CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
{ {
Params msg; CbpSolverLink* link = static_cast<CbpSolverLink*> (_link);
FacNode* src = link->getFactor();
const VarNode* dst = link->getVariable();
const SpLinkSet& links = ninf(src)->getLinks();
// calculate the product of messages that were sent
// to factor `src', except from var `dst'
unsigned msgSize = 1;
for (unsigned i = 0; i < links.size(); i++) {
msgSize *= links[i]->getVariable()->range();
}
unsigned repetitions = 1;
Params msgProduct (msgSize, LogAware::multIdenty());
if (Globals::logDomain) {
for (int i = links.size() - 1; i >= 0; i--) {
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
if (Constants::DEBUG >= 5) {
cout << " message from " << links[i]->getVariable()->label();
cout << ": " ;
}
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
if (Constants::DEBUG >= 5) {
cout << endl;
}
} else {
unsigned range = links[i]->getVariable()->range();
Util::add (msgProduct, Params (range, 0.0), repetitions);
repetitions *= range;
}
}
} else {
for (int i = links.size() - 1; i >= 0; i--) {
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
if (Constants::DEBUG >= 5) {
cout << " message from " << links[i]->getVariable()->label();
cout << ": " ;
}
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
if (Constants::DEBUG >= 5) {
cout << endl;
}
} else {
unsigned range = links[i]->getVariable()->range();
Util::multiply (msgProduct, Params (range, 1.0), repetitions);
repetitions *= range;
}
}
}
Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct);
assert (msgProduct.size() == src->factor().size());
if (Globals::logDomain) {
for (unsigned i = 0; i < result.size(); i++) {
result[i] += src->factor()[i];
}
} else {
for (unsigned i = 0; i < result.size(); i++) {
result[i] *= src->factor()[i];
}
}
if (Constants::DEBUG >= 5) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl;
}
result.sumOutAllExceptIndex (link->index());
if (Constants::DEBUG >= 5) {
cout << " marginalized: " << result.params() << endl;
}
link->getNextMessage() = result.params();
LogAware::normalize (link->getNextMessage());
if (Constants::DEBUG >= 5) {
cout << " curr msg: " << link->getMessage() << endl;
cout << " next msg: " << link->getNextMessage() << endl;
}
}
Params
CbpSolver::getVar2FactorMsg (const SpLink* _link) const
{
const CbpSolverLink* link = static_cast<const CbpSolverLink*> (_link);
const VarNode* src = link->getVariable(); const VarNode* src = link->getVariable();
const FacNode* dst = link->getFactor(); const FacNode* dst = link->getFactor();
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link); Params msg;
if (src->hasEvidence()) { if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence()); msg.resize (src->range(), LogAware::noEvidence());
double value = link->getMessage()[src->getEvidence()]; double value = link->getMessage()[src->getEvidence()];
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1); if (Constants::DEBUG >= 5) {
msg[src->getEvidence()] = value;
cout << msg << "^" << link->nrEdges() << "-1" ;
}
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
} else { } else {
msg = link->getMessage(); msg = link->getMessage();
LogAware::pow (msg, l->nrEdges() - 1);
}
if (Constants::DEBUG >= 5) { if (Constants::DEBUG >= 5) {
cout << " " << "init: " << msg << " " << src->hasEvidence() << endl; cout << msg << "^" << link->nrEdges() << "-1" ;
}
LogAware::pow (msg, link->nrEdges() - 1);
} }
const SpLinkSet& links = ninf(src)->getLinks(); const SpLinkSet& links = ninf(src)->getLinks();
if (Globals::logDomain) { if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) { CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
Util::add (msg, l->poweredMessage()); CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
Util::add (msg, cl->poweredMessage());
} }
} }
} else { } else {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) { CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
Util::multiply (msg, l->poweredMessage()); Util::multiply (msg, cl->poweredMessage());
if (Constants::DEBUG >= 5) { if (Constants::DEBUG >= 5) {
cout << " msg from " << l->getFactor()->getLabel() << ": " ; cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
cout << l->poweredMessage() << endl;
} }
} }
} }
} }
if (Constants::DEBUG >= 5) { if (Constants::DEBUG >= 5) {
cout << " result = " << msg << endl; cout << " = " << msg;
} }
return msg; return msg;
} }
@ -264,13 +356,14 @@ void
CbpSolver::printLinkInformation (void) const CbpSolver::printLinkInformation (void) const
{ {
for (unsigned i = 0; i < links_.size(); i++) { for (unsigned i = 0; i < links_.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]); CbpSolverLink* cl = static_cast<CbpSolverLink*> (links_[i]);
cout << l->toString() << ":" << endl; cout << cl->toString() << ":" << endl;
cout << " curr msg = " << l->getMessage() << endl; cout << " curr msg = " << cl->getMessage() << endl;
cout << " next msg = " << l->getNextMessage() << endl; cout << " next msg = " << cl->getNextMessage() << endl;
cout << " nr edges = " << l->nrEdges() << endl; cout << " index = " << cl->index() << endl;
cout << " powered = " << l->poweredMessage() << endl; cout << " nr edges = " << cl->nrEdges() << endl;
cout << " residual = " << l->getResidual() << endl; cout << " powered = " << cl->poweredMessage() << endl;
cout << " residual = " << cl->getResidual() << endl;
} }
} }

View File

@ -9,10 +9,12 @@ class Factor;
class CbpSolverLink : public SpLink class CbpSolverLink : public SpLink
{ {
public: public:
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c) CbpSolverLink (FacNode* fn, VarNode* vn, unsigned idx, unsigned count)
: SpLink (fn, vn), nrEdges_(c), : SpLink (fn, vn), index_(idx), nrEdges_(count),
pwdMsg_(vn->range(), LogAware::one()) { } pwdMsg_(vn->range(), LogAware::one()) { }
unsigned index (void) const { return index_; }
unsigned nrEdges (void) const { return nrEdges_; } unsigned nrEdges (void) const { return nrEdges_; }
const Params& poweredMessage (void) const { return pwdMsg_; } const Params& poweredMessage (void) const { return pwdMsg_; }
@ -26,6 +28,7 @@ class CbpSolverLink : public SpLink
} }
private: private:
unsigned index_;
unsigned nrEdges_; unsigned nrEdges_;
Params pwdMsg_; Params pwdMsg_;
}; };
@ -35,6 +38,8 @@ class CbpSolverLink : public SpLink
class CbpSolver : public BpSolver class CbpSolver : public BpSolver
{ {
public: public:
static vector<int> counts;
CbpSolver (const FactorGraph& fg); CbpSolver (const FactorGraph& fg);
~CbpSolver (void); ~CbpSolver (void);
@ -51,6 +56,8 @@ class CbpSolver : public BpSolver
void maxResidualSchedule (void); void maxResidualSchedule (void);
void calculateFactor2VariableMsg (SpLink*);
Params getVar2FactorMsg (const SpLink*) const; Params getVar2FactorMsg (const SpLink*) const;
void printLinkInformation (void) const; void printLinkInformation (void) const;

View File

@ -105,7 +105,7 @@ Factor::sumOutIndex (unsigned idx)
// on the left of `var', with the states of the remaining vars fixed // on the left of `var', with the states of the remaining vars fixed
unsigned leftVarOffset = 1; unsigned leftVarOffset = 1;
for (int i = args_.size() - 1; i > idx; i--) { for (int i = args_.size() - 1; i > (int)idx; i--) {
varOffset *= ranges_[i]; varOffset *= ranges_[i];
leftVarOffset *= ranges_[i]; leftVarOffset *= ranges_[i];
} }
@ -151,7 +151,7 @@ Factor::sumOutIndex (unsigned idx)
void void
Factor::sumOutAllExceptIndex (unsigned idx) Factor::sumOutAllExceptIndex (unsigned idx)
{ {
int i = idx; int i = (int)idx;
while (args_.size() > i + 1) { while (args_.size() > i + 1) {
sumOutLastVariable(); sumOutLastVariable();
} }