improve cbp by supporting factors connected to a single var two or more times

This commit is contained in:
Tiago Gomes 2012-04-26 00:54:06 +01:00
parent ad24a360ce
commit 689244a0d8
8 changed files with 252 additions and 226 deletions

View File

@ -325,7 +325,6 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
}
}
}
Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct);
result.multiply (src->factor());
@ -336,18 +335,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
}
result.sumOutAllExcept (dst->varId());
if (Constants::DEBUG >= 5) {
cout << " marginalized: " ;
cout << result.params() << endl;
cout << " marginalized: " << result.params() << endl;
}
const Params& resultParams = result.params();
Params& message = link->getNextMessage();
for (unsigned i = 0; i < resultParams.size(); i++) {
message[i] = resultParams[i];
}
LogAware::normalize (message);
link->getNextMessage() = result.params();
LogAware::normalize (link->getNextMessage());
if (Constants::DEBUG >= 5) {
cout << " curr msg: " << link->getMessage() << endl;
cout << " next msg: " << message << endl;
cout << " next msg: " << link->getNextMessage() << endl;
}
}

View File

@ -183,7 +183,7 @@ class BpSolver : public Solver
bool converged (void);
void printLinkInformation (void) const;
virtual void printLinkInformation (void) const;
};
#endif // HORUS_BPSOLVER_H

View File

