drop suport for Pearl belief propagation
This commit is contained in:
parent
0d23591058
commit
abb0410d07
@ -14,8 +14,8 @@
|
|||||||
void
|
void
|
||||||
DAGraph::addNode (DAGraphNode* n)
|
DAGraph::addNode (DAGraphNode* n)
|
||||||
{
|
{
|
||||||
nodes_.push_back (n);
|
|
||||||
assert (Util::contains (varMap_, n->varId()) == false);
|
assert (Util::contains (varMap_, n->varId()) == false);
|
||||||
|
nodes_.push_back (n);
|
||||||
varMap_[n->varId()] = n;
|
varMap_[n->varId()] = n;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,793 +0,0 @@
|
|||||||
#include <cstdlib>
|
|
||||||
#include <limits>
|
|
||||||
#include <time.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <sstream>
|
|
||||||
#include <iomanip>
|
|
||||||
|
|
||||||
#include "BnBpSolver.h"
|
|
||||||
#include "Indexer.h"
|
|
||||||
|
|
||||||
BnBpSolver::BnBpSolver (const BayesNet& bn) : Solver (&bn)
|
|
||||||
{
|
|
||||||
bayesNet_ = &bn;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BnBpSolver::~BnBpSolver (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
|
||||||
delete nodesI_[i];
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
delete links_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::runSolver (void)
|
|
||||||
{
|
|
||||||
clock_t start;
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
start = clock();
|
|
||||||
}
|
|
||||||
initializeSolver();
|
|
||||||
runLoopySolver();
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
|
||||||
cout << "Belief propagation converged in " ;
|
|
||||||
cout << nIters_ << " iterations" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned size = bayesNet_->nrNodes();
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
unsigned nIters = 0;
|
|
||||||
bool loopy = bayesNet_->isPolyTree() == false;
|
|
||||||
if (loopy) nIters = nIters_;
|
|
||||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
|
||||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BnBpSolver::getPosterioriOf (VarId vid)
|
|
||||||
{
|
|
||||||
BayesNode* node = bayesNet_->getBayesNode (vid);
|
|
||||||
assert (node);
|
|
||||||
return nodesI_[node->getIndex()]->getBeliefs();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BnBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|
||||||
{
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << "calculating joint distribution on: " ;
|
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
|
||||||
VarNode* var = bayesNet_->getBayesNode (jointVarIds[i]);
|
|
||||||
cout << var->label() << " " ;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
return getJointByConditioning (jointVarIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::initializeSolver (void)
|
|
||||||
{
|
|
||||||
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
|
|
||||||
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
|
||||||
delete nodesI_[i];
|
|
||||||
}
|
|
||||||
nodesI_.clear();
|
|
||||||
nodesI_.reserve (nodes.size());
|
|
||||||
links_.clear();
|
|
||||||
sortedOrder_.clear();
|
|
||||||
linkMap_.clear();
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
nodesI_.push_back (new BpNodeInfo (nodes[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
BnNodeSet roots = bayesNet_->getRootNodes();
|
|
||||||
for (unsigned i = 0; i < roots.size(); i++) {
|
|
||||||
const Params& params = roots[i]->params();
|
|
||||||
Params& piVals = ninf(roots[i])->getPiValues();
|
|
||||||
for (unsigned ri = 0; ri < roots[i]->range(); ri++) {
|
|
||||||
piVals[ri] = params[ri];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
const BnNodeSet& parents = nodes[i]->getParents();
|
|
||||||
for (unsigned j = 0; j < parents.size(); j++) {
|
|
||||||
BpLink* newLink = new BpLink (
|
|
||||||
parents[j], nodes[i], LinkOrientation::DOWN);
|
|
||||||
links_.push_back (newLink);
|
|
||||||
ninf(nodes[i])->addIncomingParentLink (newLink);
|
|
||||||
ninf(parents[j])->addOutcomingChildLink (newLink);
|
|
||||||
}
|
|
||||||
const BnNodeSet& childs = nodes[i]->getChilds();
|
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
|
||||||
BpLink* newLink = new BpLink (
|
|
||||||
childs[j], nodes[i], LinkOrientation::UP);
|
|
||||||
links_.push_back (newLink);
|
|
||||||
ninf(nodes[i])->addIncomingChildLink (newLink);
|
|
||||||
ninf(childs[j])->addOutcomingParentLink (newLink);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
if (nodes[i]->hasEvidence()) {
|
|
||||||
Params& piVals = ninf(nodes[i])->getPiValues();
|
|
||||||
Params& ldVals = ninf(nodes[i])->getLambdaValues();
|
|
||||||
for (unsigned xi = 0; xi < nodes[i]->range(); xi++) {
|
|
||||||
piVals[xi] = LogAware::noEvidence();
|
|
||||||
ldVals[xi] = LogAware::noEvidence();
|
|
||||||
}
|
|
||||||
piVals[nodes[i]->getEvidence()] = LogAware::withEvidence();
|
|
||||||
ldVals[nodes[i]->getEvidence()] = LogAware::withEvidence();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::runLoopySolver()
|
|
||||||
{
|
|
||||||
nIters_ = 0;
|
|
||||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
|
||||||
|
|
||||||
nIters_++;
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
Util::printHeader ("Iteration " + nIters_);
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (BpOptions::schedule) {
|
|
||||||
|
|
||||||
case BpOptions::Schedule::SEQ_RANDOM:
|
|
||||||
random_shuffle (links_.begin(), links_.end());
|
|
||||||
// no break
|
|
||||||
|
|
||||||
case BpOptions::Schedule::SEQ_FIXED:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateAndUpdateMessage (links_[i]);
|
|
||||||
updateValues (links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case BpOptions::Schedule::PARALLEL:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
updateMessage (links_[i]);
|
|
||||||
updateValues (links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
|
||||||
maxResidualSchedule();
|
|
||||||
break;
|
|
||||||
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BnBpSolver::converged (void) const
|
|
||||||
{
|
|
||||||
// this can happen if the graph is fully disconnected
|
|
||||||
if (links_.size() == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (nIters_ == 0 || nIters_ == 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool converged = true;
|
|
||||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
|
||||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
|
||||||
if (maxResidual < BpOptions::accuracy) {
|
|
||||||
converged = true;
|
|
||||||
} else {
|
|
||||||
converged = false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
double residual = links_[i]->getResidual();
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << links_[i]->toString() + " residual change = " ;
|
|
||||||
cout << residual << endl;
|
|
||||||
}
|
|
||||||
if (residual > BpOptions::accuracy) {
|
|
||||||
converged = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return converged;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::maxResidualSchedule (void)
|
|
||||||
{
|
|
||||||
if (nIters_ == 1) {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
|
||||||
linkMap_.insert (make_pair (links_[i], it));
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << "current residuals:" << endl;
|
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
|
||||||
it != sortedOrder_.end(); it ++) {
|
|
||||||
cout << " " << setw (30) << left << (*it)->toString();
|
|
||||||
cout << "residual = " << (*it)->getResidual() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
|
||||||
BpLink* link = *it;
|
|
||||||
if (link->getResidual() < BpOptions::accuracy) {
|
|
||||||
sortedOrder_.erase (it);
|
|
||||||
it = sortedOrder_.begin();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
updateMessage (link);
|
|
||||||
updateValues (link);
|
|
||||||
link->clearResidual();
|
|
||||||
sortedOrder_.erase (it);
|
|
||||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
|
||||||
|
|
||||||
const BpLinkSet& outParentLinks =
|
|
||||||
ninf(link->getDestination())->getOutcomingParentLinks();
|
|
||||||
for (unsigned i = 0; i < outParentLinks.size(); i++) {
|
|
||||||
if (outParentLinks[i]->getDestination() != link->getSource()
|
|
||||||
&& outParentLinks[i]->getDestination()->hasEvidence() == false) {
|
|
||||||
calculateMessage (outParentLinks[i]);
|
|
||||||
BpLinkMap::iterator iter = linkMap_.find (outParentLinks[i]);
|
|
||||||
sortedOrder_.erase (iter->second);
|
|
||||||
iter->second = sortedOrder_.insert (outParentLinks[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const BpLinkSet& outChildLinks =
|
|
||||||
ninf(link->getDestination())->getOutcomingChildLinks();
|
|
||||||
for (unsigned i = 0; i < outChildLinks.size(); i++) {
|
|
||||||
if (outChildLinks[i]->getDestination() != link->getSource()) {
|
|
||||||
calculateMessage (outChildLinks[i]);
|
|
||||||
BpLinkMap::iterator iter = linkMap_.find (outChildLinks[i]);
|
|
||||||
sortedOrder_.erase (iter->second);
|
|
||||||
iter->second = sortedOrder_.insert (outChildLinks[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
Util::printDashedLine();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::updatePiValues (BayesNode* x)
|
|
||||||
{
|
|
||||||
// π(Xi)
|
|
||||||
if (Constants::DEBUG >= 3) {
|
|
||||||
cout << "updating " << PI_SYMBOL << " values for " << x->label() << endl;
|
|
||||||
}
|
|
||||||
Params& piValues = ninf(x)->getPiValues();
|
|
||||||
const BpLinkSet& parentLinks = ninf(x)->getIncomingParentLinks();
|
|
||||||
const BnNodeSet& ps = x->getParents();
|
|
||||||
Ranges ranges;
|
|
||||||
for (unsigned i = 0; i < ps.size(); i++) {
|
|
||||||
ranges.push_back (ps[i]->range());
|
|
||||||
}
|
|
||||||
StatesIndexer indexer (ranges, false);
|
|
||||||
stringstream* calcs1 = 0;
|
|
||||||
stringstream* calcs2 = 0;
|
|
||||||
|
|
||||||
Params messageProducts (indexer.size());
|
|
||||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
double messageProduct = LogAware::multIdenty();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
|
||||||
messageProduct += parentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
|
||||||
messageProduct *= parentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
if (i != 0) *calcs1 << " + " ;
|
|
||||||
if (i != 0) *calcs2 << " + " ;
|
|
||||||
*calcs1 << parentLinks[i]->toString (indexer[i]);
|
|
||||||
*calcs2 << parentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
messageProducts[k] = messageProduct;
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " mp" << k;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
if (parentLinks.size() == 1) {
|
|
||||||
cout << " = " << messageProduct << endl;
|
|
||||||
} else {
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << messageProduct << endl;
|
|
||||||
}
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned xi = 0; xi < x->range(); xi++) {
|
|
||||||
double sum = LogAware::addIdenty();
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
indexer.reset();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
|
||||||
sum = Util::logSum (sum,
|
|
||||||
x->getProbability(xi, indexer) + messageProducts[k]);
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
|
||||||
sum += x->getProbability (xi, indexer) * messageProducts[k];
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
if (k != 0) *calcs1 << " + " ;
|
|
||||||
if (k != 0) *calcs2 << " + " ;
|
|
||||||
*calcs1 << x->cptEntryToString (xi, indexer.indices());
|
|
||||||
*calcs1 << ".mp" << k;
|
|
||||||
*calcs2 << LogAware::fl (x->getProbability (xi, indexer));
|
|
||||||
*calcs2 << "*" << messageProducts[k];
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
piValues[xi] = sum;
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " " << PI_SYMBOL << "(" << x->label() << ")" ;
|
|
||||||
cout << "[" << x->states()[xi] << "]" ;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << piValues[xi] << endl;
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::updateLambdaValues (BayesNode* x)
|
|
||||||
{
|
|
||||||
// λ(Xi)
|
|
||||||
if (Constants::DEBUG >= 3) {
|
|
||||||
cout << "updating " << LD_SYMBOL << " values for " << x->label() << endl;
|
|
||||||
}
|
|
||||||
Params& lambdaValues = ninf(x)->getLambdaValues();
|
|
||||||
const BpLinkSet& childLinks = ninf(x)->getIncomingChildLinks();
|
|
||||||
stringstream* calcs1 = 0;
|
|
||||||
stringstream* calcs2 = 0;
|
|
||||||
|
|
||||||
for (unsigned xi = 0; xi < x->range(); xi++) {
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
double product = LogAware::multIdenty();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
|
||||||
product += childLinks[i]->getMessage()[xi];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
|
||||||
product *= childLinks[i]->getMessage()[xi];
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
if (i != 0) *calcs1 << "." ;
|
|
||||||
if (i != 0) *calcs2 << "*" ;
|
|
||||||
*calcs1 << childLinks[i]->toString (xi);
|
|
||||||
*calcs2 << childLinks[i]->getMessage()[xi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
lambdaValues[xi] = product;
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " " << LD_SYMBOL << "(" << x->label() << ")" ;
|
|
||||||
cout << "[" << x->states()[xi] << "]" ;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
if (childLinks.size() == 1) {
|
|
||||||
cout << " = " << product << endl;
|
|
||||||
} else {
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << lambdaValues[xi] << endl;
|
|
||||||
}
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::calculatePiMessage (BpLink* link)
|
|
||||||
{
|
|
||||||
// πX(Zi)
|
|
||||||
BayesNode* z = link->getSource();
|
|
||||||
BayesNode* x = link->getDestination();
|
|
||||||
Params& zxPiNextMessage = link->getNextMessage();
|
|
||||||
const BpLinkSet& zChildLinks = ninf(z)->getIncomingChildLinks();
|
|
||||||
stringstream* calcs1 = 0;
|
|
||||||
stringstream* calcs2 = 0;
|
|
||||||
|
|
||||||
const Params& zPiValues = ninf(z)->getPiValues();
|
|
||||||
for (unsigned zi = 0; zi < z->range(); zi++) {
|
|
||||||
double product = zPiValues[zi];
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
*calcs1 << PI_SYMBOL << "(" << z->label() << ")";
|
|
||||||
*calcs1 << "[" << z->states()[zi] << "]" ;
|
|
||||||
*calcs2 << product;
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
|
||||||
if (zChildLinks[i]->getSource() != x) {
|
|
||||||
product += zChildLinks[i]->getMessage()[zi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
|
||||||
if (zChildLinks[i]->getSource() != x) {
|
|
||||||
product *= zChildLinks[i]->getMessage()[zi];
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
*calcs1 << "." << zChildLinks[i]->toString (zi);
|
|
||||||
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
zxPiNextMessage[zi] = product;
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " " << link->toString();
|
|
||||||
cout << "[" << z->states()[zi] << "]" ;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
if (zChildLinks.size() == 1) {
|
|
||||||
cout << " = " << product << endl;
|
|
||||||
} else {
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << product << endl;
|
|
||||||
}
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
LogAware::normalize (zxPiNextMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::calculateLambdaMessage (BpLink* link)
|
|
||||||
{
|
|
||||||
// λY(Xi)
|
|
||||||
BayesNode* y = link->getSource();
|
|
||||||
BayesNode* x = link->getDestination();
|
|
||||||
if (x->hasEvidence()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Params& yxLambdaNextMessage = link->getNextMessage();
|
|
||||||
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
|
|
||||||
const Params& yLambdaValues = ninf(y)->getLambdaValues();
|
|
||||||
int parentIndex = y->indexOfParent (x);
|
|
||||||
stringstream* calcs1 = 0;
|
|
||||||
stringstream* calcs2 = 0;
|
|
||||||
|
|
||||||
const BnNodeSet& ps = y->getParents();
|
|
||||||
Ranges ranges;
|
|
||||||
for (unsigned i = 0; i < ps.size(); i++) {
|
|
||||||
ranges.push_back (ps[i]->range());
|
|
||||||
}
|
|
||||||
StatesIndexer indexer (ranges, false);
|
|
||||||
|
|
||||||
|
|
||||||
unsigned N = indexer.size() / x->range();
|
|
||||||
Params messageProducts (N);
|
|
||||||
for (unsigned k = 0; k < N; k++) {
|
|
||||||
while (indexer[parentIndex] != 0) {
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
double messageProduct = LogAware::multIdenty();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
|
||||||
if (yParentLinks[i]->getSource() != x) {
|
|
||||||
messageProduct += yParentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
|
||||||
if (yParentLinks[i]->getSource() != x) {
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
if (messageProduct != LogAware::multIdenty()) *calcs1 << "*" ;
|
|
||||||
if (messageProduct != LogAware::multIdenty()) *calcs2 << "*" ;
|
|
||||||
*calcs1 << yParentLinks[i]->toString (indexer[i]);
|
|
||||||
*calcs2 << yParentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
messageProduct *= yParentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
messageProducts[k] = messageProduct;
|
|
||||||
++ indexer;
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " mp" << k;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
if (yParentLinks.size() == 1) {
|
|
||||||
cout << 1 << endl;
|
|
||||||
} else if (yParentLinks.size() == 2) {
|
|
||||||
cout << " = " << messageProduct << endl;
|
|
||||||
} else {
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << messageProduct << endl;
|
|
||||||
}
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned xi = 0; xi < x->range(); xi++) {
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
double outerSum = LogAware::addIdenty();
|
|
||||||
for (unsigned yi = 0; yi < y->range(); yi++) {
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
|
|
||||||
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
|
|
||||||
}
|
|
||||||
double innerSum = LogAware::addIdenty();
|
|
||||||
indexer.reset();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned k = 0; k < N; k++) {
|
|
||||||
while (indexer[parentIndex] != xi) {
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
innerSum = Util::logSum (innerSum,
|
|
||||||
y->getProbability (yi, indexer) + messageProducts[k]);
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
outerSum = Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
|
|
||||||
} else {
|
|
||||||
for (unsigned k = 0; k < N; k++) {
|
|
||||||
while (indexer[parentIndex] != xi) {
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
if (k != 0) *calcs1 << " + " ;
|
|
||||||
if (k != 0) *calcs2 << " + " ;
|
|
||||||
*calcs1 << y->cptEntryToString (yi, indexer.indices());
|
|
||||||
*calcs1 << ".mp" << k;
|
|
||||||
*calcs2 << y->getProbability (yi, indexer);
|
|
||||||
*calcs2 << "*" << messageProducts[k];
|
|
||||||
}
|
|
||||||
innerSum += y->getProbability (yi, indexer) * messageProducts[k];
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
outerSum += innerSum * yLambdaValues[yi];
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
*calcs1 << "}." << LD_SYMBOL << "(" << y->label() << ")" ;
|
|
||||||
*calcs1 << "[" << y->states()[yi] << "]";
|
|
||||||
*calcs2 << "}*" << yLambdaValues[yi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
yxLambdaNextMessage[xi] = outerSum;
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " " << link->toString();
|
|
||||||
cout << "[" << x->states()[xi] << "]" ;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << yxLambdaNextMessage[xi] << endl;
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
LogAware::normalize (yxLambdaNextMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BnBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|
||||||
{
|
|
||||||
/*
|
|
||||||
BnNodeSet jointVars;
|
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
|
||||||
assert (bayesNet_->getBayesNode (jointVarIds[i]));
|
|
||||||
jointVars.push_back (bayesNet_->getBayesNode (jointVarIds[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNet* mrn = bayesNet_->getMinimalRequesiteNetwork (jointVarIds[0]);
|
|
||||||
BnBpSolver solver (*mrn);
|
|
||||||
solver.runSolver();
|
|
||||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
|
||||||
delete mrn;
|
|
||||||
|
|
||||||
VarIds observedVids = {jointVars[0]->varId()};
|
|
||||||
|
|
||||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
|
||||||
assert (jointVars[i]->hasEvidence() == false);
|
|
||||||
VarIds reqVars = {jointVarIds[i]};
|
|
||||||
Util::addToVector (reqVars, observedVids);
|
|
||||||
mrn = bayesNet_->getMinimalRequesiteNetwork (reqVars);
|
|
||||||
Params newBeliefs;
|
|
||||||
VarNodes observedVars;
|
|
||||||
for (unsigned j = 0; j < observedVids.size(); j++) {
|
|
||||||
observedVars.push_back (mrn->getBayesNode (observedVids[j]));
|
|
||||||
}
|
|
||||||
StatesIndexer idx (observedVars, false);
|
|
||||||
while (idx.valid()) {
|
|
||||||
for (unsigned j = 0; j < observedVars.size(); j++) {
|
|
||||||
observedVars[j]->setEvidence (idx[j]);
|
|
||||||
}
|
|
||||||
BnBpSolver solver (*mrn);
|
|
||||||
solver.runSolver();
|
|
||||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
|
||||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
|
||||||
newBeliefs.push_back (beliefs[k]);
|
|
||||||
}
|
|
||||||
++ idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
int count = -1;
|
|
||||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
|
||||||
if (j % jointVars[i]->range() == 0) {
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
newBeliefs[j] *= prevBeliefs[count];
|
|
||||||
}
|
|
||||||
prevBeliefs = newBeliefs;
|
|
||||||
observedVids.push_back (jointVars[i]->varId());
|
|
||||||
delete mrn;
|
|
||||||
}
|
|
||||||
return prevBeliefs;
|
|
||||||
*/
|
|
||||||
return Params();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::printPiLambdaValues (const BayesNode* var) const
|
|
||||||
{
|
|
||||||
cout << left;
|
|
||||||
cout << setw (10) << "states" ;
|
|
||||||
cout << setw (20) << PI_SYMBOL << "(" + var->label() + ")" ;
|
|
||||||
cout << setw (20) << LD_SYMBOL << "(" + var->label() + ")" ;
|
|
||||||
cout << setw (16) << "belief" ;
|
|
||||||
cout << endl;
|
|
||||||
Util::printDashedLine();
|
|
||||||
cout << endl;
|
|
||||||
const States& states = var->states();
|
|
||||||
const Params& piVals = ninf(var)->getPiValues();
|
|
||||||
const Params& ldVals = ninf(var)->getLambdaValues();
|
|
||||||
const Params& beliefs = ninf(var)->getBeliefs();
|
|
||||||
for (unsigned xi = 0; xi < var->range(); xi++) {
|
|
||||||
cout << setw (10) << states[xi];
|
|
||||||
cout << setw (19) << piVals[xi];
|
|
||||||
cout << setw (19) << ldVals[xi];
|
|
||||||
cout.precision (Constants::PRECISION);
|
|
||||||
cout << setw (16) << beliefs[xi];
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::printAllMessageStatus (void) const
|
|
||||||
{
|
|
||||||
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
printPiLambdaValues (nodes[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BpNodeInfo::BpNodeInfo (BayesNode* node)
|
|
||||||
{
|
|
||||||
node_ = node;
|
|
||||||
piVals_.resize (node->range(), LogAware::one());
|
|
||||||
ldVals_.resize (node->range(), LogAware::one());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BpNodeInfo::getBeliefs (void) const
|
|
||||||
{
|
|
||||||
double sum = 0.0;
|
|
||||||
Params beliefs (node_->range());
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned xi = 0; xi < node_->range(); xi++) {
|
|
||||||
beliefs[xi] = exp (piVals_[xi] + ldVals_[xi]);
|
|
||||||
sum += beliefs[xi];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned xi = 0; xi < node_->range(); xi++) {
|
|
||||||
beliefs[xi] = piVals_[xi] * ldVals_[xi];
|
|
||||||
sum += beliefs[xi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert (sum);
|
|
||||||
for (unsigned xi = 0; xi < node_->range(); xi++) {
|
|
||||||
beliefs[xi] /= sum;
|
|
||||||
}
|
|
||||||
return beliefs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BpNodeInfo::receivedBottomInfluence (void) const
|
|
||||||
{
|
|
||||||
// if all lambda values are equal, then neither
|
|
||||||
// this node neither its descendents have evidence,
|
|
||||||
// we can use this to don't send lambda messages his parents
|
|
||||||
bool childInfluenced = false;
|
|
||||||
for (unsigned xi = 1; xi < node_->range(); xi++) {
|
|
||||||
if (ldVals_[xi] != ldVals_[0]) {
|
|
||||||
childInfluenced = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return childInfluenced;
|
|
||||||
}
|
|
||||||
|
|
@ -1,271 +0,0 @@
|
|||||||
#ifndef HORUS_BNBPSOLVER_H
|
|
||||||
#define HORUS_BNBPSOLVER_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <set>
|
|
||||||
|
|
||||||
#include "Solver.h"
|
|
||||||
#include "BayesNet.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
#include "Util.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
class BpNodeInfo;
|
|
||||||
|
|
||||||
static const string PI_SYMBOL = "pi" ;
|
|
||||||
static const string LD_SYMBOL = "ld" ;
|
|
||||||
|
|
||||||
enum LinkOrientation {UP, DOWN};
|
|
||||||
|
|
||||||
class BpLink
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
BpLink (BayesNode* s, BayesNode* d, LinkOrientation o)
|
|
||||||
{
|
|
||||||
source_ = s;
|
|
||||||
destin_ = d;
|
|
||||||
orientation_ = o;
|
|
||||||
if (orientation_ == LinkOrientation::DOWN) {
|
|
||||||
v1_.resize (s->range(), LogAware::tl (1.0 / s->range()));
|
|
||||||
v2_.resize (s->range(), LogAware::tl (1.0 / s->range()));
|
|
||||||
} else {
|
|
||||||
v1_.resize (d->range(), LogAware::tl (1.0 / d->range()));
|
|
||||||
v2_.resize (d->range(), LogAware::tl (1.0 / d->range()));
|
|
||||||
}
|
|
||||||
currMsg_ = &v1_;
|
|
||||||
nextMsg_ = &v2_;
|
|
||||||
residual_ = 0;
|
|
||||||
msgSended_ = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNode* getSource (void) const { return source_; }
|
|
||||||
|
|
||||||
BayesNode* getDestination (void) const { return destin_; }
|
|
||||||
|
|
||||||
LinkOrientation getOrientation (void) const { return orientation_; }
|
|
||||||
|
|
||||||
const Params& getMessage (void) const { return *currMsg_; }
|
|
||||||
|
|
||||||
Params& getNextMessage (void) { return *nextMsg_;}
|
|
||||||
|
|
||||||
bool messageWasSended (void) const { return msgSended_; }
|
|
||||||
|
|
||||||
double getResidual (void) const { return residual_; }
|
|
||||||
|
|
||||||
void clearResidual (void) { residual_ = 0;}
|
|
||||||
|
|
||||||
void updateMessage (void)
|
|
||||||
{
|
|
||||||
swap (currMsg_, nextMsg_);
|
|
||||||
msgSended_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateResidual (void)
|
|
||||||
{
|
|
||||||
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
|
||||||
}
|
|
||||||
|
|
||||||
string toString (void) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
if (orientation_ == LinkOrientation::DOWN) {
|
|
||||||
ss << PI_SYMBOL;
|
|
||||||
} else {
|
|
||||||
ss << LD_SYMBOL;
|
|
||||||
}
|
|
||||||
ss << "(" << source_->label();
|
|
||||||
ss << " --> " << destin_->label() << ")" ;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
string toString (unsigned stateIndex) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss << toString() << "[" ;
|
|
||||||
if (orientation_ == LinkOrientation::DOWN) {
|
|
||||||
ss << source_->states()[stateIndex] << "]" ;
|
|
||||||
} else {
|
|
||||||
ss << destin_->states()[stateIndex] << "]" ;
|
|
||||||
}
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
BayesNode* source_;
|
|
||||||
BayesNode* destin_;
|
|
||||||
LinkOrientation orientation_;
|
|
||||||
Params v1_;
|
|
||||||
Params v2_;
|
|
||||||
Params* currMsg_;
|
|
||||||
Params* nextMsg_;
|
|
||||||
bool msgSended_;
|
|
||||||
double residual_;
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef vector<BpLink*> BpLinkSet;
|
|
||||||
|
|
||||||
|
|
||||||
class BpNodeInfo
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
BpNodeInfo (BayesNode*);
|
|
||||||
|
|
||||||
Params& getPiValues (void) { return piVals_; }
|
|
||||||
|
|
||||||
Params& getLambdaValues (void) { return ldVals_; }
|
|
||||||
|
|
||||||
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
|
|
||||||
|
|
||||||
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
|
|
||||||
|
|
||||||
const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; }
|
|
||||||
|
|
||||||
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
|
|
||||||
|
|
||||||
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
|
|
||||||
|
|
||||||
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
|
|
||||||
|
|
||||||
void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); }
|
|
||||||
|
|
||||||
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
|
|
||||||
|
|
||||||
Params getBeliefs (void) const;
|
|
||||||
|
|
||||||
bool receivedBottomInfluence (void) const;
|
|
||||||
|
|
||||||
|
|
||||||
private:
|
|
||||||
DISALLOW_COPY_AND_ASSIGN (BpNodeInfo);
|
|
||||||
|
|
||||||
const BayesNode* node_;
|
|
||||||
Params piVals_;
|
|
||||||
Params ldVals_;
|
|
||||||
BpLinkSet inParentLinks_;
|
|
||||||
BpLinkSet inChildLinks_;
|
|
||||||
BpLinkSet outParentLinks_;
|
|
||||||
BpLinkSet outChildLinks_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BnBpSolver : public Solver
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
BnBpSolver (const BayesNet&);
|
|
||||||
|
|
||||||
~BnBpSolver (void);
|
|
||||||
|
|
||||||
void runSolver (void);
|
|
||||||
Params getPosterioriOf (VarId);
|
|
||||||
Params getJointDistributionOf (const VarIds&);
|
|
||||||
|
|
||||||
private:
|
|
||||||
DISALLOW_COPY_AND_ASSIGN (BnBpSolver);
|
|
||||||
|
|
||||||
void initializeSolver (void);
|
|
||||||
|
|
||||||
void runLoopySolver (void);
|
|
||||||
|
|
||||||
void maxResidualSchedule (void);
|
|
||||||
|
|
||||||
bool converged (void) const;
|
|
||||||
|
|
||||||
void updatePiValues (BayesNode*);
|
|
||||||
|
|
||||||
void updateLambdaValues (BayesNode*);
|
|
||||||
|
|
||||||
void calculateLambdaMessage (BpLink*);
|
|
||||||
|
|
||||||
void calculatePiMessage (BpLink*);
|
|
||||||
|
|
||||||
Params getJointByJunctionNode (const VarIds&);
|
|
||||||
|
|
||||||
Params getJointByConditioning (const VarIds&) const;
|
|
||||||
|
|
||||||
void printPiLambdaValues (const BayesNode*) const;
|
|
||||||
|
|
||||||
void printAllMessageStatus (void) const;
|
|
||||||
|
|
||||||
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
|
||||||
{
|
|
||||||
if (Constants::DEBUG >= 3) {
|
|
||||||
cout << "calculating & updating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
|
||||||
calculatePiMessage (link);
|
|
||||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
|
||||||
calculateLambdaMessage (link);
|
|
||||||
}
|
|
||||||
if (calcResidual) {
|
|
||||||
link->updateResidual();
|
|
||||||
}
|
|
||||||
link->updateMessage();
|
|
||||||
}
|
|
||||||
|
|
||||||
void calculateMessage (BpLink* link, bool calcResidual = true)
|
|
||||||
{
|
|
||||||
if (Constants::DEBUG >= 3) {
|
|
||||||
cout << "calculating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
|
||||||
calculatePiMessage (link);
|
|
||||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
|
||||||
calculateLambdaMessage (link);
|
|
||||||
}
|
|
||||||
if (calcResidual) {
|
|
||||||
link->updateResidual();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateMessage (BpLink* link)
|
|
||||||
{
|
|
||||||
if (Constants::DEBUG >= 3) {
|
|
||||||
cout << "updating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
link->updateMessage();
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateValues (BpLink* link)
|
|
||||||
{
|
|
||||||
if (!link->getDestination()->hasEvidence()) {
|
|
||||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
|
||||||
updatePiValues (link->getDestination());
|
|
||||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
|
||||||
updateLambdaValues (link->getDestination());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BpNodeInfo* ninf (const BayesNode* node) const
|
|
||||||
{
|
|
||||||
assert (node);
|
|
||||||
assert (node == bayesNet_->getBayesNode (node->varId()));
|
|
||||||
assert (node->getIndex() < nodesI_.size());
|
|
||||||
return nodesI_[node->getIndex()];
|
|
||||||
}
|
|
||||||
|
|
||||||
const BayesNet* bayesNet_;
|
|
||||||
vector<BpLink*> links_;
|
|
||||||
vector<BpNodeInfo*> nodesI_;
|
|
||||||
unsigned nIters_;
|
|
||||||
|
|
||||||
struct compare
|
|
||||||
{
|
|
||||||
inline bool operator() (const BpLink* e1, const BpLink* e2)
|
|
||||||
{
|
|
||||||
return e1->getResidual() > e2->getResidual();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef multiset<BpLink*, compare> SortedOrder;
|
|
||||||
SortedOrder sortedOrder_;
|
|
||||||
|
|
||||||
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
|
||||||
BpLinkMap linkMap_;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_BNBPSOLVER_H
|
|
||||||
|
|
@ -3,18 +3,13 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "BayesNet.h"
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "VarElimSolver.h"
|
#include "VarElimSolver.h"
|
||||||
#include "BnBpSolver.h"
|
|
||||||
#include "FgBpSolver.h"
|
#include "FgBpSolver.h"
|
||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
|
|
||||||
#include "ElimGraph.h"
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
void processArguments (BayesNet&, int, const char* []);
|
|
||||||
void processArguments (FactorGraph&, int, const char* []);
|
void processArguments (FactorGraph&, int, const char* []);
|
||||||
void runSolver (Solver*, const VarNodes&);
|
void runSolver (Solver*, const VarNodes&);
|
||||||
|
|
||||||
@ -25,53 +20,24 @@ const string USAGE = "usage: \
|
|||||||
int
|
int
|
||||||
main (int argc, const char* argv[])
|
main (int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
VarIds vids1 = { 4, 1, 2, 3 } ;
|
|
||||||
VarIds vids2 = { 4, 5 } ;
|
|
||||||
VarIds vids3 = { 4, 6 } ;
|
|
||||||
VarIds vids4 = { 4, 7 } ;
|
|
||||||
// Factor f1 (vids1, {2,2,2,2},{0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.10,0.11,0.12,0.13,0.14,0.15,0.16});
|
|
||||||
// Factor f2 (vids2, {2,2},{0.1,0.2,0.3,0.4});
|
|
||||||
// Factor f3 (vids3, {2,2},{0.1,0.2,0.3,0.4});
|
|
||||||
// Factor f4 (vids4, {2,2},{0.1,0.2,0.3,0.4});
|
|
||||||
|
|
||||||
|
|
||||||
Factor* f1 = new Factor (vids1, {2,2,2,2},{0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.10,0.11,0.12,0.13,0.14,0.15,0.16});
|
|
||||||
Factor* f2 = new Factor (vids2, {2,2},{0.1,0.2,0.3,0.4});
|
|
||||||
Factor* f3 = new Factor (vids3, {2,2},{0.1,0.2,0.3,0.4});
|
|
||||||
Factor* f4 = new Factor (vids4, {2,2},{0.1,0.2,0.3,0.4});
|
|
||||||
Factor* f5 = new Factor (vids4, {2,2},{0.1,0.2,0.3,0.4});
|
|
||||||
|
|
||||||
vector<Factor*> fs = {f1,f2,f3,f4,f5};
|
|
||||||
//FactorGraph fg;
|
|
||||||
//fg.addFactor (f1);
|
|
||||||
//fg.addFactor (f2);
|
|
||||||
//fg.addFactor (f3);
|
|
||||||
//fg.addFactor (f4);
|
|
||||||
ElimGraph eg (fs);
|
|
||||||
eg.exportToGraphViz ("_eg.dot");
|
|
||||||
return 0;
|
|
||||||
if (!argv[1]) {
|
if (!argv[1]) {
|
||||||
cerr << "error: no graphical model specified" << endl;
|
cerr << "error: no graphical model specified" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
const string& fileName = argv[1];
|
const string& fileName = argv[1];
|
||||||
const string& extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
const string& extension = fileName.substr (
|
||||||
if (extension == "xml") {
|
fileName.find_last_of ('.') + 1);
|
||||||
BayesNet bn;
|
FactorGraph fg;
|
||||||
bn.readFromBifFormat (argv[1]);
|
if (extension == "uai") {
|
||||||
processArguments (bn, argc, argv);
|
|
||||||
} else if (extension == "uai") {
|
|
||||||
FactorGraph fg;
|
|
||||||
fg.readFromUaiFormat (argv[1]);
|
fg.readFromUaiFormat (argv[1]);
|
||||||
processArguments (fg, argc, argv);
|
processArguments (fg, argc, argv);
|
||||||
} else if (extension == "fg") {
|
} else if (extension == "fg") {
|
||||||
FactorGraph fg;
|
|
||||||
fg.readFromLibDaiFormat (argv[1]);
|
fg.readFromLibDaiFormat (argv[1]);
|
||||||
processArguments (fg, argc, argv);
|
processArguments (fg, argc, argv);
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: the graphical model must be defined either " ;
|
cerr << "error: the graphical model must be defined either " ;
|
||||||
cerr << "in a xml, uai or libDAI file" << endl;
|
cerr << "in a UAI or libDAI file" << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
@ -79,83 +45,6 @@ main (int argc, const char* argv[])
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
processArguments (BayesNet& bn, int argc, const char* argv[])
|
|
||||||
{
|
|
||||||
VarNodes queryVars;
|
|
||||||
for (int i = 2; i < argc; i++) {
|
|
||||||
const string& arg = argv[i];
|
|
||||||
if (arg.find ('=') == std::string::npos) {
|
|
||||||
BayesNode* queryVar = bn.getBayesNode (arg);
|
|
||||||
if (queryVar) {
|
|
||||||
queryVars.push_back (queryVar);
|
|
||||||
} else {
|
|
||||||
cerr << "error: there isn't a variable labeled of " ;
|
|
||||||
cerr << "`" << arg << "'" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
size_t pos = arg.find ('=');
|
|
||||||
const string& label = arg.substr (0, pos);
|
|
||||||
const string& state = arg.substr (pos + 1);
|
|
||||||
if (label.empty()) {
|
|
||||||
cerr << "error: missing left argument" << endl;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
if (state.empty()) {
|
|
||||||
cerr << "error: missing right argument" << endl;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
BayesNode* node = bn.getBayesNode (label);
|
|
||||||
if (node) {
|
|
||||||
if (node->isValidState (state)) {
|
|
||||||
node->setEvidence (state);
|
|
||||||
} else {
|
|
||||||
cerr << "error: `" << state << "' " ;
|
|
||||||
cerr << "is not a valid state for " ;
|
|
||||||
cerr << "`" << node->label() << "'" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cerr << "error: there isn't a variable labeled of " ;
|
|
||||||
cerr << "`" << label << "'" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Solver* solver = 0;
|
|
||||||
FactorGraph* fg = 0;
|
|
||||||
switch (Globals::infAlgorithm) {
|
|
||||||
case InfAlgorithms::VE:
|
|
||||||
fg = new FactorGraph (bn);
|
|
||||||
solver = new VarElimSolver (*fg);
|
|
||||||
break;
|
|
||||||
case InfAlgorithms::BN_BP:
|
|
||||||
solver = new BnBpSolver (bn);
|
|
||||||
break;
|
|
||||||
case InfAlgorithms::FG_BP:
|
|
||||||
fg = new FactorGraph (bn);
|
|
||||||
solver = new FgBpSolver (*fg);
|
|
||||||
break;
|
|
||||||
case InfAlgorithms::CBP:
|
|
||||||
fg = new FactorGraph (bn);
|
|
||||||
solver = new CbpSolver (*fg);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
assert (false);
|
|
||||||
}
|
|
||||||
runSolver (solver, queryVars);
|
|
||||||
delete fg;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
@ -238,7 +127,6 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
case InfAlgorithms::VE:
|
case InfAlgorithms::VE:
|
||||||
solver = new VarElimSolver (fg);
|
solver = new VarElimSolver (fg);
|
||||||
break;
|
break;
|
||||||
case InfAlgorithms::BN_BP:
|
|
||||||
case InfAlgorithms::FG_BP:
|
case InfAlgorithms::FG_BP:
|
||||||
solver = new FgBpSolver (fg);
|
solver = new FgBpSolver (fg);
|
||||||
break;
|
break;
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "FoveSolver.h"
|
#include "FoveSolver.h"
|
||||||
#include "VarElimSolver.h"
|
#include "VarElimSolver.h"
|
||||||
#include "BnBpSolver.h"
|
|
||||||
#include "FgBpSolver.h"
|
#include "FgBpSolver.h"
|
||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
#include "ElimGraph.h"
|
#include "ElimGraph.h"
|
||||||
|
@ -56,7 +56,6 @@ HEADERS = \
|
|||||||
$(srcdir)/ConstraintTree.h \
|
$(srcdir)/ConstraintTree.h \
|
||||||
$(srcdir)/Solver.h \
|
$(srcdir)/Solver.h \
|
||||||
$(srcdir)/VarElimSolver.h \
|
$(srcdir)/VarElimSolver.h \
|
||||||
$(srcdir)/BnBpSolver.h \
|
|
||||||
$(srcdir)/FgBpSolver.h \
|
$(srcdir)/FgBpSolver.h \
|
||||||
$(srcdir)/CbpSolver.h \
|
$(srcdir)/CbpSolver.h \
|
||||||
$(srcdir)/FoveSolver.h \
|
$(srcdir)/FoveSolver.h \
|
||||||
@ -84,7 +83,6 @@ CPP_SOURCES = \
|
|||||||
$(srcdir)/VarNode.cpp \
|
$(srcdir)/VarNode.cpp \
|
||||||
$(srcdir)/Solver.cpp \
|
$(srcdir)/Solver.cpp \
|
||||||
$(srcdir)/VarElimSolver.cpp \
|
$(srcdir)/VarElimSolver.cpp \
|
||||||
$(srcdir)/BnBpSolver.cpp \
|
|
||||||
$(srcdir)/FgBpSolver.cpp \
|
$(srcdir)/FgBpSolver.cpp \
|
||||||
$(srcdir)/CbpSolver.cpp \
|
$(srcdir)/CbpSolver.cpp \
|
||||||
$(srcdir)/FoveSolver.cpp \
|
$(srcdir)/FoveSolver.cpp \
|
||||||
@ -110,7 +108,6 @@ OBJS = \
|
|||||||
VarNode.o \
|
VarNode.o \
|
||||||
Solver.o \
|
Solver.o \
|
||||||
VarElimSolver.o \
|
VarElimSolver.o \
|
||||||
BnBpSolver.o \
|
|
||||||
FgBpSolver.o \
|
FgBpSolver.o \
|
||||||
CbpSolver.o \
|
CbpSolver.o \
|
||||||
FoveSolver.o \
|
FoveSolver.o \
|
||||||
@ -134,7 +131,6 @@ HCLI_OBJS = \
|
|||||||
VarNode.o \
|
VarNode.o \
|
||||||
Solver.o \
|
Solver.o \
|
||||||
VarElimSolver.o \
|
VarElimSolver.o \
|
||||||
BnBpSolver.o \
|
|
||||||
FgBpSolver.o \
|
FgBpSolver.o \
|
||||||
CbpSolver.o \
|
CbpSolver.o \
|
||||||
FoveSolver.o \
|
FoveSolver.o \
|
||||||
|
Reference in New Issue
Block a user