improve cbp by supporting factors connected to a single var two or more times
This commit is contained in:
parent
ad24a360ce
commit
689244a0d8
@ -325,7 +325,6 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Factor result (src->factor().arguments(),
|
Factor result (src->factor().arguments(),
|
||||||
src->factor().ranges(), msgProduct);
|
src->factor().ranges(), msgProduct);
|
||||||
result.multiply (src->factor());
|
result.multiply (src->factor());
|
||||||
@ -336,18 +335,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
|||||||
}
|
}
|
||||||
result.sumOutAllExcept (dst->varId());
|
result.sumOutAllExcept (dst->varId());
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " marginalized: " ;
|
cout << " marginalized: " << result.params() << endl;
|
||||||
cout << result.params() << endl;
|
|
||||||
}
|
}
|
||||||
const Params& resultParams = result.params();
|
link->getNextMessage() = result.params();
|
||||||
Params& message = link->getNextMessage();
|
LogAware::normalize (link->getNextMessage());
|
||||||
for (unsigned i = 0; i < resultParams.size(); i++) {
|
|
||||||
message[i] = resultParams[i];
|
|
||||||
}
|
|
||||||
LogAware::normalize (message);
|
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " curr msg: " << link->getMessage() << endl;
|
cout << " curr msg: " << link->getMessage() << endl;
|
||||||
cout << " next msg: " << message << endl;
|
cout << " next msg: " << link->getNextMessage() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,7 +183,7 @@ class BpSolver : public Solver
|
|||||||
|
|
||||||
bool converged (void);
|
bool converged (void);
|
||||||
|
|
||||||
void printLinkInformation (void) const;
|
virtual void printLinkInformation (void) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_BPSOLVER_H
|
#endif // HORUS_BPSOLVER_H
|
||||||
|
@ -7,20 +7,6 @@ bool CFactorGraph::checkForIdenticalFactors = true;
|
|||||||
CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
||||||
: freeColor_(0), groundFg_(&fg)
|
: freeColor_(0), groundFg_(&fg)
|
||||||
{
|
{
|
||||||
const VarNodes& varNodes = fg.varNodes();
|
|
||||||
varSignatures_.reserve (varNodes.size());
|
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
|
||||||
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
|
|
||||||
varSignatures_.push_back (Signature (c));
|
|
||||||
}
|
|
||||||
const FacNodes& facNodes = fg.facNodes();
|
|
||||||
facSignatures_.reserve (facNodes.size());
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
unsigned c = facNodes[i]->neighbors().size() + 1;
|
|
||||||
facSignatures_.push_back (Signature (c));
|
|
||||||
}
|
|
||||||
varColors_.resize (varNodes.size());
|
|
||||||
facColors_.resize (facNodes.size());
|
|
||||||
findIdenticalFactors();
|
findIdenticalFactors();
|
||||||
setInitialColors();
|
setInitialColors();
|
||||||
createGroups();
|
createGroups();
|
||||||
@ -77,6 +63,9 @@ CFactorGraph::findIdenticalFactors()
|
|||||||
void
|
void
|
||||||
CFactorGraph::setInitialColors (void)
|
CFactorGraph::setInitialColors (void)
|
||||||
{
|
{
|
||||||
|
varColors_.resize (groundFg_->nrVarNodes());
|
||||||
|
facColors_.resize (groundFg_->nrFacNodes());
|
||||||
|
|
||||||
// create the initial variable colors
|
// create the initial variable colors
|
||||||
VarColorMap colorMap;
|
VarColorMap colorMap;
|
||||||
const VarNodes& varNodes = groundFg_->varNodes();
|
const VarNodes& varNodes = groundFg_->varNodes();
|
||||||
@ -127,31 +116,11 @@ CFactorGraph::createGroups (void)
|
|||||||
while (groupsHaveChanged || nIters == 1) {
|
while (groupsHaveChanged || nIters == 1) {
|
||||||
nIters ++;
|
nIters ++;
|
||||||
|
|
||||||
unsigned prevFactorGroupsSize = facGroups.size();
|
|
||||||
facGroups.clear();
|
|
||||||
// set a new color to the factors with the same signature
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
const Signature& signature = getSignature (facNodes[i]);
|
|
||||||
FacSignMap::iterator it = facGroups.find (signature);
|
|
||||||
if (it == facGroups.end()) {
|
|
||||||
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
|
||||||
}
|
|
||||||
it->second.push_back (facNodes[i]);
|
|
||||||
}
|
|
||||||
for (FacSignMap::iterator it = facGroups.begin();
|
|
||||||
it != facGroups.end(); it++) {
|
|
||||||
Color newColor = getFreeColor();
|
|
||||||
FacNodes& groupMembers = it->second;
|
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
|
||||||
setColor (groupMembers[i], newColor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// set a new color to the variables with the same signature
|
// set a new color to the variables with the same signature
|
||||||
unsigned prevVarGroupsSize = varGroups.size();
|
unsigned prevVarGroupsSize = varGroups.size();
|
||||||
varGroups.clear();
|
varGroups.clear();
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
const Signature& signature = getSignature (varNodes[i]);
|
const VarSignature& signature = getSignature (varNodes[i]);
|
||||||
VarSignMap::iterator it = varGroups.find (signature);
|
VarSignMap::iterator it = varGroups.find (signature);
|
||||||
if (it == varGroups.end()) {
|
if (it == varGroups.end()) {
|
||||||
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||||
@ -167,6 +136,26 @@ CFactorGraph::createGroups (void)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned prevFactorGroupsSize = facGroups.size();
|
||||||
|
facGroups.clear();
|
||||||
|
// set a new color to the factors with the same signature
|
||||||
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
|
const FacSignature& signature = getSignature (facNodes[i]);
|
||||||
|
FacSignMap::iterator it = facGroups.find (signature);
|
||||||
|
if (it == facGroups.end()) {
|
||||||
|
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||||
|
}
|
||||||
|
it->second.push_back (facNodes[i]);
|
||||||
|
}
|
||||||
|
for (FacSignMap::iterator it = facGroups.begin();
|
||||||
|
it != facGroups.end(); it++) {
|
||||||
|
Color newColor = getFreeColor();
|
||||||
|
FacNodes& groupMembers = it->second;
|
||||||
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
|
setColor (groupMembers[i], newColor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||||
|| prevFactorGroupsSize != facGroups.size();
|
|| prevFactorGroupsSize != facGroups.size();
|
||||||
}
|
}
|
||||||
@ -209,67 +198,61 @@ CFactorGraph::createClusters (
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
const Signature&
|
VarSignature
|
||||||
CFactorGraph::getSignature (const VarNode* varNode)
|
CFactorGraph::getSignature (const VarNode* varNode)
|
||||||
{
|
{
|
||||||
Signature& sign = varSignatures_[varNode->getIndex()];
|
|
||||||
Colors::iterator it = sign.colors.begin();
|
|
||||||
const FacNodes& neighs = varNode->neighbors();
|
const FacNodes& neighs = varNode->neighbors();
|
||||||
|
VarSignature sign;
|
||||||
|
sign.reserve (neighs.size() + 1);
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
*it = getColor (neighs[i]);
|
sign.push_back (make_pair (
|
||||||
it ++;
|
getColor (neighs[i]),
|
||||||
*it = neighs[i]->factor().indexOf (varNode->varId());
|
neighs[i]->factor().indexOf (varNode->varId())));
|
||||||
it ++;
|
|
||||||
}
|
}
|
||||||
*it = getColor (varNode);
|
std::sort (sign.begin(), sign.end());
|
||||||
|
sign.push_back (make_pair (getColor (varNode), 0));
|
||||||
return sign;
|
return sign;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const Signature&
|
FacSignature
|
||||||
CFactorGraph::getSignature (const FacNode* facNode)
|
CFactorGraph::getSignature (const FacNode* facNode)
|
||||||
{
|
{
|
||||||
Signature& sign = facSignatures_[facNode->getIndex()];
|
|
||||||
Colors::iterator it = sign.colors.begin();
|
|
||||||
const VarNodes& neighs = facNode->neighbors();
|
const VarNodes& neighs = facNode->neighbors();
|
||||||
|
FacSignature sign;
|
||||||
|
sign.reserve (neighs.size() + 1);
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
*it = getColor (neighs[i]);
|
sign.push_back (getColor (neighs[i]));
|
||||||
it ++;
|
|
||||||
}
|
}
|
||||||
std::sort (sign.colors.begin(), -- sign.colors.end());
|
sign.push_back (getColor (facNode));
|
||||||
*it = getColor (facNode);
|
|
||||||
return sign;
|
return sign;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph*
|
FactorGraph*
|
||||||
CFactorGraph::getGroundFactorGraph (void) const
|
CFactorGraph::getGroundFactorGraph (void)
|
||||||
{
|
{
|
||||||
FactorGraph* fg = new FactorGraph();
|
FactorGraph* fg = new FactorGraph();
|
||||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||||
VarNode* newVar = new VarNode (varClusters_[i]->members()[0]);
|
VarNode* newVar = new VarNode (varClusters_[i]->first());
|
||||||
varClusters_[i]->setRepresentative (newVar);
|
varClusters_[i]->setRepresentative (newVar);
|
||||||
fg->addVarNode (newVar);
|
fg->addVarNode (newVar);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
||||||
const VarClusters& myVarClusters = facClusters_[i]->varClusters();
|
Vars vars;
|
||||||
Vars myGroundVars;
|
const VarClusters& clusters = facClusters_[i]->varClusters();
|
||||||
myGroundVars.reserve (myVarClusters.size());
|
for (unsigned j = 0; j < clusters.size(); j++) {
|
||||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
vars.push_back (clusters[j]->representative());
|
||||||
VarNode* v = myVarClusters[j]->getRepresentative();
|
|
||||||
myGroundVars.push_back (v);
|
|
||||||
}
|
}
|
||||||
|
const Factor& groundFac = facClusters_[i]->first()->factor();
|
||||||
FacNode* fn = new FacNode (Factor (
|
FacNode* fn = new FacNode (Factor (
|
||||||
myGroundVars,
|
vars, groundFac.params(), groundFac.distId()));
|
||||||
facClusters_[i]->members()[0]->factor().params(),
|
|
||||||
facClusters_[i]->members()[0]->factor().distId()));
|
|
||||||
facClusters_[i]->setRepresentative (fn);
|
facClusters_[i]->setRepresentative (fn);
|
||||||
fg->addFacNode (fn);
|
fg->addFacNode (fn);
|
||||||
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
for (unsigned j = 0; j < vars.size(); j++) {
|
||||||
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
|
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fg;
|
return fg;
|
||||||
@ -280,29 +263,21 @@ CFactorGraph::getGroundFactorGraph (void) const
|
|||||||
unsigned
|
unsigned
|
||||||
CFactorGraph::getEdgeCount (
|
CFactorGraph::getEdgeCount (
|
||||||
const FacCluster* fc,
|
const FacCluster* fc,
|
||||||
const VarCluster* vc) const
|
const VarCluster* vc,
|
||||||
|
unsigned index) const
|
||||||
{
|
{
|
||||||
unsigned count = 0;
|
unsigned count = 0;
|
||||||
VarId vid = vc->members().front()->varId();
|
VarId reprVid = vc->representative()->varId();
|
||||||
const FacNodes& members = fc->members();
|
VarNode* groundVar = groundFg_->getVarNode (reprVid);
|
||||||
for (unsigned i = 0; i < members.size(); i++) {
|
const FacNodes& neighs = groundVar->neighbors();
|
||||||
if (members[i]->factor().contains (vid)) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
|
FacNodes::const_iterator it;
|
||||||
|
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
|
||||||
|
if (it != fc->members().end() &&
|
||||||
|
(*it)->factor().indexOf (reprVid) == (int)index) {
|
||||||
count ++;
|
count ++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG > 0) {
|
|
||||||
const VarNodes& vars = vc->members();
|
|
||||||
for (unsigned i = 1; i < vars.size(); i++) {
|
|
||||||
VarId vid = vars[i]->varId();
|
|
||||||
unsigned count2 = 0;
|
|
||||||
for (unsigned i = 0; i < members.size(); i++) {
|
|
||||||
if (members[i]->factor().contains (vid)) {
|
|
||||||
count2 ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert (count == count2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -327,7 +302,6 @@ CFactorGraph::printGroups (
|
|||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
count = 1;
|
count = 1;
|
||||||
cout << endl << "factor groups:" << endl;
|
cout << endl << "factor groups:" << endl;
|
||||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||||
|
@ -10,17 +10,19 @@
|
|||||||
|
|
||||||
class VarCluster;
|
class VarCluster;
|
||||||
class FacCluster;
|
class FacCluster;
|
||||||
class Signature;
|
class VarSignatureHash;
|
||||||
class SignatureHash;
|
class FacSignatureHash;
|
||||||
|
|
||||||
typedef long Color;
|
typedef long Color;
|
||||||
typedef vector<Color> Colors;
|
typedef vector<Color> Colors;
|
||||||
|
typedef vector<std::pair<Color,unsigned>> VarSignature;
|
||||||
|
typedef vector<Color> FacSignature;
|
||||||
|
|
||||||
typedef unordered_map<unsigned, Color> DistColorMap;
|
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||||
typedef unordered_map<unsigned, Colors> VarColorMap;
|
typedef unordered_map<unsigned, Colors> VarColorMap;
|
||||||
|
|
||||||
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
|
typedef unordered_map<VarSignature, VarNodes, VarSignatureHash> VarSignMap;
|
||||||
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
|
typedef unordered_map<FacSignature, FacNodes, FacSignatureHash> FacSignMap;
|
||||||
|
|
||||||
typedef vector<VarCluster*> VarClusters;
|
typedef vector<VarCluster*> VarClusters;
|
||||||
typedef vector<FacCluster*> FacClusters;
|
typedef vector<FacCluster*> FacClusters;
|
||||||
@ -28,53 +30,27 @@ typedef vector<FacCluster*> FacClusters;
|
|||||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||||
|
|
||||||
|
|
||||||
struct Signature
|
struct VarSignatureHash
|
||||||
{
|
{
|
||||||
Signature (unsigned size) : colors(size) { }
|
size_t operator() (const VarSignature &sig) const
|
||||||
|
|
||||||
bool operator< (const Signature& sig) const
|
|
||||||
{
|
{
|
||||||
if (colors.size() < sig.colors.size()) {
|
size_t val = hash<size_t>()(sig.size());
|
||||||
return true;
|
for (unsigned i = 0; i < sig.size(); i++) {
|
||||||
} else if (colors.size() > sig.colors.size()) {
|
val ^= hash<size_t>()(sig[i].first);
|
||||||
return false;
|
val ^= hash<size_t>()(sig[i].second);
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < colors.size(); i++) {
|
|
||||||
if (colors[i] < sig.colors[i]) {
|
|
||||||
return true;
|
|
||||||
} else if (colors[i] > sig.colors[i]) {
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
return val;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator== (const Signature& sig) const
|
|
||||||
{
|
|
||||||
if (colors.size() != sig.colors.size()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < colors.size(); i++) {
|
|
||||||
if (colors[i] != sig.colors[i]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
Colors colors;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct FacSignatureHash
|
||||||
struct SignatureHash
|
|
||||||
{
|
{
|
||||||
size_t operator() (const Signature &sig) const
|
size_t operator() (const FacSignature &sig) const
|
||||||
{
|
{
|
||||||
size_t val = hash<size_t>()(sig.colors.size());
|
size_t val = hash<size_t>()(sig.size());
|
||||||
for (unsigned i = 0; i < sig.colors.size(); i++) {
|
for (unsigned i = 0; i < sig.size(); i++) {
|
||||||
val ^= hash<size_t>()(sig.colors[i]);
|
val ^= hash<size_t>()(sig[i]);
|
||||||
}
|
}
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
@ -87,19 +63,16 @@ class VarCluster
|
|||||||
public:
|
public:
|
||||||
VarCluster (const VarNodes& vs) : members_(vs) { }
|
VarCluster (const VarNodes& vs) : members_(vs) { }
|
||||||
|
|
||||||
|
const VarNode* first (void) const { return members_.front(); }
|
||||||
|
|
||||||
const VarNodes& members (void) const { return members_; }
|
const VarNodes& members (void) const { return members_; }
|
||||||
|
|
||||||
const FacClusters& facClusters (void) const { return facClusters_; }
|
VarNode* representative (void) const { return repr_; }
|
||||||
|
|
||||||
void addFacCluster (FacCluster* fc) { facClusters_.push_back (fc); }
|
|
||||||
|
|
||||||
VarNode* getRepresentative (void) const { return repr_; }
|
|
||||||
|
|
||||||
void setRepresentative (VarNode* vn) { repr_ = vn; }
|
void setRepresentative (VarNode* vn) { repr_ = vn; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
VarNodes members_;
|
VarNodes members_;
|
||||||
FacClusters facClusters_;
|
|
||||||
VarNode* repr_;
|
VarNode* repr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -108,27 +81,18 @@ class FacCluster
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
|
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
|
||||||
: members_(fcs), varClusters_(vcs)
|
: members_(fcs), varClusters_(vcs) { }
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
const FacNode* first (void) const { return members_.front(); }
|
||||||
varClusters_[i]->addFacCluster (this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const FacNodes& members (void) const { return members_; }
|
const FacNodes& members (void) const { return members_; }
|
||||||
|
|
||||||
const VarClusters& varClusters (void) const { return varClusters_; }
|
VarClusters& varClusters (void) { return varClusters_; }
|
||||||
|
|
||||||
FacNode* getRepresentative (void) const { return repr_; }
|
FacNode* representative (void) const { return repr_; }
|
||||||
|
|
||||||
void setRepresentative (FacNode* fn) { repr_ = fn; }
|
void setRepresentative (FacNode* fn) { repr_ = fn; }
|
||||||
|
|
||||||
bool containsGround (const FacNode* fn) const
|
|
||||||
{
|
|
||||||
return std::find (members_.begin(), members_.end(), fn)
|
|
||||||
!= members_.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FacNodes members_;
|
FacNodes members_;
|
||||||
VarClusters varClusters_;
|
VarClusters varClusters_;
|
||||||
@ -147,15 +111,16 @@ class CFactorGraph
|
|||||||
|
|
||||||
const FacClusters& facClusters (void) { return facClusters_; }
|
const FacClusters& facClusters (void) { return facClusters_; }
|
||||||
|
|
||||||
VarNode* getEquivalentVariable (VarId vid)
|
VarNode* getEquivalent (VarId vid)
|
||||||
{
|
{
|
||||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||||
return vc->getRepresentative();
|
return vc->representative();
|
||||||
}
|
}
|
||||||
|
|
||||||
FactorGraph* getGroundFactorGraph (void) const;
|
FactorGraph* getGroundFactorGraph (void);
|
||||||
|
|
||||||
unsigned getEdgeCount (const FacCluster*, const VarCluster*) const;
|
unsigned getEdgeCount (const FacCluster*,
|
||||||
|
const VarCluster*, unsigned index) const;
|
||||||
|
|
||||||
static bool checkForIdenticalFactors;
|
static bool checkForIdenticalFactors;
|
||||||
|
|
||||||
@ -184,11 +149,6 @@ class CFactorGraph
|
|||||||
facColors_[fn->getIndex()] = c;
|
facColors_[fn->getIndex()] = c;
|
||||||
}
|
}
|
||||||
|
|
||||||
VarCluster* getVariableCluster (VarId vid) const
|
|
||||||
{
|
|
||||||
return vid2VarCluster_.find (vid)->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
void findIdenticalFactors (void);
|
void findIdenticalFactors (void);
|
||||||
|
|
||||||
void setInitialColors (void);
|
void setInitialColors (void);
|
||||||
@ -197,17 +157,15 @@ class CFactorGraph
|
|||||||
|
|
||||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||||
|
|
||||||
const Signature& getSignature (const VarNode*);
|
VarSignature getSignature (const VarNode*);
|
||||||
|
|
||||||
const Signature& getSignature (const FacNode*);
|
FacSignature getSignature (const FacNode*);
|
||||||
|
|
||||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||||
|
|
||||||
Color freeColor_;
|
Color freeColor_;
|
||||||
Colors varColors_;
|
Colors varColors_;
|
||||||
Colors facColors_;
|
Colors facColors_;
|
||||||
vector<Signature> varSignatures_;
|
|
||||||
vector<Signature> facSignatures_;
|
|
||||||
VarClusters varClusters_;
|
VarClusters varClusters_;
|
||||||
FacClusters facClusters_;
|
FacClusters facClusters_;
|
||||||
VarId2VarCluster vid2VarCluster_;
|
VarId2VarCluster vid2VarCluster_;
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
|
|
||||||
|
vector<int> CbpSolver::counts;
|
||||||
|
|
||||||
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
||||||
{
|
{
|
||||||
@ -24,16 +25,14 @@ CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
|||||||
Statistics::updateCompressingStatistics (nrGroundVars,
|
Statistics::updateCompressingStatistics (nrGroundVars,
|
||||||
nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless);
|
nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless);
|
||||||
}
|
}
|
||||||
// cout << "uncompressed factor graph:" << endl;
|
if (Constants::DEBUG >= 5) {
|
||||||
// cout << " " << fg.nrVarNodes() << " variables " << endl;
|
cout << "uncompressed factor graph:" << endl;
|
||||||
// cout << " " << fg.nrFacNodes() << " factors " << endl;
|
cout << " " << fg.nrVarNodes() << " variables " << endl;
|
||||||
// cout << "compressed factor graph:" << endl;
|
cout << " " << fg.nrFacNodes() << " factors " << endl;
|
||||||
// cout << " " << fg_->nrVarNodes() << " variables " << endl;
|
cout << "compressed factor graph:" << endl;
|
||||||
// cout << " " << fg_->nrFacNodes() << " factors " << endl;
|
cout << " " << fg_->nrVarNodes() << " variables " << endl;
|
||||||
// Util::printHeader ("Compressed Factor Graph");
|
cout << " " << fg_->nrFacNodes() << " factors " << endl;
|
||||||
// fg_->print();
|
}
|
||||||
// Util::printHeader ("Uncompressed Factor Graph");
|
|
||||||
// fg.print();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -81,8 +80,8 @@ CbpSolver::getPosterioriOf (VarId vid)
|
|||||||
if (runned_ == false) {
|
if (runned_ == false) {
|
||||||
runSolver();
|
runSolver();
|
||||||
}
|
}
|
||||||
assert (cfg_->getEquivalentVariable (vid));
|
assert (cfg_->getEquivalent (vid));
|
||||||
VarNode* var = cfg_->getEquivalentVariable (vid);
|
VarNode* var = cfg_->getEquivalent (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->range(), LogAware::noEvidence());
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
@ -115,7 +114,7 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
|||||||
{
|
{
|
||||||
VarIds eqVarIds;
|
VarIds eqVarIds;
|
||||||
for (unsigned i = 0; i < jointVids.size(); i++) {
|
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||||
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
|
VarNode* vn = cfg_->getEquivalent (jointVids[i]);
|
||||||
eqVarIds.push_back (vn->varId());
|
eqVarIds.push_back (vn->varId());
|
||||||
}
|
}
|
||||||
return BpSolver::getJointDistributionOf (eqVarIds);
|
return BpSolver::getJointDistributionOf (eqVarIds);
|
||||||
@ -130,10 +129,15 @@ CbpSolver::createLinks (void)
|
|||||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||||
const VarClusters& vcs = fcs[i]->varClusters();
|
const VarClusters& vcs = fcs[i]->varClusters();
|
||||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||||
unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]);
|
unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j);
|
||||||
|
if (Constants::DEBUG >= 5) {
|
||||||
|
cout << "creating edge " ;
|
||||||
|
cout << fcs[i]->representative()->getLabel() << " -> " ;
|
||||||
|
cout << vcs[j]->representative()->label();
|
||||||
|
cout << " idx=" << j << ", count=" << count << endl;
|
||||||
|
}
|
||||||
links_.push_back (new CbpSolverLink (
|
links_.push_back (new CbpSolverLink (
|
||||||
fcs[i]->getRepresentative(),
|
fcs[i]->representative(), vcs[j]->representative(), j, count));
|
||||||
vcs[j]->getRepresentative(), c));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -213,47 +217,135 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
void
|
||||||
CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
|
||||||
{
|
{
|
||||||
Params msg;
|
CbpSolverLink* link = static_cast<CbpSolverLink*> (_link);
|
||||||
|
FacNode* src = link->getFactor();
|
||||||
|
const VarNode* dst = link->getVariable();
|
||||||
|
const SpLinkSet& links = ninf(src)->getLinks();
|
||||||
|
// calculate the product of messages that were sent
|
||||||
|
// to factor `src', except from var `dst'
|
||||||
|
unsigned msgSize = 1;
|
||||||
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
|
msgSize *= links[i]->getVariable()->range();
|
||||||
|
}
|
||||||
|
unsigned repetitions = 1;
|
||||||
|
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (int i = links.size() - 1; i >= 0; i--) {
|
||||||
|
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
||||||
|
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
|
||||||
|
if (Constants::DEBUG >= 5) {
|
||||||
|
cout << " message from " << links[i]->getVariable()->label();
|
||||||
|
cout << ": " ;
|
||||||
|
}
|
||||||
|
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||||
|
repetitions *= links[i]->getVariable()->range();
|
||||||
|
if (Constants::DEBUG >= 5) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unsigned range = links[i]->getVariable()->range();
|
||||||
|
Util::add (msgProduct, Params (range, 0.0), repetitions);
|
||||||
|
repetitions *= range;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = links.size() - 1; i >= 0; i--) {
|
||||||
|
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
||||||
|
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
|
||||||
|
if (Constants::DEBUG >= 5) {
|
||||||
|
cout << " message from " << links[i]->getVariable()->label();
|
||||||
|
cout << ": " ;
|
||||||
|
}
|
||||||
|
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||||
|
repetitions *= links[i]->getVariable()->range();
|
||||||
|
if (Constants::DEBUG >= 5) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unsigned range = links[i]->getVariable()->range();
|
||||||
|
Util::multiply (msgProduct, Params (range, 1.0), repetitions);
|
||||||
|
repetitions *= range;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Factor result (src->factor().arguments(),
|
||||||
|
src->factor().ranges(), msgProduct);
|
||||||
|
assert (msgProduct.size() == src->factor().size());
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (unsigned i = 0; i < result.size(); i++) {
|
||||||
|
result[i] += src->factor()[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < result.size(); i++) {
|
||||||
|
result[i] *= src->factor()[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Constants::DEBUG >= 5) {
|
||||||
|
cout << " message product: " << msgProduct << endl;
|
||||||
|
cout << " original factor: " << src->factor().params() << endl;
|
||||||
|
cout << " factor product: " << result.params() << endl;
|
||||||
|
}
|
||||||
|
result.sumOutAllExceptIndex (link->index());
|
||||||
|
if (Constants::DEBUG >= 5) {
|
||||||
|
cout << " marginalized: " << result.params() << endl;
|
||||||
|
}
|
||||||
|
link->getNextMessage() = result.params();
|
||||||
|
LogAware::normalize (link->getNextMessage());
|
||||||
|
if (Constants::DEBUG >= 5) {
|
||||||
|
cout << " curr msg: " << link->getMessage() << endl;
|
||||||
|
cout << " next msg: " << link->getNextMessage() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
CbpSolver::getVar2FactorMsg (const SpLink* _link) const
|
||||||
|
{
|
||||||
|
const CbpSolverLink* link = static_cast<const CbpSolverLink*> (_link);
|
||||||
const VarNode* src = link->getVariable();
|
const VarNode* src = link->getVariable();
|
||||||
const FacNode* dst = link->getFactor();
|
const FacNode* dst = link->getFactor();
|
||||||
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
Params msg;
|
||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->range(), LogAware::noEvidence());
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
double value = link->getMessage()[src->getEvidence()];
|
double value = link->getMessage()[src->getEvidence()];
|
||||||
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
|
if (Constants::DEBUG >= 5) {
|
||||||
|
msg[src->getEvidence()] = value;
|
||||||
|
cout << msg << "^" << link->nrEdges() << "-1" ;
|
||||||
|
}
|
||||||
|
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
|
||||||
} else {
|
} else {
|
||||||
msg = link->getMessage();
|
msg = link->getMessage();
|
||||||
LogAware::pow (msg, l->nrEdges() - 1);
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " " << "init: " << msg << " " << src->hasEvidence() << endl;
|
cout << msg << "^" << link->nrEdges() << "-1" ;
|
||||||
|
}
|
||||||
|
LogAware::pow (msg, link->nrEdges() - 1);
|
||||||
}
|
}
|
||||||
const SpLinkSet& links = ninf(src)->getLinks();
|
const SpLinkSet& links = ninf(src)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getFactor() != dst) {
|
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
||||||
Util::add (msg, l->poweredMessage());
|
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||||
|
Util::add (msg, cl->poweredMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getFactor() != dst) {
|
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
||||||
Util::multiply (msg, l->poweredMessage());
|
Util::multiply (msg, cl->poweredMessage());
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
|
||||||
cout << l->poweredMessage() << endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " result = " << msg << endl;
|
cout << " = " << msg;
|
||||||
}
|
}
|
||||||
return msg;
|
return msg;
|
||||||
}
|
}
|
||||||
@ -264,13 +356,14 @@ void
|
|||||||
CbpSolver::printLinkInformation (void) const
|
CbpSolver::printLinkInformation (void) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
|
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links_[i]);
|
||||||
cout << l->toString() << ":" << endl;
|
cout << cl->toString() << ":" << endl;
|
||||||
cout << " curr msg = " << l->getMessage() << endl;
|
cout << " curr msg = " << cl->getMessage() << endl;
|
||||||
cout << " next msg = " << l->getNextMessage() << endl;
|
cout << " next msg = " << cl->getNextMessage() << endl;
|
||||||
cout << " nr edges = " << l->nrEdges() << endl;
|
cout << " index = " << cl->index() << endl;
|
||||||
cout << " powered = " << l->poweredMessage() << endl;
|
cout << " nr edges = " << cl->nrEdges() << endl;
|
||||||
cout << " residual = " << l->getResidual() << endl;
|
cout << " powered = " << cl->poweredMessage() << endl;
|
||||||
|
cout << " residual = " << cl->getResidual() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,10 +9,12 @@ class Factor;
|
|||||||
class CbpSolverLink : public SpLink
|
class CbpSolverLink : public SpLink
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
|
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned idx, unsigned count)
|
||||||
: SpLink (fn, vn), nrEdges_(c),
|
: SpLink (fn, vn), index_(idx), nrEdges_(count),
|
||||||
pwdMsg_(vn->range(), LogAware::one()) { }
|
pwdMsg_(vn->range(), LogAware::one()) { }
|
||||||
|
|
||||||
|
unsigned index (void) const { return index_; }
|
||||||
|
|
||||||
unsigned nrEdges (void) const { return nrEdges_; }
|
unsigned nrEdges (void) const { return nrEdges_; }
|
||||||
|
|
||||||
const Params& poweredMessage (void) const { return pwdMsg_; }
|
const Params& poweredMessage (void) const { return pwdMsg_; }
|
||||||
@ -26,6 +28,7 @@ class CbpSolverLink : public SpLink
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
unsigned index_;
|
||||||
unsigned nrEdges_;
|
unsigned nrEdges_;
|
||||||
Params pwdMsg_;
|
Params pwdMsg_;
|
||||||
};
|
};
|
||||||
@ -35,6 +38,8 @@ class CbpSolverLink : public SpLink
|
|||||||
class CbpSolver : public BpSolver
|
class CbpSolver : public BpSolver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
static vector<int> counts;
|
||||||
|
|
||||||
CbpSolver (const FactorGraph& fg);
|
CbpSolver (const FactorGraph& fg);
|
||||||
|
|
||||||
~CbpSolver (void);
|
~CbpSolver (void);
|
||||||
@ -51,6 +56,8 @@ class CbpSolver : public BpSolver
|
|||||||
|
|
||||||
void maxResidualSchedule (void);
|
void maxResidualSchedule (void);
|
||||||
|
|
||||||
|
void calculateFactor2VariableMsg (SpLink*);
|
||||||
|
|
||||||
Params getVar2FactorMsg (const SpLink*) const;
|
Params getVar2FactorMsg (const SpLink*) const;
|
||||||
|
|
||||||
void printLinkInformation (void) const;
|
void printLinkInformation (void) const;
|
||||||
|
@ -105,7 +105,7 @@ Factor::sumOutIndex (unsigned idx)
|
|||||||
// on the left of `var', with the states of the remaining vars fixed
|
// on the left of `var', with the states of the remaining vars fixed
|
||||||
unsigned leftVarOffset = 1;
|
unsigned leftVarOffset = 1;
|
||||||
|
|
||||||
for (int i = args_.size() - 1; i > idx; i--) {
|
for (int i = args_.size() - 1; i > (int)idx; i--) {
|
||||||
varOffset *= ranges_[i];
|
varOffset *= ranges_[i];
|
||||||
leftVarOffset *= ranges_[i];
|
leftVarOffset *= ranges_[i];
|
||||||
}
|
}
|
||||||
@ -151,7 +151,7 @@ Factor::sumOutIndex (unsigned idx)
|
|||||||
void
|
void
|
||||||
Factor::sumOutAllExceptIndex (unsigned idx)
|
Factor::sumOutAllExceptIndex (unsigned idx)
|
||||||
{
|
{
|
||||||
int i = idx;
|
int i = (int)idx;
|
||||||
while (args_.size() > i + 1) {
|
while (args_.size() > i + 1) {
|
||||||
sumOutLastVariable();
|
sumOutLastVariable();
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user