refactorings
This commit is contained in:
parent
78e86a6330
commit
8697fcd2b4
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
||||||
{
|
{
|
||||||
factorGraph_ = &fg;
|
fg_ = &fg;
|
||||||
runned_ = false;
|
runned_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,8 +54,8 @@ BpSolver::getPosterioriOf (VarId vid)
|
|||||||
if (runned_ == false) {
|
if (runned_ == false) {
|
||||||
runSolver();
|
runSolver();
|
||||||
}
|
}
|
||||||
assert (factorGraph_->getVarNode (vid));
|
assert (fg_->getVarNode (vid));
|
||||||
VarNode* var = factorGraph_->getVarNode (vid);
|
VarNode* var = fg_->getVarNode (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->range(), LogAware::noEvidence());
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
@ -88,7 +88,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
runSolver();
|
runSolver();
|
||||||
}
|
}
|
||||||
int idx = -1;
|
int idx = -1;
|
||||||
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
|
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
|
||||||
const FacNodes& facNodes = vn->neighbors();
|
const FacNodes& facNodes = vn->neighbors();
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
if (facNodes[i]->factor().contains (jointVarIds)) {
|
if (facNodes[i]->factor().contains (jointVarIds)) {
|
||||||
@ -121,37 +121,64 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BpSolver::initializeSolver (void)
|
BpSolver::runSolver (void)
|
||||||
{
|
{
|
||||||
const VarNodes& varNodes = factorGraph_->varNodes();
|
clock_t start;
|
||||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
if (Constants::COLLECT_STATS) {
|
||||||
delete varsI_[i];
|
start = clock();
|
||||||
}
|
}
|
||||||
varsI_.reserve (varNodes.size());
|
initializeSolver();
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
nIters_ = 0;
|
||||||
varsI_.push_back (new SPNodeInfo());
|
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||||
|
nIters_ ++;
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
|
Util::printHeader (" Iteration " + nIters_);
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
switch (BpOptions::schedule) {
|
||||||
|
case BpOptions::Schedule::SEQ_RANDOM:
|
||||||
|
random_shuffle (links_.begin(), links_.end());
|
||||||
|
// no break
|
||||||
|
case BpOptions::Schedule::SEQ_FIXED:
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
calculateAndUpdateMessage (links_[i]);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case BpOptions::Schedule::PARALLEL:
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
calculateMessage (links_[i]);
|
||||||
|
}
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
updateMessage(links_[i]);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||||
|
maxResidualSchedule();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
const FacNodes& facNodes = factorGraph_->facNodes();
|
cout << endl;
|
||||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
if (nIters_ < BpOptions::maxIter) {
|
||||||
delete facsI_[i];
|
cout << "Sum-Product converged in " ;
|
||||||
|
cout << nIters_ << " iterations" << endl;
|
||||||
|
} else {
|
||||||
|
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
facsI_.reserve (facNodes.size());
|
unsigned size = fg_->varNodes().size();
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
if (Constants::COLLECT_STATS) {
|
||||||
facsI_.push_back (new SPNodeInfo());
|
unsigned nIters = 0;
|
||||||
}
|
bool loopy = fg_->isTree() == false;
|
||||||
|
if (loopy) nIters = nIters_;
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||||
delete links_[i];
|
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||||
}
|
|
||||||
createLinks();
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
FacNode* src = links_[i]->getFactor();
|
|
||||||
VarNode* dst = links_[i]->getVariable();
|
|
||||||
ninf (dst)->addSpLink (links_[i]);
|
|
||||||
ninf (src)->addSpLink (links_[i]);
|
|
||||||
}
|
}
|
||||||
|
runned_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -159,7 +186,7 @@ BpSolver::initializeSolver (void)
|
|||||||
void
|
void
|
||||||
BpSolver::createLinks (void)
|
BpSolver::createLinks (void)
|
||||||
{
|
{
|
||||||
const FacNodes& facNodes = factorGraph_->facNodes();
|
const FacNodes& facNodes = fg_->facNodes();
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
const VarNodes& neighbors = facNodes[i]->neighbors();
|
const VarNodes& neighbors = facNodes[i]->neighbors();
|
||||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||||
@ -342,11 +369,11 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
{
|
{
|
||||||
VarNodes jointVars;
|
VarNodes jointVars;
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||||
assert (factorGraph_->getVarNode (jointVarIds[i]));
|
assert (fg_->getVarNode (jointVarIds[i]));
|
||||||
jointVars.push_back (factorGraph_->getVarNode (jointVarIds[i]));
|
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
FactorGraph* fg = new FactorGraph (*factorGraph_);
|
FactorGraph* fg = new FactorGraph (*fg_);
|
||||||
BpSolver solver (*fg);
|
BpSolver solver (*fg);
|
||||||
solver.runSolver();
|
solver.runSolver();
|
||||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||||
@ -390,93 +417,24 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BpSolver::printLinkInformation (void) const
|
BpSolver::initializeSolver (void)
|
||||||
{
|
{
|
||||||
|
const VarNodes& varNodes = fg_->varNodes();
|
||||||
|
varsI_.reserve (varNodes.size());
|
||||||
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
|
varsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
const FacNodes& facNodes = fg_->facNodes();
|
||||||
|
facsI_.reserve (facNodes.size());
|
||||||
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
|
facsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
createLinks();
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
SpLink* l = links_[i];
|
FacNode* src = links_[i]->getFactor();
|
||||||
cout << l->toString() << ":" << endl;
|
VarNode* dst = links_[i]->getVariable();
|
||||||
cout << " curr msg = " ;
|
ninf (dst)->addSpLink (links_[i]);
|
||||||
cout << l->getMessage() << endl;
|
ninf (src)->addSpLink (links_[i]);
|
||||||
cout << " next msg = " ;
|
|
||||||
cout << l->getNextMessage() << endl;
|
|
||||||
cout << " residual = " << l->getResidual() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BpSolver::runSolver (void)
|
|
||||||
{
|
|
||||||
clock_t start;
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
start = clock();
|
|
||||||
}
|
|
||||||
runLoopySolver();
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
|
||||||
cout << "Sum-Product converged in " ;
|
|
||||||
cout << nIters_ << " iterations" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
unsigned size = factorGraph_->varNodes().size();
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
unsigned nIters = 0;
|
|
||||||
bool loopy = factorGraph_->isTree() == false;
|
|
||||||
if (loopy) nIters = nIters_;
|
|
||||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
|
||||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
|
||||||
}
|
|
||||||
runned_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BpSolver::runLoopySolver (void)
|
|
||||||
{
|
|
||||||
initializeSolver();
|
|
||||||
nIters_ = 0;
|
|
||||||
|
|
||||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
|
||||||
|
|
||||||
nIters_ ++;
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
Util::printHeader (" Iteration " + nIters_);
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (BpOptions::schedule) {
|
|
||||||
case BpOptions::Schedule::SEQ_RANDOM:
|
|
||||||
random_shuffle (links_.begin(), links_.end());
|
|
||||||
// no break
|
|
||||||
|
|
||||||
case BpOptions::Schedule::SEQ_FIXED:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateAndUpdateMessage (links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case BpOptions::Schedule::PARALLEL:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
updateMessage(links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
|
||||||
maxResidualSchedule();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -488,7 +446,7 @@ BpSolver::converged (void)
|
|||||||
if (links_.size() == 0) {
|
if (links_.size() == 0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (nIters_ == 0 || nIters_ == 1) {
|
if (nIters_ <= 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool converged = true;
|
bool converged = true;
|
||||||
@ -514,3 +472,19 @@ BpSolver::converged (void)
|
|||||||
return converged;
|
return converged;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BpSolver::printLinkInformation (void) const
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
SpLink* l = links_[i];
|
||||||
|
cout << l->toString() << ":" << endl;
|
||||||
|
cout << " curr msg = " ;
|
||||||
|
cout << l->getMessage() << endl;
|
||||||
|
cout << " next msg = " ;
|
||||||
|
cout << l->getNextMessage() << endl;
|
||||||
|
cout << " residual = " << l->getResidual() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#ifndef HORUS_BpSolver_H
|
#ifndef HORUS_BPSOLVER_H
|
||||||
#define HORUS_BpSolver_H
|
#define HORUS_BPSOLVER_H
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -102,7 +102,7 @@ class BpSolver : public Solver
|
|||||||
virtual Params getJointDistributionOf (const VarIds&);
|
virtual Params getJointDistributionOf (const VarIds&);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void initializeSolver (void);
|
void runSolver (void);
|
||||||
|
|
||||||
virtual void createLinks (void);
|
virtual void createLinks (void);
|
||||||
|
|
||||||
@ -114,8 +114,6 @@ class BpSolver : public Solver
|
|||||||
|
|
||||||
virtual Params getJointByConditioning (const VarIds&) const;
|
virtual Params getJointByConditioning (const VarIds&) const;
|
||||||
|
|
||||||
virtual void printLinkInformation (void) const;
|
|
||||||
|
|
||||||
SPNodeInfo* ninf (const VarNode* var) const
|
SPNodeInfo* ninf (const VarNode* var) const
|
||||||
{
|
{
|
||||||
return varsI_[var->getIndex()];
|
return varsI_[var->getIndex()];
|
||||||
@ -170,7 +168,7 @@ class BpSolver : public Solver
|
|||||||
vector<SPNodeInfo*> varsI_;
|
vector<SPNodeInfo*> varsI_;
|
||||||
vector<SPNodeInfo*> facsI_;
|
vector<SPNodeInfo*> facsI_;
|
||||||
bool runned_;
|
bool runned_;
|
||||||
const FactorGraph* factorGraph_;
|
const FactorGraph* fg_;
|
||||||
|
|
||||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
||||||
SortedOrder sortedOrder_;
|
SortedOrder sortedOrder_;
|
||||||
@ -179,10 +177,12 @@ class BpSolver : public Solver
|
|||||||
SpLinkMap linkMap_;
|
SpLinkMap linkMap_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runSolver (void);
|
void initializeSolver (void);
|
||||||
void runLoopySolver (void);
|
|
||||||
bool converged (void);
|
bool converged (void);
|
||||||
|
|
||||||
|
void printLinkInformation (void) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_BpSolver_H
|
#endif // HORUS_BPSOLVER_H
|
||||||
|
|
||||||
|
@ -18,14 +18,14 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
|||||||
}
|
}
|
||||||
|
|
||||||
const FacNodes& facNodes = fg.facNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
factorSignatures_.reserve (facNodes.size());
|
facSignatures_.reserve (facNodes.size());
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
unsigned c = facNodes[i]->neighbors().size() + 1;
|
unsigned c = facNodes[i]->neighbors().size() + 1;
|
||||||
factorSignatures_.push_back (Signature (c));
|
facSignatures_.push_back (Signature (c));
|
||||||
}
|
}
|
||||||
|
|
||||||
varColors_.resize (varNodes.size());
|
varColors_.resize (varNodes.size());
|
||||||
factorColors_.resize (facNodes.size());
|
facColors_.resize (facNodes.size());
|
||||||
setInitialColors();
|
setInitialColors();
|
||||||
createGroups();
|
createGroups();
|
||||||
}
|
}
|
||||||
@ -111,7 +111,7 @@ void
|
|||||||
CFactorGraph::createGroups (void)
|
CFactorGraph::createGroups (void)
|
||||||
{
|
{
|
||||||
VarSignMap varGroups;
|
VarSignMap varGroups;
|
||||||
FacSignMap factorGroups;
|
FacSignMap facGroups;
|
||||||
unsigned nIters = 0;
|
unsigned nIters = 0;
|
||||||
bool groupsHaveChanged = true;
|
bool groupsHaveChanged = true;
|
||||||
const VarNodes& varNodes = groundFg_->varNodes();
|
const VarNodes& varNodes = groundFg_->varNodes();
|
||||||
@ -120,19 +120,19 @@ CFactorGraph::createGroups (void)
|
|||||||
while (groupsHaveChanged || nIters == 1) {
|
while (groupsHaveChanged || nIters == 1) {
|
||||||
nIters ++;
|
nIters ++;
|
||||||
|
|
||||||
unsigned prevFactorGroupsSize = factorGroups.size();
|
unsigned prevFactorGroupsSize = facGroups.size();
|
||||||
factorGroups.clear();
|
facGroups.clear();
|
||||||
// set a new color to the factors with the same signature
|
// set a new color to the factors with the same signature
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
const Signature& signature = getSignature (facNodes[i]);
|
const Signature& signature = getSignature (facNodes[i]);
|
||||||
FacSignMap::iterator it = factorGroups.find (signature);
|
FacSignMap::iterator it = facGroups.find (signature);
|
||||||
if (it == factorGroups.end()) {
|
if (it == facGroups.end()) {
|
||||||
it = factorGroups.insert (make_pair (signature, FacNodes())).first;
|
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||||
}
|
}
|
||||||
it->second.push_back (facNodes[i]);
|
it->second.push_back (facNodes[i]);
|
||||||
}
|
}
|
||||||
for (FacSignMap::iterator it = factorGroups.begin();
|
for (FacSignMap::iterator it = facGroups.begin();
|
||||||
it != factorGroups.end(); it++) {
|
it != facGroups.end(); it++) {
|
||||||
Color newColor = getFreeColor();
|
Color newColor = getFreeColor();
|
||||||
FacNodes& groupMembers = it->second;
|
FacNodes& groupMembers = it->second;
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
@ -161,10 +161,10 @@ CFactorGraph::createGroups (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||||
|| prevFactorGroupsSize != factorGroups.size();
|
|| prevFactorGroupsSize != facGroups.size();
|
||||||
}
|
}
|
||||||
//printGroups (varGroups, factorGroups);
|
//printGroups (varGroups, facGroups);
|
||||||
createClusters (varGroups, factorGroups);
|
createClusters (varGroups, facGroups);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -172,7 +172,7 @@ CFactorGraph::createGroups (void)
|
|||||||
void
|
void
|
||||||
CFactorGraph::createClusters (
|
CFactorGraph::createClusters (
|
||||||
const VarSignMap& varGroups,
|
const VarSignMap& varGroups,
|
||||||
const FacSignMap& factorGroups)
|
const FacSignMap& facGroups)
|
||||||
{
|
{
|
||||||
varClusters_.reserve (varGroups.size());
|
varClusters_.reserve (varGroups.size());
|
||||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
@ -185,12 +185,12 @@ CFactorGraph::createClusters (
|
|||||||
varClusters_.push_back (vc);
|
varClusters_.push_back (vc);
|
||||||
}
|
}
|
||||||
|
|
||||||
facClusters_.reserve (factorGroups.size());
|
facClusters_.reserve (facGroups.size());
|
||||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||||
it != factorGroups.end(); it++) {
|
it != facGroups.end(); it++) {
|
||||||
FacNode* groupFactor = it->second[0];
|
FacNode* groupFactor = it->second[0];
|
||||||
const VarNodes& neighs = groupFactor->neighbors();
|
const VarNodes& neighs = groupFactor->neighbors();
|
||||||
VarClusterSet varClusters;
|
VarClusters varClusters;
|
||||||
varClusters.reserve (neighs.size());
|
varClusters.reserve (neighs.size());
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
VarId vid = neighs[i]->varId();
|
VarId vid = neighs[i]->varId();
|
||||||
@ -223,7 +223,7 @@ CFactorGraph::getSignature (const VarNode* varNode)
|
|||||||
const Signature&
|
const Signature&
|
||||||
CFactorGraph::getSignature (const FacNode* facNode)
|
CFactorGraph::getSignature (const FacNode* facNode)
|
||||||
{
|
{
|
||||||
Signature& sign = factorSignatures_[facNode->getIndex()];
|
Signature& sign = facSignatures_[facNode->getIndex()];
|
||||||
vector<Color>::iterator it = sign.colors.begin();
|
vector<Color>::iterator it = sign.colors.begin();
|
||||||
const VarNodes& neighs = facNode->neighbors();
|
const VarNodes& neighs = facNode->neighbors();
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
@ -237,7 +237,7 @@ CFactorGraph::getSignature (const FacNode* facNode)
|
|||||||
|
|
||||||
|
|
||||||
FactorGraph*
|
FactorGraph*
|
||||||
CFactorGraph::getCompressedFactorGraph (void)
|
CFactorGraph::getGroundFactorGraph (void) const
|
||||||
{
|
{
|
||||||
FactorGraph* fg = new FactorGraph();
|
FactorGraph* fg = new FactorGraph();
|
||||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||||
@ -248,7 +248,7 @@ CFactorGraph::getCompressedFactorGraph (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
||||||
const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters();
|
const VarClusters& myVarClusters = facClusters_[i]->getVarClusters();
|
||||||
Vars myGroundVars;
|
Vars myGroundVars;
|
||||||
myGroundVars.reserve (myVarClusters.size());
|
myGroundVars.reserve (myVarClusters.size());
|
||||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||||
@ -300,7 +300,7 @@ CFactorGraph::getGroundEdgeCount (
|
|||||||
void
|
void
|
||||||
CFactorGraph::printGroups (
|
CFactorGraph::printGroups (
|
||||||
const VarSignMap& varGroups,
|
const VarSignMap& varGroups,
|
||||||
const FacSignMap& factorGroups) const
|
const FacSignMap& facGroups) const
|
||||||
{
|
{
|
||||||
unsigned count = 1;
|
unsigned count = 1;
|
||||||
cout << "variable groups:" << endl;
|
cout << "variable groups:" << endl;
|
||||||
@ -319,8 +319,8 @@ CFactorGraph::printGroups (
|
|||||||
|
|
||||||
count = 1;
|
count = 1;
|
||||||
cout << endl << "factor groups:" << endl;
|
cout << endl << "factor groups:" << endl;
|
||||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||||
it != factorGroups.end(); it++) {
|
it != facGroups.end(); it++) {
|
||||||
const FacNodes& groupMembers = it->second;
|
const FacNodes& groupMembers = it->second;
|
||||||
if (groupMembers.size() > 0) {
|
if (groupMembers.size() > 0) {
|
||||||
cout << ++count << ": " ;
|
cout << ++count << ": " ;
|
||||||
|
@ -22,8 +22,8 @@ typedef unordered_map<unsigned, vector<Color>> VarColorMap;
|
|||||||
typedef unordered_map<unsigned, Color> DistColorMap;
|
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||||
|
|
||||||
typedef vector<VarCluster*> VarClusterSet;
|
typedef vector<VarCluster*> VarClusters;
|
||||||
typedef vector<FacCluster*> FacClusterSet;
|
typedef vector<FacCluster*> FacClusters;
|
||||||
|
|
||||||
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
|
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
|
||||||
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
|
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
|
||||||
@ -99,18 +99,20 @@ class VarCluster
|
|||||||
facClusters_.push_back (fc);
|
facClusters_.push_back (fc);
|
||||||
}
|
}
|
||||||
|
|
||||||
const FacClusterSet& getFacClusters (void) const
|
const FacClusters& getFacClusters (void) const
|
||||||
{
|
{
|
||||||
return facClusters_;
|
return facClusters_;
|
||||||
}
|
}
|
||||||
|
|
||||||
VarNode* getRepresentativeVariable (void) const { return representVar_; }
|
VarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||||
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
|
|
||||||
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
|
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
|
||||||
|
|
||||||
|
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
VarNodes groundVars_;
|
VarNodes groundVars_;
|
||||||
FacClusterSet facClusters_;
|
FacClusters facClusters_;
|
||||||
VarNode* representVar_;
|
VarNode* representVar_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -118,7 +120,7 @@ class VarCluster
|
|||||||
class FacCluster
|
class FacCluster
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FacCluster (const FacNodes& groundFactors, const VarClusterSet& vcs)
|
FacCluster (const FacNodes& groundFactors, const VarClusters& vcs)
|
||||||
{
|
{
|
||||||
groundFactors_ = groundFactors;
|
groundFactors_ = groundFactors;
|
||||||
varClusters_ = vcs;
|
varClusters_ = vcs;
|
||||||
@ -127,7 +129,7 @@ class FacCluster
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const VarClusterSet& getVarClusters (void) const
|
const VarClusters& getVarClusters (void) const
|
||||||
{
|
{
|
||||||
return varClusters_;
|
return varClusters_;
|
||||||
}
|
}
|
||||||
@ -160,7 +162,7 @@ class FacCluster
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
FacNodes groundFactors_;
|
FacNodes groundFactors_;
|
||||||
VarClusterSet varClusters_;
|
VarClusters varClusters_;
|
||||||
FacNode* representFactor_;
|
FacNode* representFactor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -172,9 +174,9 @@ class CFactorGraph
|
|||||||
|
|
||||||
~CFactorGraph (void);
|
~CFactorGraph (void);
|
||||||
|
|
||||||
const VarClusterSet& getVarClusters (void) { return varClusters_; }
|
const VarClusters& getVarClusters (void) { return varClusters_; }
|
||||||
|
|
||||||
const FacClusterSet& getFacClusters (void) { return facClusters_; }
|
const FacClusters& getFacClusters (void) { return facClusters_; }
|
||||||
|
|
||||||
VarNode* getEquivalentVariable (VarId vid)
|
VarNode* getEquivalentVariable (VarId vid)
|
||||||
{
|
{
|
||||||
@ -182,7 +184,7 @@ class CFactorGraph
|
|||||||
return vc->getRepresentativeVariable();
|
return vc->getRepresentativeVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
FactorGraph* getCompressedFactorGraph (void);
|
FactorGraph* getGroundFactorGraph (void) const;
|
||||||
|
|
||||||
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
|
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
|
||||||
|
|
||||||
@ -200,7 +202,7 @@ class CFactorGraph
|
|||||||
return varColors_[vn->getIndex()];
|
return varColors_[vn->getIndex()];
|
||||||
}
|
}
|
||||||
Color getColor (const FacNode* fn) const {
|
Color getColor (const FacNode* fn) const {
|
||||||
return factorColors_[fn->getIndex()];
|
return facColors_[fn->getIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
void setColor (const VarNode* vn, Color c)
|
void setColor (const VarNode* vn, Color c)
|
||||||
@ -210,7 +212,7 @@ class CFactorGraph
|
|||||||
|
|
||||||
void setColor (const FacNode* fn, Color c)
|
void setColor (const FacNode* fn, Color c)
|
||||||
{
|
{
|
||||||
factorColors_[fn->getIndex()] = c;
|
facColors_[fn->getIndex()] = c;
|
||||||
}
|
}
|
||||||
|
|
||||||
VarCluster* getVariableCluster (VarId vid) const
|
VarCluster* getVariableCluster (VarId vid) const
|
||||||
@ -232,11 +234,11 @@ class CFactorGraph
|
|||||||
|
|
||||||
Color freeColor_;
|
Color freeColor_;
|
||||||
vector<Color> varColors_;
|
vector<Color> varColors_;
|
||||||
vector<Color> factorColors_;
|
vector<Color> facColors_;
|
||||||
vector<Signature> varSignatures_;
|
vector<Signature> varSignatures_;
|
||||||
vector<Signature> factorSignatures_;
|
vector<Signature> facSignatures_;
|
||||||
VarClusterSet varClusters_;
|
VarClusters varClusters_;
|
||||||
FacClusterSet facClusters_;
|
FacClusters facClusters_;
|
||||||
VarId2VarCluster vid2VarCluster_;
|
VarId2VarCluster vid2VarCluster_;
|
||||||
const FactorGraph* groundFg_;
|
const FactorGraph* groundFg_;
|
||||||
};
|
};
|
||||||
|
@ -1,10 +1,41 @@
|
|||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
|
|
||||||
|
|
||||||
|
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
||||||
|
{
|
||||||
|
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
||||||
|
if (Constants::COLLECT_STATS) {
|
||||||
|
nGroundVars = fg_->varNodes().size();
|
||||||
|
nGroundFacs = fg_->facNodes().size();
|
||||||
|
const VarNodes& vars = fg_->varNodes();
|
||||||
|
nWithoutNeighs = 0;
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
const FacNodes& factors = vars[i]->neighbors();
|
||||||
|
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
||||||
|
nWithoutNeighs ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfg_ = new CFactorGraph (fg);
|
||||||
|
fg_ = cfg_->getGroundFactorGraph();
|
||||||
|
if (Constants::COLLECT_STATS) {
|
||||||
|
unsigned nClusterVars = fg_->varNodes().size();
|
||||||
|
unsigned nClusterFacs = fg_->facNodes().size();
|
||||||
|
Statistics::updateCompressingStatistics (nGroundVars,
|
||||||
|
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
|
||||||
|
}
|
||||||
|
// Util::printHeader ("Uncompressed Factor Graph");
|
||||||
|
// fg->print();
|
||||||
|
// Util::printHeader ("Compressed Factor Graph");
|
||||||
|
// fg_->print();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CbpSolver::~CbpSolver (void)
|
CbpSolver::~CbpSolver (void)
|
||||||
{
|
{
|
||||||
delete lfg_;
|
delete cfg_;
|
||||||
delete factorGraph_;
|
delete fg_;
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
delete links_[i];
|
delete links_[i];
|
||||||
}
|
}
|
||||||
@ -16,8 +47,11 @@ CbpSolver::~CbpSolver (void)
|
|||||||
Params
|
Params
|
||||||
CbpSolver::getPosterioriOf (VarId vid)
|
CbpSolver::getPosterioriOf (VarId vid)
|
||||||
{
|
{
|
||||||
assert (lfg_->getEquivalentVariable (vid));
|
if (runned_ == false) {
|
||||||
VarNode* var = lfg_->getEquivalentVariable (vid);
|
runSolver();
|
||||||
|
}
|
||||||
|
assert (cfg_->getEquivalentVariable (vid));
|
||||||
|
VarNode* var = cfg_->getEquivalentVariable (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->range(), LogAware::noEvidence());
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
@ -26,16 +60,16 @@ CbpSolver::getPosterioriOf (VarId vid)
|
|||||||
probs.resize (var->range(), LogAware::multIdenty());
|
probs.resize (var->range(), LogAware::multIdenty());
|
||||||
const SpLinkSet& links = ninf(var)->getLinks();
|
const SpLinkSet& links = ninf(var)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::add (probs, l->getPoweredMessage());
|
Util::add (probs, l->poweredMessage());
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
Util::fromLog (probs);
|
Util::fromLog (probs);
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::multiply (probs, l->getPoweredMessage());
|
Util::multiply (probs, l->poweredMessage());
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
}
|
}
|
||||||
@ -46,67 +80,28 @@ CbpSolver::getPosterioriOf (VarId vid)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
CbpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
||||||
{
|
{
|
||||||
VarIds eqVarIds;
|
VarIds eqVarIds;
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||||
eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId());
|
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
|
||||||
|
eqVarIds.push_back (vn->varId());
|
||||||
}
|
}
|
||||||
return BpSolver::getJointDistributionOf (eqVarIds);
|
return BpSolver::getJointDistributionOf (eqVarIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CbpSolver::initializeSolver (void)
|
|
||||||
{
|
|
||||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
nGroundVars = factorGraph_->varNodes().size();
|
|
||||||
nGroundFacs = factorGraph_->facNodes().size();
|
|
||||||
const VarNodes& vars = factorGraph_->varNodes();
|
|
||||||
nWithoutNeighs = 0;
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
const FacNodes& factors = vars[i]->neighbors();
|
|
||||||
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
|
||||||
nWithoutNeighs ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
lfg_ = new CFactorGraph (*factorGraph_);
|
|
||||||
|
|
||||||
// cout << "Uncompressed Factor Graph" << endl;
|
|
||||||
// factorGraph_->print();
|
|
||||||
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
|
|
||||||
factorGraph_ = lfg_->getCompressedFactorGraph();
|
|
||||||
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
unsigned nClusterVars = factorGraph_->varNodes().size();
|
|
||||||
unsigned nClusterFacs = factorGraph_->facNodes().size();
|
|
||||||
Statistics::updateCompressingStatistics (nGroundVars,
|
|
||||||
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// cout << "Compressed Factor Graph" << endl;
|
|
||||||
// factorGraph_->print();
|
|
||||||
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
|
|
||||||
// abort();
|
|
||||||
BpSolver::initializeSolver();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CbpSolver::createLinks (void)
|
CbpSolver::createLinks (void)
|
||||||
{
|
{
|
||||||
const FacClusterSet fcs = lfg_->getFacClusters();
|
const FacClusters& fcs = cfg_->getFacClusters();
|
||||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||||
const VarClusterSet vcs = fcs[i]->getVarClusters();
|
const VarClusters& vcs = fcs[i]->getVarClusters();
|
||||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||||
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
|
unsigned c = cfg_->getGroundEdgeCount (fcs[i], vcs[j]);
|
||||||
links_.push_back (new CbpSolverLink (fcs[i]->getRepresentativeFactor(),
|
links_.push_back (new CbpSolverLink (
|
||||||
|
fcs[i]->getRepresentativeFactor(),
|
||||||
vcs[j]->getRepresentativeVariable(), c));
|
vcs[j]->getRepresentativeVariable(), c));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -197,10 +192,10 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->range(), LogAware::noEvidence());
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
double value = link->getMessage()[src->getEvidence()];
|
double value = link->getMessage()[src->getEvidence()];
|
||||||
msg[src->getEvidence()] = LogAware::pow (value, l->getNumberOfEdges() - 1);
|
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
|
||||||
} else {
|
} else {
|
||||||
msg = link->getMessage();
|
msg = link->getMessage();
|
||||||
LogAware::pow (msg, l->getNumberOfEdges() - 1);
|
LogAware::pow (msg, l->nrEdges() - 1);
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " " << "init: " << msg << endl;
|
cout << " " << "init: " << msg << endl;
|
||||||
@ -210,17 +205,17 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getFactor() != dst) {
|
if (links[i]->getFactor() != dst) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::add (msg, l->getPoweredMessage());
|
Util::add (msg, l->poweredMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getFactor() != dst) {
|
if (links[i]->getFactor() != dst) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::multiply (msg, l->getPoweredMessage());
|
Util::multiply (msg, l->poweredMessage());
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
||||||
cout << l->getPoweredMessage() << endl;
|
cout << l->poweredMessage() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -242,7 +237,7 @@ CbpSolver::printLinkInformation (void) const
|
|||||||
cout << l->toString() << ":" << endl;
|
cout << l->toString() << ":" << endl;
|
||||||
cout << " curr msg = " << l->getMessage() << endl;
|
cout << " curr msg = " << l->getMessage() << endl;
|
||||||
cout << " next msg = " << l->getNextMessage() << endl;
|
cout << " next msg = " << l->getNextMessage() << endl;
|
||||||
cout << " powered = " << l->getPoweredMessage() << endl;
|
cout << " powered = " << l->poweredMessage() << endl;
|
||||||
cout << " residual = " << l->getResidual() << endl;
|
cout << " residual = " << l->getResidual() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,27 +9,25 @@ class Factor;
|
|||||||
class CbpSolverLink : public SpLink
|
class CbpSolverLink : public SpLink
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c) : SpLink (fn, vn)
|
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
|
||||||
{
|
: SpLink (fn, vn), nrEdges_(c),
|
||||||
edgeCount_ = c;
|
pwdMsg_(vn->range(), LogAware::one()) { }
|
||||||
poweredMsg_.resize (vn->range(), LogAware::one());
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
unsigned nrEdges (void) const { return nrEdges_; }
|
||||||
|
|
||||||
const Params& getPoweredMessage (void) const { return poweredMsg_; }
|
const Params& poweredMessage (void) const { return pwdMsg_; }
|
||||||
|
|
||||||
void updateMessage (void)
|
void updateMessage (void)
|
||||||
{
|
{
|
||||||
poweredMsg_ = *nextMsg_;
|
pwdMsg_ = *nextMsg_;
|
||||||
swap (currMsg_, nextMsg_);
|
swap (currMsg_, nextMsg_);
|
||||||
msgSended_ = true;
|
msgSended_ = true;
|
||||||
LogAware::pow (poweredMsg_, edgeCount_);
|
LogAware::pow (pwdMsg_, nrEdges_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Params poweredMsg_;
|
unsigned nrEdges_;
|
||||||
unsigned edgeCount_;
|
Params pwdMsg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -37,16 +35,15 @@ class CbpSolverLink : public SpLink
|
|||||||
class CbpSolver : public BpSolver
|
class CbpSolver : public BpSolver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CbpSolver (const FactorGraph& fg) : BpSolver (fg) { }
|
CbpSolver (const FactorGraph& fg);
|
||||||
|
|
||||||
~CbpSolver (void);
|
~CbpSolver (void);
|
||||||
|
|
||||||
Params getPosterioriOf (VarId);
|
Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
Params getJointDistributionOf (const VarIds&);
|
Params getJointDistributionOf (const VarIds&);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void initializeSolver (void);
|
|
||||||
|
|
||||||
void createLinks (void);
|
void createLinks (void);
|
||||||
|
|
||||||
@ -56,8 +53,7 @@ class CbpSolver : public BpSolver
|
|||||||
|
|
||||||
void printLinkInformation (void) const;
|
void printLinkInformation (void) const;
|
||||||
|
|
||||||
CFactorGraph* lfg_;
|
CFactorGraph* cfg_;
|
||||||
FactorGraph* factorGraph_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_CBP_H
|
#endif // HORUS_CBP_H
|
||||||
|
@ -14,25 +14,43 @@ void processArguments (FactorGraph&, int, const char* []);
|
|||||||
void runSolver (const FactorGraph&, const VarIds&);
|
void runSolver (const FactorGraph&, const VarIds&);
|
||||||
|
|
||||||
const string USAGE = "usage: \
|
const string USAGE = "usage: \
|
||||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
main (int argc, const char* argv[])
|
main (int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
if (!argv[1]) {
|
if (argc <= 1) {
|
||||||
|
cerr << "error: no solver specified" << endl;
|
||||||
cerr << "error: no graphical model specified" << endl;
|
cerr << "error: no graphical model specified" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
string fileName = argv[1];
|
if (argc <= 2) {
|
||||||
|
cerr << "error: no graphical model specified" << endl;
|
||||||
|
cerr << USAGE << endl;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
|
string solver (argv[1]);
|
||||||
|
if (solver == "ve") {
|
||||||
|
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||||
|
} else if (solver == "bp") {
|
||||||
|
Globals::infAlgorithm = InfAlgorithms::BP;
|
||||||
|
} else if (solver == "cbp") {
|
||||||
|
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||||
|
} else {
|
||||||
|
cerr << "error: unknow solver `" << solver << "'" << endl ;
|
||||||
|
cerr << USAGE << endl;
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
string fileName (argv[2]);
|
||||||
string extension = fileName.substr (
|
string extension = fileName.substr (
|
||||||
fileName.find_last_of ('.') + 1);
|
fileName.find_last_of ('.') + 1);
|
||||||
FactorGraph fg;
|
FactorGraph fg;
|
||||||
if (extension == "uai") {
|
if (extension == "uai") {
|
||||||
fg.readFromUaiFormat (argv[1]);
|
fg.readFromUaiFormat (fileName.c_str());
|
||||||
} else if (extension == "fg") {
|
} else if (extension == "fg") {
|
||||||
fg.readFromLibDaiFormat (argv[1]);
|
fg.readFromLibDaiFormat (fileName.c_str());
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: the graphical model must be defined either " ;
|
cerr << "error: the graphical model must be defined either " ;
|
||||||
cerr << "in a UAI or libDAI file" << endl;
|
cerr << "in a UAI or libDAI file" << endl;
|
||||||
@ -48,7 +66,7 @@ void
|
|||||||
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
VarIds queryIds;
|
VarIds queryIds;
|
||||||
for (int i = 2; i < argc; i++) {
|
for (int i = 3; i < argc; i++) {
|
||||||
const string& arg = argv[i];
|
const string& arg = argv[i];
|
||||||
if (arg.find ('=') == std::string::npos) {
|
if (arg.find ('=') == std::string::npos) {
|
||||||
if (!Util::isInteger (arg)) {
|
if (!Util::isInteger (arg)) {
|
||||||
|
@ -8,7 +8,7 @@ Solver::printAnswer (const VarIds& vids)
|
|||||||
Vars unobservedVars;
|
Vars unobservedVars;
|
||||||
VarIds unobservedVids;
|
VarIds unobservedVids;
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
VarNode* vn = fg_.getVarNode (vids[i]);
|
VarNode* vn = fg.getVarNode (vids[i]);
|
||||||
if (vn->hasEvidence() == false) {
|
if (vn->hasEvidence() == false) {
|
||||||
unobservedVars.push_back (vn);
|
unobservedVars.push_back (vn);
|
||||||
unobservedVids.push_back (vids[i]);
|
unobservedVids.push_back (vids[i]);
|
||||||
@ -29,7 +29,7 @@ Solver::printAnswer (const VarIds& vids)
|
|||||||
void
|
void
|
||||||
Solver::printAllPosterioris (void)
|
Solver::printAllPosterioris (void)
|
||||||
{
|
{
|
||||||
const VarNodes& vars = fg_.varNodes();
|
const VarNodes& vars = fg.varNodes();
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
printAnswer ({vars[i]->varId()});
|
printAnswer ({vars[i]->varId()});
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,7 @@ using namespace std;
|
|||||||
class Solver
|
class Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Solver (const FactorGraph& fg) : fg_(fg) { }
|
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
|
||||||
|
|
||||||
virtual ~Solver() { } // ensure that subclass destructor is called
|
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ class Solver
|
|||||||
void printAllPosterioris (void);
|
void printAllPosterioris (void);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const FactorGraph& fg_;
|
const FactorGraph& fg;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_SOLVER_H
|
#endif // HORUS_SOLVER_H
|
||||||
|
@ -35,7 +35,7 @@ VarElimSolver::solveQuery (VarIds queryVids)
|
|||||||
void
|
void
|
||||||
VarElimSolver::createFactorList (void)
|
VarElimSolver::createFactorList (void)
|
||||||
{
|
{
|
||||||
const FacNodes& facNodes = fg_.facNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
factorList_.reserve (facNodes.size() * 2);
|
factorList_.reserve (facNodes.size() * 2);
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
||||||
@ -57,7 +57,7 @@ VarElimSolver::createFactorList (void)
|
|||||||
void
|
void
|
||||||
VarElimSolver::absorveEvidence (void)
|
VarElimSolver::absorveEvidence (void)
|
||||||
{
|
{
|
||||||
const VarNodes& varNodes = fg_.varNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
if (varNodes[i]->hasEvidence()) {
|
if (varNodes[i]->hasEvidence()) {
|
||||||
const vector<unsigned>& idxs =
|
const vector<unsigned>& idxs =
|
||||||
@ -103,7 +103,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
|||||||
|
|
||||||
VarIds unobservedVids;
|
VarIds unobservedVids;
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
if (fg_.getVarNode (vids[i])->hasEvidence() == false) {
|
if (fg.getVarNode (vids[i])->hasEvidence() == false) {
|
||||||
unobservedVids.push_back (vids[i]);
|
unobservedVids.push_back (vids[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user