add missing files
This commit is contained in:
parent
d1b25f0864
commit
f1d52c0389
496
packages/CLPBN/clpbn/bp/BpSolver.cpp
Normal file
496
packages/CLPBN/clpbn/bp/BpSolver.cpp
Normal file
@ -0,0 +1,496 @@
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "BpSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Indexer.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
||||
{
|
||||
factorGraph_ = &fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
BpSolver::~BpSolver (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
||||
delete facsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::runSolver (void)
|
||||
{
|
||||
clock_t start;
|
||||
if (Constants::COLLECT_STATS) {
|
||||
start = clock();
|
||||
}
|
||||
runLoopySolver();
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
unsigned size = factorGraph_->varNodes().size();
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = factorGraph_->isTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
assert (factorGraph_->getVarNode (vid));
|
||||
VarNode* var = factorGraph_->getVarNode (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->range(), LogAware::noEvidence());
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::add (probs, links[i]->getMessage());
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::multiply (probs, links[i]->getMessage());
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
}
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
{
|
||||
int idx = -1;
|
||||
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
|
||||
const FactorNodes& factorNodes = vn->neighbors();
|
||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
||||
if (factorNodes[i]->factor()->contains (jointVarIds)) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (idx == -1) {
|
||||
return getJointByConditioning (jointVarIds);
|
||||
} else {
|
||||
Factor res (*factorNodes[idx]->factor());
|
||||
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Factor msg (links[i]->getVariable()->varId(),
|
||||
links[i]->getVariable()->range(),
|
||||
getVar2FactorMsg (links[i]));
|
||||
res.multiply (msg);
|
||||
}
|
||||
res.sumOutAllExcept (jointVarIds);
|
||||
res.reorderArguments (jointVarIds);
|
||||
res.normalize();
|
||||
Params jointDist = res.params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (jointDist);
|
||||
}
|
||||
return jointDist;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::runLoopySolver (void)
|
||||
{
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
|
||||
nIters_ ++;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader (" Iteration " + nIters_);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
updateMessage(links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::initializeSolver (void)
|
||||
{
|
||||
const VarNodes& varNodes = factorGraph_->varNodes();
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
varsI_.reserve (varNodes.size());
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
varsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
const FactorNodes& facNodes = factorGraph_->factorNodes();
|
||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
||||
delete facsI_[i];
|
||||
}
|
||||
facsI_.reserve (facNodes.size());
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
facsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
createLinks();
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
FactorNode* src = links_[i]->getFactor();
|
||||
VarNode* dst = links_[i]->getVariable();
|
||||
ninf (dst)->addSpLink (links_[i]);
|
||||
ninf (src)->addSpLink (links_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::createLinks (void)
|
||||
{
|
||||
const FactorNodes& facNodes = factorGraph_->factorNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const VarNodes& neighbors = facNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BpSolver::converged (void)
|
||||
{
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIters_ == 0 || nIters_ == 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
} else {
|
||||
converged = true;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
if (Constants::DEBUG == 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
SpLink* link = *it;
|
||||
if (link->getResidual() < BpOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
updateMessage (link);
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
const FactorNodes& factorNeighbors = link->getVariable()->neighbors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
if (factorNeighbors[i] != link->getFactor()) {
|
||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) {
|
||||
calculateMessage (links[j]);
|
||||
SpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printDashedLine();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
||||
{
|
||||
const FactorNode* src = link->getFactor();
|
||||
const VarNode* dst = link->getVariable();
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dst'
|
||||
unsigned msgSize = 1;
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
msgSize *= links[i]->getVariable()->range();
|
||||
}
|
||||
unsigned repetitions = 1;
|
||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||
if (Globals::logDomain) {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
} else {
|
||||
unsigned ds = links[i]->getVariable()->range();
|
||||
Util::add (msgProduct, Params (ds, 1.0), repetitions);
|
||||
repetitions *= ds;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message from " << links[i]->getVariable()->label();
|
||||
cout << ": " << endl;
|
||||
}
|
||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
} else {
|
||||
unsigned ds = links[i]->getVariable()->range();
|
||||
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
|
||||
repetitions *= ds;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Factor result (src->factor()->arguments(),
|
||||
src->factor()->ranges(),
|
||||
msgProduct);
|
||||
result.multiply (*(src->factor()));
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->params() << endl;
|
||||
cout << " factor product: " << result.params() << endl;
|
||||
}
|
||||
result.sumOutAllExcept (dst->varId());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " marginalized: " ;
|
||||
cout << result.params() << endl;
|
||||
}
|
||||
const Params& resultParams = result.params();
|
||||
Params& message = link->getNextMessage();
|
||||
for (unsigned i = 0; i < resultParams.size(); i++) {
|
||||
message[i] = resultParams[i];
|
||||
}
|
||||
LogAware::normalize (message);
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " curr msg: " << link->getMessage() << endl;
|
||||
cout << " next msg: " << message << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
{
|
||||
const VarNode* src = link->getVariable();
|
||||
const FactorNode* dst = link->getFactor();
|
||||
Params msg;
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
msg[src->getEvidence()] = LogAware::withEvidence();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << msg;
|
||||
}
|
||||
} else {
|
||||
msg.resize (src->range(), LogAware::one());
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << msg;
|
||||
}
|
||||
const SpLinkSet& links = ninf (src)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
Util::add (msg, links[i]->getMessage());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
Util::multiply (msg, links[i]->getMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " x " << links[i]->getMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
VarNodes jointVars;
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
assert (factorGraph_->getVarNode (jointVarIds[i]));
|
||||
jointVars.push_back (factorGraph_->getVarNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
FactorGraph* fg = new FactorGraph (*factorGraph_);
|
||||
BpSolver solver (*fg);
|
||||
solver.runSolver();
|
||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
|
||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
||||
assert (jointVars[i]->hasEvidence() == false);
|
||||
Params newBeliefs;
|
||||
Vars observedVars;
|
||||
for (unsigned j = 0; j < observedVids.size(); j++) {
|
||||
observedVars.push_back (fg->getVarNode (observedVids[j]));
|
||||
}
|
||||
StatesIndexer idx (observedVars, false);
|
||||
while (idx.valid()) {
|
||||
for (unsigned j = 0; j < observedVars.size(); j++) {
|
||||
observedVars[j]->setEvidence (idx[j]);
|
||||
}
|
||||
++ idx;
|
||||
BpSolver solver (*fg);
|
||||
solver.runSolver();
|
||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
}
|
||||
|
||||
int count = -1;
|
||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->range() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
observedVids.push_back (jointVars[i]->varId());
|
||||
}
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::printLinkInformation (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
SpLink* l = links_[i];
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " ;
|
||||
cout << l->getMessage() << endl;
|
||||
cout << " next msg = " ;
|
||||
cout << l->getNextMessage() << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
186
packages/CLPBN/clpbn/bp/BpSolver.h
Normal file
186
packages/CLPBN/clpbn/bp/BpSolver.h
Normal file
@ -0,0 +1,186 @@
|
||||
#ifndef HORUS_BpSolver_H
|
||||
#define HORUS_BpSolver_H
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "Factor.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Util.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class SpLink
|
||||
{
|
||||
public:
|
||||
SpLink (FactorNode* fn, VarNode* vn)
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
msgSended_ = false;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
virtual ~SpLink (void) { };
|
||||
|
||||
FactorNode* getFactor (void) const { return fac_; }
|
||||
|
||||
VarNode* getVariable (void) const { return var_; }
|
||||
|
||||
const Params& getMessage (void) const { return *currMsg_; }
|
||||
|
||||
Params& getNextMessage (void) { return *nextMsg_; }
|
||||
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
|
||||
double getResidual (void) const { return residual_; }
|
||||
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = LogAware::getMaxNorm (v1_,v2_);
|
||||
}
|
||||
|
||||
virtual void updateMessage (void)
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
msgSended_ = true;
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << fac_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->label();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
FactorNode* fac_;
|
||||
VarNode* var_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
|
||||
typedef vector<SpLink*> SpLinkSet;
|
||||
|
||||
|
||||
class SPNodeInfo
|
||||
{
|
||||
public:
|
||||
void addSpLink (SpLink* link) { links_.push_back (link); }
|
||||
const SpLinkSet& getLinks (void) { return links_; }
|
||||
private:
|
||||
SpLinkSet links_;
|
||||
};
|
||||
|
||||
|
||||
class BpSolver : public Solver
|
||||
{
|
||||
public:
|
||||
BpSolver (const FactorGraph&);
|
||||
|
||||
virtual ~BpSolver (void);
|
||||
|
||||
void runSolver (void);
|
||||
|
||||
virtual Params getPosterioriOf (VarId);
|
||||
|
||||
virtual Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
protected:
|
||||
virtual void initializeSolver (void);
|
||||
|
||||
virtual void createLinks (void);
|
||||
|
||||
virtual void maxResidualSchedule (void);
|
||||
|
||||
virtual void calculateFactor2VariableMsg (SpLink*) const;
|
||||
|
||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
||||
|
||||
virtual Params getJointByConditioning (const VarIds&) const;
|
||||
|
||||
virtual void printLinkInformation (void) const;
|
||||
|
||||
SPNodeInfo* ninf (const VarNode* var) const
|
||||
{
|
||||
return varsI_[var->getIndex()];
|
||||
}
|
||||
|
||||
SPNodeInfo* ninf (const FactorNode* fac) const
|
||||
{
|
||||
return facsI_[fac->getIndex()];
|
||||
}
|
||||
|
||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
|
||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
|
||||
void updateMessage (SpLink* link)
|
||||
{
|
||||
link->updateMessage();
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
struct CompareResidual
|
||||
{
|
||||
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
||||
{
|
||||
return link1->getResidual() > link2->getResidual();
|
||||
}
|
||||
};
|
||||
|
||||
SpLinkSet links_;
|
||||
unsigned nIters_;
|
||||
vector<SPNodeInfo*> varsI_;
|
||||
vector<SPNodeInfo*> facsI_;
|
||||
const FactorGraph* factorGraph_;
|
||||
|
||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
|
||||
SpLinkMap linkMap_;
|
||||
|
||||
private:
|
||||
void runLoopySolver (void);
|
||||
bool converged (void);
|
||||
};
|
||||
|
||||
#endif // HORUS_BpSolver_H
|
||||
|
102
packages/CLPBN/clpbn/bp/Var.cpp
Normal file
102
packages/CLPBN/clpbn/bp/Var.cpp
Normal file
@ -0,0 +1,102 @@
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
|
||||
#include "Var.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
unordered_map<VarId, VarInfo> Var::varsInfo_;
|
||||
|
||||
|
||||
Var::Var (const Var* v)
|
||||
{
|
||||
varId_ = v->varId();
|
||||
range_ = v->range();
|
||||
evidence_ = v->getEvidence();
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
Var::Var (VarId varId, unsigned range, int evidence)
|
||||
{
|
||||
assert (range != 0);
|
||||
assert (evidence < (int) range);
|
||||
varId_ = varId;
|
||||
range_ = range;
|
||||
evidence_ = evidence;
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Var::isValidState (int stateIndex)
|
||||
{
|
||||
return stateIndex >= 0 && stateIndex < (int) range_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Var::isValidState (const string& stateName)
|
||||
{
|
||||
States states = Var::getVarInformation (varId_).states;
|
||||
return Util::contains (states, stateName);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Var::setEvidence (int ev)
|
||||
{
|
||||
assert (ev < (int) range_);
|
||||
evidence_ = ev;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Var::setEvidence (const string& ev)
|
||||
{
|
||||
States states = Var::getVarInformation (varId_).states;
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
if (states[i] == ev) {
|
||||
evidence_ = i;
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert (false);
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
Var::label (void) const
|
||||
{
|
||||
if (Var::variablesHaveInformation()) {
|
||||
return Var::getVarInformation (varId_).label;
|
||||
}
|
||||
stringstream ss;
|
||||
ss << "x" << varId_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
States
|
||||
Var::states (void) const
|
||||
{
|
||||
if (Var::variablesHaveInformation()) {
|
||||
return Var::getVarInformation (varId_).states;
|
||||
}
|
||||
States states;
|
||||
for (unsigned i = 0; i < range_; i++) {
|
||||
stringstream ss;
|
||||
ss << i ;
|
||||
states.push_back (ss.str());
|
||||
}
|
||||
return states;
|
||||
}
|
||||
|
108
packages/CLPBN/clpbn/bp/Var.h
Normal file
108
packages/CLPBN/clpbn/bp/Var.h
Normal file
@ -0,0 +1,108 @@
|
||||
#ifndef HORUS_Var_H
|
||||
#define HORUS_Var_H
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "Util.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
struct VarInfo
|
||||
{
|
||||
VarInfo (string l, const States& sts) : label(l), states(sts) { }
|
||||
string label;
|
||||
States states;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class Var
|
||||
{
|
||||
public:
|
||||
Var (const Var*);
|
||||
|
||||
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
||||
|
||||
virtual ~Var (void) { };
|
||||
|
||||
unsigned varId (void) const { return varId_; }
|
||||
|
||||
unsigned range (void) const { return range_; }
|
||||
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
|
||||
unsigned getIndex (void) const { return index_; }
|
||||
|
||||
void setIndex (unsigned idx) { index_ = idx; }
|
||||
|
||||
operator unsigned () const { return index_; }
|
||||
|
||||
bool hasEvidence (void) const
|
||||
{
|
||||
return evidence_ != Constants::NO_EVIDENCE;
|
||||
}
|
||||
|
||||
bool operator== (const Var& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ == var.varId();
|
||||
}
|
||||
|
||||
bool operator!= (const Var& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ != var.varId();
|
||||
}
|
||||
|
||||
bool isValidState (int);
|
||||
|
||||
bool isValidState (const string&);
|
||||
|
||||
void setEvidence (int);
|
||||
|
||||
void setEvidence (const string&);
|
||||
|
||||
string label (void) const;
|
||||
|
||||
States states (void) const;
|
||||
|
||||
static void addVariableInformation (
|
||||
VarId vid, string label, const States& states)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid) == false);
|
||||
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
||||
}
|
||||
|
||||
static VarInfo getVarInformation (VarId vid)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid));
|
||||
return varsInfo_.find (vid)->second;
|
||||
}
|
||||
|
||||
static bool variablesHaveInformation (void)
|
||||
{
|
||||
return varsInfo_.size() != 0;
|
||||
}
|
||||
|
||||
static void clearVariablesInformation (void)
|
||||
{
|
||||
varsInfo_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
VarId varId_;
|
||||
unsigned range_;
|
||||
int evidence_;
|
||||
unsigned index_;
|
||||
|
||||
static unordered_map<VarId, VarInfo> varsInfo_;
|
||||
|
||||
};
|
||||
|
||||
#endif // BP_Var_H
|
||||
|
Reference in New Issue
Block a user