new version of belief propagation solver.
This commit is contained in:
@@ -1,38 +1,77 @@
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "SPSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "FgVarNode.h"
|
||||
#include "Factor.h"
|
||||
|
||||
SPSolver* Link::klass = 0;
|
||||
#include "Shared.h"
|
||||
|
||||
|
||||
SPSolver::SPSolver (const FactorGraph& fg) : Solver (&fg)
|
||||
SPSolver::SPSolver (FactorGraph& fg) : Solver (&fg)
|
||||
{
|
||||
fg_ = &fg;
|
||||
accuracy_ = 0.0001;
|
||||
maxIter_ = 10000;
|
||||
//schedule_ = S_SEQ_FIXED;
|
||||
//schedule_ = S_SEQ_RANDOM;
|
||||
//schedule_ = S_SEQ_PARALLEL;
|
||||
schedule_ = S_MAX_RESIDUAL;
|
||||
Link::klass = this;
|
||||
FgVarSet vars = fg_->getFgVarNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
msgs_.push_back (new MessageBanket (vars[i]));
|
||||
}
|
||||
fg_ = &fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
SPSolver::~SPSolver (void)
|
||||
{
|
||||
for (unsigned i = 0; i < msgs_.size(); i++) {
|
||||
delete msgs_[i];
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < factorsI_.size(); i++) {
|
||||
delete factorsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::runTreeSolver (void)
|
||||
{
|
||||
CFactorSet factors = fg_->getFactors();
|
||||
bool finish = false;
|
||||
while (!finish) {
|
||||
finish = true;
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
CLinkSet links = factorsI_[factors[i]->getIndex()]->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (!links[j]->messageWasSended()) {
|
||||
if (readyToSendMessage(links[j])) {
|
||||
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||
links[j]->updateMessage();
|
||||
}
|
||||
finish = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
SPSolver::readyToSendMessage (const Link* link) const
|
||||
{
|
||||
CFgVarSet factorVars = link->getFactor()->getFgVarNodes();
|
||||
for (unsigned i = 0; i < factorVars.size(); i++) {
|
||||
if (factorVars[i] != link->getVariable()) {
|
||||
CLinkSet links = varsI_[factorVars[i]->getIndex()]->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getFactor() != link->getFactor() &&
|
||||
!links[j]->messageWasSended()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -40,62 +79,54 @@ SPSolver::~SPSolver (void)
|
||||
void
|
||||
SPSolver::runSolver (void)
|
||||
{
|
||||
initializeSolver();
|
||||
runTreeSolver();
|
||||
return;
|
||||
nIter_ = 0;
|
||||
vector<Factor*> factors = fg_->getFactors();
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
FgVarSet neighbors = factors[i]->getFgVarNodes();
|
||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||
updateOrder_.push_back (Link (factors[i], neighbors[j]));
|
||||
}
|
||||
}
|
||||
|
||||
while (!converged() && nIter_ < maxIter_) {
|
||||
if (DL >= 1) {
|
||||
while (!converged() && nIter_ < SolverOptions::maxIter) {
|
||||
|
||||
nIter_ ++;
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
cout << " Iteration " << nIter_ + 1 << endl;
|
||||
cout << " Iteration " << nIter_ << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
switch (schedule_) {
|
||||
|
||||
case S_SEQ_RANDOM:
|
||||
random_shuffle (updateOrder_.begin(), updateOrder_.end());
|
||||
switch (SolverOptions::schedule) {
|
||||
case SolverOptions::S_SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
|
||||
case S_SEQ_FIXED:
|
||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
||||
Link& link = updateOrder_[c];
|
||||
calculateNextMessage (link.source, link.destination);
|
||||
updateMessage (updateOrder_[c]);
|
||||
case SolverOptions::S_SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||
links_[i]->updateMessage();
|
||||
}
|
||||
break;
|
||||
|
||||
case S_PARALLEL:
|
||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
||||
Link link = updateOrder_[c];
|
||||
calculateNextMessage (link.source, link.destination);
|
||||
case SolverOptions::S_PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||
}
|
||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
||||
Link link = updateOrder_[c];
|
||||
updateMessage (updateOrder_[c]);
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->updateMessage();
|
||||
}
|
||||
break;
|
||||
|
||||
case S_MAX_RESIDUAL:
|
||||
case SolverOptions::S_MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
|
||||
nIter_++;
|
||||
}
|
||||
cout << endl;
|
||||
if (DL >= 1) {
|
||||
if (nIter_ < maxIter_) {
|
||||
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
if (nIter_ < SolverOptions::maxIter) {
|
||||
cout << "Loopy Sum-Product converged in " ;
|
||||
cout << nIter_ << " iterations" << endl;
|
||||
} else {
|
||||
@@ -108,58 +139,168 @@ SPSolver::runSolver (void)
|
||||
|
||||
|
||||
ParamSet
|
||||
SPSolver::getPosterioriOf (const Variable* var) const
|
||||
SPSolver::getPosterioriOf (Vid vid) const
|
||||
{
|
||||
assert (var);
|
||||
assert (var == fg_->getVariableById (var->getVarId()));
|
||||
assert (var->getIndex() < msgs_.size());
|
||||
assert (fg_->getFgVarNode (vid));
|
||||
FgVarNode* var = fg_->getFgVarNode (vid);
|
||||
ParamSet probs;
|
||||
|
||||
ParamSet probs (var->getDomainSize(), 1);
|
||||
if (var->hasEvidence()) {
|
||||
for (unsigned i = 0; i < probs.size(); i++) {
|
||||
if ((int)i != var->getEvidence()) {
|
||||
probs[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
probs.resize (var->getDomainSize(), 0.0);
|
||||
probs[var->getEvidence()] = 1.0;
|
||||
} else {
|
||||
|
||||
MessageBanket* mb = msgs_[var->getIndex()];
|
||||
const FgVarNode* varNode = fg_->getFgVarNodes()[var->getIndex()];
|
||||
vector<Factor*> neighbors = varNode->getFactors();
|
||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
||||
const Message& msg = mb->getMessage (neighbors[i]);
|
||||
probs.resize (var->getDomainSize(), 1.0);
|
||||
CLinkSet links = varsI_[var->getIndex()]->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CParamSet msg = links[i]->getMessage();
|
||||
for (unsigned j = 0; j < msg.size(); j++) {
|
||||
probs[j] *= msg[j];
|
||||
}
|
||||
}
|
||||
Util::normalize (probs);
|
||||
}
|
||||
|
||||
return probs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
SPSolver::getJointDistributionOf (const VidSet& jointVids)
|
||||
{
|
||||
FgVarSet jointVars;
|
||||
unsigned dsize = 1;
|
||||
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||
FgVarNode* varNode = fg_->getFgVarNode (jointVids[i]);
|
||||
dsize *= varNode->getDomainSize();
|
||||
jointVars.push_back (varNode);
|
||||
}
|
||||
|
||||
unsigned maxVid = std::numeric_limits<unsigned>::max();
|
||||
FgVarNode* junctionVar = new FgVarNode (maxVid, dsize);
|
||||
FgVarSet factorVars = { junctionVar };
|
||||
for (unsigned i = 0; i < jointVars.size(); i++) {
|
||||
factorVars.push_back (jointVars[i]);
|
||||
}
|
||||
|
||||
unsigned nParams = dsize * dsize;
|
||||
ParamSet params (nParams);
|
||||
for (unsigned i = 0; i < nParams; i++) {
|
||||
unsigned row = i / dsize;
|
||||
unsigned col = i % dsize;
|
||||
if (row == col) {
|
||||
params[i] = 1;
|
||||
} else {
|
||||
params[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Distribution* dist = new Distribution (params, maxVid);
|
||||
Factor* newFactor = new Factor (factorVars, dist);
|
||||
fg_->addVariable (junctionVar);
|
||||
fg_->addFactor (newFactor);
|
||||
|
||||
runSolver();
|
||||
ParamSet results = getPosterioriOf (maxVid);
|
||||
deleteJunction (newFactor, junctionVar);
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::initializeSolver (void)
|
||||
{
|
||||
fg_->setIndexes();
|
||||
|
||||
CFgVarSet vars = fg_->getFgVarNodes();
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
varsI_.reserve (vars.size());
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
CFactorSet factors = fg_->getFactors();
|
||||
for (unsigned i = 0; i < factorsI_.size(); i++) {
|
||||
delete factorsI_[i];
|
||||
}
|
||||
factorsI_.reserve (factors.size());
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
factorsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
createLinks();
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
Factor* source = links_[i]->getFactor();
|
||||
FgVarNode* dest = links_[i]->getVariable();
|
||||
varsI_[dest->getIndex()]->addLink (links_[i]);
|
||||
factorsI_[source->getIndex()]->addLink (links_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::createLinks (void)
|
||||
{
|
||||
CFactorSet factors = fg_->getFactors();
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
CFgVarSet neighbors = factors[i]->getFgVarNodes();
|
||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||
links_.push_back (new Link (factors[i], neighbors[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::deleteJunction (Factor* f, FgVarNode* v)
|
||||
{
|
||||
fg_->removeFactor (f);
|
||||
f->freeDistribution();
|
||||
delete f;
|
||||
fg_->removeVariable (v);
|
||||
delete v;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
SPSolver::converged (void)
|
||||
{
|
||||
// this can happen if the graph is fully disconnected
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIter_ == 0 || nIter_ == 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
for (unsigned i = 0; i < updateOrder_.size(); i++) {
|
||||
double residual = getResidual (updateOrder_[i]);
|
||||
if (DL >= 1) {
|
||||
cout << updateOrder_[i].toString();
|
||||
cout << " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > accuracy_) {
|
||||
if (SolverOptions::schedule == SolverOptions::S_MAX_RESIDUAL) {
|
||||
Param maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual < SolverOptions::accuracy) {
|
||||
converged = true;
|
||||
} else {
|
||||
converged = false;
|
||||
if (DL == 0) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (DL >= 2) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
}
|
||||
if (residual > SolverOptions::accuracy) {
|
||||
converged = false;
|
||||
if (DL == 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
@@ -169,127 +310,161 @@ SPSolver::converged (void)
|
||||
void
|
||||
SPSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIter_ == 0) {
|
||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
||||
Link& l = updateOrder_[c];
|
||||
calculateNextMessage (l.source, l.destination);
|
||||
if (DL >= 1) {
|
||||
cout << updateOrder_[c].toString() << " residual = " ;
|
||||
cout << getResidual (updateOrder_[c]) << endl;
|
||||
if (nIter_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
if (DL >= 2 && DL < 5) {
|
||||
cout << "calculating " << links_[i]->toString() << endl;
|
||||
}
|
||||
}
|
||||
sort (updateOrder_.begin(), updateOrder_.end(), compareResidual);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
||||
Link& link = updateOrder_.front();
|
||||
updateMessage (link);
|
||||
resetResidual (link);
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
cout << endl << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
// update the messages that depend on message source --> destination
|
||||
vector<Factor*> fstLevelNeighbors = link.destination->getFactors();
|
||||
for (unsigned i = 0; i < fstLevelNeighbors.size(); i++) {
|
||||
if (fstLevelNeighbors[i] != link.source) {
|
||||
FgVarSet sndLevelNeighbors;
|
||||
sndLevelNeighbors = fstLevelNeighbors[i]->getFgVarNodes();
|
||||
for (unsigned j = 0; j < sndLevelNeighbors.size(); j++) {
|
||||
if (sndLevelNeighbors[j] != link.destination) {
|
||||
calculateNextMessage (fstLevelNeighbors[i], sndLevelNeighbors[j]);
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
Link* link = *it;
|
||||
if (DL >= 2) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
}
|
||||
if (link->getResidual() < SolverOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
link->updateMessage();
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
CFactorSet factorNeighbors = link->getVariable()->getFactors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
if (factorNeighbors[i] != link->getFactor()) {
|
||||
CLinkSet links = factorsI_[factorNeighbors[i]->getIndex()]->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) {
|
||||
if (DL >= 2 && DL < 5) {
|
||||
cout << " calculating " << links[j]->toString() << endl;
|
||||
}
|
||||
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||
LinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
sort (updateOrder_.begin(), updateOrder_.end(), compareResidual);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::updateMessage (const Link& link)
|
||||
ParamSet
|
||||
SPSolver::getFactor2VarMsg (const Link* link) const
|
||||
{
|
||||
updateMessage (link.source, link.destination);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::updateMessage (const Factor* src, const FgVarNode* dest)
|
||||
{
|
||||
msgs_[dest->getIndex()]->updateMessage (src);
|
||||
/* cout << src->getLabel() << " --> " << dest->getLabel() << endl;
|
||||
cout << " m: " ;
|
||||
Message msg = msgs_[dest->getIndex()]->getMessage (src);
|
||||
for (unsigned i = 0; i < msg.size(); i++) {
|
||||
if (i != 0) cout << ", " ;
|
||||
cout << msg[i];
|
||||
}
|
||||
cout << endl;
|
||||
*/
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::calculateNextMessage (const Link& link)
|
||||
{
|
||||
calculateNextMessage (link.source, link.destination);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
SPSolver::calculateNextMessage (const Factor* src, const FgVarNode* dest)
|
||||
{
|
||||
FgVarSet neighbors = src->getFgVarNodes();
|
||||
// calculate the product of MessageBankets sended
|
||||
const Factor* src = link->getFactor();
|
||||
const FgVarNode* dest = link->getVariable();
|
||||
CFgVarSet neighbors = src->getFgVarNodes();
|
||||
CLinkSet links = factorsI_[src->getIndex()]->getLinks();
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dest'
|
||||
Factor result = *src;
|
||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
||||
if (neighbors[i] != dest) {
|
||||
Message msg (neighbors[i]->getDomainSize(), 1);
|
||||
calculateVarFactorMessage (neighbors[i], src, msg);
|
||||
result *= Factor (neighbors[i], msg);
|
||||
}
|
||||
Factor result (*src);
|
||||
Factor temp;
|
||||
if (DL >= 5) {
|
||||
cout << "calculating " ;
|
||||
cout << src->getLabel() << " --> " << dest->getLabel();
|
||||
cout << endl;
|
||||
}
|
||||
// marginalize all vars except `dest'
|
||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
||||
if (neighbors[i] != dest) {
|
||||
result.marginalizeVariable (neighbors[i]);
|
||||
}
|
||||
}
|
||||
msgs_[dest->getIndex()]->setNextMessage (src, result.getParameters());
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::calculateVarFactorMessage (const FgVarNode* src,
|
||||
const Factor* dest,
|
||||
Message& placeholder) const
|
||||
{
|
||||
assert (src->getDomainSize() == (int)placeholder.size());
|
||||
if (src->hasEvidence()) {
|
||||
for (unsigned i = 0; i < placeholder.size(); i++) {
|
||||
if ((int)i != src->getEvidence()) {
|
||||
placeholder[i] = 0.0;
|
||||
if (links[i]->getVariable() != dest) {
|
||||
if (DL >= 5) {
|
||||
cout << " message from " << links[i]->getVariable()->getLabel();
|
||||
cout << ": " ;
|
||||
ParamSet p = getVar2FactorMsg (links[i]);
|
||||
cout << endl;
|
||||
Factor temp2 (links[i]->getVariable(), p);
|
||||
temp.multiplyByFactor (temp2);
|
||||
temp2.freeDistribution();
|
||||
} else {
|
||||
placeholder[i] = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
MessageBanket* mb = msgs_[src->getIndex()];
|
||||
vector<Factor*> neighbors = src->getFactors();
|
||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
||||
if (neighbors[i] != dest) {
|
||||
const Message& fromFactor = mb->getMessage (neighbors[i]);
|
||||
for (unsigned j = 0; j < fromFactor.size(); j++) {
|
||||
placeholder[j] *= fromFactor[j];
|
||||
}
|
||||
Factor temp2 (links[i]->getVariable(), getVar2FactorMsg (links[i]));
|
||||
temp.multiplyByFactor (temp2);
|
||||
temp2.freeDistribution();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (links.size() >= 2) {
|
||||
result.multiplyByFactor (temp, &(src->getCptEntries()));
|
||||
if (DL >= 5) {
|
||||
cout << " message product: " ;
|
||||
cout << Util::parametersToString (temp.getParameters()) << endl;
|
||||
cout << " factor product: " ;
|
||||
cout << Util::parametersToString (src->getParameters());
|
||||
cout << " x " ;
|
||||
cout << Util::parametersToString (temp.getParameters());
|
||||
cout << " = " ;
|
||||
cout << Util::parametersToString (result.getParameters()) << endl;
|
||||
}
|
||||
temp.freeDistribution();
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getVariable() != dest) {
|
||||
result.removeVariable (links[i]->getVariable());
|
||||
}
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " final message: " ;
|
||||
cout << Util::parametersToString (result.getParameters()) << endl << endl;
|
||||
}
|
||||
ParamSet msg = result.getParameters();
|
||||
result.freeDistribution();
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
SPSolver::getVar2FactorMsg (const Link* link) const
|
||||
{
|
||||
const FgVarNode* src = link->getVariable();
|
||||
const Factor* dest = link->getFactor();
|
||||
ParamSet msg;
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->getDomainSize(), 0.0);
|
||||
msg[src->getEvidence()] = 1.0;
|
||||
if (DL >= 5) {
|
||||
cout << Util::parametersToString (msg);
|
||||
}
|
||||
} else {
|
||||
msg.resize (src->getDomainSize(), 1.0);
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << Util::parametersToString (msg);
|
||||
}
|
||||
CLinkSet links = varsI_[src->getIndex()]->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dest) {
|
||||
CParamSet msgFromFactor = links[i]->getMessage();
|
||||
for (unsigned j = 0; j < msgFromFactor.size(); j++) {
|
||||
msg[j] *= msgFromFactor[j];
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " x " << Util::parametersToString (msgFromFactor);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " = " << Util::parametersToString (msg);
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user