This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/bp/BnBpSolver.h
2011-12-12 15:29:51 +00:00

263 lines
7.9 KiB
C++

#ifndef HORUS_BNBPSOLVER_H
#define HORUS_BNBPSOLVER_H
#include <vector>
#include <set>
#include "Solver.h"
#include "BayesNet.h"
#include "Shared.h"
using namespace std;
class BpNodeInfo;
static const string PI_SYMBOL = "pi" ;
static const string LD_SYMBOL = "ld" ;
enum LinkOrientation {UP, DOWN};
enum JointCalcType {CHAIN_RULE, JUNCTION_NODE};
class BpLink
{
public:
BpLink (BayesNode* s, BayesNode* d, LinkOrientation o)
{
source_ = s;
destin_ = d;
orientation_ = o;
if (orientation_ == LinkOrientation::DOWN) {
v1_.resize (s->nrStates(), Util::tl (1.0/s->nrStates()));
v2_.resize (s->nrStates(), Util::tl (1.0/s->nrStates()));
} else {
v1_.resize (d->nrStates(), Util::tl (1.0/d->nrStates()));
v2_.resize (d->nrStates(), Util::tl (1.0/d->nrStates()));
}
currMsg_ = &v1_;
nextMsg_ = &v2_;
residual_ = 0;
msgSended_ = false;
}
void updateMessage (void)
{
swap (currMsg_, nextMsg_);
msgSended_ = true;
}
void updateResidual (void)
{
residual_ = Util::getMaxNorm (v1_, v2_);
}
string toString (void) const
{
stringstream ss;
if (orientation_ == LinkOrientation::DOWN) {
ss << PI_SYMBOL;
} else {
ss << LD_SYMBOL;
}
ss << "(" << source_->label();
ss << " --> " << destin_->label() << ")" ;
return ss.str();
}
string toString (unsigned stateIndex) const
{
stringstream ss;
ss << toString() << "[" ;
if (orientation_ == LinkOrientation::DOWN) {
ss << source_->states()[stateIndex] << "]" ;
} else {
ss << destin_->states()[stateIndex] << "]" ;
}
return ss.str();
}
BayesNode* getSource (void) const { return source_; }
BayesNode* getDestination (void) const { return destin_; }
LinkOrientation getOrientation (void) const { return orientation_; }
const ParamSet& getMessage (void) const { return *currMsg_; }
ParamSet& getNextMessage (void) { return *nextMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0;}
private:
BayesNode* source_;
BayesNode* destin_;
LinkOrientation orientation_;
ParamSet v1_;
ParamSet v2_;
ParamSet* currMsg_;
ParamSet* nextMsg_;
bool msgSended_;
double residual_;
};
typedef vector<BpLink*> BpLinkSet;
class BpNodeInfo
{
public:
BpNodeInfo (BayesNode*);
ParamSet getBeliefs (void) const;
bool receivedBottomInfluence (void) const;
ParamSet& getPiValues (void) { return piVals_; }
ParamSet& getLambdaValues (void) { return ldVals_; }
void incNumPiMsgsReceived (void) { nPiMsgsRcv_ ++; }
void incNumLambdaMsgsReceived (void) { nLdMsgsRcv_ ++; }
bool piValuesCalculated (void) { return piValsCalc_; }
bool lambdaValuesCalculated (void) { return ldValsCalc_; }
void markPiValuesAsCalculated (void);
void markLambdaValuesAsCalculated (void);
bool receivedAllPiMessages (void);
bool receivedAllLambdaMessages (void);
bool readyToSendPiMsgTo (const BayesNode*) const ;
bool readyToSendLambdaMsgTo (const BayesNode*) const;
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; }
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); }
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
private:
DISALLOW_COPY_AND_ASSIGN (BpNodeInfo);
ParamSet piVals_; // pi values
ParamSet ldVals_; // lambda values
unsigned nPiMsgsRcv_;
unsigned nLdMsgsRcv_;
bool piValsCalc_;
bool ldValsCalc_;
BpLinkSet inParentLinks_;
BpLinkSet inChildLinks_;
BpLinkSet outParentLinks_;
BpLinkSet outChildLinks_;
const BayesNode* node_;
};
class BnBpSolver : public Solver
{
public:
BnBpSolver (const BayesNet&);
~BnBpSolver (void);
void runSolver (void);
ParamSet getPosterioriOf (VarId);
ParamSet getJointDistributionOf (const VarIdSet&);
private:
DISALLOW_COPY_AND_ASSIGN (BnBpSolver);
void initializeSolver (void);
void runPolyTreeSolver (void);
void runLoopySolver (void);
void maxResidualSchedule (void);
bool converged (void) const;
void updatePiValues (BayesNode*);
void updateLambdaValues (BayesNode*);
void calculateLambdaMessage (BpLink*);
void calculatePiMessage (BpLink*);
ParamSet getJointByJunctionNode (const VarIdSet&);
ParamSet getJointByChainRule (const VarIdSet&) const;
void printPiLambdaValues (const BayesNode*) const;
void printAllMessageStatus (void) const;
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
cout << "calculating & updating " << link->toString() << endl;
}
if (link->getOrientation() == LinkOrientation::DOWN) {
calculatePiMessage (link);
} else if (link->getOrientation() == LinkOrientation::UP) {
calculateLambdaMessage (link);
}
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void calculateMessage (BpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
cout << "calculating " << link->toString() << endl;
}
if (link->getOrientation() == LinkOrientation::DOWN) {
calculatePiMessage (link);
} else if (link->getOrientation() == LinkOrientation::UP) {
calculateLambdaMessage (link);
}
if (calcResidual) {
link->updateResidual();
}
}
void updateMessage (BpLink* link)
{
if (DL >= 3) {
cout << "updating " << link->toString() << endl;
}
link->updateMessage();
}
void updateValues (BpLink* link)
{
if (!link->getDestination()->hasEvidence()) {
if (link->getOrientation() == LinkOrientation::DOWN) {
updatePiValues (link->getDestination());
} else if (link->getOrientation() == LinkOrientation::UP) {
updateLambdaValues (link->getDestination());
}
}
}
BpNodeInfo* ninf (const BayesNode* node) const
{
assert (node);
assert (node == bayesNet_->getBayesNode (node->varId()));
assert (node->getIndex() < nodesI_.size());
return nodesI_[node->getIndex()];
}
const BayesNet* bayesNet_;
vector<BpLink*> links_;
vector<BpNodeInfo*> nodesI_;
unsigned nIters_;
JointCalcType jointCalcType_;
struct compare
{
inline bool operator() (const BpLink* e1, const BpLink* e2)
{
return e1->getResidual() > e2->getResidual();
}
};
typedef multiset<BpLink*, compare> SortedOrder;
SortedOrder sortedOrder_;
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
BpLinkMap linkMap_;
};
#endif // HORUS_BNBPSOLVER_H