improve cbp by supporting factors connected to a single var two or more times
This commit is contained in:
parent
ad24a360ce
commit
689244a0d8
@ -325,7 +325,6 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Factor result (src->factor().arguments(),
|
||||
src->factor().ranges(), msgProduct);
|
||||
result.multiply (src->factor());
|
||||
@ -336,18 +335,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
||||
}
|
||||
result.sumOutAllExcept (dst->varId());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " marginalized: " ;
|
||||
cout << result.params() << endl;
|
||||
cout << " marginalized: " << result.params() << endl;
|
||||
}
|
||||
const Params& resultParams = result.params();
|
||||
Params& message = link->getNextMessage();
|
||||
for (unsigned i = 0; i < resultParams.size(); i++) {
|
||||
message[i] = resultParams[i];
|
||||
}
|
||||
LogAware::normalize (message);
|
||||
link->getNextMessage() = result.params();
|
||||
LogAware::normalize (link->getNextMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " curr msg: " << link->getMessage() << endl;
|
||||
cout << " next msg: " << message << endl;
|
||||
cout << " next msg: " << link->getNextMessage() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -183,7 +183,7 @@ class BpSolver : public Solver
|
||||
|
||||
bool converged (void);
|
||||
|
||||
void printLinkInformation (void) const;
|
||||
virtual void printLinkInformation (void) const;
|
||||
};
|
||||
|
||||
#endif // HORUS_BPSOLVER_H
|
||||
|
@ -7,20 +7,6 @@ bool CFactorGraph::checkForIdenticalFactors = true;
|
||||
CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
||||
: freeColor_(0), groundFg_(&fg)
|
||||
{
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
varSignatures_.reserve (varNodes.size());
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
|
||||
varSignatures_.push_back (Signature (c));
|
||||
}
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
facSignatures_.reserve (facNodes.size());
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
unsigned c = facNodes[i]->neighbors().size() + 1;
|
||||
facSignatures_.push_back (Signature (c));
|
||||
}
|
||||
varColors_.resize (varNodes.size());
|
||||
facColors_.resize (facNodes.size());
|
||||
findIdenticalFactors();
|
||||
setInitialColors();
|
||||
createGroups();
|
||||
@ -77,6 +63,9 @@ CFactorGraph::findIdenticalFactors()
|
||||
void
|
||||
CFactorGraph::setInitialColors (void)
|
||||
{
|
||||
varColors_.resize (groundFg_->nrVarNodes());
|
||||
facColors_.resize (groundFg_->nrFacNodes());
|
||||
|
||||
// create the initial variable colors
|
||||
VarColorMap colorMap;
|
||||
const VarNodes& varNodes = groundFg_->varNodes();
|
||||
@ -127,31 +116,11 @@ CFactorGraph::createGroups (void)
|
||||
while (groupsHaveChanged || nIters == 1) {
|
||||
nIters ++;
|
||||
|
||||
unsigned prevFactorGroupsSize = facGroups.size();
|
||||
facGroups.clear();
|
||||
// set a new color to the factors with the same signature
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const Signature& signature = getSignature (facNodes[i]);
|
||||
FacSignMap::iterator it = facGroups.find (signature);
|
||||
if (it == facGroups.end()) {
|
||||
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||
}
|
||||
it->second.push_back (facNodes[i]);
|
||||
}
|
||||
for (FacSignMap::iterator it = facGroups.begin();
|
||||
it != facGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FacNodes& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
}
|
||||
|
||||
// set a new color to the variables with the same signature
|
||||
unsigned prevVarGroupsSize = varGroups.size();
|
||||
varGroups.clear();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
const Signature& signature = getSignature (varNodes[i]);
|
||||
const VarSignature& signature = getSignature (varNodes[i]);
|
||||
VarSignMap::iterator it = varGroups.find (signature);
|
||||
if (it == varGroups.end()) {
|
||||
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||
@ -167,6 +136,26 @@ CFactorGraph::createGroups (void)
|
||||
}
|
||||
}
|
||||
|
||||
unsigned prevFactorGroupsSize = facGroups.size();
|
||||
facGroups.clear();
|
||||
// set a new color to the factors with the same signature
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const FacSignature& signature = getSignature (facNodes[i]);
|
||||
FacSignMap::iterator it = facGroups.find (signature);
|
||||
if (it == facGroups.end()) {
|
||||
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||
}
|
||||
it->second.push_back (facNodes[i]);
|
||||
}
|
||||
for (FacSignMap::iterator it = facGroups.begin();
|
||||
it != facGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FacNodes& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
}
|
||||
|
||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||
|| prevFactorGroupsSize != facGroups.size();
|
||||
}
|
||||
@ -183,7 +172,7 @@ CFactorGraph::createClusters (
|
||||
{
|
||||
varClusters_.reserve (varGroups.size());
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
it != varGroups.end(); it++) {
|
||||
const VarNodes& groupVars = it->second;
|
||||
VarCluster* vc = new VarCluster (groupVars);
|
||||
for (unsigned i = 0; i < groupVars.size(); i++) {
|
||||
@ -194,7 +183,7 @@ CFactorGraph::createClusters (
|
||||
|
||||
facClusters_.reserve (facGroups.size());
|
||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||
it != facGroups.end(); it++) {
|
||||
it != facGroups.end(); it++) {
|
||||
FacNode* groupFactor = it->second[0];
|
||||
const VarNodes& neighs = groupFactor->neighbors();
|
||||
VarClusters varClusters;
|
||||
@ -209,67 +198,61 @@ CFactorGraph::createClusters (
|
||||
|
||||
|
||||
|
||||
const Signature&
|
||||
VarSignature
|
||||
CFactorGraph::getSignature (const VarNode* varNode)
|
||||
{
|
||||
Signature& sign = varSignatures_[varNode->getIndex()];
|
||||
Colors::iterator it = sign.colors.begin();
|
||||
const FacNodes& neighs = varNode->neighbors();
|
||||
VarSignature sign;
|
||||
sign.reserve (neighs.size() + 1);
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
*it = getColor (neighs[i]);
|
||||
it ++;
|
||||
*it = neighs[i]->factor().indexOf (varNode->varId());
|
||||
it ++;
|
||||
sign.push_back (make_pair (
|
||||
getColor (neighs[i]),
|
||||
neighs[i]->factor().indexOf (varNode->varId())));
|
||||
}
|
||||
*it = getColor (varNode);
|
||||
std::sort (sign.begin(), sign.end());
|
||||
sign.push_back (make_pair (getColor (varNode), 0));
|
||||
return sign;
|
||||
}
|
||||
|
||||
|
||||
|
||||
const Signature&
|
||||
FacSignature
|
||||
CFactorGraph::getSignature (const FacNode* facNode)
|
||||
{
|
||||
Signature& sign = facSignatures_[facNode->getIndex()];
|
||||
Colors::iterator it = sign.colors.begin();
|
||||
const VarNodes& neighs = facNode->neighbors();
|
||||
FacSignature sign;
|
||||
sign.reserve (neighs.size() + 1);
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
*it = getColor (neighs[i]);
|
||||
it ++;
|
||||
sign.push_back (getColor (neighs[i]));
|
||||
}
|
||||
std::sort (sign.colors.begin(), -- sign.colors.end());
|
||||
*it = getColor (facNode);
|
||||
sign.push_back (getColor (facNode));
|
||||
return sign;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
CFactorGraph::getGroundFactorGraph (void) const
|
||||
CFactorGraph::getGroundFactorGraph (void)
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
VarNode* newVar = new VarNode (varClusters_[i]->members()[0]);
|
||||
VarNode* newVar = new VarNode (varClusters_[i]->first());
|
||||
varClusters_[i]->setRepresentative (newVar);
|
||||
fg->addVarNode (newVar);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
||||
const VarClusters& myVarClusters = facClusters_[i]->varClusters();
|
||||
Vars myGroundVars;
|
||||
myGroundVars.reserve (myVarClusters.size());
|
||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||
VarNode* v = myVarClusters[j]->getRepresentative();
|
||||
myGroundVars.push_back (v);
|
||||
Vars vars;
|
||||
const VarClusters& clusters = facClusters_[i]->varClusters();
|
||||
for (unsigned j = 0; j < clusters.size(); j++) {
|
||||
vars.push_back (clusters[j]->representative());
|
||||
}
|
||||
const Factor& groundFac = facClusters_[i]->first()->factor();
|
||||
FacNode* fn = new FacNode (Factor (
|
||||
myGroundVars,
|
||||
facClusters_[i]->members()[0]->factor().params(),
|
||||
facClusters_[i]->members()[0]->factor().distId()));
|
||||
vars, groundFac.params(), groundFac.distId()));
|
||||
facClusters_[i]->setRepresentative (fn);
|
||||
fg->addFacNode (fn);
|
||||
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
||||
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
|
||||
for (unsigned j = 0; j < vars.size(); j++) {
|
||||
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
|
||||
}
|
||||
}
|
||||
return fg;
|
||||
@ -280,29 +263,21 @@ CFactorGraph::getGroundFactorGraph (void) const
|
||||
unsigned
|
||||
CFactorGraph::getEdgeCount (
|
||||
const FacCluster* fc,
|
||||
const VarCluster* vc) const
|
||||
const VarCluster* vc,
|
||||
unsigned index) const
|
||||
{
|
||||
unsigned count = 0;
|
||||
VarId vid = vc->members().front()->varId();
|
||||
const FacNodes& members = fc->members();
|
||||
for (unsigned i = 0; i < members.size(); i++) {
|
||||
if (members[i]->factor().contains (vid)) {
|
||||
VarId reprVid = vc->representative()->varId();
|
||||
VarNode* groundVar = groundFg_->getVarNode (reprVid);
|
||||
const FacNodes& neighs = groundVar->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
FacNodes::const_iterator it;
|
||||
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
|
||||
if (it != fc->members().end() &&
|
||||
(*it)->factor().indexOf (reprVid) == (int)index) {
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG > 0) {
|
||||
const VarNodes& vars = vc->members();
|
||||
for (unsigned i = 1; i < vars.size(); i++) {
|
||||
VarId vid = vars[i]->varId();
|
||||
unsigned count2 = 0;
|
||||
for (unsigned i = 0; i < members.size(); i++) {
|
||||
if (members[i]->factor().contains (vid)) {
|
||||
count2 ++;
|
||||
}
|
||||
}
|
||||
assert (count == count2);
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
@ -327,7 +302,6 @@ CFactorGraph::printGroups (
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
count = 1;
|
||||
cout << endl << "factor groups:" << endl;
|
||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||
|
@ -10,17 +10,19 @@
|
||||
|
||||
class VarCluster;
|
||||
class FacCluster;
|
||||
class Signature;
|
||||
class SignatureHash;
|
||||
class VarSignatureHash;
|
||||
class FacSignatureHash;
|
||||
|
||||
typedef long Color;
|
||||
typedef vector<Color> Colors;
|
||||
typedef vector<std::pair<Color,unsigned>> VarSignature;
|
||||
typedef vector<Color> FacSignature;
|
||||
|
||||
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||
typedef unordered_map<unsigned, Colors> VarColorMap;
|
||||
|
||||
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
|
||||
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
|
||||
typedef unordered_map<VarSignature, VarNodes, VarSignatureHash> VarSignMap;
|
||||
typedef unordered_map<FacSignature, FacNodes, FacSignatureHash> FacSignMap;
|
||||
|
||||
typedef vector<VarCluster*> VarClusters;
|
||||
typedef vector<FacCluster*> FacClusters;
|
||||
@ -28,53 +30,27 @@ typedef vector<FacCluster*> FacClusters;
|
||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||
|
||||
|
||||
struct Signature
|
||||
struct VarSignatureHash
|
||||
{
|
||||
Signature (unsigned size) : colors(size) { }
|
||||
|
||||
bool operator< (const Signature& sig) const
|
||||
size_t operator() (const VarSignature &sig) const
|
||||
{
|
||||
if (colors.size() < sig.colors.size()) {
|
||||
return true;
|
||||
} else if (colors.size() > sig.colors.size()) {
|
||||
return false;
|
||||
} else {
|
||||
for (unsigned i = 0; i < colors.size(); i++) {
|
||||
if (colors[i] < sig.colors[i]) {
|
||||
return true;
|
||||
} else if (colors[i] > sig.colors[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
size_t val = hash<size_t>()(sig.size());
|
||||
for (unsigned i = 0; i < sig.size(); i++) {
|
||||
val ^= hash<size_t>()(sig[i].first);
|
||||
val ^= hash<size_t>()(sig[i].second);
|
||||
}
|
||||
return false;
|
||||
return val;
|
||||
}
|
||||
|
||||
bool operator== (const Signature& sig) const
|
||||
{
|
||||
if (colors.size() != sig.colors.size()) {
|
||||
return false;
|
||||
}
|
||||
for (unsigned i = 0; i < colors.size(); i++) {
|
||||
if (colors[i] != sig.colors[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Colors colors;
|
||||
};
|
||||
|
||||
|
||||
|
||||
struct SignatureHash
|
||||
struct FacSignatureHash
|
||||
{
|
||||
size_t operator() (const Signature &sig) const
|
||||
size_t operator() (const FacSignature &sig) const
|
||||
{
|
||||
size_t val = hash<size_t>()(sig.colors.size());
|
||||
for (unsigned i = 0; i < sig.colors.size(); i++) {
|
||||
val ^= hash<size_t>()(sig.colors[i]);
|
||||
size_t val = hash<size_t>()(sig.size());
|
||||
for (unsigned i = 0; i < sig.size(); i++) {
|
||||
val ^= hash<size_t>()(sig[i]);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
@ -87,19 +63,16 @@ class VarCluster
|
||||
public:
|
||||
VarCluster (const VarNodes& vs) : members_(vs) { }
|
||||
|
||||
const VarNode* first (void) const { return members_.front(); }
|
||||
|
||||
const VarNodes& members (void) const { return members_; }
|
||||
|
||||
const FacClusters& facClusters (void) const { return facClusters_; }
|
||||
|
||||
void addFacCluster (FacCluster* fc) { facClusters_.push_back (fc); }
|
||||
|
||||
VarNode* getRepresentative (void) const { return repr_; }
|
||||
VarNode* representative (void) const { return repr_; }
|
||||
|
||||
void setRepresentative (VarNode* vn) { repr_ = vn; }
|
||||
|
||||
private:
|
||||
VarNodes members_;
|
||||
FacClusters facClusters_;
|
||||
VarNode* repr_;
|
||||
};
|
||||
|
||||
@ -108,26 +81,17 @@ class FacCluster
|
||||
{
|
||||
public:
|
||||
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
|
||||
: members_(fcs), varClusters_(vcs)
|
||||
{
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
varClusters_[i]->addFacCluster (this);
|
||||
}
|
||||
}
|
||||
: members_(fcs), varClusters_(vcs) { }
|
||||
|
||||
const FacNode* first (void) const { return members_.front(); }
|
||||
|
||||
const FacNodes& members (void) const { return members_; }
|
||||
|
||||
const VarClusters& varClusters (void) const { return varClusters_; }
|
||||
VarClusters& varClusters (void) { return varClusters_; }
|
||||
|
||||
FacNode* getRepresentative (void) const { return repr_; }
|
||||
FacNode* representative (void) const { return repr_; }
|
||||
|
||||
void setRepresentative (FacNode* fn) { repr_ = fn; }
|
||||
|
||||
bool containsGround (const FacNode* fn) const
|
||||
{
|
||||
return std::find (members_.begin(), members_.end(), fn)
|
||||
!= members_.end();
|
||||
}
|
||||
|
||||
private:
|
||||
FacNodes members_;
|
||||
@ -147,15 +111,16 @@ class CFactorGraph
|
||||
|
||||
const FacClusters& facClusters (void) { return facClusters_; }
|
||||
|
||||
VarNode* getEquivalentVariable (VarId vid)
|
||||
VarNode* getEquivalent (VarId vid)
|
||||
{
|
||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||
return vc->getRepresentative();
|
||||
return vc->representative();
|
||||
}
|
||||
|
||||
FactorGraph* getGroundFactorGraph (void) const;
|
||||
FactorGraph* getGroundFactorGraph (void);
|
||||
|
||||
unsigned getEdgeCount (const FacCluster*, const VarCluster*) const;
|
||||
unsigned getEdgeCount (const FacCluster*,
|
||||
const VarCluster*, unsigned index) const;
|
||||
|
||||
static bool checkForIdenticalFactors;
|
||||
|
||||
@ -184,11 +149,6 @@ class CFactorGraph
|
||||
facColors_[fn->getIndex()] = c;
|
||||
}
|
||||
|
||||
VarCluster* getVariableCluster (VarId vid) const
|
||||
{
|
||||
return vid2VarCluster_.find (vid)->second;
|
||||
}
|
||||
|
||||
void findIdenticalFactors (void);
|
||||
|
||||
void setInitialColors (void);
|
||||
@ -197,21 +157,19 @@ class CFactorGraph
|
||||
|
||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||
|
||||
const Signature& getSignature (const VarNode*);
|
||||
VarSignature getSignature (const VarNode*);
|
||||
|
||||
const Signature& getSignature (const FacNode*);
|
||||
FacSignature getSignature (const FacNode*);
|
||||
|
||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||
|
||||
Color freeColor_;
|
||||
Colors varColors_;
|
||||
Colors facColors_;
|
||||
vector<Signature> varSignatures_;
|
||||
vector<Signature> facSignatures_;
|
||||
VarClusters varClusters_;
|
||||
FacClusters facClusters_;
|
||||
VarId2VarCluster vid2VarCluster_;
|
||||
const FactorGraph* groundFg_;
|
||||
Color freeColor_;
|
||||
Colors varColors_;
|
||||
Colors facColors_;
|
||||
VarClusters varClusters_;
|
||||
FacClusters facClusters_;
|
||||
VarId2VarCluster vid2VarCluster_;
|
||||
const FactorGraph* groundFg_;
|
||||
};
|
||||
|
||||
#endif // HORUS_CFACTORGRAPH_H
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include "CbpSolver.h"
|
||||
|
||||
vector<int> CbpSolver::counts;
|
||||
|
||||
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
||||
{
|
||||
@ -24,16 +25,14 @@ CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
||||
Statistics::updateCompressingStatistics (nrGroundVars,
|
||||
nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless);
|
||||
}
|
||||
// cout << "uncompressed factor graph:" << endl;
|
||||
// cout << " " << fg.nrVarNodes() << " variables " << endl;
|
||||
// cout << " " << fg.nrFacNodes() << " factors " << endl;
|
||||
// cout << "compressed factor graph:" << endl;
|
||||
// cout << " " << fg_->nrVarNodes() << " variables " << endl;
|
||||
// cout << " " << fg_->nrFacNodes() << " factors " << endl;
|
||||
// Util::printHeader ("Compressed Factor Graph");
|
||||
// fg_->print();
|
||||
// Util::printHeader ("Uncompressed Factor Graph");
|
||||
// fg.print();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << "uncompressed factor graph:" << endl;
|
||||
cout << " " << fg.nrVarNodes() << " variables " << endl;
|
||||
cout << " " << fg.nrFacNodes() << " factors " << endl;
|
||||
cout << "compressed factor graph:" << endl;
|
||||
cout << " " << fg_->nrVarNodes() << " variables " << endl;
|
||||
cout << " " << fg_->nrFacNodes() << " factors " << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -81,8 +80,8 @@ CbpSolver::getPosterioriOf (VarId vid)
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
assert (cfg_->getEquivalentVariable (vid));
|
||||
VarNode* var = cfg_->getEquivalentVariable (vid);
|
||||
assert (cfg_->getEquivalent (vid));
|
||||
VarNode* var = cfg_->getEquivalent (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->range(), LogAware::noEvidence());
|
||||
@ -115,7 +114,7 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
||||
{
|
||||
VarIds eqVarIds;
|
||||
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
|
||||
VarNode* vn = cfg_->getEquivalent (jointVids[i]);
|
||||
eqVarIds.push_back (vn->varId());
|
||||
}
|
||||
return BpSolver::getJointDistributionOf (eqVarIds);
|
||||
@ -125,15 +124,20 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
||||
|
||||
void
|
||||
CbpSolver::createLinks (void)
|
||||
{
|
||||
{
|
||||
const FacClusters& fcs = cfg_->facClusters();
|
||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||
const VarClusters& vcs = fcs[i]->varClusters();
|
||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||
unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]);
|
||||
unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j);
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << "creating edge " ;
|
||||
cout << fcs[i]->representative()->getLabel() << " -> " ;
|
||||
cout << vcs[j]->representative()->label();
|
||||
cout << " idx=" << j << ", count=" << count << endl;
|
||||
}
|
||||
links_.push_back (new CbpSolverLink (
|
||||
fcs[i]->getRepresentative(),
|
||||
vcs[j]->getRepresentative(), c));
|
||||
fcs[i]->representative(), vcs[j]->representative(), j, count));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -213,47 +217,135 @@ CbpSolver::maxResidualSchedule (void)
|
||||
|
||||
|
||||
|
||||
Params
|
||||
CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
void
|
||||
CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
|
||||
{
|
||||
Params msg;
|
||||
CbpSolverLink* link = static_cast<CbpSolverLink*> (_link);
|
||||
FacNode* src = link->getFactor();
|
||||
const VarNode* dst = link->getVariable();
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dst'
|
||||
unsigned msgSize = 1;
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
msgSize *= links[i]->getVariable()->range();
|
||||
}
|
||||
unsigned repetitions = 1;
|
||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||
if (Globals::logDomain) {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
||||
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message from " << links[i]->getVariable()->label();
|
||||
cout << ": " ;
|
||||
}
|
||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << endl;
|
||||
}
|
||||
} else {
|
||||
unsigned range = links[i]->getVariable()->range();
|
||||
Util::add (msgProduct, Params (range, 0.0), repetitions);
|
||||
repetitions *= range;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
||||
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message from " << links[i]->getVariable()->label();
|
||||
cout << ": " ;
|
||||
}
|
||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << endl;
|
||||
}
|
||||
} else {
|
||||
unsigned range = links[i]->getVariable()->range();
|
||||
Util::multiply (msgProduct, Params (range, 1.0), repetitions);
|
||||
repetitions *= range;
|
||||
}
|
||||
}
|
||||
}
|
||||
Factor result (src->factor().arguments(),
|
||||
src->factor().ranges(), msgProduct);
|
||||
assert (msgProduct.size() == src->factor().size());
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < result.size(); i++) {
|
||||
result[i] += src->factor()[i];
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < result.size(); i++) {
|
||||
result[i] *= src->factor()[i];
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->factor().params() << endl;
|
||||
cout << " factor product: " << result.params() << endl;
|
||||
}
|
||||
result.sumOutAllExceptIndex (link->index());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " marginalized: " << result.params() << endl;
|
||||
}
|
||||
link->getNextMessage() = result.params();
|
||||
LogAware::normalize (link->getNextMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " curr msg: " << link->getMessage() << endl;
|
||||
cout << " next msg: " << link->getNextMessage() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
CbpSolver::getVar2FactorMsg (const SpLink* _link) const
|
||||
{
|
||||
const CbpSolverLink* link = static_cast<const CbpSolverLink*> (_link);
|
||||
const VarNode* src = link->getVariable();
|
||||
const FacNode* dst = link->getFactor();
|
||||
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
||||
Params msg;
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
double value = link->getMessage()[src->getEvidence()];
|
||||
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
|
||||
if (Constants::DEBUG >= 5) {
|
||||
msg[src->getEvidence()] = value;
|
||||
cout << msg << "^" << link->nrEdges() << "-1" ;
|
||||
}
|
||||
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
|
||||
} else {
|
||||
msg = link->getMessage();
|
||||
LogAware::pow (msg, l->nrEdges() - 1);
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " " << "init: " << msg << " " << src->hasEvidence() << endl;
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << msg << "^" << link->nrEdges() << "-1" ;
|
||||
}
|
||||
LogAware::pow (msg, link->nrEdges() - 1);
|
||||
}
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (msg, l->poweredMessage());
|
||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (msg, cl->poweredMessage());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::multiply (msg, l->poweredMessage());
|
||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
||||
Util::multiply (msg, cl->poweredMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
||||
cout << l->poweredMessage() << endl;
|
||||
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " result = " << msg << endl;
|
||||
cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
@ -264,13 +356,14 @@ void
|
||||
CbpSolver::printLinkInformation (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " << l->getMessage() << endl;
|
||||
cout << " next msg = " << l->getNextMessage() << endl;
|
||||
cout << " nr edges = " << l->nrEdges() << endl;
|
||||
cout << " powered = " << l->poweredMessage() << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links_[i]);
|
||||
cout << cl->toString() << ":" << endl;
|
||||
cout << " curr msg = " << cl->getMessage() << endl;
|
||||
cout << " next msg = " << cl->getNextMessage() << endl;
|
||||
cout << " index = " << cl->index() << endl;
|
||||
cout << " nr edges = " << cl->nrEdges() << endl;
|
||||
cout << " powered = " << cl->poweredMessage() << endl;
|
||||
cout << " residual = " << cl->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9,10 +9,12 @@ class Factor;
|
||||
class CbpSolverLink : public SpLink
|
||||
{
|
||||
public:
|
||||
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
|
||||
: SpLink (fn, vn), nrEdges_(c),
|
||||
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned idx, unsigned count)
|
||||
: SpLink (fn, vn), index_(idx), nrEdges_(count),
|
||||
pwdMsg_(vn->range(), LogAware::one()) { }
|
||||
|
||||
unsigned index (void) const { return index_; }
|
||||
|
||||
unsigned nrEdges (void) const { return nrEdges_; }
|
||||
|
||||
const Params& poweredMessage (void) const { return pwdMsg_; }
|
||||
@ -26,6 +28,7 @@ class CbpSolverLink : public SpLink
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned index_;
|
||||
unsigned nrEdges_;
|
||||
Params pwdMsg_;
|
||||
};
|
||||
@ -35,6 +38,8 @@ class CbpSolverLink : public SpLink
|
||||
class CbpSolver : public BpSolver
|
||||
{
|
||||
public:
|
||||
static vector<int> counts;
|
||||
|
||||
CbpSolver (const FactorGraph& fg);
|
||||
|
||||
~CbpSolver (void);
|
||||
@ -51,6 +56,8 @@ class CbpSolver : public BpSolver
|
||||
|
||||
void maxResidualSchedule (void);
|
||||
|
||||
void calculateFactor2VariableMsg (SpLink*);
|
||||
|
||||
Params getVar2FactorMsg (const SpLink*) const;
|
||||
|
||||
void printLinkInformation (void) const;
|
||||
|
@ -105,7 +105,7 @@ Factor::sumOutIndex (unsigned idx)
|
||||
// on the left of `var', with the states of the remaining vars fixed
|
||||
unsigned leftVarOffset = 1;
|
||||
|
||||
for (int i = args_.size() - 1; i > idx; i--) {
|
||||
for (int i = args_.size() - 1; i > (int)idx; i--) {
|
||||
varOffset *= ranges_[i];
|
||||
leftVarOffset *= ranges_[i];
|
||||
}
|
||||
@ -151,7 +151,7 @@ Factor::sumOutIndex (unsigned idx)
|
||||
void
|
||||
Factor::sumOutAllExceptIndex (unsigned idx)
|
||||
{
|
||||
int i = idx;
|
||||
int i = (int)idx;
|
||||
while (args_.size() > i + 1) {
|
||||
sumOutLastVariable();
|
||||
}
|
||||
|
@ -163,7 +163,7 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
reverse (vids.begin(), vids.end());
|
||||
Factor f (vids, ranges, params);
|
||||
reverse (vids.begin(), vids.end());
|
||||
f.reorderArguments (vids);
|
||||
f.reorderArguments (vids);
|
||||
addFactor (f);
|
||||
}
|
||||
is.close();
|
||||
|
Reference in New Issue
Block a user