This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/bp/CFactorGraph.cpp

346 lines
9.5 KiB
C++
Raw Normal View History

2011-12-12 15:29:51 +00:00
#include "CFactorGraph.h"
#include "Factor.h"
2012-03-22 11:33:24 +00:00
bool CFactorGraph::checkForIdenticalFactors = true;
2011-12-12 15:29:51 +00:00
2012-04-16 21:42:14 +01:00
CFactorGraph::CFactorGraph (const FactorGraph& fg)
: freeColor_(0), groundFg_(&fg)
2011-12-12 15:29:51 +00:00
{
2012-04-05 23:00:48 +01:00
const VarNodes& varNodes = fg.varNodes();
2011-12-12 15:29:51 +00:00
varSignatures_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
varSignatures_.push_back (Signature (c));
}
2012-04-10 11:51:56 +01:00
const FacNodes& facNodes = fg.facNodes();
2012-04-10 20:43:08 +01:00
facSignatures_.reserve (facNodes.size());
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned c = facNodes[i]->neighbors().size() + 1;
2012-04-10 20:43:08 +01:00
facSignatures_.push_back (Signature (c));
2011-12-12 15:29:51 +00:00
}
varColors_.resize (varNodes.size());
2012-04-10 20:43:08 +01:00
facColors_.resize (facNodes.size());
2012-04-16 21:42:14 +01:00
findIdenticalFactors();
2011-12-12 15:29:51 +00:00
setInitialColors();
createGroups();
}
CFactorGraph::~CFactorGraph (void)
{
for (unsigned i = 0; i < varClusters_.size(); i++) {
delete varClusters_[i];
}
2012-03-22 11:33:24 +00:00
for (unsigned i = 0; i < facClusters_.size(); i++) {
delete facClusters_[i];
2011-12-12 15:29:51 +00:00
}
}
2012-04-16 21:42:14 +01:00
void
CFactorGraph::findIdenticalFactors()
{
if (checkForIdenticalFactors == false) {
return;
}
const FacNodes& facNodes = groundFg_->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
facNodes[i]->factor().setDistId (Util::maxUnsigned());
}
unsigned groupCount = 1;
for (unsigned i = 0; i < facNodes.size(); i++) {
Factor& f1 = facNodes[i]->factor();
if (f1.distId() != Util::maxUnsigned()) {
continue;
}
f1.setDistId (groupCount);
for (unsigned j = i + 1; j < facNodes.size(); j++) {
Factor& f2 = facNodes[j]->factor();
if (f2.distId() != Util::maxUnsigned()) {
continue;
}
if (f1.size() == f2.size() &&
f1.ranges() == f2.ranges() &&
f1.params() == f2.params()) {
f2.setDistId (groupCount);
}
}
groupCount ++;
}
}
2011-12-12 15:29:51 +00:00
void
CFactorGraph::setInitialColors (void)
{
// create the initial variable colors
VarColorMap colorMap;
2012-04-05 23:00:48 +01:00
const VarNodes& varNodes = groundFg_->varNodes();
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < varNodes.size(); i++) {
2012-04-05 18:38:56 +01:00
unsigned dsize = varNodes[i]->range();
2011-12-12 15:29:51 +00:00
VarColorMap::iterator it = colorMap.find (dsize);
if (it == colorMap.end()) {
it = colorMap.insert (make_pair (
2012-04-13 15:56:37 +01:00
dsize, Colors (dsize+1,-1))).first;
2011-12-12 15:29:51 +00:00
}
unsigned idx;
if (varNodes[i]->hasEvidence()) {
idx = varNodes[i]->getEvidence();
} else {
idx = dsize;
}
2012-04-13 15:56:37 +01:00
Colors& stateColors = it->second;
2011-12-12 15:29:51 +00:00
if (stateColors[idx] == -1) {
stateColors[idx] = getFreeColor();
}
setColor (varNodes[i], stateColors[idx]);
}
2012-04-10 11:51:56 +01:00
const FacNodes& facNodes = groundFg_->facNodes();
2011-12-12 15:29:51 +00:00
// create the initial factor colors
DistColorMap distColors;
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor().distId();
2012-03-31 23:27:37 +01:00
DistColorMap::iterator it = distColors.find (distId);
2011-12-12 15:29:51 +00:00
if (it == distColors.end()) {
2012-03-31 23:27:37 +01:00
it = distColors.insert (make_pair (distId, getFreeColor())).first;
2011-12-12 15:29:51 +00:00
}
setColor (facNodes[i], it->second);
}
}
void
CFactorGraph::createGroups (void)
{
2012-03-31 23:27:37 +01:00
VarSignMap varGroups;
2012-04-10 20:43:08 +01:00
FacSignMap facGroups;
2011-12-12 15:29:51 +00:00
unsigned nIters = 0;
bool groupsHaveChanged = true;
2012-04-05 23:00:48 +01:00
const VarNodes& varNodes = groundFg_->varNodes();
2012-04-10 11:51:56 +01:00
const FacNodes& facNodes = groundFg_->facNodes();
2011-12-12 15:29:51 +00:00
while (groupsHaveChanged || nIters == 1) {
nIters ++;
2012-04-10 20:43:08 +01:00
unsigned prevFactorGroupsSize = facGroups.size();
facGroups.clear();
2011-12-12 15:29:51 +00:00
// set a new color to the factors with the same signature
for (unsigned i = 0; i < facNodes.size(); i++) {
const Signature& signature = getSignature (facNodes[i]);
2012-04-10 20:43:08 +01:00
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
2011-12-12 15:29:51 +00:00
}
it->second.push_back (facNodes[i]);
}
2012-04-10 20:43:08 +01:00
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); it++) {
2011-12-12 15:29:51 +00:00
Color newColor = getFreeColor();
2012-04-10 11:51:56 +01:00
FacNodes& groupMembers = it->second;
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
// set a new color to the variables with the same signature
unsigned prevVarGroupsSize = varGroups.size();
varGroups.clear();
for (unsigned i = 0; i < varNodes.size(); i++) {
const Signature& signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) {
2012-04-05 23:00:48 +01:00
it = varGroups.insert (make_pair (signature, VarNodes())).first;
2011-12-12 15:29:51 +00:00
}
it->second.push_back (varNodes[i]);
}
for (VarSignMap::iterator it = varGroups.begin();
it != varGroups.end(); it++) {
Color newColor = getFreeColor();
2012-04-05 23:00:48 +01:00
VarNodes& groupMembers = it->second;
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
2012-04-10 20:43:08 +01:00
|| prevFactorGroupsSize != facGroups.size();
2011-12-12 15:29:51 +00:00
}
// printGroups (varGroups, facGroups);
2012-04-10 20:43:08 +01:00
createClusters (varGroups, facGroups);
2011-12-12 15:29:51 +00:00
}
void
2012-03-31 23:27:37 +01:00
CFactorGraph::createClusters (
const VarSignMap& varGroups,
2012-04-10 20:43:08 +01:00
const FacSignMap& facGroups)
2011-12-12 15:29:51 +00:00
{
varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); it++) {
2012-04-05 23:00:48 +01:00
const VarNodes& groupVars = it->second;
2011-12-12 15:29:51 +00:00
VarCluster* vc = new VarCluster (groupVars);
for (unsigned i = 0; i < groupVars.size(); i++) {
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
}
varClusters_.push_back (vc);
}
2012-04-10 20:43:08 +01:00
facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); it++) {
2012-04-10 11:51:56 +01:00
FacNode* groupFactor = it->second[0];
2012-04-05 23:00:48 +01:00
const VarNodes& neighs = groupFactor->neighbors();
2012-04-10 20:43:08 +01:00
VarClusters varClusters;
2011-12-12 15:29:51 +00:00
varClusters.reserve (neighs.size());
for (unsigned i = 0; i < neighs.size(); i++) {
VarId vid = neighs[i]->varId();
varClusters.push_back (vid2VarCluster_.find (vid)->second);
}
2012-03-22 11:33:24 +00:00
facClusters_.push_back (new FacCluster (it->second, varClusters));
2011-12-12 15:29:51 +00:00
}
}
const Signature&
2012-04-05 23:00:48 +01:00
CFactorGraph::getSignature (const VarNode* varNode)
2011-12-12 15:29:51 +00:00
{
Signature& sign = varSignatures_[varNode->getIndex()];
2012-04-13 15:56:37 +01:00
Colors::iterator it = sign.colors.begin();
2012-04-10 11:51:56 +01:00
const FacNodes& neighs = varNode->neighbors();
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]);
it ++;
*it = neighs[i]->factor().indexOf (varNode->varId());
2011-12-12 15:29:51 +00:00
it ++;
}
*it = getColor (varNode);
return sign;
}
const Signature&
2012-04-10 11:51:56 +01:00
CFactorGraph::getSignature (const FacNode* facNode)
2011-12-12 15:29:51 +00:00
{
2012-04-10 20:43:08 +01:00
Signature& sign = facSignatures_[facNode->getIndex()];
2012-04-13 15:56:37 +01:00
Colors::iterator it = sign.colors.begin();
2012-04-05 23:00:48 +01:00
const VarNodes& neighs = facNode->neighbors();
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]);
it ++;
}
*it = getColor (facNode);
return sign;
}
FactorGraph*
2012-04-10 20:43:08 +01:00
CFactorGraph::getGroundFactorGraph (void) const
2011-12-12 15:29:51 +00:00
{
FactorGraph* fg = new FactorGraph();
for (unsigned i = 0; i < varClusters_.size(); i++) {
2012-04-16 21:42:14 +01:00
VarNode* newVar = new VarNode (varClusters_[i]->members()[0]);
varClusters_[i]->setRepresentative (newVar);
fg->addVarNode (newVar);
2011-12-12 15:29:51 +00:00
}
2012-03-22 11:33:24 +00:00
for (unsigned i = 0; i < facClusters_.size(); i++) {
2012-04-16 21:42:14 +01:00
const VarClusters& myVarClusters = facClusters_[i]->varClusters();
Vars myGroundVars;
2011-12-12 15:29:51 +00:00
myGroundVars.reserve (myVarClusters.size());
for (unsigned j = 0; j < myVarClusters.size(); j++) {
2012-04-16 21:42:14 +01:00
VarNode* v = myVarClusters[j]->getRepresentative();
2011-12-12 15:29:51 +00:00
myGroundVars.push_back (v);
}
2012-04-16 21:42:14 +01:00
FacNode* fn = new FacNode (Factor (
myGroundVars,
facClusters_[i]->members()[0]->factor().params(),
facClusters_[i]->members()[0]->factor().distId()));
facClusters_[i]->setRepresentative (fn);
2012-04-10 11:51:56 +01:00
fg->addFacNode (fn);
2011-12-12 15:29:51 +00:00
for (unsigned j = 0; j < myGroundVars.size(); j++) {
2012-04-10 11:51:56 +01:00
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
2011-12-12 15:29:51 +00:00
}
}
return fg;
}
unsigned
CFactorGraph::getEdgeCount (
2012-03-22 11:33:24 +00:00
const FacCluster* fc,
const VarCluster* vc) const
2011-12-12 15:29:51 +00:00
{
unsigned count = 0;
2012-04-16 21:42:14 +01:00
VarId vid = vc->members().front()->varId();
const FacNodes& members = fc->members();
for (unsigned i = 0; i < members.size(); i++) {
if (members[i]->factor().contains (vid)) {
2011-12-12 15:29:51 +00:00
count ++;
}
}
2012-04-16 21:42:14 +01:00
if (Constants::DEBUG > 0) {
const VarNodes& vars = vc->members();
for (unsigned i = 1; i < vars.size(); i++) {
VarId vid = vars[i]->varId();
unsigned count2 = 0;
for (unsigned i = 0; i < members.size(); i++) {
if (members[i]->factor().contains (vid)) {
count2 ++;
}
}
assert (count == count2);
}
}
2011-12-12 15:29:51 +00:00
return count;
}
void
2012-03-31 23:27:37 +01:00
CFactorGraph::printGroups (
const VarSignMap& varGroups,
2012-04-10 20:43:08 +01:00
const FacSignMap& facGroups) const
2011-12-12 15:29:51 +00:00
{
unsigned count = 1;
cout << "variable groups:" << endl;
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); it++) {
2012-04-05 23:00:48 +01:00
const VarNodes& groupMembers = it->second;
2011-12-12 15:29:51 +00:00
if (groupMembers.size() > 0) {
cout << count << ": " ;
for (unsigned i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->label() << " " ;
}
count ++;
cout << endl;
}
}
count = 1;
cout << endl << "factor groups:" << endl;
2012-04-10 20:43:08 +01:00
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); it++) {
2012-04-10 11:51:56 +01:00
const FacNodes& groupMembers = it->second;
2011-12-12 15:29:51 +00:00
if (groupMembers.size() > 0) {
cout << ++count << ": " ;
for (unsigned i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->getLabel() << " " ;
}
count ++;
cout << endl;
}
}
}