add support to (real) lifted belief propagation

This commit is contained in:
Tiago Gomes 2012-05-31 21:12:46 +01:00
parent 22780c4559
commit 22d8876d33
10 changed files with 646 additions and 830 deletions

View File

@ -82,7 +82,7 @@ BpSolver::getPosterioriOf (VarId vid)
probs[var->getEvidence()] = LogAware::withEvidence();
} else {
probs.resize (var->range(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks();
const BpLinks& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) {
probs += links[i]->message();
@ -120,7 +120,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
return getJointByConditioning (jointVarIds);
} else {
Factor res (facNodes[idx]->factor());
const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
const BpLinks& links = ninf(facNodes[idx])->getLinks();
for (size_t i = 0; i < links.size(); i++) {
Factor msg ({links[i]->varNode()->varId()},
{links[i]->varNode()->range()},
@ -194,7 +194,7 @@ BpSolver::createLinks (void)
for (size_t i = 0; i < facNodes.size(); i++) {
const VarNodes& neighbors = facNodes[i]->neighbors();
for (size_t j = 0; j < neighbors.size(); j++) {
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
links_.push_back (new BpLink (facNodes[i], neighbors[j]));
}
}
}
@ -224,7 +224,7 @@ BpSolver::maxResidualSchedule (void)
}
SortedOrder::iterator it = sortedOrder_.begin();
SpLink* link = *it;
BpLink* link = *it;
if (link->residual() < BpOptions::accuracy) {
return;
}
@ -237,11 +237,11 @@ BpSolver::maxResidualSchedule (void)
const FacNodes& factorNeighbors = link->varNode()->neighbors();
for (size_t i = 0; i < factorNeighbors.size(); i++) {
if (factorNeighbors[i] != link->facNode()) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
for (size_t j = 0; j < links.size(); j++) {
if (links[j]->varNode() != link->varNode()) {
calculateMessage (links[j]);
SpLinkMap::iterator iter = linkMap_.find (links[j]);
BpLinkMap::iterator iter = linkMap_.find (links[j]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[j]);
}
@ -257,11 +257,11 @@ BpSolver::maxResidualSchedule (void)
void
BpSolver::calcFactorToVarMsg (SpLink* link)
BpSolver::calcFactorToVarMsg (BpLink* link)
{
FacNode* src = link->facNode();
const VarNode* dst = link->varNode();
const SpLinkSet& links = ninf(src)->getLinks();
const BpLinks& links = ninf(src)->getLinks();
// calculate the product of messages that were sent
// to factor `src', except from var `dst'
unsigned reps = 1;
@ -321,7 +321,7 @@ BpSolver::calcFactorToVarMsg (SpLink* link)
Params
BpSolver::getVarToFactorMsg (const SpLink* link) const
BpSolver::getVarToFactorMsg (const BpLink* link) const
{
const VarNode* src = link->varNode();
Params msg;
@ -334,8 +334,8 @@ BpSolver::getVarToFactorMsg (const SpLink* link) const
if (Constants::SHOW_BP_CALCS) {
cout << msg;
}
SpLinkSet::const_iterator it;
const SpLinkSet& links = ninf (src)->getLinks();
BpLinks::const_iterator it;
const BpLinks& links = ninf (src)->getLinks();
if (Globals::logDomain) {
for (it = links.begin(); it != links.end(); ++it) {
msg += (*it)->message();
@ -432,8 +432,8 @@ BpSolver::initializeSolver (void)
for (size_t i = 0; i < links_.size(); i++) {
FacNode* src = links_[i]->facNode();
VarNode* dst = links_[i]->varNode();
ninf (dst)->addSpLink (links_[i]);
ninf (src)->addSpLink (links_[i]);
ninf (dst)->addBpLink (links_[i]);
ninf (src)->addBpLink (links_[i]);
}
}
@ -491,7 +491,7 @@ void
BpSolver::printLinkInformation (void) const
{
for (size_t i = 0; i < links_.size(); i++) {
SpLink* l = links_[i];
BpLink* l = links_[i];
cout << l->toString() << ":" << endl;
cout << " curr msg = " ;
cout << l->message() << endl;

View File

@ -13,10 +13,10 @@
using namespace std;
class SpLink
class BpLink
{
public:
SpLink (FacNode* fn, VarNode* vn)
BpLink (FacNode* fn, VarNode* vn)
{
fac_ = fn;
var_ = vn;
@ -27,7 +27,7 @@ class SpLink
residual_ = 0.0;
}
virtual ~SpLink (void) { };
virtual ~BpLink (void) { };
FacNode* facNode (void) const { return fac_; }
@ -70,16 +70,16 @@ class SpLink
double residual_;
};
typedef vector<SpLink*> SpLinkSet;
typedef vector<BpLink*> BpLinks;
class SPNodeInfo
{
public:
void addSpLink (SpLink* link) { links_.push_back (link); }
const SpLinkSet& getLinks (void) { return links_; }
void addBpLink (BpLink* link) { links_.push_back (link); }
const BpLinks& getLinks (void) { return links_; }
private:
SpLinkSet links_;
BpLinks links_;
};
@ -105,9 +105,9 @@ class BpSolver : public Solver
virtual void maxResidualSchedule (void);
virtual void calcFactorToVarMsg (SpLink*);
virtual void calcFactorToVarMsg (BpLink*);
virtual Params getVarToFactorMsg (const SpLink*) const;
virtual Params getVarToFactorMsg (const BpLink*) const;
virtual Params getJointByConditioning (const VarIds&) const;
@ -121,7 +121,7 @@ class BpSolver : public Solver
return facsI_[fac->getIndex()];
}
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
{
if (Globals::verbosity > 2) {
cout << "calculating & updating " << link->toString() << endl;
@ -133,7 +133,7 @@ class BpSolver : public Solver
link->updateMessage();
}
void calculateMessage (SpLink* link, bool calcResidual = true)
void calculateMessage (BpLink* link, bool calcResidual = true)
{
if (Globals::verbosity > 2) {
cout << "calculating " << link->toString() << endl;
@ -144,7 +144,7 @@ class BpSolver : public Solver
}
}
void updateMessage (SpLink* link)
void updateMessage (BpLink* link)
{
link->updateMessage();
if (Globals::verbosity > 2) {
@ -154,24 +154,24 @@ class BpSolver : public Solver
struct CompareResidual
{
inline bool operator() (const SpLink* link1, const SpLink* link2)
inline bool operator() (const BpLink* link1, const BpLink* link2)
{
return link1->residual() > link2->residual();
}
};
SpLinkSet links_;
BpLinks links_;
unsigned nIters_;
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_;
bool runned_;
const FactorGraph* fg_;
typedef multiset<SpLink*, CompareResidual> SortedOrder;
typedef multiset<BpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_;
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
SpLinkMap linkMap_;
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
BpLinkMap linkMap_;
private:
void initializeSolver (void);

View File

@ -1,318 +0,0 @@
#include "CFactorGraph.h"
#include "Factor.h"
bool CFactorGraph::checkForIdenticalFactors = true;
CFactorGraph::CFactorGraph (const FactorGraph& fg)
: freeColor_(0), groundFg_(&fg)
{
findIdenticalFactors();
setInitialColors();
createGroups();
}
CFactorGraph::~CFactorGraph (void)
{
for (size_t i = 0; i < varClusters_.size(); i++) {
delete varClusters_[i];
}
for (size_t i = 0; i < facClusters_.size(); i++) {
delete facClusters_[i];
}
}
void
CFactorGraph::findIdenticalFactors()
{
const FacNodes& facNodes = groundFg_->facNodes();
if (checkForIdenticalFactors == false ||
facNodes.size() == 1) {
return;
}
for (size_t i = 0; i < facNodes.size(); i++) {
facNodes[i]->factor().setDistId (Util::maxUnsigned());
}
unsigned groupCount = 1;
for (size_t i = 0; i < facNodes.size() - 1; i++) {
Factor& f1 = facNodes[i]->factor();
if (f1.distId() != Util::maxUnsigned()) {
continue;
}
f1.setDistId (groupCount);
for (size_t j = i + 1; j < facNodes.size(); j++) {
Factor& f2 = facNodes[j]->factor();
if (f2.distId() != Util::maxUnsigned()) {
continue;
}
if (f1.size() == f2.size() &&
f1.ranges() == f2.ranges() &&
f1.params() == f2.params()) {
f2.setDistId (groupCount);
}
}
groupCount ++;
}
}
void
CFactorGraph::setInitialColors (void)
{
varColors_.resize (groundFg_->nrVarNodes());
facColors_.resize (groundFg_->nrFacNodes());
// create the initial variable colors
VarColorMap colorMap;
const VarNodes& varNodes = groundFg_->varNodes();
for (size_t i = 0; i < varNodes.size(); i++) {
unsigned range = varNodes[i]->range();
VarColorMap::iterator it = colorMap.find (range);
if (it == colorMap.end()) {
it = colorMap.insert (make_pair (
range, Colors (range + 1, -1))).first;
}
unsigned idx = varNodes[i]->hasEvidence()
? varNodes[i]->getEvidence()
: range;
Colors& stateColors = it->second;
if (stateColors[idx] == -1) {
stateColors[idx] = getFreeColor();
}
setColor (varNodes[i], stateColors[idx]);
}
const FacNodes& facNodes = groundFg_->facNodes();
// create the initial factor colors
DistColorMap distColors;
for (size_t i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor().distId();
DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) {
it = distColors.insert (make_pair (distId, getFreeColor())).first;
}
setColor (facNodes[i], it->second);
}
}
void
CFactorGraph::createGroups (void)
{
VarSignMap varGroups;
FacSignMap facGroups;
unsigned nIters = 0;
bool groupsHaveChanged = true;
const VarNodes& varNodes = groundFg_->varNodes();
const FacNodes& facNodes = groundFg_->facNodes();
while (groupsHaveChanged || nIters == 1) {
nIters ++;
// set a new color to the variables with the same signature
size_t prevVarGroupsSize = varGroups.size();
varGroups.clear();
for (size_t i = 0; i < varNodes.size(); i++) {
const VarSignature& signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signature, VarNodes())).first;
}
it->second.push_back (varNodes[i]);
}
for (VarSignMap::iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
Color newColor = getFreeColor();
VarNodes& groupMembers = it->second;
for (size_t i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
size_t prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (size_t i = 0; i < facNodes.size(); i++) {
const FacSignature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
Color newColor = getFreeColor();
FacNodes& groupMembers = it->second;
for (size_t i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != facGroups.size();
}
// printGroups (varGroups, facGroups);
createClusters (varGroups, facGroups);
}
void
CFactorGraph::createClusters (
const VarSignMap& varGroups,
const FacSignMap& facGroups)
{
varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
const VarNodes& groupVars = it->second;
VarCluster* vc = new VarCluster (groupVars);
for (size_t i = 0; i < groupVars.size(); i++) {
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
}
varClusters_.push_back (vc);
}
facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
FacNode* groupFactor = it->second[0];
const VarNodes& neighs = groupFactor->neighbors();
VarClusters varClusters;
varClusters.reserve (neighs.size());
for (size_t i = 0; i < neighs.size(); i++) {
VarId vid = neighs[i]->varId();
varClusters.push_back (vid2VarCluster_.find (vid)->second);
}
facClusters_.push_back (new FacCluster (it->second, varClusters));
}
}
VarSignature
CFactorGraph::getSignature (const VarNode* varNode)
{
const FacNodes& neighs = varNode->neighbors();
VarSignature sign;
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (make_pair (
getColor (neighs[i]),
neighs[i]->factor().indexOf (varNode->varId())));
}
std::sort (sign.begin(), sign.end());
sign.push_back (make_pair (getColor (varNode), 0));
return sign;
}
FacSignature
CFactorGraph::getSignature (const FacNode* facNode)
{
const VarNodes& neighs = facNode->neighbors();
FacSignature sign;
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (getColor (neighs[i]));
}
sign.push_back (getColor (facNode));
return sign;
}
FactorGraph*
CFactorGraph::getGroundFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
for (size_t i = 0; i < varClusters_.size(); i++) {
VarNode* newVar = new VarNode (varClusters_[i]->first());
varClusters_[i]->setRepresentative (newVar);
fg->addVarNode (newVar);
}
for (size_t i = 0; i < facClusters_.size(); i++) {
Vars vars;
const VarClusters& clusters = facClusters_[i]->varClusters();
for (size_t j = 0; j < clusters.size(); j++) {
vars.push_back (clusters[j]->representative());
}
const Factor& groundFac = facClusters_[i]->first()->factor();
FacNode* fn = new FacNode (Factor (
vars, groundFac.params(), groundFac.distId()));
facClusters_[i]->setRepresentative (fn);
fg->addFacNode (fn);
for (size_t j = 0; j < vars.size(); j++) {
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
}
}
return fg;
}
unsigned
CFactorGraph::getEdgeCount (
const FacCluster* fc,
const VarCluster* vc,
size_t index) const
{
unsigned count = 0;
VarId reprVid = vc->representative()->varId();
VarNode* groundVar = groundFg_->getVarNode (reprVid);
const FacNodes& neighs = groundVar->neighbors();
for (size_t i = 0; i < neighs.size(); i++) {
FacNodes::const_iterator it;
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
if (it != fc->members().end() &&
(*it)->factor().indexOf (reprVid) == index) {
count ++;
}
}
return count;
}
void
CFactorGraph::printGroups (
const VarSignMap& varGroups,
const FacSignMap& facGroups) const
{
unsigned count = 1;
cout << "variable groups:" << endl;
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
const VarNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->label() << " " ;
}
count ++;
cout << endl;
}
}
count = 1;
cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << ++count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->getLabel() << " " ;
}
count ++;
cout << endl;
}
}
}

View File

@ -1,176 +0,0 @@
#ifndef HORUS_CFACTORGRAPH_H
#define HORUS_CFACTORGRAPH_H
#include <unordered_map>
#include "FactorGraph.h"
#include "Factor.h"
#include "Util.h"
#include "Horus.h"
class VarCluster;
class FacCluster;
class VarSignatureHash;
class FacSignatureHash;
typedef long Color;
typedef vector<Color> Colors;
typedef vector<std::pair<Color,unsigned>> VarSignature;
typedef vector<Color> FacSignature;
typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<unsigned, Colors> VarColorMap;
typedef unordered_map<VarSignature, VarNodes, VarSignatureHash> VarSignMap;
typedef unordered_map<FacSignature, FacNodes, FacSignatureHash> FacSignMap;
typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
struct VarSignatureHash
{
size_t operator() (const VarSignature &sig) const
{
size_t val = hash<size_t>()(sig.size());
for (size_t i = 0; i < sig.size(); i++) {
val ^= hash<size_t>()(sig[i].first);
val ^= hash<size_t>()(sig[i].second);
}
return val;
}
};
struct FacSignatureHash
{
size_t operator() (const FacSignature &sig) const
{
size_t val = hash<size_t>()(sig.size());
for (size_t i = 0; i < sig.size(); i++) {
val ^= hash<size_t>()(sig[i]);
}
return val;
}
};
class VarCluster
{
public:
VarCluster (const VarNodes& vs) : members_(vs) { }
const VarNode* first (void) const { return members_.front(); }
const VarNodes& members (void) const { return members_; }
VarNode* representative (void) const { return repr_; }
void setRepresentative (VarNode* vn) { repr_ = vn; }
private:
VarNodes members_;
VarNode* repr_;
};
class FacCluster
{
public:
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
: members_(fcs), varClusters_(vcs) { }
const FacNode* first (void) const { return members_.front(); }
const FacNodes& members (void) const { return members_; }
VarClusters& varClusters (void) { return varClusters_; }
FacNode* representative (void) const { return repr_; }
void setRepresentative (FacNode* fn) { repr_ = fn; }
private:
FacNodes members_;
VarClusters varClusters_;
FacNode* repr_;
};
class CFactorGraph
{
public:
CFactorGraph (const FactorGraph&);
~CFactorGraph (void);
const VarClusters& varClusters (void) { return varClusters_; }
const FacClusters& facClusters (void) { return facClusters_; }
VarNode* getEquivalent (VarId vid)
{
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->representative();
}
FactorGraph* getGroundFactorGraph (void);
unsigned getEdgeCount (const FacCluster*,
const VarCluster*, size_t index) const;
static bool checkForIdenticalFactors;
private:
Color getFreeColor (void)
{
++ freeColor_;
return freeColor_ - 1;
}
Color getColor (const VarNode* vn) const
{
return varColors_[vn->getIndex()];
}
Color getColor (const FacNode* fn) const {
return facColors_[fn->getIndex()];
}
void setColor (const VarNode* vn, Color c)
{
varColors_[vn->getIndex()] = c;
}
void setColor (const FacNode* fn, Color c)
{
facColors_[fn->getIndex()] = c;
}
void findIdenticalFactors (void);
void setInitialColors (void);
void createGroups (void);
void createClusters (const VarSignMap&, const FacSignMap&);
VarSignature getSignature (const VarNode*);
FacSignature getSignature (const FacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
Color freeColor_;
Colors varColors_;
Colors facColors_;
VarClusters varClusters_;
FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_;
const FactorGraph* groundFg_;
};
#endif // HORUS_CFACTORGRAPH_H

View File

@ -1,22 +1,32 @@
#include "CbpSolver.h"
#include "WeightedBpSolver.h"
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
bool CbpSolver::checkForIdenticalFactors = true;
CbpSolver::CbpSolver (const FactorGraph& fg)
: Solver (fg), freeColor_(0)
{
cfg_ = new CFactorGraph (fg);
fg_ = cfg_->getGroundFactorGraph();
findIdenticalFactors();
setInitialColors();
createGroups();
compressedFg_ = getCompressedFactorGraph();
solver_ = new WeightedBpSolver (*compressedFg_, getWeights());
}
CbpSolver::~CbpSolver (void)
{
delete cfg_;
delete fg_;
for (size_t i = 0; i < links_.size(); i++) {
delete links_[i];
delete solver_;
delete compressedFg_;
for (size_t i = 0; i < varClusters_.size(); i++) {
delete varClusters_[i];
}
for (size_t i = 0; i < facClusters_.size(); i++) {
delete facClusters_[i];
}
links_.clear();
}
@ -38,44 +48,28 @@ CbpSolver::printSolverFlags (void) const
ss << ",accuracy=" << BpOptions::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",chkif=" <<
Util::toString (CFactorGraph::checkForIdenticalFactors);
Util::toString (CbpSolver::checkForIdenticalFactors);
ss << "]" ;
cout << ss.str() << endl;
}
Params
CbpSolver::solveQuery (VarIds queryVids)
{
assert (queryVids.empty() == false);
return queryVids.size() == 1
? getPosterioriOf (queryVids[0])
: getJointDistributionOf (queryVids);
}
Params
CbpSolver::getPosterioriOf (VarId vid)
{
if (runned_ == false) {
runSolver();
}
assert (cfg_->getEquivalent (vid));
VarNode* var = cfg_->getEquivalent (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = LogAware::withEvidence();
} else {
probs.resize (var->range(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
probs += l->powMessage();
}
LogAware::normalize (probs);
Util::exp (probs);
} else {
for (size_t i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
probs *= l->powMessage();
}
LogAware::normalize (probs);
}
}
return probs;
return solver_->getPosterioriOf (getRepresentative (vid));
}
@ -83,255 +77,320 @@ CbpSolver::getPosterioriOf (VarId vid)
Params
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
{
VarIds eqVarIds;
VarIds representatives;
for (size_t i = 0; i < jointVids.size(); i++) {
VarNode* vn = cfg_->getEquivalent (jointVids[i]);
eqVarIds.push_back (vn->varId());
representatives.push_back (getRepresentative (jointVids[i]));
}
return BpSolver::getJointDistributionOf (eqVarIds);
return solver_->getJointDistributionOf (representatives);
}
void
CbpSolver::createLinks (void)
CbpSolver::findIdenticalFactors()
{
if (Globals::verbosity > 0) {
cout << "compressed factor graph contains " ;
cout << fg_->nrVarNodes() << " variables and " ;
cout << fg_->nrFacNodes() << " factors " << endl;
cout << endl;
}
const FacClusters& fcs = cfg_->facClusters();
for (size_t i = 0; i < fcs.size(); i++) {
const VarClusters& vcs = fcs[i]->varClusters();
for (size_t j = 0; j < vcs.size(); j++) {
unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j);
if (Globals::verbosity > 1) {
cout << "creating link " ;
cout << fcs[i]->representative()->getLabel();
cout << " -- " ;
cout << vcs[j]->representative()->label();
cout << " idx=" << j << ", count=" << count << endl;
}
links_.push_back (new CbpSolverLink (
fcs[i]->representative(), vcs[j]->representative(), j, count));
}
}
if (Globals::verbosity > 1) {
cout << endl;
}
}
void
CbpSolver::maxResidualSchedule (void)
{
if (nIters_ == 1) {
for (size_t i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it));
if (Globals::verbosity >= 1) {
cout << "calculating " << links_[i]->toString() << endl;
}
}
const FacNodes& facNodes = fg.facNodes();
if (checkForIdenticalFactors == false ||
facNodes.size() == 1) {
return;
}
for (size_t c = 0; c < links_.size(); c++) {
if (Globals::verbosity > 1) {
cout << endl << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); ++it) {
cout << " " << setw (30) << left << (*it)->toString();
cout << "residual = " << (*it)->residual() << endl;
}
}
SortedOrder::iterator it = sortedOrder_.begin();
SpLink* link = *it;
if (Globals::verbosity >= 1) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
}
if (link->residual() < BpOptions::accuracy) {
return;
}
link->updateMessage();
link->clearResidual();
sortedOrder_.erase (it);
linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin
const FacNodes& factorNeighbors = link->varNode()->neighbors();
for (size_t i = 0; i < factorNeighbors.size(); i++) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
for (size_t j = 0; j < links.size(); j++) {
if (links[j]->varNode() != link->varNode()) {
if (Globals::verbosity > 1) {
cout << " calculating " << links[j]->toString() << endl;
}
calculateMessage (links[j]);
SpLinkMap::iterator iter = linkMap_.find (links[j]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[j]);
}
}
}
// in counting bp, the message that a variable X sends to
// to a factor F depends on the message that F sent to the X
const SpLinkSet& links = ninf(link->facNode())->getLinks();
for (size_t i = 0; i < links.size(); i++) {
if (links[i]->varNode() != link->varNode()) {
if (Globals::verbosity > 1) {
cout << " calculating " << links[i]->toString() << endl;
}
calculateMessage (links[i]);
SpLinkMap::iterator iter = linkMap_.find (links[i]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[i]);
for (size_t i = 0; i < facNodes.size(); i++) {
facNodes[i]->factor().setDistId (Util::maxUnsigned());
}
unsigned groupCount = 1;
for (size_t i = 0; i < facNodes.size() - 1; i++) {
Factor& f1 = facNodes[i]->factor();
if (f1.distId() != Util::maxUnsigned()) {
continue;
}
f1.setDistId (groupCount);
for (size_t j = i + 1; j < facNodes.size(); j++) {
Factor& f2 = facNodes[j]->factor();
if (f2.distId() != Util::maxUnsigned()) {
continue;
}
if (f1.size() == f2.size() &&
f1.ranges() == f2.ranges() &&
f1.params() == f2.params()) {
f2.setDistId (groupCount);
}
}
groupCount ++;
}
}
void
CbpSolver::calcFactorToVarMsg (SpLink* _link)
CbpSolver::setInitialColors (void)
{
CbpSolverLink* link = static_cast<CbpSolverLink*> (_link);
FacNode* src = link->facNode();
const VarNode* dst = link->varNode();
const SpLinkSet& links = ninf(src)->getLinks();
// calculate the product of messages that were sent
// to factor `src', except from var `dst'
unsigned reps = 1;
unsigned msgSize = Util::sizeExpected (src->factor().ranges());
Params msgProduct (msgSize, LogAware::multIdenty());
if (Globals::logDomain) {
for (size_t i = links.size(); i-- > 0; ) {
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
if ( ! (cl->varNode() == dst && cl->index() == link->index())) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
varColors_.resize (fg.nrVarNodes());
facColors_.resize (fg.nrFacNodes());
// create the initial variable colors
VarColorMap colorMap;
const VarNodes& varNodes = fg.varNodes();
for (size_t i = 0; i < varNodes.size(); i++) {
unsigned range = varNodes[i]->range();
VarColorMap::iterator it = colorMap.find (range);
if (it == colorMap.end()) {
it = colorMap.insert (make_pair (
range, Colors (range + 1, -1))).first;
}
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::plus<double>());
if (Constants::SHOW_BP_CALCS) {
cout << endl;
unsigned idx = varNodes[i]->hasEvidence()
? varNodes[i]->getEvidence()
: range;
Colors& stateColors = it->second;
if (stateColors[idx] == -1) {
stateColors[idx] = getNewColor();
}
setColor (varNodes[i], stateColors[idx]);
}
reps *= links[i]->varNode()->range();
const FacNodes& facNodes = fg.facNodes();
// create the initial factor colors
DistColorMap distColors;
for (size_t i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor().distId();
DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) {
it = distColors.insert (make_pair (distId, getNewColor())).first;
}
} else {
for (size_t i = links.size(); i-- > 0; ) {
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
if ( ! (cl->varNode() == dst && cl->index() == link->index())) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
setColor (facNodes[i], it->second);
}
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::multiplies<double>());
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
}
reps *= links[i]->varNode()->range();
}
}
Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct);
assert (msgProduct.size() == src->factor().size());
if (Globals::logDomain) {
result.params() += src->factor().params();
} else {
result.params() *= src->factor().params();
}
if (Constants::SHOW_BP_CALCS) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl;
}
result.sumOutAllExceptIndex (link->index());
if (Constants::SHOW_BP_CALCS) {
cout << " marginalized: " << result.params() << endl;
}
link->nextMessage() = result.params();
LogAware::normalize (link->nextMessage());
if (Constants::SHOW_BP_CALCS) {
cout << " curr msg: " << link->message() << endl;
cout << " next msg: " << link->nextMessage() << endl;
}
}
Params
CbpSolver::getVarToFactorMsg (const SpLink* _link) const
{
const CbpSolverLink* link = static_cast<const CbpSolverLink*> (_link);
const VarNode* src = link->varNode();
const FacNode* dst = link->facNode();
Params msg;
if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence());
double value = link->message()[src->getEvidence()];
if (Constants::SHOW_BP_CALCS) {
msg[src->getEvidence()] = value;
cout << msg << "^" << link->nrEdges() << "-1" ;
}
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
} else {
msg = link->message();
if (Constants::SHOW_BP_CALCS) {
cout << msg << "^" << link->nrEdges() << "-1" ;
}
LogAware::pow (msg, link->nrEdges() - 1);
}
const SpLinkSet& links = ninf(src)->getLinks();
if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->facNode() == dst && cl->index() == link->index())) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
msg += cl->powMessage();
}
}
} else {
for (size_t i = 0; i < links.size(); i++) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->facNode() == dst && cl->index() == link->index())) {
msg *= cl->powMessage();
if (Constants::SHOW_BP_CALCS) {
cout << " x " << cl->nextMessage() << "^" << link->nrEdges();
}
}
}
}
if (Constants::SHOW_BP_CALCS) {
cout << " = " << msg;
}
return msg;
}
void
CbpSolver::printLinkInformation (void) const
CbpSolver::createGroups (void)
{
for (size_t i = 0; i < links_.size(); i++) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links_[i]);
cout << cl->toString() << ":" << endl;
cout << " curr msg = " << cl->message() << endl;
cout << " next msg = " << cl->nextMessage() << endl;
cout << " index = " << cl->index() << endl;
cout << " nr edges = " << cl->nrEdges() << endl;
cout << " powered = " << cl->powMessage() << endl;
cout << " residual = " << cl->residual() << endl;
VarSignMap varGroups;
FacSignMap facGroups;
unsigned nIters = 0;
bool groupsHaveChanged = true;
const VarNodes& varNodes = fg.varNodes();
const FacNodes& facNodes = fg.facNodes();
while (groupsHaveChanged || nIters == 1) {
nIters ++;
// set a new color to the variables with the same signature
size_t prevVarGroupsSize = varGroups.size();
varGroups.clear();
for (size_t i = 0; i < varNodes.size(); i++) {
const VarSignature& signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signature, VarNodes())).first;
}
it->second.push_back (varNodes[i]);
}
for (VarSignMap::iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
Color newColor = getNewColor();
VarNodes& groupMembers = it->second;
for (size_t i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
size_t prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (size_t i = 0; i < facNodes.size(); i++) {
const FacSignature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
Color newColor = getNewColor();
FacNodes& groupMembers = it->second;
for (size_t i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != facGroups.size();
}
// printGroups (varGroups, facGroups);
createClusters (varGroups, facGroups);
}
void
CbpSolver::createClusters (
const VarSignMap& varGroups,
const FacSignMap& facGroups)
{
varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
const VarNodes& groupVars = it->second;
VarCluster* vc = new VarCluster (groupVars);
for (size_t i = 0; i < groupVars.size(); i++) {
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
}
varClusters_.push_back (vc);
}
facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
FacNode* groupFactor = it->second[0];
const VarNodes& neighs = groupFactor->neighbors();
VarClusters varClusters;
varClusters.reserve (neighs.size());
for (size_t i = 0; i < neighs.size(); i++) {
VarId vid = neighs[i]->varId();
varClusters.push_back (vid2VarCluster_.find (vid)->second);
}
facClusters_.push_back (new FacCluster (it->second, varClusters));
}
}
VarSignature
CbpSolver::getSignature (const VarNode* varNode)
{
const FacNodes& neighs = varNode->neighbors();
VarSignature sign;
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (make_pair (
getColor (neighs[i]),
neighs[i]->factor().indexOf (varNode->varId())));
}
std::sort (sign.begin(), sign.end());
sign.push_back (make_pair (getColor (varNode), 0));
return sign;
}
FacSignature
CbpSolver::getSignature (const FacNode* facNode)
{
const VarNodes& neighs = facNode->neighbors();
FacSignature sign;
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (getColor (neighs[i]));
}
sign.push_back (getColor (facNode));
return sign;
}
FactorGraph*
CbpSolver::getCompressedFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
for (size_t i = 0; i < varClusters_.size(); i++) {
VarNode* newVar = new VarNode (varClusters_[i]->first());
varClusters_[i]->setRepresentative (newVar);
fg->addVarNode (newVar);
}
for (size_t i = 0; i < facClusters_.size(); i++) {
Vars vars;
const VarClusters& clusters = facClusters_[i]->varClusters();
for (size_t j = 0; j < clusters.size(); j++) {
vars.push_back (clusters[j]->representative());
}
const Factor& groundFac = facClusters_[i]->first()->factor();
FacNode* fn = new FacNode (Factor (
vars, groundFac.params(), groundFac.distId()));
facClusters_[i]->setRepresentative (fn);
fg->addFacNode (fn);
for (size_t j = 0; j < vars.size(); j++) {
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
}
}
return fg;
}
vector<vector<unsigned>>
CbpSolver::getWeights (void) const
{
vector<vector<unsigned>> weights;
weights.reserve (facClusters_.size());
for (size_t i = 0; i < facClusters_.size(); i++) {
const VarClusters& neighs = facClusters_[i]->varClusters();
weights.push_back ({ });
weights.back().reserve (neighs.size());
for (size_t j = 0; j < neighs.size(); j++) {
weights.back().push_back (getWeight (
facClusters_[i], neighs[j], j));
}
}
return weights;
}
unsigned
CbpSolver::getWeight (
const FacCluster* fc,
const VarCluster* vc,
size_t index) const
{
unsigned weight = 0;
VarId reprVid = vc->representative()->varId();
VarNode* groundVar = fg.getVarNode (reprVid);
const FacNodes& neighs = groundVar->neighbors();
for (size_t i = 0; i < neighs.size(); i++) {
FacNodes::const_iterator it;
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
if (it != fc->members().end() &&
(*it)->factor().indexOf (reprVid) == index) {
weight ++;
}
}
return weight;
}
void
CbpSolver::printGroups (
const VarSignMap& varGroups,
const FacSignMap& facGroups) const
{
unsigned count = 1;
cout << "variable groups:" << endl;
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
const VarNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->label() << " " ;
}
count ++;
cout << endl;
}
}
count = 1;
cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << ++count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->getLabel() << " " ;
}
count ++;
cout << endl;
}
}
}