@ -7,20 +7,6 @@ bool CFactorGraph::checkForIdenticalFactors = true;
CFactorGraph::CFactorGraph (const FactorGraph& fg)
: freeColor_(0), groundFg_(&fg)
{
const VarNodes& varNodes = fg.varNodes();
varSignatures_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
varSignatures_.push_back (Signature (c));
}
const FacNodes& facNodes = fg.facNodes();
facSignatures_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned c = facNodes[i]->neighbors().size() + 1;
facSignatures_.push_back (Signature (c));
}
varColors_.resize (varNodes.size());
facColors_.resize (facNodes.size());
findIdenticalFactors();
setInitialColors();
createGroups();
@ -77,6 +63,9 @@ CFactorGraph::findIdenticalFactors()
void
CFactorGraph::setInitialColors (void)
{
varColors_.resize (groundFg_->nrVarNodes());
facColors_.resize (groundFg_->nrFacNodes());
// create the initial variable colors
VarColorMap colorMap;
const VarNodes& varNodes = groundFg_->varNodes();
@ -127,31 +116,11 @@ CFactorGraph::createGroups (void)
while (groupsHaveChanged || nIters == 1) {
nIters ++;
unsigned prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (unsigned i = 0; i < facNodes.size(); i++) {
const Signature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); it++) {
Color newColor = getFreeColor();
FacNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
// set a new color to the variables with the same signature
unsigned prevVarGroupsSize = varGroups.size();
varGroups.clear();
for (unsigned i = 0; i < varNodes.size(); i++) {
const Signature& signature = getSignature (varNodes[i]);
const VarSignature& signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signature, VarNodes())).first;
@ -167,6 +136,26 @@ CFactorGraph::createGroups (void)
}
}
unsigned prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (unsigned i = 0; i < facNodes.size(); i++) {
const FacSignature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); it++) {
Color newColor = getFreeColor();
FacNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != facGroups.size();
}
@ -183,7 +172,7 @@ CFactorGraph::createClusters (
{
varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); it++) {
it != varGroups.end(); it++) {
const VarNodes& groupVars = it->second;
VarCluster* vc = new VarCluster (groupVars);
for (unsigned i = 0; i < groupVars.size(); i++) {
@ -194,7 +183,7 @@ CFactorGraph::createClusters (
facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); it++) {
it != facGroups.end(); it++) {
FacNode* groupFactor = it->second[0];
const VarNodes& neighs = groupFactor->neighbors();
VarClusters varClusters;
@ -209,67 +198,61 @@ CFactorGraph::createClusters (
const Signature&
VarSignature
CFactorGraph::getSignature (const VarNode* varNode)
{
Signature& sign = varSignatures_[varNode->getIndex()];
Colors::iterator it = sign.colors.begin();
const FacNodes& neighs = varNode->neighbors();
VarSignature sign;
sign.reserve (neighs.size() + 1);
for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]);
it ++;
*it = neighs[i]->factor().indexOf (varNode->varId());
it ++;
sign.push_back (make_pair (
getColor (neighs[i]),
neighs[i]->factor().indexOf (varNode->varId())));
}
*it = getColor (varNode);
std::sort (sign.begin(), sign.end());
sign.push_back (make_pair (getColor (varNode), 0));
return sign;
}
const Signature&
FacSignature
CFactorGraph::getSignature (const FacNode* facNode)
{
Signature& sign = facSignatures_[facNode->getIndex()];
Colors::iterator it = sign.colors.begin();
const VarNodes& neighs = facNode->neighbors();
FacSignature sign;
sign.reserve (neighs.size() + 1);
for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]);
it ++;
sign.push_back (getColor (neighs[i]));
}
std::sort (sign.colors.begin(), -- sign.colors.end());
*it = getColor (facNode);
sign.push_back (getColor (facNode));
return sign;
}
FactorGraph*
CFactorGraph::getGroundFactorGraph (void) const
CFactorGraph::getGroundFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
for (unsigned i = 0; i < varClusters_.size(); i++) {
VarNode* newVar = new VarNode (varClusters_[i]->members()[0]);
VarNode* newVar = new VarNode (varClusters_[i]->first());
varClusters_[i]->setRepresentative (newVar);
fg->addVarNode (newVar);
}
for (unsigned i = 0; i < facClusters_.size(); i++) {
const VarClusters& myVarClusters = facClusters_[i]->varClusters();
Vars myGroundVars;
myGroundVars.reserve (myVarClusters.size());
for (unsigned j = 0; j < myVarClusters.size(); j++) {
VarNode* v = myVarClusters[j]->getRepresentative();
myGroundVars.push_back (v);
Vars vars;
const VarClusters& clusters = facClusters_[i]->varClusters();
for (unsigned j = 0; j < clusters.size(); j++) {
vars.push_back (clusters[j]->representative());
}
const Factor& groundFac = facClusters_[i]->first()->factor();
FacNode* fn = new FacNode (Factor (
myGroundVars,
facClusters_[i]->members()[0]->factor().params(),
facClusters_[i]->members()[0]->factor().distId()));
vars, groundFac.params(), groundFac.distId()));
facClusters_[i]->setRepresentative (fn);
fg->addFacNode (fn);
for (unsigned j = 0; j < myGroundVars.size(); j++) {
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
for (unsigned j = 0; j < vars.size(); j++) {
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
}
}
return fg;
@ -280,29 +263,21 @@ CFactorGraph::getGroundFactorGraph (void) const
unsigned
CFactorGraph::getEdgeCount (
const FacCluster* fc,
const VarCluster* vc) const
const VarCluster* vc,
unsigned index) const
{
unsigned count = 0;
VarId vid = vc->members().front()->varId();
const FacNodes& members = fc->members();
for (unsigned i = 0; i < members.size(); i++) {
if (members[i]->factor().contains (vid)) {
VarId reprVid = vc->representative()->varId();
VarNode* groundVar = groundFg_->getVarNode (reprVid);
const FacNodes& neighs = groundVar->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) {
FacNodes::const_iterator it;
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
if (it != fc->members().end() &&
(*it)->factor().indexOf (reprVid) == (int)index) {
count ++;
}
}
if (Constants::DEBUG > 0) {
const VarNodes& vars = vc->members();
for (unsigned i = 1; i < vars.size(); i++) {
VarId vid = vars[i]->varId();
unsigned count2 = 0;
for (unsigned i = 0; i < members.size(); i++) {
if (members[i]->factor().contains (vid)) {
count2 ++;
}
}
assert (count == count2);
}
}
return count;
}
@ -327,7 +302,6 @@ CFactorGraph::printGroups (
cout << endl;
}
}
count = 1;
cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = facGroups.begin();

View File

