2011-05-17 12:00:33 +01:00
|
|
|
#include <cstdlib>
|
2011-07-22 21:33:30 +01:00
|
|
|
#include <limits>
|
2011-05-17 12:00:33 +01:00
|
|
|
#include <time.h>
|
2011-07-22 21:33:30 +01:00
|
|
|
|
2011-05-17 12:00:33 +01:00
|
|
|
#include <iostream>
|
|
|
|
#include <sstream>
|
2011-07-22 21:33:30 +01:00
|
|
|
#include <iomanip>
|
2011-05-17 12:00:33 +01:00
|
|
|
|
|
|
|
#include "BPSolver.h"
|
|
|
|
|
|
|
|
BPSolver::BPSolver (const BayesNet& bn) : Solver (&bn)
|
2011-07-22 21:33:30 +01:00
|
|
|
{
|
2011-05-17 12:00:33 +01:00
|
|
|
bn_ = &bn;
|
2011-07-22 21:33:30 +01:00
|
|
|
useAlwaysLoopySolver_ = false;
|
|
|
|
//jointCalcType_ = CHAIN_RULE;
|
|
|
|
jointCalcType_ = JUNCTION_NODE;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BPSolver::~BPSolver (void)
|
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
|
|
|
delete nodesI_[i];
|
|
|
|
}
|
|
|
|
for (unsigned i = 0; i < links_.size(); i++) {
|
|
|
|
delete links_[i];
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
BPSolver::runSolver (void)
|
|
|
|
{
|
|
|
|
clock_t start_ = clock();
|
2011-07-22 21:33:30 +01:00
|
|
|
unsigned size = bn_->getNumberOfNodes();
|
|
|
|
unsigned nIters = 0;
|
|
|
|
initializeSolver();
|
|
|
|
if (bn_->isSingleConnected() && !useAlwaysLoopySolver_) {
|
2011-05-17 12:00:33 +01:00
|
|
|
runPolyTreeSolver();
|
|
|
|
Statistics::numSolvedPolyTrees ++;
|
|
|
|
} else {
|
2011-07-22 21:33:30 +01:00
|
|
|
runLoopySolver();
|
2011-05-17 12:00:33 +01:00
|
|
|
Statistics::numSolvedLoopyNets ++;
|
2011-07-22 21:33:30 +01:00
|
|
|
if (nIter_ >= SolverOptions::maxIter) {
|
2011-05-17 12:00:33 +01:00
|
|
|
Statistics::numUnconvergedRuns ++;
|
|
|
|
} else {
|
2011-07-22 21:33:30 +01:00
|
|
|
nIters = nIter_;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
if (DL >= 2) {
|
2011-05-17 12:00:33 +01:00
|
|
|
cout << endl;
|
2011-07-22 21:33:30 +01:00
|
|
|
if (nIter_ < SolverOptions::maxIter) {
|
2011-05-17 12:00:33 +01:00
|
|
|
cout << "Belief propagation converged in " ;
|
|
|
|
cout << nIter_ << " iterations" << endl;
|
|
|
|
} else {
|
|
|
|
cout << "The maximum number of iterations was hit, terminating..." ;
|
|
|
|
cout << endl;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
double time = (double (clock() - start_)) / CLOCKS_PER_SEC;
|
2011-07-22 21:33:30 +01:00
|
|
|
Statistics::updateStats (size, nIters, time);
|
|
|
|
if (EXPORT_TO_DOT && size > EXPORT_MIN_SIZE) {
|
|
|
|
stringstream ss;
|
|
|
|
ss << size << "." ;
|
|
|
|
ss << Statistics::getCounting (size) << ".dot" ;
|
|
|
|
bn_->exportToDotFormat (ss.str().c_str());
|
|
|
|
}
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ParamSet
|
2011-07-22 21:33:30 +01:00
|
|
|
BPSolver::getPosterioriOf (Vid vid) const
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
BayesNode* node = bn_->getBayesNode (vid);
|
|
|
|
assert (node);
|
|
|
|
return nodesI_[node->getIndex()]->getBeliefs();
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ParamSet
|
2011-07-22 21:33:30 +01:00
|
|
|
BPSolver::getJointDistributionOf (const VidSet& jointVids)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
if (DL >= 2) {
|
|
|
|
cout << "calculating joint distribution on: " ;
|
|
|
|
for (unsigned i = 0; i < jointVids.size(); i++) {
|
|
|
|
Variable* var = bn_->getBayesNode (jointVids[i]);
|
|
|
|
cout << var->getLabel() << " " ;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
cout << endl;
|
|
|
|
}
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
if (jointCalcType_ == JUNCTION_NODE) {
|
|
|
|
return getJointByJunctionNode (jointVids);
|
|
|
|
} else {
|
|
|
|
return getJointByChainRule (jointVids);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
BPSolver::initializeSolver (void)
|
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
if (DL >= 2) {
|
|
|
|
if (!useAlwaysLoopySolver_) {
|
|
|
|
cout << "-> solver type = polytree solver" << endl;
|
|
|
|
cout << "-> schedule = n/a";
|
2011-05-17 12:00:33 +01:00
|
|
|
} else {
|
2011-07-22 21:33:30 +01:00
|
|
|
cout << "-> solver = loopy solver" << endl;
|
|
|
|
cout << "-> schedule = ";
|
|
|
|
switch (SolverOptions::schedule) {
|
|
|
|
case SolverOptions::S_SEQ_FIXED: cout << "sequential fixed" ; break;
|
|
|
|
case SolverOptions::S_SEQ_RANDOM: cout << "sequential random" ; break;
|
|
|
|
case SolverOptions::S_PARALLEL: cout << "parallel" ; break;
|
|
|
|
case SolverOptions::S_MAX_RESIDUAL: cout << "max residual" ; break;
|
|
|
|
}
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
cout << endl;
|
2011-07-22 21:33:30 +01:00
|
|
|
cout << "-> joint method = " ;
|
|
|
|
if (jointCalcType_ == JUNCTION_NODE) {
|
|
|
|
cout << "junction node" << endl;
|
|
|
|
} else {
|
|
|
|
cout << "chain rule " << endl;
|
|
|
|
}
|
|
|
|
cout << "-> max iters = " << SolverOptions::maxIter << endl;
|
|
|
|
cout << "-> accuracy = " << SolverOptions::accuracy << endl;
|
2011-05-17 12:00:33 +01:00
|
|
|
cout << endl;
|
|
|
|
}
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
CBnNodeSet nodes = bn_->getBayesNodes();
|
|
|
|
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
|
|
|
delete nodesI_[i];
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
nodesI_.clear();
|
|
|
|
nodesI_.reserve (nodes.size());
|
|
|
|
links_.clear();
|
2011-05-17 12:00:33 +01:00
|
|
|
sortedOrder_.clear();
|
|
|
|
edgeMap_.clear();
|
|
|
|
|
|
|
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
2011-07-22 21:33:30 +01:00
|
|
|
nodesI_.push_back (new BPNodeInfo (nodes[i]));
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
BnNodeSet roots = bn_->getRootNodes();
|
2011-05-17 12:00:33 +01:00
|
|
|
for (unsigned i = 0; i < roots.size(); i++) {
|
|
|
|
const ParamSet& params = roots[i]->getParameters();
|
|
|
|
ParamSet& piVals = M(roots[i])->getPiValues();
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned ri = 0; ri < roots[i]->getDomainSize(); ri++) {
|
2011-05-17 12:00:33 +01:00
|
|
|
piVals[ri] = params[ri];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
2011-07-22 21:33:30 +01:00
|
|
|
CBnNodeSet parents = nodes[i]->getParents();
|
|
|
|
for (unsigned j = 0; j < parents.size(); j++) {
|
|
|
|
Edge* newLink = new Edge (parents[j], nodes[i], PI_MSG);
|
|
|
|
links_.push_back (newLink);
|
|
|
|
M(nodes[i])->addIncomingParentLink (newLink);
|
|
|
|
M(parents[j])->addOutcomingChildLink (newLink);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
CBnNodeSet childs = nodes[i]->getChilds();
|
2011-05-17 12:00:33 +01:00
|
|
|
for (unsigned j = 0; j < childs.size(); j++) {
|
2011-07-22 21:33:30 +01:00
|
|
|
Edge* newLink = new Edge (childs[j], nodes[i], LAMBDA_MSG);
|
|
|
|
links_.push_back (newLink);
|
|
|
|
M(nodes[i])->addIncomingChildLink (newLink);
|
|
|
|
M(childs[j])->addOutcomingParentLink (newLink);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
|
2011-05-17 12:00:33 +01:00
|
|
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
|
|
if (nodes[i]->hasEvidence()) {
|
2011-07-22 21:33:30 +01:00
|
|
|
ParamSet& piVals = M(nodes[i])->getPiValues();
|
|
|
|
ParamSet& ldVals = M(nodes[i])->getLambdaValues();
|
|
|
|
for (unsigned xi = 0; xi < nodes[i]->getDomainSize(); xi++) {
|
|
|
|
piVals[xi] = 0.0;
|
|
|
|
ldVals[xi] = 0.0;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
piVals[nodes[i]->getEvidence()] = 1.0;
|
|
|
|
ldVals[nodes[i]->getEvidence()] = 1.0;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
2011-07-22 21:33:30 +01:00
|
|
|
BPSolver::runPolyTreeSolver (void)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
CBnNodeSet nodes = bn_->getBayesNodes();
|
|
|
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
|
|
if (nodes[i]->isRoot()) {
|
|
|
|
M(nodes[i])->markPiValuesAsCalculated();
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
if (nodes[i]->isLeaf()) {
|
|
|
|
M(nodes[i])->markLambdaValuesAsCalculated();
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
bool finish = false;
|
|
|
|
while (!finish) {
|
|
|
|
finish = true;
|
|
|
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
|
|
if (M(nodes[i])->arePiValuesCalculated() == false
|
|
|
|
&& M(nodes[i])->receivedAllPiMessages()) {
|
|
|
|
if (!nodes[i]->hasEvidence()) {
|
|
|
|
updatePiValues (nodes[i]);
|
|
|
|
}
|
|
|
|
M(nodes[i])->markPiValuesAsCalculated();
|
|
|
|
finish = false;
|
|
|
|
}
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
if (M(nodes[i])->areLambdaValuesCalculated() == false
|
|
|
|
&& M(nodes[i])->receivedAllLambdaMessages()) {
|
|
|
|
if (!nodes[i]->hasEvidence()) {
|
|
|
|
updateLambdaValues (nodes[i]);
|
|
|
|
}
|
|
|
|
M(nodes[i])->markLambdaValuesAsCalculated();
|
|
|
|
finish = false;
|
|
|
|
}
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
if (M(nodes[i])->arePiValuesCalculated()) {
|
|
|
|
CEdgeSet outChildLinks = M(nodes[i])->getOutcomingChildLinks();
|
|
|
|
for (unsigned j = 0; j < outChildLinks.size(); j++) {
|
|
|
|
BayesNode* child = outChildLinks[j]->getDestination();
|
|
|
|
if (!outChildLinks[j]->messageWasSended()) {
|
|
|
|
if (M(nodes[i])->readyToSendPiMsgTo (child)) {
|
|
|
|
outChildLinks[j]->setNextMessage (getMessage (outChildLinks[j]));
|
|
|
|
outChildLinks[j]->updateMessage();
|
|
|
|
M(child)->incNumPiMsgsRcv();
|
|
|
|
}
|
|
|
|
finish = false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
if (M(nodes[i])->areLambdaValuesCalculated()) {
|
|
|
|
CEdgeSet outParentLinks = M(nodes[i])->getOutcomingParentLinks();
|
|
|
|
for (unsigned j = 0; j < outParentLinks.size(); j++) {
|
|
|
|
BayesNode* parent = outParentLinks[j]->getDestination();
|
|
|
|
if (!outParentLinks[j]->messageWasSended()) {
|
|
|
|
if (M(nodes[i])->readyToSendLambdaMsgTo (parent)) {
|
|
|
|
outParentLinks[j]->setNextMessage (getMessage (outParentLinks[j]));
|
|
|
|
outParentLinks[j]->updateMessage();
|
|
|
|
M(parent)->incNumLambdaMsgsRcv();
|
|
|
|
}
|
|
|
|
finish = false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
2011-07-22 21:33:30 +01:00
|
|
|
BPSolver::runLoopySolver()
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
|
|
|
nIter_ = 0;
|
2011-07-22 21:33:30 +01:00
|
|
|
while (!converged() && nIter_ < SolverOptions::maxIter) {
|
2011-05-17 12:00:33 +01:00
|
|
|
|
|
|
|
nIter_++;
|
2011-07-22 21:33:30 +01:00
|
|
|
if (DL >= 2) {
|
2011-05-17 12:00:33 +01:00
|
|
|
cout << endl;
|
|
|
|
cout << "****************************************" ;
|
|
|
|
cout << "****************************************" ;
|
|
|
|
cout << endl;
|
|
|
|
cout << " Iteration " << nIter_ << endl;
|
|
|
|
cout << "****************************************" ;
|
|
|
|
cout << "****************************************" ;
|
|
|
|
cout << endl;
|
|
|
|
}
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
switch (SolverOptions::schedule) {
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
case SolverOptions::S_SEQ_RANDOM:
|
|
|
|
random_shuffle (links_.begin(), links_.end());
|
2011-05-17 12:00:33 +01:00
|
|
|
// no break
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
case SolverOptions::S_SEQ_FIXED:
|
|
|
|
for (unsigned i = 0; i < links_.size(); i++) {
|
|
|
|
links_[i]->setNextMessage (getMessage (links_[i]));
|
|
|
|
links_[i]->updateMessage();
|
|
|
|
updateValues (links_[i]);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
break;
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
case SolverOptions::S_PARALLEL:
|
|
|
|
for (unsigned i = 0; i < links_.size(); i++) {
|
|
|
|
//calculateNextMessage (links_[i]);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned i = 0; i < links_.size(); i++) {
|
|
|
|
//updateMessage (links_[i]);
|
|
|
|
//updateValues (links_[i]);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
break;
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
case SolverOptions::S_MAX_RESIDUAL:
|
2011-05-17 12:00:33 +01:00
|
|
|
maxResidualSchedule();
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
bool
|
|
|
|
BPSolver::converged (void) const
|
|
|
|
{
|
|
|
|
// this can happen if the graph is fully disconnected
|
|
|
|
if (links_.size() == 0) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
if (nIter_ == 0 || nIter_ == 1) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
bool converged = true;
|
|
|
|
if (SolverOptions::schedule == SolverOptions::S_MAX_RESIDUAL) {
|
|
|
|
Param maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
|
|
|
if (maxResidual < SolverOptions::accuracy) {
|
|
|
|
converged = true;
|
|
|
|
} else {
|
|
|
|
converged = false;
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
CBnNodeSet nodes = bn_->getBayesNodes();
|
|
|
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
|
|
if (!nodes[i]->hasEvidence()) {
|
|
|
|
double change = M(nodes[i])->getBeliefChange();
|
|
|
|
if (DL >= 2) {
|
|
|
|
cout << nodes[i]->getLabel() + " belief change = " ;
|
|
|
|
cout << change << endl;
|
|
|
|
}
|
|
|
|
if (change > SolverOptions::accuracy) {
|
|
|
|
converged = false;
|
|
|
|
if (DL == 0) break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return converged;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-05-17 12:00:33 +01:00
|
|
|
void
|
|
|
|
BPSolver::maxResidualSchedule (void)
|
|
|
|
{
|
|
|
|
if (nIter_ == 1) {
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned i = 0; i < links_.size(); i++) {
|
|
|
|
links_[i]->setNextMessage (getMessage (links_[i]));
|
|
|
|
links_[i]->updateResidual();
|
|
|
|
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
|
|
|
edgeMap_.insert (make_pair (links_[i], it));
|
|
|
|
if (DL >= 2) {
|
|
|
|
cout << "calculating " << links_[i]->toString() << endl;
|
|
|
|
}
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
|
2011-07-22 21:33:30 +01:00
|
|
|
if (DL >= 2) {
|
|
|
|
cout << endl << "current residuals:" << endl;
|
|
|
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
2011-05-17 12:00:33 +01:00
|
|
|
it != sortedOrder_.end(); it ++) {
|
2011-07-22 21:33:30 +01:00
|
|
|
cout << " " << setw (30) << left << (*it)->toString();
|
|
|
|
cout << "residual = " << (*it)->getResidual() << endl;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
SortedOrder::iterator it = sortedOrder_.begin();
|
|
|
|
Edge* edge = *it;
|
|
|
|
if (DL >= 2) {
|
|
|
|
cout << "updating " << edge->toString() << endl;
|
|
|
|
}
|
|
|
|
if (edge->getResidual() < SolverOptions::accuracy) {
|
2011-05-17 12:00:33 +01:00
|
|
|
return;
|
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
edge->updateMessage();
|
|
|
|
updateValues (edge);
|
|
|
|
edge->clearResidual();
|
2011-05-17 12:00:33 +01:00
|
|
|
sortedOrder_.erase (it);
|
2011-07-22 21:33:30 +01:00
|
|
|
edgeMap_.find (edge)->second = sortedOrder_.insert (edge);
|
|
|
|
|
|
|
|
// update the messages that depend on message source --> destin
|
|
|
|
CEdgeSet outChildLinks =
|
|
|
|
M(edge->getDestination())->getOutcomingChildLinks();
|
|
|
|
for (unsigned i = 0; i < outChildLinks.size(); i++) {
|
|
|
|
if (outChildLinks[i]->getDestination() != edge->getSource()) {
|
|
|
|
if (DL >= 2) {
|
|
|
|
cout << " calculating " << outChildLinks[i]->toString() << endl;
|
|
|
|
}
|
|
|
|
outChildLinks[i]->setNextMessage (getMessage (outChildLinks[i]));
|
|
|
|
outChildLinks[i]->updateResidual();
|
|
|
|
EdgeMap::iterator iter = edgeMap_.find (outChildLinks[i]);
|
2011-05-17 12:00:33 +01:00
|
|
|
sortedOrder_.erase (iter->second);
|
2011-07-22 21:33:30 +01:00
|
|
|
iter->second = sortedOrder_.insert (outChildLinks[i]);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
CEdgeSet outParentLinks =
|
|
|
|
M(edge->getDestination())->getOutcomingParentLinks();
|
|
|
|
for (unsigned i = 0; i < outParentLinks.size(); i++) {
|
|
|
|
if (outParentLinks[i]->getDestination() != edge->getSource()) {
|
|
|
|
//&& !outParentLinks[i]->getDestination()->hasEvidence()) FIXME{
|
|
|
|
if (DL >= 2) {
|
|
|
|
cout << " calculating " << outParentLinks[i]->toString() << endl;
|
|
|
|
}
|
|
|
|
outParentLinks[i]->setNextMessage (getMessage (outParentLinks[i]));
|
|
|
|
outParentLinks[i]->updateResidual();
|
|
|
|
EdgeMap::iterator iter = edgeMap_.find (outParentLinks[i]);
|
2011-05-17 12:00:33 +01:00
|
|
|
sortedOrder_.erase (iter->second);
|
2011-07-22 21:33:30 +01:00
|
|
|
iter->second = sortedOrder_.insert (outParentLinks[i]);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
BPSolver::updatePiValues (BayesNode* x)
|
|
|
|
{
|
|
|
|
// π(Xi)
|
2011-07-22 21:33:30 +01:00
|
|
|
if (DL >= 3) {
|
|
|
|
cout << "updating " << PI << " values for " << x->getLabel() << endl;
|
|
|
|
}
|
|
|
|
CEdgeSet parentLinks = M(x)->getIncomingParentLinks();
|
|
|
|
assert (x->getParents() == parentLinks.size());
|
2011-05-17 12:00:33 +01:00
|
|
|
const vector<CptEntry>& entries = x->getCptEntries();
|
2011-07-22 21:33:30 +01:00
|
|
|
stringstream* calcs1 = 0;
|
|
|
|
stringstream* calcs2 = 0;
|
2011-05-17 12:00:33 +01:00
|
|
|
|
|
|
|
ParamSet messageProducts (entries.size());
|
|
|
|
for (unsigned k = 0; k < entries.size(); k++) {
|
|
|
|
if (DL >= 5) {
|
|
|
|
calcs1 = new stringstream;
|
|
|
|
calcs2 = new stringstream;
|
|
|
|
}
|
|
|
|
double messageProduct = 1.0;
|
2011-07-22 21:33:30 +01:00
|
|
|
const DConf& conf = entries[k].getDomainConfiguration();
|
|
|
|
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
|
|
|
assert (parentLinks[i]->getSource() == parents[i]);
|
|
|
|
assert (parentLinks[i]->getDestination() == x);
|
|
|
|
messageProduct *= parentLinks[i]->getMessage()[conf[i]];
|
2011-05-17 12:00:33 +01:00
|
|
|
if (DL >= 5) {
|
|
|
|
if (i != 0) *calcs1 << "." ;
|
|
|
|
if (i != 0) *calcs2 << "*" ;
|
2011-07-22 21:33:30 +01:00
|
|
|
*calcs1 << PI << "(" << parentLinks[i]->getSource()->getLabel();
|
|
|
|
*calcs1 << " --> " << x->getLabel() << ")" ;
|
|
|
|
*calcs1 << "[" ;
|
|
|
|
*calcs1 << parentLinks[i]->getSource()->getDomain()[conf[i]];
|
|
|
|
*calcs1 << "]";
|
|
|
|
*calcs2 << parentLinks[i]->getMessage()[conf[i]];
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
messageProducts[k] = messageProduct;
|
|
|
|
if (DL >= 5) {
|
|
|
|
cout << " mp" << k;
|
|
|
|
cout << " = " << (*calcs1).str();
|
2011-07-22 21:33:30 +01:00
|
|
|
if (parentLinks.size() == 1) {
|
2011-05-17 12:00:33 +01:00
|
|
|
cout << " = " << messageProduct << endl;
|
|
|
|
} else {
|
|
|
|
cout << " = " << (*calcs2).str();
|
|
|
|
cout << " = " << messageProduct << endl;
|
|
|
|
}
|
|
|
|
delete calcs1;
|
|
|
|
delete calcs2;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned xi = 0; xi < x->getDomainSize(); xi++) {
|
2011-05-17 12:00:33 +01:00
|
|
|
double sum = 0.0;
|
|
|
|
if (DL >= 5) {
|
|
|
|
calcs1 = new stringstream;
|
|
|
|
calcs2 = new stringstream;
|
|
|
|
}
|
|
|
|
for (unsigned k = 0; k < entries.size(); k++) {
|
|
|
|
sum += x->getProbability (xi, entries[k]) * messageProducts[k];
|
|
|
|
if (DL >= 5) {
|
|
|
|
if (k != 0) *calcs1 << " + " ;
|
|
|
|
if (k != 0) *calcs2 << " + " ;
|
|
|
|
*calcs1 << x->cptEntryToString (xi, entries[k]);
|
|
|
|
*calcs1 << ".mp" << k;
|
|
|
|
*calcs2 << x->getProbability (xi, entries[k]);
|
|
|
|
*calcs2 << "*" << messageProducts[k];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
M(x)->setPiValue (xi, sum);
|
|
|
|
if (DL >= 5) {
|
|
|
|
cout << " " << PI << "(" << x->getLabel() << ")" ;
|
|
|
|
cout << "[" << x->getDomain()[xi] << "]" ;
|
|
|
|
cout << " = " << (*calcs1).str();
|
|
|
|
cout << " = " << (*calcs2).str();
|
|
|
|
cout << " = " << sum << endl;
|
|
|
|
delete calcs1;
|
|
|
|
delete calcs2;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
BPSolver::updateLambdaValues (BayesNode* x)
|
|
|
|
{
|
|
|
|
// λ(Xi)
|
2011-07-22 21:33:30 +01:00
|
|
|
if (DL >= 3) {
|
|
|
|
cout << "updating " << LD << " values for " << x->getLabel() << endl;
|
|
|
|
}
|
|
|
|
CEdgeSet childLinks = M(x)->getIncomingChildLinks();
|
|
|
|
stringstream* calcs1 = 0;
|
|
|
|
stringstream* calcs2 = 0;
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned xi = 0; xi < x->getDomainSize(); xi++) {
|
2011-05-17 12:00:33 +01:00
|
|
|
double product = 1.0;
|
|
|
|
if (DL >= 5) {
|
|
|
|
calcs1 = new stringstream;
|
|
|
|
calcs2 = new stringstream;
|
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned i = 0; i < childLinks.size(); i++) {
|
|
|
|
assert (childLinks[i]->getDestination() == x);
|
|
|
|
product *= childLinks[i]->getMessage()[xi];
|
2011-05-17 12:00:33 +01:00
|
|
|
if (DL >= 5) {
|
|
|
|
if (i != 0) *calcs1 << "." ;
|
|
|
|
if (i != 0) *calcs2 << "*" ;
|
2011-07-22 21:33:30 +01:00
|
|
|
*calcs1 << LD << "(" << childLinks[i]->getSource()->getLabel();
|
2011-05-17 12:00:33 +01:00
|
|
|
*calcs1 << "-->" << x->getLabel() << ")" ;
|
|
|
|
*calcs1 << "[" << x->getDomain()[xi] << "]" ;
|
2011-07-22 21:33:30 +01:00
|
|
|
*calcs2 << childLinks[i]->getMessage()[xi];
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
M(x)->setLambdaValue (xi, product);
|
|
|
|
if (DL >= 5) {
|
|
|
|
cout << " " << LD << "(" << x->getLabel() << ")" ;
|
|
|
|
cout << "[" << x->getDomain()[xi] << "]" ;
|
|
|
|
cout << " = " << (*calcs1).str();
|
2011-07-22 21:33:30 +01:00
|
|
|
if (childLinks.size() == 1) {
|
2011-05-17 12:00:33 +01:00
|
|
|
cout << " = " << product << endl;
|
|
|
|
} else {
|
|
|
|
cout << " = " << (*calcs2).str();
|
|
|
|
cout << " = " << product << endl;
|
|
|
|
}
|
|
|
|
delete calcs1;
|
|
|
|
delete calcs2;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
ParamSet
|
|
|
|
BPSolver::calculateNextPiMessage (Edge* edge)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
|
|
|
// πX(Zi)
|
2011-07-22 21:33:30 +01:00
|
|
|
BayesNode* z = edge->getSource();
|
|
|
|
BayesNode* x = edge->getDestination();
|
|
|
|
ParamSet zxPiNextMessage (z->getDomainSize());
|
|
|
|
CEdgeSet zChildLinks = M(z)->getIncomingChildLinks();
|
|
|
|
stringstream* calcs1 = 0;
|
|
|
|
stringstream* calcs2 = 0;
|
|
|
|
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned zi = 0; zi < z->getDomainSize(); zi++) {
|
2011-05-17 12:00:33 +01:00
|
|
|
double product = M(z)->getPiValue (zi);
|
|
|
|
if (DL >= 5) {
|
|
|
|
calcs1 = new stringstream;
|
|
|
|
calcs2 = new stringstream;
|
|
|
|
*calcs1 << PI << "(" << z->getLabel() << ")";
|
|
|
|
*calcs1 << "[" << z->getDomain()[zi] << "]" ;
|
|
|
|
*calcs2 << product;
|
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
|
|
|
assert (zChildLinks[i]->getDestination() == z);
|
|
|
|
if (zChildLinks[i]->getSource() != x) {
|
|
|
|
product *= zChildLinks[i]->getMessage()[zi];
|
2011-05-17 12:00:33 +01:00
|
|
|
if (DL >= 5) {
|
2011-07-22 21:33:30 +01:00
|
|
|
*calcs1 << "." << LD << "(" << zChildLinks[i]->getSource()->getLabel();
|
2011-05-17 12:00:33 +01:00
|
|
|
*calcs1 << "-->" << z->getLabel() << ")";
|
|
|
|
*calcs1 << "[" << z->getDomain()[zi] + "]" ;
|
2011-07-22 21:33:30 +01:00
|
|
|
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
zxPiNextMessage[zi] = product;
|
|
|
|
if (DL >= 5) {
|
|
|
|
cout << " " << PI << "(" << z->getLabel();
|
|
|
|
cout << "-->" << x->getLabel() << ")" ;
|
|
|
|
cout << "[" << z->getDomain()[zi] << "]" ;
|
|
|
|
cout << " = " << (*calcs1).str();
|
2011-07-22 21:33:30 +01:00
|
|
|
if (zChildLinks.size() == 1) {
|
2011-05-17 12:00:33 +01:00
|
|
|
cout << " = " << product << endl;
|
|
|
|
} else {
|
|
|
|
cout << " = " << (*calcs2).str();
|
|
|
|
cout << " = " << product << endl;
|
|
|
|
}
|
|
|
|
delete calcs1;
|
|
|
|
delete calcs2;
|
|
|
|
}
|
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
return zxPiNextMessage;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
ParamSet
|
|
|
|
BPSolver::calculateNextLambdaMessage (Edge* edge)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
|
|
|
// λY(Xi)
|
2011-07-22 21:33:30 +01:00
|
|
|
BayesNode* y = edge->getSource();
|
|
|
|
BayesNode* x = edge->getDestination();
|
|
|
|
if (!M(y)->receivedBottomInfluence()) {
|
|
|
|
//cout << "returning 1" << endl;
|
|
|
|
//return edge->getMessage();
|
|
|
|
}
|
|
|
|
if (x->hasEvidence()) {
|
|
|
|
//cout << "returning 2" << endl;
|
|
|
|
//return edge->getMessage();
|
|
|
|
}
|
|
|
|
ParamSet yxLambdaNextMessage (x->getDomainSize());
|
|
|
|
CEdgeSet yParentLinks = M(y)->getIncomingParentLinks();
|
2011-05-17 12:00:33 +01:00
|
|
|
const vector<CptEntry>& allEntries = y->getCptEntries();
|
|
|
|
int parentIndex = y->getIndexOfParent (x);
|
2011-07-22 21:33:30 +01:00
|
|
|
stringstream* calcs1 = 0;
|
|
|
|
stringstream* calcs2 = 0;
|
2011-05-17 12:00:33 +01:00
|
|
|
|
|
|
|
vector<CptEntry> entries;
|
2011-07-22 21:33:30 +01:00
|
|
|
DConstraint constr = make_pair (parentIndex, 0);
|
2011-05-17 12:00:33 +01:00
|
|
|
for (unsigned i = 0; i < allEntries.size(); i++) {
|
|
|
|
if (allEntries[i].matchConstraints(constr)) {
|
|
|
|
entries.push_back (allEntries[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
ParamSet messageProducts (entries.size());
|
|
|
|
for (unsigned k = 0; k < entries.size(); k++) {
|
|
|
|
if (DL >= 5) {
|
|
|
|
calcs1 = new stringstream;
|
|
|
|
calcs2 = new stringstream;
|
|
|
|
}
|
|
|
|
double messageProduct = 1.0;
|
2011-07-22 21:33:30 +01:00
|
|
|
const DConf& conf = entries[k].getDomainConfiguration();
|
|
|
|
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
|
|
|
assert (yParentLinks[i]->getDestination() == y);
|
|
|
|
if (yParentLinks[i]->getSource() != x) {
|
2011-05-17 12:00:33 +01:00
|
|
|
if (DL >= 5) {
|
|
|
|
if (messageProduct != 1.0) *calcs1 << "*" ;
|
|
|
|
if (messageProduct != 1.0) *calcs2 << "*" ;
|
2011-07-22 21:33:30 +01:00
|
|
|
*calcs1 << PI << "(" << yParentLinks[i]->getSource()->getLabel();
|
2011-05-17 12:00:33 +01:00
|
|
|
*calcs1 << "-->" << y->getLabel() << ")" ;
|
2011-07-22 21:33:30 +01:00
|
|
|
*calcs1 << "[" ;
|
|
|
|
*calcs1 << yParentLinks[i]->getSource()->getDomain()[conf[i]];
|
|
|
|
*calcs1 << "]" ;
|
|
|
|
*calcs2 << yParentLinks[i]->getMessage()[conf[i]];
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
messageProduct *= yParentLinks[i]->getMessage()[conf[i]];
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
messageProducts[k] = messageProduct;
|
|
|
|
if (DL >= 5) {
|
|
|
|
cout << " mp" << k;
|
|
|
|
cout << " = " << (*calcs1).str();
|
2011-07-22 21:33:30 +01:00
|
|
|
if (yParentLinks.size() == 1) {
|
2011-05-17 12:00:33 +01:00
|
|
|
cout << 1 << endl;
|
2011-07-22 21:33:30 +01:00
|
|
|
} else if (yParentLinks.size() == 2) {
|
2011-05-17 12:00:33 +01:00
|
|
|
cout << " = " << messageProduct << endl;
|
|
|
|
} else {
|
|
|
|
cout << " = " << (*calcs2).str();
|
|
|
|
cout << " = " << messageProduct << endl;
|
|
|
|
}
|
|
|
|
delete calcs1;
|
|
|
|
delete calcs2;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned xi = 0; xi < x->getDomainSize(); xi++) {
|
2011-05-17 12:00:33 +01:00
|
|
|
if (DL >= 5) {
|
|
|
|
calcs1 = new stringstream;
|
|
|
|
calcs2 = new stringstream;
|
|
|
|
}
|
|
|
|
vector<CptEntry> entries;
|
2011-07-22 21:33:30 +01:00
|
|
|
DConstraint constr = make_pair (parentIndex, xi);
|
2011-05-17 12:00:33 +01:00
|
|
|
for (unsigned i = 0; i < allEntries.size(); i++) {
|
|
|
|
if (allEntries[i].matchConstraints(constr)) {
|
|
|
|
entries.push_back (allEntries[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
double outerSum = 0.0;
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned yi = 0; yi < y->getDomainSize(); yi++) {
|
2011-05-17 12:00:33 +01:00
|
|
|
if (DL >= 5) {
|
|
|
|
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
|
|
|
|
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
|
|
|
|
}
|
|
|
|
double innerSum = 0.0;
|
|
|
|
for (unsigned k = 0; k < entries.size(); k++) {
|
|
|
|
if (DL >= 5) {
|
|
|
|
if (k != 0) *calcs1 << " + " ;
|
|
|
|
if (k != 0) *calcs2 << " + " ;
|
|
|
|
*calcs1 << y->cptEntryToString (yi, entries[k]);
|
|
|
|
*calcs1 << ".mp" << k;
|
|
|
|
*calcs2 << y->getProbability (yi, entries[k]);
|
|
|
|
*calcs2 << "*" << messageProducts[k];
|
|
|
|
}
|
|
|
|
innerSum += y->getProbability (yi, entries[k]) * messageProducts[k];
|
|
|
|
}
|
|
|
|
outerSum += innerSum * M(y)->getLambdaValue (yi);
|
|
|
|
if (DL >= 5) {
|
|
|
|
*calcs1 << "}." << LD << "(" << y->getLabel() << ")" ;
|
|
|
|
*calcs1 << "[" << y->getDomain()[yi] << "]";
|
|
|
|
*calcs2 << "}*" << M(y)->getLambdaValue (yi);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
yxLambdaNextMessage[xi] = outerSum;
|
|
|
|
if (DL >= 5) {
|
|
|
|
cout << " " << LD << "(" << y->getLabel();
|
|
|
|
cout << "-->" << x->getLabel() << ")" ;
|
|
|
|
cout << "[" << x->getDomain()[xi] << "]" ;
|
|
|
|
cout << " = " << (*calcs1).str();
|
|
|
|
cout << " = " << (*calcs2).str();
|
|
|
|
cout << " = " << outerSum << endl;
|
|
|
|
delete calcs1;
|
|
|
|
delete calcs2;
|
|
|
|
}
|
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
return yxLambdaNextMessage;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ParamSet
|
|
|
|
BPSolver::getJointByJunctionNode (const VidSet& jointVids) const
|
|
|
|
{
|
|
|
|
BnNodeSet jointVars;
|
|
|
|
for (unsigned i = 0; i < jointVids.size(); i++) {
|
|
|
|
jointVars.push_back (bn_->getBayesNode (jointVids[i]));
|
|
|
|
}
|
|
|
|
|
|
|
|
BayesNet* mrn = bn_->getMinimalRequesiteNetwork (jointVids);
|
|
|
|
|
|
|
|
BnNodeSet parents;
|
|
|
|
unsigned dsize = 1;
|
|
|
|
for (unsigned i = 0; i < jointVars.size(); i++) {
|
|
|
|
parents.push_back (mrn->getBayesNode (jointVids[i]));
|
|
|
|
dsize *= jointVars[i]->getDomainSize();
|
|
|
|
}
|
|
|
|
|
|
|
|
unsigned nParams = dsize * dsize;
|
|
|
|
ParamSet params (nParams);
|
|
|
|
|
|
|
|
for (unsigned i = 0; i < nParams; i++) {
|
|
|
|
unsigned row = i / dsize;
|
|
|
|
unsigned col = i % dsize;
|
|
|
|
if (row == col) {
|
|
|
|
params[i] = 1;
|
|
|
|
} else {
|
|
|
|
params[i] = 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
unsigned maxVid = std::numeric_limits<unsigned>::max();
|
|
|
|
Distribution* dist = new Distribution (params);
|
|
|
|
|
|
|
|
mrn->addNode (maxVid, dsize, NO_EVIDENCE, parents, dist);
|
|
|
|
mrn->setIndexes();
|
|
|
|
|
|
|
|
BPSolver solver (*mrn);
|
|
|
|
solver.runSolver();
|
|
|
|
|
|
|
|
const ParamSet& results = solver.getPosterioriOf (maxVid);
|
|
|
|
|
|
|
|
delete mrn;
|
|
|
|
delete dist;
|
|
|
|
|
|
|
|
return results;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ParamSet
|
|
|
|
BPSolver::getJointByChainRule (const VidSet& jointVids) const
|
|
|
|
{
|
|
|
|
BnNodeSet jointVars;
|
|
|
|
for (unsigned i = 0; i < jointVids.size(); i++) {
|
|
|
|
jointVars.push_back (bn_->getBayesNode (jointVids[i]));
|
|
|
|
}
|
|
|
|
|
|
|
|
BayesNet* mrn = bn_->getMinimalRequesiteNetwork (jointVids[0]);
|
|
|
|
BPSolver solver (*mrn);
|
|
|
|
solver.runSolver();
|
|
|
|
ParamSet prevBeliefs = solver.getPosterioriOf (jointVids[0]);
|
|
|
|
delete mrn;
|
|
|
|
|
|
|
|
VarSet observedVars = {jointVars[0]};
|
|
|
|
|
|
|
|
for (unsigned i = 1; i < jointVids.size(); i++) {
|
|
|
|
mrn = bn_->getMinimalRequesiteNetwork (jointVids[i]);
|
|
|
|
ParamSet newBeliefs;
|
|
|
|
vector<DConf> confs =
|
|
|
|
Util::getDomainConfigurations (observedVars);
|
|
|
|
for (unsigned j = 0; j < confs.size(); j++) {
|
|
|
|
for (unsigned k = 0; k < observedVars.size(); k++) {
|
|
|
|
if (!observedVars[k]->hasEvidence()) {
|
|
|
|
BayesNode* node = mrn->getBayesNode (observedVars[k]->getVarId());
|
|
|
|
if (node) {
|
|
|
|
node->setEvidence (confs[j][k]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
BPSolver solver (*mrn);
|
|
|
|
solver.runSolver();
|
|
|
|
ParamSet beliefs = solver.getPosterioriOf (jointVids[i]);
|
|
|
|
for (unsigned k = 0; k < beliefs.size(); k++) {
|
|
|
|
newBeliefs.push_back (beliefs[k]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
int count = -1;
|
|
|
|
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
|
|
|
if (j % jointVars[i]->getDomainSize() == 0) {
|
|
|
|
count ++;
|
|
|
|
}
|
|
|
|
newBeliefs[j] *= prevBeliefs[count];
|
|
|
|
}
|
|
|
|
prevBeliefs = newBeliefs;
|
|
|
|
observedVars.push_back (jointVars[i]);
|
|
|
|
delete mrn;
|
|
|
|
}
|
|
|
|
return prevBeliefs;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
BPSolver::printMessageStatusOf (const BayesNode* var) const
|
|
|
|
{
|
|
|
|
cout << left;
|
|
|
|
cout << setw (10) << "domain" ;
|
|
|
|
cout << setw (20) << PI << "(" + var->getLabel() + ")" ;
|
|
|
|
cout << setw (20) << LD << "(" + var->getLabel() + ")" ;
|
|
|
|
cout << setw (16) << "belief" ;
|
|
|
|
cout << endl;
|
|
|
|
cout << "--------------------------------" ;
|
|
|
|
cout << "--------------------------------" ;
|
|
|
|
cout << endl;
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
BPNodeInfo* x = M(var);
|
2011-05-17 12:00:33 +01:00
|
|
|
ParamSet& piVals = x->getPiValues();
|
|
|
|
ParamSet& ldVals = x->getLambdaValues();
|
|
|
|
ParamSet beliefs = x->getBeliefs();
|
|
|
|
const Domain& domain = var->getDomain();
|
2011-07-22 21:33:30 +01:00
|
|
|
CBnNodeSet& childs = var->getChilds();
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned xi = 0; xi < var->getDomainSize(); xi++) {
|
2011-05-17 12:00:33 +01:00
|
|
|
cout << setw (10) << domain[xi];
|
|
|
|
cout << setw (19) << piVals[xi];
|
|
|
|
cout << setw (19) << ldVals[xi];
|
|
|
|
cout.precision (PRECISION);
|
|
|
|
cout << setw (16) << beliefs[xi];
|
|
|
|
cout << endl;
|
|
|
|
}
|
|
|
|
cout << endl;
|
|
|
|
if (childs.size() > 0) {
|
|
|
|
string s = "(" + var->getLabel() + ")" ;
|
|
|
|
for (unsigned j = 0; j < childs.size(); j++) {
|
|
|
|
cout << setw (10) << "domain" ;
|
|
|
|
cout << setw (28) << PI + childs[j]->getLabel() + s;
|
|
|
|
cout << setw (28) << LD + childs[j]->getLabel() + s;
|
|
|
|
cout << endl;
|
|
|
|
cout << "--------------------------------" ;
|
|
|
|
cout << "--------------------------------" ;
|
|
|
|
cout << endl;
|
2011-07-22 21:33:30 +01:00
|
|
|
/* FIXME
|
2011-05-17 12:00:33 +01:00
|
|
|
const ParamSet& piMessage = x->getPiMessage (childs[j]);
|
|
|
|
const ParamSet& lambdaMessage = x->getLambdaMessage (childs[j]);
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned xi = 0; xi < var->getDomainSize(); xi++) {
|
2011-05-17 12:00:33 +01:00
|
|
|
cout << setw (10) << domain[xi];
|
|
|
|
cout.precision (PRECISION);
|
|
|
|
cout << setw (27) << piMessage[xi];
|
|
|
|
cout.precision (PRECISION);
|
|
|
|
cout << setw (27) << lambdaMessage[xi];
|
|
|
|
cout << endl;
|
|
|
|
}
|
|
|
|
cout << endl;
|
2011-07-22 21:33:30 +01:00
|
|
|
*/
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
BPSolver::printAllMessageStatus (void) const
|
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
CBnNodeSet nodes = bn_->getBayesNodes();
|
2011-05-17 12:00:33 +01:00
|
|
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
|
|
printMessageStatusOf (nodes[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|