View File

@ -1,40 +1,109 @@
#ifndef HORUS_CBP_H
#define HORUS_CBP_H
#ifndef HORUS_CBPSOLVER_H
#define HORUS_CBPSOLVER_H
#include "BpSolver.h"
#include "CFactorGraph.h"
#include <unordered_map>
class Factor;
#include "Solver.h"
#include "FactorGraph.h"
#include "Util.h"
#include "Horus.h"
class CbpSolverLink : public SpLink
class VarCluster;
class FacCluster;
class VarSignHash;
class FacSignHash;
class WeightedBpSolver;
typedef long Color;
typedef vector<Color> Colors;
typedef vector<std::pair<Color,unsigned>> VarSignature;
typedef vector<Color> FacSignature;
typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<unsigned, Colors> VarColorMap;
typedef unordered_map<VarSignature, VarNodes, VarSignHash> VarSignMap;
typedef unordered_map<FacSignature, FacNodes, FacSignHash> FacSignMap;
typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
struct VarSignHash
{
public:
CbpSolverLink (FacNode* fn, VarNode* vn, size_t idx, unsigned count)
: SpLink (fn, vn), index_(idx), nrEdges_(count),
pwdMsg_(vn->range(), LogAware::one()) { }
size_t index (void) const { return index_; }
unsigned nrEdges (void) const { return nrEdges_; }
const Params& powMessage (void) const { return pwdMsg_; }
void updateMessage (void)
size_t operator() (const VarSignature &sig) const
{
pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_);
LogAware::pow (pwdMsg_, nrEdges_);
size_t val = hash<size_t>()(sig.size());
for (size_t i = 0; i < sig.size(); i++) {
val ^= hash<size_t>()(sig[i].first);
val ^= hash<size_t>()(sig[i].second);
}
return val;
}
};
private:
size_t index_;
unsigned nrEdges_;
Params pwdMsg_;
struct FacSignHash
{
size_t operator() (const FacSignature &sig) const
{
size_t val = hash<size_t>()(sig.size());
for (size_t i = 0; i < sig.size(); i++) {
val ^= hash<size_t>()(sig[i]);
}
return val;
}
};
class CbpSolver : public BpSolver
class VarCluster
{
public:
VarCluster (const VarNodes& vs) : members_(vs) { }
const VarNode* first (void) const { return members_.front(); }
const VarNodes& members (void) const { return members_; }
VarNode* representative (void) const { return repr_; }
void setRepresentative (VarNode* vn) { repr_ = vn; }
private:
VarNodes members_;
VarNode* repr_;
};
class FacCluster
{
public:
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
: members_(fcs), varClusters_(vcs) { }
const FacNode* first (void) const { return members_.front(); }
const FacNodes& members (void) const { return members_; }
VarClusters& varClusters (void) { return varClusters_; }
FacNode* representative (void) const { return repr_; }
void setRepresentative (FacNode* fn) { repr_ = fn; }
private:
FacNodes members_;
VarClusters varClusters_;
FacNode* repr_;
};
class CbpSolver : public Solver
{
public:
CbpSolver (const FactorGraph& fg);
@ -43,24 +112,80 @@ class CbpSolver : public BpSolver
void printSolverFlags (void) const;
Params solveQuery (VarIds);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
static bool checkForIdenticalFactors;
private:
void createLinks (void);
Color getNewColor (void)
{
++ freeColor_;
return freeColor_ - 1;
}
void maxResidualSchedule (void);
Color getColor (const VarNode* vn) const
{
return varColors_[vn->getIndex()];
}
void calcFactorToVarMsg (SpLink*);
Color getColor (const FacNode* fn) const
{
return facColors_[fn->getIndex()];
}
Params getVarToFactorMsg (const SpLink*) const;
void setColor (const VarNode* vn, Color c)
{
varColors_[vn->getIndex()] = c;
}
void printLinkInformation (void) const;
void setColor (const FacNode* fn, Color c)
{
facColors_[fn->getIndex()] = c;
}
CFactorGraph* cfg_;
void findIdenticalFactors (void);
void setInitialColors (void);
void createGroups (void);
void createClusters (const VarSignMap&, const FacSignMap&);
VarSignature getSignature (const VarNode*);
FacSignature getSignature (const FacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
VarId getRepresentative (VarId vid)
{
assert (Util::contains (vid2VarCluster_, vid));
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->representative()->varId();
}
FactorGraph* getCompressedFactorGraph (void);
vector<vector<unsigned>> getWeights (void) const;
unsigned getWeight (const FacCluster*,
const VarCluster*, size_t index) const;
Color freeColor_;
Colors varColors_;
Colors facColors_;
VarClusters varClusters_;
FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_;
const FactorGraph* compressedFg_;
WeightedBpSolver* solver_;
};
#endif // HORUS_CBP_H
#endif // HORUS_CBPSOLVER_H

View File

@ -436,7 +436,7 @@ void runBpSolver (
if (Globals::groundSolver == GroundSolvers::BP) {
solver = new BpSolver (*fg); // FIXME
} else if (Globals::groundSolver == GroundSolvers::CBP) {
CFactorGraph::checkForIdenticalFactors = false;
CbpSolver::checkForIdenticalFactors = false;
solver = new CbpSolver (*fg); // FIXME
} else {
cerr << "error: unknow solver" << endl;

View File

@ -1,19 +1,37 @@
#include "LiftedBpSolver.h"
#include "WeightedBpSolver.h"
#include "FactorGraph.h"
#include "FoveSolver.h"
Params
LiftedBpSolver::getPosterioriOf (const Ground&)
LiftedBpSolver::LiftedBpSolver (const ParfactorList& pfList)
: pfList_(pfList)
{
return Params();
refineParfactors();
solver_ = new WeightedBpSolver (*getFactorGraph(), getWeights());
}
Params
LiftedBpSolver::getJointDistributionOf (const Grounds&)
LiftedBpSolver::getPosterioriOf (const Ground& query)
{
return Params();
vector<PrvGroup> groups = getQueryGroups ({query});
return solver_->getPosterioriOf (groups[0]);
}
Params
LiftedBpSolver::getJointDistributionOf (const Grounds& query)
{
vector<PrvGroup> groups = getQueryGroups (query);
VarIds queryVids;
for (unsigned i = 0; i < groups.size(); i++) {
queryVids.push_back (groups[i]);
}
return solver_->getJointDistributionOf (queryVids);
}
@ -28,3 +46,96 @@ LiftedBpSolver::printSolverFlags (void) const
cout << ss.str() << endl;
}
void
LiftedBpSolver::refineParfactors (void)
{
while (iterate() == false);
if (Globals::verbosity > 2) {
Util::printHeader ("AFTER REFINEMENT");
pfList_.print();
}
}
bool
LiftedBpSolver::iterate (void)
{
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
const ProbFormulas& args = (*it)->arguments();
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
if ((*it)->constr()->isCountNormalized (lvs) == false) {
Parfactors pfs = FoveSolver::countNormalize (*it, lvs);
it = pfList_.removeAndDelete (it);
pfList_.add (pfs);
return false;
}
}
++ it;
}
return true;
}
vector<PrvGroup>
LiftedBpSolver::getQueryGroups (const Grounds& query)
{
vector<PrvGroup> queryGroups;
for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
if ((*it)->containsGround (query[i])) {
queryGroups.push_back ((*it)->findGroup (query[i]));
break;
}
}
}
assert (queryGroups.size() == query.size());
return queryGroups;
}
FactorGraph*
LiftedBpSolver::getFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
vector<PrvGroup> groups = (*it)->getAllGroups();
VarIds varIds;
for (size_t i = 0; i < groups.size(); i++) {
varIds.push_back (groups[i]);
}
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
}
return fg;
}
vector<vector<unsigned>>
LiftedBpSolver::getWeights (void) const
{
vector<vector<unsigned>> weights;
weights.reserve (pfList_.size());
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
const ProbFormulas& args = (*it)->arguments();
weights.push_back ({ });
weights.back().reserve (args.size());
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
weights.back().push_back ((*it)->constr()->getConditionalCount (lvs));
}
}
return weights;
}

