add support to (real) lifted belief propagation
This commit is contained in:
@@ -1,22 +1,32 @@
|
||||
#include "CbpSolver.h"
|
||||
#include "WeightedBpSolver.h"
|
||||
|
||||
|
||||
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
||||
bool CbpSolver::checkForIdenticalFactors = true;
|
||||
|
||||
|
||||
CbpSolver::CbpSolver (const FactorGraph& fg)
|
||||
: Solver (fg), freeColor_(0)
|
||||
{
|
||||
cfg_ = new CFactorGraph (fg);
|
||||
fg_ = cfg_->getGroundFactorGraph();
|
||||
findIdenticalFactors();
|
||||
setInitialColors();
|
||||
createGroups();
|
||||
compressedFg_ = getCompressedFactorGraph();
|
||||
solver_ = new WeightedBpSolver (*compressedFg_, getWeights());
|
||||
}
|
||||
|
||||
|
||||
|
||||
CbpSolver::~CbpSolver (void)
|
||||
{
|
||||
delete cfg_;
|
||||
delete fg_;
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
delete solver_;
|
||||
delete compressedFg_;
|
||||
for (size_t i = 0; i < varClusters_.size(); i++) {
|
||||
delete varClusters_[i];
|
||||
}
|
||||
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||
delete facClusters_[i];
|
||||
}
|
||||
links_.clear();
|
||||
}
|
||||
|
||||
|
||||
@@ -38,44 +48,28 @@ CbpSolver::printSolverFlags (void) const
|
||||
ss << ",accuracy=" << BpOptions::accuracy;
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << ",chkif=" <<
|
||||
Util::toString (CFactorGraph::checkForIdenticalFactors);
|
||||
Util::toString (CbpSolver::checkForIdenticalFactors);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
CbpSolver::solveQuery (VarIds queryVids)
|
||||
{
|
||||
assert (queryVids.empty() == false);
|
||||
return queryVids.size() == 1
|
||||
? getPosterioriOf (queryVids[0])
|
||||
: getJointDistributionOf (queryVids);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
CbpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
assert (cfg_->getEquivalent (vid));
|
||||
VarNode* var = cfg_->getEquivalent (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->range(), LogAware::noEvidence());
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
probs += l->powMessage();
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
Util::exp (probs);
|
||||
} else {
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
probs *= l->powMessage();
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
}
|
||||
}
|
||||
return probs;
|
||||
return solver_->getPosterioriOf (getRepresentative (vid));
|
||||
}
|
||||
|
||||
|
||||
@@ -83,255 +77,320 @@ CbpSolver::getPosterioriOf (VarId vid)
|
||||
Params
|
||||
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
||||
{
|
||||
VarIds eqVarIds;
|
||||
VarIds representatives;
|
||||
for (size_t i = 0; i < jointVids.size(); i++) {
|
||||
VarNode* vn = cfg_->getEquivalent (jointVids[i]);
|
||||
eqVarIds.push_back (vn->varId());
|
||||
representatives.push_back (getRepresentative (jointVids[i]));
|
||||
}
|
||||
return BpSolver::getJointDistributionOf (eqVarIds);
|
||||
return solver_->getJointDistributionOf (representatives);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::createLinks (void)
|
||||
CbpSolver::findIdenticalFactors()
|
||||
{
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "compressed factor graph contains " ;
|
||||
cout << fg_->nrVarNodes() << " variables and " ;
|
||||
cout << fg_->nrFacNodes() << " factors " << endl;
|
||||
cout << endl;
|
||||
}
|
||||
const FacClusters& fcs = cfg_->facClusters();
|
||||
for (size_t i = 0; i < fcs.size(); i++) {
|
||||
const VarClusters& vcs = fcs[i]->varClusters();
|
||||
for (size_t j = 0; j < vcs.size(); j++) {
|
||||
unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j);
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "creating link " ;
|
||||
cout << fcs[i]->representative()->getLabel();
|
||||
cout << " -- " ;
|
||||
cout << vcs[j]->representative()->label();
|
||||
cout << " idx=" << j << ", count=" << count << endl;
|
||||
}
|
||||
links_.push_back (new CbpSolverLink (
|
||||
fcs[i]->representative(), vcs[j]->representative(), j, count));
|
||||
}
|
||||
}
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
if (Globals::verbosity >= 1) {
|
||||
cout << "calculating " << links_[i]->toString() << endl;
|
||||
}
|
||||
}
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
if (checkForIdenticalFactors == false ||
|
||||
facNodes.size() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t c = 0; c < links_.size(); c++) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << endl << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); ++it) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->residual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
SpLink* link = *it;
|
||||
if (Globals::verbosity >= 1) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
}
|
||||
if (link->residual() < BpOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
link->updateMessage();
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
||||
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||
for (size_t j = 0; j < links.size(); j++) {
|
||||
if (links[j]->varNode() != link->varNode()) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << " calculating " << links[j]->toString() << endl;
|
||||
}
|
||||
calculateMessage (links[j]);
|
||||
SpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// in counting bp, the message that a variable X sends to
|
||||
// to a factor F depends on the message that F sent to the X
|
||||
const SpLinkSet& links = ninf(link->facNode())->getLinks();
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
if (links[i]->varNode() != link->varNode()) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << " calculating " << links[i]->toString() << endl;
|
||||
}
|
||||
calculateMessage (links[i]);
|
||||
SpLinkMap::iterator iter = linkMap_.find (links[i]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[i]);
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
facNodes[i]->factor().setDistId (Util::maxUnsigned());
|
||||
}
|
||||
unsigned groupCount = 1;
|
||||
for (size_t i = 0; i < facNodes.size() - 1; i++) {
|
||||
Factor& f1 = facNodes[i]->factor();
|
||||
if (f1.distId() != Util::maxUnsigned()) {
|
||||
continue;
|
||||
}
|
||||
f1.setDistId (groupCount);
|
||||
for (size_t j = i + 1; j < facNodes.size(); j++) {
|
||||
Factor& f2 = facNodes[j]->factor();
|
||||
if (f2.distId() != Util::maxUnsigned()) {
|
||||
continue;
|
||||
}
|
||||
if (f1.size() == f2.size() &&
|
||||
f1.ranges() == f2.ranges() &&
|
||||
f1.params() == f2.params()) {
|
||||
f2.setDistId (groupCount);
|
||||
}
|
||||
}
|
||||
groupCount ++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::calcFactorToVarMsg (SpLink* _link)
|
||||
CbpSolver::setInitialColors (void)
|
||||
{
|
||||
CbpSolverLink* link = static_cast<CbpSolverLink*> (_link);
|
||||
FacNode* src = link->facNode();
|
||||
const VarNode* dst = link->varNode();
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dst'
|
||||
unsigned reps = 1;
|
||||
unsigned msgSize = Util::sizeExpected (src->factor().ranges());
|
||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||
if (Globals::logDomain) {
|
||||
for (size_t i = links.size(); i-- > 0; ) {
|
||||
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
||||
if ( ! (cl->varNode() == dst && cl->index() == link->index())) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message from " << links[i]->varNode()->label();
|
||||
cout << ": " ;
|
||||
}
|
||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||
reps, std::plus<double>());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
reps *= links[i]->varNode()->range();
|
||||
varColors_.resize (fg.nrVarNodes());
|
||||
facColors_.resize (fg.nrFacNodes());
|
||||
// create the initial variable colors
|
||||
VarColorMap colorMap;
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
unsigned range = varNodes[i]->range();
|
||||
VarColorMap::iterator it = colorMap.find (range);
|
||||
if (it == colorMap.end()) {
|
||||
it = colorMap.insert (make_pair (
|
||||
range, Colors (range + 1, -1))).first;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = links.size(); i-- > 0; ) {
|
||||
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
||||
if ( ! (cl->varNode() == dst && cl->index() == link->index())) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message from " << links[i]->varNode()->label();
|
||||
cout << ": " ;
|
||||
}
|
||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||
reps, std::multiplies<double>());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
reps *= links[i]->varNode()->range();
|
||||
unsigned idx = varNodes[i]->hasEvidence()
|
||||
? varNodes[i]->getEvidence()
|
||||
: range;
|
||||
Colors& stateColors = it->second;
|
||||
if (stateColors[idx] == -1) {
|
||||
stateColors[idx] = getNewColor();
|
||||
}
|
||||
setColor (varNodes[i], stateColors[idx]);
|
||||
}
|
||||
Factor result (src->factor().arguments(),
|
||||
src->factor().ranges(), msgProduct);
|
||||
assert (msgProduct.size() == src->factor().size());
|
||||
if (Globals::logDomain) {
|
||||
result.params() += src->factor().params();
|
||||
} else {
|
||||
result.params() *= src->factor().params();
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->factor().params() << endl;
|
||||
cout << " factor product: " << result.params() << endl;
|
||||
}
|
||||
result.sumOutAllExceptIndex (link->index());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " marginalized: " << result.params() << endl;
|
||||
}
|
||||
link->nextMessage() = result.params();
|
||||
LogAware::normalize (link->nextMessage());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " curr msg: " << link->message() << endl;
|
||||
cout << " next msg: " << link->nextMessage() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
CbpSolver::getVarToFactorMsg (const SpLink* _link) const
|
||||
{
|
||||
const CbpSolverLink* link = static_cast<const CbpSolverLink*> (_link);
|
||||
const VarNode* src = link->varNode();
|
||||
const FacNode* dst = link->facNode();
|
||||
Params msg;
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
double value = link->message()[src->getEvidence()];
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
msg[src->getEvidence()] = value;
|
||||
cout << msg << "^" << link->nrEdges() << "-1" ;
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
// create the initial factor colors
|
||||
DistColorMap distColors;
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
unsigned distId = facNodes[i]->factor().distId();
|
||||
DistColorMap::iterator it = distColors.find (distId);
|
||||
if (it == distColors.end()) {
|
||||
it = distColors.insert (make_pair (distId, getNewColor())).first;
|
||||
}
|
||||
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
|
||||
} else {
|
||||
msg = link->message();
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << msg << "^" << link->nrEdges() << "-1" ;
|
||||
}
|
||||
LogAware::pow (msg, link->nrEdges() - 1);
|
||||
setColor (facNodes[i], it->second);
|
||||
}
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||
if ( ! (cl->facNode() == dst && cl->index() == link->index())) {
|
||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||
msg += cl->powMessage();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||
if ( ! (cl->facNode() == dst && cl->index() == link->index())) {
|
||||
msg *= cl->powMessage();
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " x " << cl->nextMessage() << "^" << link->nrEdges();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::printLinkInformation (void) const
|
||||
CbpSolver::createGroups (void)
|
||||
{
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links_[i]);
|
||||
cout << cl->toString() << ":" << endl;
|
||||
cout << " curr msg = " << cl->message() << endl;
|
||||
cout << " next msg = " << cl->nextMessage() << endl;
|
||||
cout << " index = " << cl->index() << endl;
|
||||
cout << " nr edges = " << cl->nrEdges() << endl;
|
||||
cout << " powered = " << cl->powMessage() << endl;
|
||||
cout << " residual = " << cl->residual() << endl;
|
||||
VarSignMap varGroups;
|
||||
FacSignMap facGroups;
|
||||
unsigned nIters = 0;
|
||||
bool groupsHaveChanged = true;
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
|
||||
while (groupsHaveChanged || nIters == 1) {
|
||||
nIters ++;
|
||||
|
||||
// set a new color to the variables with the same signature
|
||||
size_t prevVarGroupsSize = varGroups.size();
|
||||
varGroups.clear();
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
const VarSignature& signature = getSignature (varNodes[i]);
|
||||
VarSignMap::iterator it = varGroups.find (signature);
|
||||
if (it == varGroups.end()) {
|
||||
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||
}
|
||||
it->second.push_back (varNodes[i]);
|
||||
}
|
||||
for (VarSignMap::iterator it = varGroups.begin();
|
||||
it != varGroups.end(); ++it) {
|
||||
Color newColor = getNewColor();
|
||||
VarNodes& groupMembers = it->second;
|
||||
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
}
|
||||
|
||||
size_t prevFactorGroupsSize = facGroups.size();
|
||||
facGroups.clear();
|
||||
// set a new color to the factors with the same signature
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
const FacSignature& signature = getSignature (facNodes[i]);
|
||||
FacSignMap::iterator it = facGroups.find (signature);
|
||||
if (it == facGroups.end()) {
|
||||
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||
}
|
||||
it->second.push_back (facNodes[i]);
|
||||
}
|
||||
for (FacSignMap::iterator it = facGroups.begin();
|
||||
it != facGroups.end(); ++it) {
|
||||
Color newColor = getNewColor();
|
||||
FacNodes& groupMembers = it->second;
|
||||
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
}
|
||||
|
||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||
|| prevFactorGroupsSize != facGroups.size();
|
||||
}
|
||||
// printGroups (varGroups, facGroups);
|
||||
createClusters (varGroups, facGroups);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::createClusters (
|
||||
const VarSignMap& varGroups,
|
||||
const FacSignMap& facGroups)
|
||||
{
|
||||
varClusters_.reserve (varGroups.size());
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); ++it) {
|
||||
const VarNodes& groupVars = it->second;
|
||||
VarCluster* vc = new VarCluster (groupVars);
|
||||
for (size_t i = 0; i < groupVars.size(); i++) {
|
||||
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
||||
}
|
||||
varClusters_.push_back (vc);
|
||||
}
|
||||
|
||||
facClusters_.reserve (facGroups.size());
|
||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||
it != facGroups.end(); ++it) {
|
||||
FacNode* groupFactor = it->second[0];
|
||||
const VarNodes& neighs = groupFactor->neighbors();
|
||||
VarClusters varClusters;
|
||||
varClusters.reserve (neighs.size());
|
||||
for (size_t i = 0; i < neighs.size(); i++) {
|
||||
VarId vid = neighs[i]->varId();
|
||||
varClusters.push_back (vid2VarCluster_.find (vid)->second);
|
||||
}
|
||||
facClusters_.push_back (new FacCluster (it->second, varClusters));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarSignature
|
||||
CbpSolver::getSignature (const VarNode* varNode)
|
||||
{
|
||||
const FacNodes& neighs = varNode->neighbors();
|
||||
VarSignature sign;
|
||||
sign.reserve (neighs.size() + 1);
|
||||
for (size_t i = 0; i < neighs.size(); i++) {
|
||||
sign.push_back (make_pair (
|
||||
getColor (neighs[i]),
|
||||
neighs[i]->factor().indexOf (varNode->varId())));
|
||||
}
|
||||
std::sort (sign.begin(), sign.end());
|
||||
sign.push_back (make_pair (getColor (varNode), 0));
|
||||
return sign;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FacSignature
|
||||
CbpSolver::getSignature (const FacNode* facNode)
|
||||
{
|
||||
const VarNodes& neighs = facNode->neighbors();
|
||||
FacSignature sign;
|
||||
sign.reserve (neighs.size() + 1);
|
||||
for (size_t i = 0; i < neighs.size(); i++) {
|
||||
sign.push_back (getColor (neighs[i]));
|
||||
}
|
||||
sign.push_back (getColor (facNode));
|
||||
return sign;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
CbpSolver::getCompressedFactorGraph (void)
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
for (size_t i = 0; i < varClusters_.size(); i++) {
|
||||
VarNode* newVar = new VarNode (varClusters_[i]->first());
|
||||
varClusters_[i]->setRepresentative (newVar);
|
||||
fg->addVarNode (newVar);
|
||||
}
|
||||
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||
Vars vars;
|
||||
const VarClusters& clusters = facClusters_[i]->varClusters();
|
||||
for (size_t j = 0; j < clusters.size(); j++) {
|
||||
vars.push_back (clusters[j]->representative());
|
||||
}
|
||||
const Factor& groundFac = facClusters_[i]->first()->factor();
|
||||
FacNode* fn = new FacNode (Factor (
|
||||
vars, groundFac.params(), groundFac.distId()));
|
||||
facClusters_[i]->setRepresentative (fn);
|
||||
fg->addFacNode (fn);
|
||||
for (size_t j = 0; j < vars.size(); j++) {
|
||||
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
|
||||
}
|
||||
}
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<vector<unsigned>>
|
||||
CbpSolver::getWeights (void) const
|
||||
{
|
||||
vector<vector<unsigned>> weights;
|
||||
weights.reserve (facClusters_.size());
|
||||
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||
const VarClusters& neighs = facClusters_[i]->varClusters();
|
||||
weights.push_back ({ });
|
||||
weights.back().reserve (neighs.size());
|
||||
for (size_t j = 0; j < neighs.size(); j++) {
|
||||
weights.back().push_back (getWeight (
|
||||
facClusters_[i], neighs[j], j));
|
||||
}
|
||||
}
|
||||
return weights;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
CbpSolver::getWeight (
|
||||
const FacCluster* fc,
|
||||
const VarCluster* vc,
|
||||
size_t index) const
|
||||
{
|
||||
unsigned weight = 0;
|
||||
VarId reprVid = vc->representative()->varId();
|
||||
VarNode* groundVar = fg.getVarNode (reprVid);
|
||||
const FacNodes& neighs = groundVar->neighbors();
|
||||
for (size_t i = 0; i < neighs.size(); i++) {
|
||||
FacNodes::const_iterator it;
|
||||
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
|
||||
if (it != fc->members().end() &&
|
||||
(*it)->factor().indexOf (reprVid) == index) {
|
||||
weight ++;
|
||||
}
|
||||
}
|
||||
return weight;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::printGroups (
|
||||
const VarSignMap& varGroups,
|
||||
const FacSignMap& facGroups) const
|
||||
{
|
||||
unsigned count = 1;
|
||||
cout << "variable groups:" << endl;
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); ++it) {
|
||||
const VarNodes& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << count << ": " ;
|
||||
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->label() << " " ;
|
||||
}
|
||||
count ++;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
count = 1;
|
||||
cout << endl << "factor groups:" << endl;
|
||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||
it != facGroups.end(); ++it) {
|
||||
const FacNodes& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << ++count << ": " ;
|
||||
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->getLabel() << " " ;
|
||||
}
|
||||
count ++;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user