@ -10,17 +10,19 @@
class VarCluster;
class FacCluster;
class Signature;
class SignatureHash;
class VarSignatureHash;
class FacSignatureHash;
typedef long Color;
typedef vector<Color> Colors;
typedef vector<std::pair<Color,unsigned>> VarSignature;
typedef vector<Color> FacSignature;
typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<unsigned, Colors> VarColorMap;
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
typedef unordered_map<VarSignature, VarNodes, VarSignatureHash> VarSignMap;
typedef unordered_map<FacSignature, FacNodes, FacSignatureHash> FacSignMap;
typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters;
@ -28,53 +30,27 @@ typedef vector<FacCluster*> FacClusters;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
struct Signature
struct VarSignatureHash
{
Signature (unsigned size) : colors(size) { }
bool operator< (const Signature& sig) const
size_t operator() (const VarSignature &sig) const
{
if (colors.size() < sig.colors.size()) {
return true;
} else if (colors.size() > sig.colors.size()) {
return false;
} else {
for (unsigned i = 0; i < colors.size(); i++) {
if (colors[i] < sig.colors[i]) {
return true;
} else if (colors[i] > sig.colors[i]) {
return false;
}
}
size_t val = hash<size_t>()(sig.size());
for (unsigned i = 0; i < sig.size(); i++) {
val ^= hash<size_t>()(sig[i].first);
val ^= hash<size_t>()(sig[i].second);
}
return false;
return val;
}
bool operator== (const Signature& sig) const
{
if (colors.size() != sig.colors.size()) {
return false;
}
for (unsigned i = 0; i < colors.size(); i++) {
if (colors[i] != sig.colors[i]) {
return false;
}
}
return true;
}
Colors colors;
};
struct SignatureHash
struct FacSignatureHash
{
size_t operator() (const Signature &sig) const
size_t operator() (const FacSignature &sig) const
{
size_t val = hash<size_t>()(sig.colors.size());
for (unsigned i = 0; i < sig.colors.size(); i++) {
val ^= hash<size_t>()(sig.colors[i]);
size_t val = hash<size_t>()(sig.size());
for (unsigned i = 0; i < sig.size(); i++) {
val ^= hash<size_t>()(sig[i]);
}
return val;
}
@ -87,19 +63,16 @@ class VarCluster
public:
VarCluster (const VarNodes& vs) : members_(vs) { }
const VarNode* first (void) const { return members_.front(); }
const VarNodes& members (void) const { return members_; }
const FacClusters& facClusters (void) const { return facClusters_; }
void addFacCluster (FacCluster* fc) { facClusters_.push_back (fc); }
VarNode* getRepresentative (void) const { return repr_; }
VarNode* representative (void) const { return repr_; }
void setRepresentative (VarNode* vn) { repr_ = vn; }
private:
VarNodes members_;
FacClusters facClusters_;
VarNode* repr_;
};
@ -108,26 +81,17 @@ class FacCluster
{
public:
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
: members_(fcs), varClusters_(vcs)
{
for (unsigned i = 0; i < varClusters_.size(); i++) {
varClusters_[i]->addFacCluster (this);
}
}
: members_(fcs), varClusters_(vcs) { }
const FacNode* first (void) const { return members_.front(); }
const FacNodes& members (void) const { return members_; }
const VarClusters& varClusters (void) const { return varClusters_; }
VarClusters& varClusters (void) { return varClusters_; }
FacNode* getRepresentative (void) const { return repr_; }
FacNode* representative (void) const { return repr_; }
void setRepresentative (FacNode* fn) { repr_ = fn; }
bool containsGround (const FacNode* fn) const
{
return std::find (members_.begin(), members_.end(), fn)
!= members_.end();
}
private:
FacNodes members_;
@ -147,15 +111,16 @@ class CFactorGraph
const FacClusters& facClusters (void) { return facClusters_; }
VarNode* getEquivalentVariable (VarId vid)
VarNode* getEquivalent (VarId vid)
{
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->getRepresentative();
return vc->representative();
}
FactorGraph* getGroundFactorGraph (void) const;
FactorGraph* getGroundFactorGraph (void);
unsigned getEdgeCount (const FacCluster*, const VarCluster*) const;
unsigned getEdgeCount (const FacCluster*,
const VarCluster*, unsigned index) const;
static bool checkForIdenticalFactors;
@ -184,11 +149,6 @@ class CFactorGraph
facColors_[fn->getIndex()] = c;
}
VarCluster* getVariableCluster (VarId vid) const
{
return vid2VarCluster_.find (vid)->second;
}
void findIdenticalFactors (void);
void setInitialColors (void);
@ -197,21 +157,19 @@ class CFactorGraph
void createClusters (const VarSignMap&, const FacSignMap&);
const Signature& getSignature (const VarNode*);
VarSignature getSignature (const VarNode*);
const Signature& getSignature (const FacNode*);
FacSignature getSignature (const FacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
Color freeColor_;
Colors varColors_;
Colors facColors_;
vector<Signature> varSignatures_;
vector<Signature> facSignatures_;
VarClusters varClusters_;
FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_;
const FactorGraph* groundFg_;
Color freeColor_;
Colors varColors_;
Colors facColors_;
VarClusters varClusters_;
FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_;
const FactorGraph* groundFg_;
};
#endif // HORUS_CFACTORGRAPH_H

View File

@ -1,5 +1,6 @@
#include "CbpSolver.h"
vector<int> CbpSolver::counts;
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
{
@ -24,16 +25,14 @@ CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
Statistics::updateCompressingStatistics (nrGroundVars,
nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless);
}
// cout << "uncompressed factor graph:" << endl;
// cout << " " << fg.nrVarNodes() << " variables " << endl;
// cout << " " << fg.nrFacNodes() << " factors " << endl;
// cout << "compressed factor graph:" << endl;
// cout << " " << fg_->nrVarNodes() << " variables " << endl;
// cout << " " << fg_->nrFacNodes() << " factors " << endl;
// Util::printHeader ("Compressed Factor Graph");
// fg_->print();
// Util::printHeader ("Uncompressed Factor Graph");
// fg.print();
if (Constants::DEBUG >= 5) {
cout << "uncompressed factor graph:" << endl;
cout << " " << fg.nrVarNodes() << " variables " << endl;
cout << " " << fg.nrFacNodes() << " factors " << endl;
cout << "compressed factor graph:" << endl;
cout << " " << fg_->nrVarNodes() << " variables " << endl;
cout << " " << fg_->nrFacNodes() << " factors " << endl;
}
}
@ -81,8 +80,8 @@ CbpSolver::getPosterioriOf (VarId vid)
if (runned_ == false) {
runSolver();
}
assert (cfg_->getEquivalentVariable (vid));
VarNode* var = cfg_->getEquivalentVariable (vid);
assert (cfg_->getEquivalent (vid));
VarNode* var = cfg_->getEquivalent (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence());
@ -115,7 +114,7 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVids)
{
VarIds eqVarIds;
for (unsigned i = 0; i < jointVids.size(); i++) {
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
VarNode* vn = cfg_->getEquivalent (jointVids[i]);
eqVarIds.push_back (vn->varId());
}
return BpSolver::getJointDistributionOf (eqVarIds);
@ -125,15 +124,20 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVids)
void
CbpSolver::createLinks (void)
{
{
const FacClusters& fcs = cfg_->facClusters();
for (unsigned i = 0; i < fcs.size(); i++) {
const VarClusters& vcs = fcs[i]->varClusters();
for (unsigned j = 0; j < vcs.size(); j++) {
unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]);
unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j);
if (Constants::DEBUG >= 5) {
cout << "creating edge " ;
cout << fcs[i]->representative()->getLabel() << " -> " ;
cout << vcs[j]->representative()->label();
cout << " idx=" << j << ", count=" << count << endl;
}
links_.push_back (new CbpSolverLink (
fcs[i]->getRepresentative(),
vcs[j]->getRepresentative(), c));
fcs[i]->representative(), vcs[j]->representative(), j, count));
}
}
}
@ -213,47 +217,135 @@ CbpSolver::maxResidualSchedule (void)
Params
CbpSolver::getVar2FactorMsg (const SpLink* link) const
void
CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
{
Params msg;
CbpSolverLink* link = static_cast<CbpSolverLink*> (_link);
FacNode* src = link->getFactor();
const VarNode* dst = link->getVariable();
const SpLinkSet& links = ninf(src)->getLinks();
// calculate the product of messages that were sent
// to factor `src', except from var `dst'
unsigned msgSize = 1;
for (unsigned i = 0; i < links.size(); i++) {
msgSize *= links[i]->getVariable()->range();
}
unsigned repetitions = 1;
Params msgProduct (msgSize, LogAware::multIdenty());
if (Globals::logDomain) {
for (int i = links.size() - 1; i >= 0; i--) {
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
if (Constants::DEBUG >= 5) {
cout << " message from " << links[i]->getVariable()->label();
cout << ": " ;
}
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
if (Constants::DEBUG >= 5) {
cout << endl;
}
} else {
unsigned range = links[i]->getVariable()->range();
Util::add (msgProduct, Params (range, 0.0), repetitions);
repetitions *= range;
}
}
} else {
for (int i = links.size() - 1; i >= 0; i--) {
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
if (Constants::DEBUG >= 5) {
cout << " message from " << links[i]->getVariable()->label();
cout << ": " ;
}
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
if (Constants::DEBUG >= 5) {
cout << endl;
}
} else {
unsigned range = links[i]->getVariable()->range();
Util::multiply (msgProduct, Params (range, 1.0), repetitions);
repetitions *= range;
}
}
}
Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct);
assert (msgProduct.size() == src->factor().size());
if (Globals::logDomain) {
for (unsigned i = 0; i < result.size(); i++) {
result[i] += src->factor()[i];
}
} else {
for (unsigned i = 0; i < result.size(); i++) {
result[i] *= src->factor()[i];
}
}
if (Constants::DEBUG >= 5) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl;
}
result.sumOutAllExceptIndex (link->index());
if (Constants::DEBUG >= 5) {
cout << " marginalized: " << result.params() << endl;
}
link->getNextMessage() = result.params();
LogAware::normalize (link->getNextMessage());
if (Constants::DEBUG >= 5) {
cout << " curr msg: " << link->getMessage() << endl;
cout << " next msg: " << link->getNextMessage() << endl;
}
}
Params
CbpSolver::getVar2FactorMsg (const SpLink* _link) const
{
const CbpSolverLink* link = static_cast<const CbpSolverLink*> (_link);
const VarNode* src = link->getVariable();
const FacNode* dst = link->getFactor();
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
Params msg;
if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence());
double value = link->getMessage()[src->getEvidence()];
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
if (Constants::DEBUG >= 5) {
msg[src->getEvidence()] = value;
cout << msg << "^" << link->nrEdges() << "-1" ;
}
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
} else {
msg = link->getMessage();
LogAware::pow (msg, l->nrEdges() - 1);
}
if (Constants::DEBUG >= 5) {
cout << " " << "init: " << msg << " " << src->hasEvidence() << endl;
if (Constants::DEBUG >= 5) {
cout << msg << "^" << link->nrEdges() << "-1" ;
}
LogAware::pow (msg, link->nrEdges() - 1);
}
const SpLinkSet& links = ninf(src)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (msg, l->poweredMessage());
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
Util::add (msg, cl->poweredMessage());
}
}
} else {
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (msg, l->poweredMessage());
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
Util::multiply (msg, cl->poweredMessage());
if (Constants::DEBUG >= 5) {
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
cout << l->poweredMessage() << endl;
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
}
}
}
}
if (Constants::DEBUG >= 5) {
cout << " result = " << msg << endl;
cout << " = " << msg;
}
return msg;
}
@ -264,13 +356,14 @@ void
CbpSolver::printLinkInformation (void) const
{
for (unsigned i = 0; i < links_.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
cout << l->toString() << ":" << endl;
cout << " curr msg = " << l->getMessage() << endl;
cout << " next msg = " << l->getNextMessage() << endl;
cout << " nr edges = " << l->nrEdges() << endl;
cout << " powered = " << l->poweredMessage() << endl;
cout << " residual = " << l->getResidual() << endl;
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links_[i]);
cout << cl->toString() << ":" << endl;
cout << " curr msg = " << cl->getMessage() << endl;
cout << " next msg = " << cl->getNextMessage() << endl;
cout << " index = " << cl->index() << endl;
cout << " nr edges = " << cl->nrEdges() << endl;
cout << " powered = " << cl->poweredMessage() << endl;
cout << " residual = " << cl->getResidual() << endl;
}
}