View File

@ -3,10 +3,14 @@
#include "ParfactorList.h"
class SpLink;
class FactorGraph;
class WeightedBpSolver;
class LiftedBpSolver
{
public:
LiftedBpSolver (const ParfactorList& pfList) : pfList_(pfList) { }
LiftedBpSolver (const ParfactorList& pfList);
Params getPosterioriOf (const Ground&);
@ -15,7 +19,18 @@ class LiftedBpSolver
void printSolverFlags (void) const;
private:
void refineParfactors (void);
bool iterate (void);
vector<PrvGroup> getQueryGroups (const Grounds&);
FactorGraph* getFactorGraph (void);
vector<vector<unsigned>> getWeights (void) const;
ParfactorList pfList_;
WeightedBpSolver* solver_;
};

View File

@ -50,7 +50,6 @@ HEADERS = \
$(srcdir)/ElimGraph.h \
$(srcdir)/FactorGraph.h \
$(srcdir)/Factor.h \
$(srcdir)/CFactorGraph.h \
$(srcdir)/ConstraintTree.h \
$(srcdir)/Solver.h \
$(srcdir)/VarElimSolver.h \
@ -66,6 +65,7 @@ HEADERS = \
$(srcdir)/LiftedUtils.h \
$(srcdir)/TinySet.h \
$(srcdir)/LiftedBpSolver.h \
$(srcdir)/WeightedBpSolver.h \
$(srcdir)/Util.h \
$(srcdir)/Horus.h
@ -75,7 +75,6 @@ CPP_SOURCES = \
$(srcdir)/ElimGraph.cpp \
$(srcdir)/FactorGraph.cpp \
$(srcdir)/Factor.cpp \
$(srcdir)/CFactorGraph.cpp \
$(srcdir)/ConstraintTree.cpp \
$(srcdir)/Var.cpp \
$(srcdir)/Solver.cpp \
@ -90,6 +89,7 @@ CPP_SOURCES = \
$(srcdir)/LiftedUtils.cpp \
$(srcdir)/Util.cpp \
$(srcdir)/LiftedBpSolver.cpp \
$(srcdir)/WeightedBpSolver.cpp \
$(srcdir)/HorusYap.cpp \
$(srcdir)/HorusCli.cpp
@ -99,7 +99,6 @@ OBJS = \
ElimGraph.o \
FactorGraph.o \
Factor.o \
CFactorGraph.o \
ConstraintTree.o \
Var.o \
Solver.o \
@ -114,6 +113,7 @@ OBJS = \
LiftedUtils.o \
Util.o \
LiftedBpSolver.o \
WeightedBpSolver.o \
HorusYap.o
HCLI_OBJS = \
@ -122,7 +122,6 @@ HCLI_OBJS = \
ElimGraph.o \
FactorGraph.o \
Factor.o \
CFactorGraph.o \
ConstraintTree.o \
Var.o \
Solver.o \
@ -134,6 +133,7 @@ HCLI_OBJS = \
ProbFormula.o \
Histogram.o \
ParfactorList.o \
WeightedBpSolver.o \
LiftedUtils.o \
Util.o \
HorusCli.o