This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/BeliefProp.h

189 lines
3.9 KiB
C
Raw Normal View History

#ifndef HORUS_BELIEFPROP_H
#define HORUS_BELIEFPROP_H
2012-05-23 14:56:01 +01:00
#include <set>
#include <vector>
#include <sstream>
2012-11-14 21:55:51 +00:00
#include "GroundSolver.h"
2012-05-23 14:56:01 +01:00
#include "Factor.h"
#include "FactorGraph.h"
#include "Util.h"
using namespace std;
class BpLink
2012-05-23 14:56:01 +01:00
{
public:
BpLink (FacNode* fn, VarNode* vn)
2012-05-23 14:56:01 +01:00
{
fac_ = fn;
var_ = vn;
2012-05-24 16:14:13 +01:00
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
2012-05-23 14:56:01 +01:00
currMsg_ = &v1_;
nextMsg_ = &v2_;
residual_ = 0.0;
}
virtual ~BpLink (void) { };
2012-05-23 14:56:01 +01:00
2012-05-28 19:41:24 +01:00
FacNode* facNode (void) const { return fac_; }
2012-05-23 14:56:01 +01:00
2012-05-28 19:41:24 +01:00
VarNode* varNode (void) const { return var_; }
2012-05-23 14:56:01 +01:00
2012-05-28 19:41:24 +01:00
const Params& message (void) const { return *currMsg_; }
2012-05-23 14:56:01 +01:00
2012-05-28 19:41:24 +01:00
Params& nextMessage (void) { return *nextMsg_; }
2012-05-23 14:56:01 +01:00
2012-05-28 19:41:24 +01:00
double residual (void) const { return residual_; }
2012-05-23 14:56:01 +01:00
void clearResidual (void) { residual_ = 0.0; }
void updateResidual (void)
{
residual_ = LogAware::getMaxNorm (v1_,v2_);
}
virtual void updateMessage (void)
{
swap (currMsg_, nextMsg_);
}
string toString (void) const
{
stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
2012-12-17 18:39:42 +00:00
2012-05-23 14:56:01 +01:00
protected:
FacNode* fac_;
VarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
double residual_;
};
typedef vector<BpLink*> BpLinks;
2012-05-23 14:56:01 +01:00
class SPNodeInfo
{
public:
void addBpLink (BpLink* link) { links_.push_back (link); }
const BpLinks& getLinks (void) { return links_; }
2012-05-23 14:56:01 +01:00
private:
BpLinks links_;
2012-05-23 14:56:01 +01:00
};
2012-11-14 21:55:51 +00:00
class BeliefProp : public GroundSolver
2012-05-23 14:56:01 +01:00
{
public:
BeliefProp (const FactorGraph&);
2012-05-23 14:56:01 +01:00
virtual ~BeliefProp (void);
2012-05-23 14:56:01 +01:00
Params solveQuery (VarIds);
virtual void printSolverFlags (void) const;
virtual Params getPosterioriOf (VarId);
virtual Params getJointDistributionOf (const VarIds&);
2012-12-17 18:39:42 +00:00
2012-05-23 14:56:01 +01:00
protected:
void runSolver (void);
virtual void createLinks (void);
virtual void maxResidualSchedule (void);
virtual void calcFactorToVarMsg (BpLink*);
2012-05-23 14:56:01 +01:00
virtual Params getVarToFactorMsg (const BpLink*) const;
2012-05-23 14:56:01 +01:00
virtual Params getJointByConditioning (const VarIds&) const;
public:
2012-09-11 18:48:16 +01:00
Params getFactorJoint (FacNode* fn, const VarIds&);
protected:
2012-05-23 14:56:01 +01:00
SPNodeInfo* ninf (const VarNode* var) const
{
return varsI_[var->getIndex()];
}
SPNodeInfo* ninf (const FacNode* fac) const
{
return facsI_[fac->getIndex()];
}
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
2012-05-23 14:56:01 +01:00
{
if (Globals::verbosity > 2) {
cout << "calculating & updating " << link->toString() << endl;
}
2012-05-28 19:41:24 +01:00
calcFactorToVarMsg (link);
2012-05-23 14:56:01 +01:00
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void calculateMessage (BpLink* link, bool calcResidual = true)
2012-05-23 14:56:01 +01:00
{
if (Globals::verbosity > 2) {
cout << "calculating " << link->toString() << endl;
}
2012-05-28 19:41:24 +01:00
calcFactorToVarMsg (link);
2012-05-23 14:56:01 +01:00
if (calcResidual) {
link->updateResidual();
}
}
void updateMessage (BpLink* link)
2012-05-23 14:56:01 +01:00
{
link->updateMessage();
if (Globals::verbosity > 2) {
cout << "updating " << link->toString() << endl;
}
}
struct CompareResidual
{
inline bool operator() (const BpLink* link1, const BpLink* link2)
2012-05-23 14:56:01 +01:00
{
2012-05-28 19:41:24 +01:00
return link1->residual() > link2->residual();
2012-05-23 14:56:01 +01:00
}
};
2012-05-31 21:24:40 +01:00
BpLinks links_;
2012-05-23 14:56:01 +01:00
unsigned nIters_;
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_;
bool runned_;
typedef multiset<BpLink*, CompareResidual> SortedOrder;
2012-05-23 14:56:01 +01:00
SortedOrder sortedOrder_;
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
BpLinkMap linkMap_;
2012-05-23 14:56:01 +01:00
private:
void initializeSolver (void);
bool converged (void);
virtual void printLinkInformation (void) const;
};
#endif // HORUS_BELIEFPROP_H
2012-05-23 14:56:01 +01:00