This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/CountingBp.cpp

434 lines
12 KiB
C++
Raw Normal View History

2013-02-07 20:09:10 +00:00
#include <cassert>
#include <iostream>
#include <sstream>
#include "CountingBp.h"
#include "WeightedBp.h"
2012-05-23 14:56:01 +01:00
namespace Horus {
2013-02-07 23:53:13 +00:00
2012-12-27 23:21:32 +00:00
bool CountingBp::fif_ = true;
CountingBp::CountingBp (const FactorGraph& fg)
2012-11-14 21:55:51 +00:00
: GroundSolver (fg), freeColor_(0)
2012-05-23 14:56:01 +01:00
{
findIdenticalFactors();
setInitialColors();
createGroups();
compressedFg_ = getCompressedFactorGraph();
solver_ = new WeightedBp (*compressedFg_, getWeights());
2012-05-23 14:56:01 +01:00
}
CountingBp::~CountingBp (void)
2012-05-23 14:56:01 +01:00
{
delete solver_;
delete compressedFg_;
for (size_t i = 0; i < varClusters_.size(); i++) {
delete varClusters_[i];
}
for (size_t i = 0; i < facClusters_.size(); i++) {
delete facClusters_[i];
2012-05-23 14:56:01 +01:00
}
}
void
CountingBp::printSolverFlags (void) const
2012-05-23 14:56:01 +01:00
{
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2012-05-23 14:56:01 +01:00
ss << "counting bp [" ;
2013-01-08 17:06:40 +00:00
ss << "bp_msg_schedule=" ;
switch (WeightedBp::msgSchedule()) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
2012-05-23 14:56:01 +01:00
}
2013-01-08 17:06:40 +00:00
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
ss << ",bp_accuracy=" << WeightedBp::accuracy();
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",fif=" << Util::toString (CountingBp::fif_);
2012-05-23 14:56:01 +01:00
ss << "]" ;
2013-02-07 13:37:15 +00:00
std::cout << ss.str() << std::endl;
2012-05-23 14:56:01 +01:00
}
Params
CountingBp::solveQuery (VarIds queryVids)
{
assert (queryVids.empty() == false);
2012-05-31 23:06:53 +01:00
Params res;
if (queryVids.size() == 1) {
res = solver_->getPosterioriOf (getRepresentative (queryVids[0]));
} else {
2012-06-01 19:29:23 +01:00
VarNode* vn = fg.getVarNode (queryVids[0]);
const FacNodes& facNodes = vn->neighbors();
size_t idx = facNodes.size();
for (size_t i = 0; i < facNodes.size(); i++) {
if (facNodes[i]->factor().contains (queryVids)) {
idx = i;
break;
}
}
if (idx == facNodes.size()) {
2012-11-14 21:55:51 +00:00
res = GroundSolver::getJointByConditioning (
GroundSolverType::CBP, fg, queryVids);
} else {
VarIds reprArgs;
for (size_t i = 0; i < queryVids.size(); i++) {
reprArgs.push_back (getRepresentative (queryVids[i]));
}
FacNode* reprFac = getRepresentative (facNodes[idx]);
2012-12-27 12:54:58 +00:00
assert (reprFac);
res = solver_->getFactorJoint (reprFac, reprArgs);
2012-05-31 23:06:53 +01:00
}
2012-05-23 14:56:01 +01:00
}
2012-05-31 23:06:53 +01:00
return res;
2012-05-23 14:56:01 +01:00
}
void
CountingBp::findIdenticalFactors()
2012-05-23 14:56:01 +01:00
{
const FacNodes& facNodes = fg.facNodes();
2012-12-27 23:21:32 +00:00
if (fif_ == false || facNodes.size() == 1) {
return;
}
for (size_t i = 0; i < facNodes.size(); i++) {
facNodes[i]->factor().setDistId (Util::maxUnsigned());
2012-05-23 14:56:01 +01:00
}
unsigned groupCount = 1;
for (size_t i = 0; i < facNodes.size() - 1; i++) {
Factor& f1 = facNodes[i]->factor();
if (f1.distId() != Util::maxUnsigned()) {
continue;
}
f1.setDistId (groupCount);
for (size_t j = i + 1; j < facNodes.size(); j++) {
Factor& f2 = facNodes[j]->factor();
if (f2.distId() != Util::maxUnsigned()) {
continue;
}
if (f1.size() == f2.size() &&
f1.ranges() == f2.ranges() &&
f1.params() == f2.params()) {
f2.setDistId (groupCount);
2012-05-23 14:56:01 +01:00
}
}
groupCount ++;
2012-05-23 14:56:01 +01:00
}
}
void
CountingBp::setInitialColors (void)
2012-05-23 14:56:01 +01:00
{
varColors_.resize (fg.nrVarNodes());
facColors_.resize (fg.nrFacNodes());
// create the initial variable colors
VarColorMap colorMap;
const VarNodes& varNodes = fg.varNodes();
for (size_t i = 0; i < varNodes.size(); i++) {
unsigned range = varNodes[i]->range();
VarColorMap::iterator it = colorMap.find (range);
if (it == colorMap.end()) {
2013-02-07 13:37:15 +00:00
it = colorMap.insert (std::make_pair (
2012-12-20 23:19:10 +00:00
range, Colors (range + 1, -1))).first;
2012-05-23 14:56:01 +01:00
}
unsigned idx = varNodes[i]->hasEvidence()
? varNodes[i]->getEvidence()
: range;
Colors& stateColors = it->second;
if (stateColors[idx] == -1) {
stateColors[idx] = getNewColor();
}
setColor (varNodes[i], stateColors[idx]);
2012-05-23 14:56:01 +01:00
}
const FacNodes& facNodes = fg.facNodes();
// create the initial factor colors
DistColorMap distColors;
for (size_t i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor().distId();
DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) {
2013-02-07 13:37:15 +00:00
it = distColors.insert (std::make_pair (
distId, getNewColor())).first;
2012-05-23 14:56:01 +01:00
}
setColor (facNodes[i], it->second);
}
}
2012-05-23 14:56:01 +01:00
void
CountingBp::createGroups (void)
{
VarSignMap varGroups;
FacSignMap facGroups;
unsigned nIters = 0;
bool groupsHaveChanged = true;
const VarNodes& varNodes = fg.varNodes();
const FacNodes& facNodes = fg.facNodes();
while (groupsHaveChanged || nIters == 1) {
nIters ++;
// set a new color to the variables with the same signature
size_t prevVarGroupsSize = varGroups.size();
varGroups.clear();
for (size_t i = 0; i < varNodes.size(); i++) {
const VarSignature& signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) {
2013-02-07 13:37:15 +00:00
it = varGroups.insert (std::make_pair (
signature, VarNodes())).first;
}
it->second.push_back (varNodes[i]);
2012-05-23 14:56:01 +01:00
}
for (VarSignMap::iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
Color newColor = getNewColor();
VarNodes& groupMembers = it->second;
for (size_t i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
2012-05-23 14:56:01 +01:00
}
size_t prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (size_t i = 0; i < facNodes.size(); i++) {
const FacSignature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
2013-02-07 13:37:15 +00:00
it = facGroups.insert (std::make_pair (
signature, FacNodes())).first;
2012-05-23 14:56:01 +01:00
}
it->second.push_back (facNodes[i]);
2012-05-23 14:56:01 +01:00
}
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
Color newColor = getNewColor();
FacNodes& groupMembers = it->second;
for (size_t i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
2012-05-23 14:56:01 +01:00
}
}
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != facGroups.size();
2012-05-23 14:56:01 +01:00
}
// printGroups (varGroups, facGroups);
createClusters (varGroups, facGroups);
2012-05-23 14:56:01 +01:00
}
void
CountingBp::createClusters (
const VarSignMap& varGroups,
const FacSignMap& facGroups)
2012-05-23 14:56:01 +01:00
{
varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
const VarNodes& groupVars = it->second;
VarCluster* vc = new VarCluster (groupVars);
for (size_t i = 0; i < groupVars.size(); i++) {
2013-02-07 13:37:15 +00:00
varClusterMap_.insert (std::make_pair (
groupVars[i]->varId(), vc));
2012-05-23 14:56:01 +01:00
}
varClusters_.push_back (vc);
2012-05-23 14:56:01 +01:00
}
facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
FacNode* groupFactor = it->second[0];
const VarNodes& neighs = groupFactor->neighbors();
VarClusters varClusters;
varClusters.reserve (neighs.size());
for (size_t i = 0; i < neighs.size(); i++) {
VarId vid = neighs[i]->varId();
varClusters.push_back (varClusterMap_.find (vid)->second);
}
facClusters_.push_back (new FacCluster (it->second, varClusters));
2012-05-23 14:56:01 +01:00
}
}
VarSignature
CountingBp::getSignature (const VarNode* varNode)
{
const FacNodes& neighs = varNode->neighbors();
VarSignature sign;
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
2013-02-07 13:37:15 +00:00
sign.push_back (std::make_pair (
getColor (neighs[i]),
neighs[i]->factor().indexOf (varNode->varId())));
2012-05-23 14:56:01 +01:00
}
std::sort (sign.begin(), sign.end());
2013-02-07 13:37:15 +00:00
sign.push_back (std::make_pair (getColor (varNode), 0));
return sign;
}
FacSignature
CountingBp::getSignature (const FacNode* facNode)
{
const VarNodes& neighs = facNode->neighbors();
FacSignature sign;
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (getColor (neighs[i]));
2012-05-23 14:56:01 +01:00
}
sign.push_back (getColor (facNode));
return sign;
2012-05-23 14:56:01 +01:00
}
VarId
CountingBp::getRepresentative (VarId vid)
{
assert (Util::contains (varClusterMap_, vid));
VarCluster* vc = varClusterMap_.find (vid)->second;
return vc->representative()->varId();
}
FacNode*
CountingBp::getRepresentative (FacNode* fn)
{
for (size_t i = 0; i < facClusters_.size(); i++) {
if (Util::contains (facClusters_[i]->members(), fn)) {
return facClusters_[i]->representative();
}
}
return 0;
}
FactorGraph*
CountingBp::getCompressedFactorGraph (void)
2012-05-23 14:56:01 +01:00
{
FactorGraph* fg = new FactorGraph();
for (size_t i = 0; i < varClusters_.size(); i++) {
VarNode* newVar = new VarNode (varClusters_[i]->first());
varClusters_[i]->setRepresentative (newVar);
fg->addVarNode (newVar);
}
for (size_t i = 0; i < facClusters_.size(); i++) {
Vars vars;
const VarClusters& clusters = facClusters_[i]->varClusters();
for (size_t j = 0; j < clusters.size(); j++) {
vars.push_back (clusters[j]->representative());
2012-05-23 14:56:01 +01:00
}
const Factor& groundFac = facClusters_[i]->first()->factor();
FacNode* fn = new FacNode (Factor (
vars, groundFac.params(), groundFac.distId()));
facClusters_[i]->setRepresentative (fn);
fg->addFacNode (fn);
for (size_t j = 0; j < vars.size(); j++) {
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
2012-05-23 14:56:01 +01:00
}
}
return fg;
}
2013-02-07 13:37:15 +00:00
std::vector<std::vector<unsigned>>
CountingBp::getWeights (void) const
{
2013-02-07 13:37:15 +00:00
std::vector<std::vector<unsigned>> weights;
weights.reserve (facClusters_.size());
for (size_t i = 0; i < facClusters_.size(); i++) {
const VarClusters& neighs = facClusters_[i]->varClusters();
weights.push_back ({ });
weights.back().reserve (neighs.size());
for (size_t j = 0; j < neighs.size(); j++) {
weights.back().push_back (getWeight (
facClusters_[i], neighs[j], j));
2012-05-23 14:56:01 +01:00
}
}
return weights;
}
unsigned
CountingBp::getWeight (
const FacCluster* fc,
const VarCluster* vc,
size_t index) const
{
unsigned weight = 0;
VarId reprVid = vc->representative()->varId();
VarNode* groundVar = fg.getVarNode (reprVid);
const FacNodes& neighs = groundVar->neighbors();
for (size_t i = 0; i < neighs.size(); i++) {
FacNodes::const_iterator it;
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
if (it != fc->members().end() &&
(*it)->factor().indexOf (reprVid) == index) {
weight ++;
}
2012-05-23 14:56:01 +01:00
}
return weight;
2012-05-23 14:56:01 +01:00
}
void
CountingBp::printGroups (
const VarSignMap& varGroups,
const FacSignMap& facGroups) const
2012-05-23 14:56:01 +01:00
{
unsigned count = 1;
2013-02-07 13:37:15 +00:00
std::cout << "variable groups:" << std::endl;
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
const VarNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
2013-02-07 13:37:15 +00:00
std::cout << count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
2013-02-07 13:37:15 +00:00
std::cout << groupMembers[i]->label() << " " ;
}
count ++;
2013-02-07 13:37:15 +00:00
std::cout << std::endl;
}
}
count = 1;
2013-02-07 13:37:15 +00:00
std::cout << std::endl << "factor groups:" << std::endl;
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
2013-02-07 13:37:15 +00:00
std::cout << ++count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
2013-02-07 13:37:15 +00:00
std::cout << groupMembers[i]->getLabel() << " " ;
}
count ++;
2013-02-07 13:37:15 +00:00
std::cout << std::endl;
}
2012-05-23 14:56:01 +01:00
}
}
} // namespace Horus
2013-02-07 23:53:13 +00:00