View File

@ -9,10 +9,12 @@ class Factor;
class CbpSolverLink : public SpLink
{
public:
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
: SpLink (fn, vn), nrEdges_(c),
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned idx, unsigned count)
: SpLink (fn, vn), index_(idx), nrEdges_(count),
pwdMsg_(vn->range(), LogAware::one()) { }
unsigned index (void) const { return index_; }
unsigned nrEdges (void) const { return nrEdges_; }
const Params& poweredMessage (void) const { return pwdMsg_; }
@ -26,6 +28,7 @@ class CbpSolverLink : public SpLink
}
private:
unsigned index_;
unsigned nrEdges_;
Params pwdMsg_;
};
@ -35,6 +38,8 @@ class CbpSolverLink : public SpLink
class CbpSolver : public BpSolver
{
public:
static vector<int> counts;
CbpSolver (const FactorGraph& fg);
~CbpSolver (void);
@ -51,6 +56,8 @@ class CbpSolver : public BpSolver
void maxResidualSchedule (void);
void calculateFactor2VariableMsg (SpLink*);
Params getVar2FactorMsg (const SpLink*) const;
void printLinkInformation (void) const;

View File

@ -105,7 +105,7 @@ Factor::sumOutIndex (unsigned idx)
// on the left of `var', with the states of the remaining vars fixed
unsigned leftVarOffset = 1;
for (int i = args_.size() - 1; i > idx; i--) {
for (int i = args_.size() - 1; i > (int)idx; i--) {
varOffset *= ranges_[i];
leftVarOffset *= ranges_[i];
}
@ -151,7 +151,7 @@ Factor::sumOutIndex (unsigned idx)
void
Factor::sumOutAllExceptIndex (unsigned idx)
{
int i = idx;
int i = (int)idx;
while (args_.size() > i + 1) {
sumOutLastVariable();
}

View File

@ -163,7 +163,7 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
reverse (vids.begin(), vids.end());
Factor f (vids, ranges, params);
reverse (vids.begin(), vids.end());
f.reorderArguments (vids);
f.reorderArguments (vids);
addFactor (f);
}
is.close();