2012-06-12 16:29:57 +01:00
|
|
|
#ifndef HORUS_BELIEFPROP_H
|
|
|
|
#define HORUS_BELIEFPROP_H
|
2012-05-23 14:56:01 +01:00
|
|
|
|
|
|
|
#include <set>
|
|
|
|
#include <vector>
|
2012-12-27 12:54:58 +00:00
|
|
|
|
2012-05-23 14:56:01 +01:00
|
|
|
#include <sstream>
|
|
|
|
|
2012-11-14 21:55:51 +00:00
|
|
|
#include "GroundSolver.h"
|
2012-05-23 14:56:01 +01:00
|
|
|
#include "FactorGraph.h"
|
2012-12-27 12:54:58 +00:00
|
|
|
|
2012-05-23 14:56:01 +01:00
|
|
|
|
|
|
|
using namespace std;
|
|
|
|
|
|
|
|
|
2012-12-27 15:00:30 +00:00
|
|
|
enum MsgSchedule {
|
|
|
|
SEQ_FIXED,
|
|
|
|
SEQ_RANDOM,
|
|
|
|
PARALLEL,
|
|
|
|
MAX_RESIDUAL
|
|
|
|
};
|
|
|
|
|
|
|
|
|
2012-05-31 21:12:46 +01:00
|
|
|
class BpLink
|
2012-05-23 14:56:01 +01:00
|
|
|
{
|
|
|
|
public:
|
2012-05-31 21:12:46 +01:00
|
|
|
BpLink (FacNode* fn, VarNode* vn)
|
2012-12-20 23:19:10 +00:00
|
|
|
{
|
2012-05-23 14:56:01 +01:00
|
|
|
fac_ = fn;
|
|
|
|
var_ = vn;
|
2012-05-24 16:14:13 +01:00
|
|
|
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
|
|
|
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
2012-05-23 14:56:01 +01:00
|
|
|
currMsg_ = &v1_;
|
|
|
|
nextMsg_ = &v2_;
|
|
|
|
residual_ = 0.0;
|
|
|
|
}
|
|
|
|
|
2012-05-31 21:12:46 +01:00
|
|
|
virtual ~BpLink (void) { };
|
2012-05-23 14:56:01 +01:00
|
|
|
|
2012-05-28 19:41:24 +01:00
|
|
|
FacNode* facNode (void) const { return fac_; }
|
2012-05-23 14:56:01 +01:00
|
|
|
|
2012-05-28 19:41:24 +01:00
|
|
|
VarNode* varNode (void) const { return var_; }
|
2012-05-23 14:56:01 +01:00
|
|
|
|
2012-05-28 19:41:24 +01:00
|
|
|
const Params& message (void) const { return *currMsg_; }
|
2012-05-23 14:56:01 +01:00
|
|
|
|
2012-05-28 19:41:24 +01:00
|
|
|
Params& nextMessage (void) { return *nextMsg_; }
|
2012-05-23 14:56:01 +01:00
|
|
|
|
2012-05-28 19:41:24 +01:00
|
|
|
double residual (void) const { return residual_; }
|
2012-05-23 14:56:01 +01:00
|
|
|
|
|
|
|
void clearResidual (void) { residual_ = 0.0; }
|
|
|
|
|
|
|
|
void updateResidual (void)
|
|
|
|
{
|
2013-01-05 12:04:43 +00:00
|
|
|
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
2012-05-23 14:56:01 +01:00
|
|
|
}
|
|
|
|
|
2012-12-20 23:19:10 +00:00
|
|
|
virtual void updateMessage (void)
|
2012-05-23 14:56:01 +01:00
|
|
|
{
|
|
|
|
swap (currMsg_, nextMsg_);
|
|
|
|
}
|
|
|
|
|
|
|
|
string toString (void) const
|
|
|
|
{
|
|
|
|
stringstream ss;
|
|
|
|
ss << fac_->getLabel();
|
|
|
|
ss << " -- " ;
|
|
|
|
ss << var_->label();
|
|
|
|
return ss.str();
|
|
|
|
}
|
2012-12-17 18:39:42 +00:00
|
|
|
|
2012-05-23 14:56:01 +01:00
|
|
|
protected:
|
|
|
|
FacNode* fac_;
|
|
|
|
VarNode* var_;
|
|
|
|
Params v1_;
|
|
|
|
Params v2_;
|
|
|
|
Params* currMsg_;
|
|
|
|
Params* nextMsg_;
|
|
|
|
double residual_;
|
2012-12-27 22:25:45 +00:00
|
|
|
|
|
|
|
private:
|
|
|
|
DISALLOW_COPY_AND_ASSIGN (BpLink);
|
2012-05-23 14:56:01 +01:00
|
|
|
};
|
|
|
|
|
2012-05-31 21:12:46 +01:00
|
|
|
typedef vector<BpLink*> BpLinks;
|
2012-05-23 14:56:01 +01:00
|
|
|
|
|
|
|
|
|
|
|
class SPNodeInfo
|
|
|
|
{
|
|
|
|
public:
|
2012-12-27 22:25:45 +00:00
|
|
|
SPNodeInfo (void) { }
|
2012-05-31 21:12:46 +01:00
|
|
|
void addBpLink (BpLink* link) { links_.push_back (link); }
|
|
|
|
const BpLinks& getLinks (void) { return links_; }
|
2012-05-23 14:56:01 +01:00
|
|
|
private:
|
2012-05-31 21:12:46 +01:00
|
|
|
BpLinks links_;
|
2012-12-27 22:25:45 +00:00
|
|
|
DISALLOW_COPY_AND_ASSIGN (SPNodeInfo);
|
2012-05-23 14:56:01 +01:00
|
|
|
};
|
|
|
|
|
|
|
|
|
2012-11-14 21:55:51 +00:00
|
|
|
class BeliefProp : public GroundSolver
|
2012-05-23 14:56:01 +01:00
|
|
|
{
|
|
|
|
public:
|
2012-06-12 16:29:57 +01:00
|
|
|
BeliefProp (const FactorGraph&);
|
2012-05-23 14:56:01 +01:00
|
|
|
|
2012-06-12 16:29:57 +01:00
|
|
|
virtual ~BeliefProp (void);
|
2012-05-23 14:56:01 +01:00
|
|
|
|
|
|
|
Params solveQuery (VarIds);
|
|
|
|
|
|
|
|
virtual void printSolverFlags (void) const;
|
|
|
|
|
|
|
|
virtual Params getPosterioriOf (VarId);
|
|
|
|
|
|
|
|
virtual Params getJointDistributionOf (const VarIds&);
|
2012-12-17 18:39:42 +00:00
|
|
|
|
2012-12-27 15:05:40 +00:00
|
|
|
Params getFactorJoint (FacNode* fn, const VarIds&);
|
|
|
|
|
2012-12-27 15:44:40 +00:00
|
|
|
static double accuracy (void) { return accuracy_; }
|
|
|
|
|
|
|
|
static void setAccuracy (double acc) { accuracy_ = acc; }
|
|
|
|
|
|
|
|
static unsigned maxIterations (void) { return maxIter_; }
|
|
|
|
|
|
|
|
static void setMaxIterations (unsigned mi) { maxIter_ = mi; }
|
|
|
|
|
|
|
|
static MsgSchedule msgSchedule (void) { return schedule_; }
|
|
|
|
|
|
|
|
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
|
2012-12-27 15:00:30 +00:00
|
|
|
|
2012-06-13 12:47:41 +01:00
|
|
|
protected:
|
2012-05-23 14:56:01 +01:00
|
|
|
SPNodeInfo* ninf (const VarNode* var) const
|
|
|
|
{
|
|
|
|
return varsI_[var->getIndex()];
|
|
|
|
}
|
|
|
|
|
|
|
|
SPNodeInfo* ninf (const FacNode* fac) const
|
|
|
|
{
|
|
|
|
return facsI_[fac->getIndex()];
|
|
|
|
}
|
|
|
|
|
2012-05-31 21:12:46 +01:00
|
|
|
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
2012-05-23 14:56:01 +01:00
|
|
|
{
|
|
|
|
if (Globals::verbosity > 2) {
|
|
|
|
cout << "calculating & updating " << link->toString() << endl;
|
|
|
|
}
|
2012-05-28 19:41:24 +01:00
|
|
|
calcFactorToVarMsg (link);
|
2012-05-23 14:56:01 +01:00
|
|
|
if (calcResidual) {
|
|
|
|
link->updateResidual();
|
|
|
|
}
|
|
|
|
link->updateMessage();
|
|
|
|
}
|
|
|
|
|
2012-05-31 21:12:46 +01:00
|
|
|
void calculateMessage (BpLink* link, bool calcResidual = true)
|
2012-05-23 14:56:01 +01:00
|
|
|
{
|
|
|
|
if (Globals::verbosity > 2) {
|
|
|
|
cout << "calculating " << link->toString() << endl;
|
|
|
|
}
|
2012-05-28 19:41:24 +01:00
|
|
|
calcFactorToVarMsg (link);
|
2012-05-23 14:56:01 +01:00
|
|
|
if (calcResidual) {
|
|
|
|
link->updateResidual();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2012-05-31 21:12:46 +01:00
|
|
|
void updateMessage (BpLink* link)
|
2012-05-23 14:56:01 +01:00
|
|
|
{
|
|
|
|
link->updateMessage();
|
|
|
|
if (Globals::verbosity > 2) {
|
|
|
|
cout << "updating " << link->toString() << endl;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
struct CompareResidual
|
|
|
|
{
|
2012-05-31 21:12:46 +01:00
|
|
|
inline bool operator() (const BpLink* link1, const BpLink* link2)
|
2012-05-23 14:56:01 +01:00
|
|
|
{
|
2012-05-28 19:41:24 +01:00
|
|
|
return link1->residual() > link2->residual();
|
2012-05-23 14:56:01 +01:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2012-12-27 15:05:40 +00:00
|
|
|
void runSolver (void);
|
|
|
|
|
|
|
|
virtual void createLinks (void);
|
|
|
|
|
|
|
|
virtual void maxResidualSchedule (void);
|
|
|
|
|
|
|
|
virtual void calcFactorToVarMsg (BpLink*);
|
|
|
|
|
|
|
|
virtual Params getVarToFactorMsg (const BpLink*) const;
|
|
|
|
|
|
|
|
virtual Params getJointByConditioning (const VarIds&) const;
|
|
|
|
|
2012-05-31 21:24:40 +01:00
|
|
|
BpLinks links_;
|
2012-05-23 14:56:01 +01:00
|
|
|
unsigned nIters_;
|
|
|
|
vector<SPNodeInfo*> varsI_;
|
|
|
|
vector<SPNodeInfo*> facsI_;
|
|
|
|
bool runned_;
|
|
|
|
|
2012-05-31 21:12:46 +01:00
|
|
|
typedef multiset<BpLink*, CompareResidual> SortedOrder;
|
2012-05-23 14:56:01 +01:00
|
|
|
SortedOrder sortedOrder_;
|
|
|
|
|
2012-05-31 21:12:46 +01:00
|
|
|
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
|
|
|
BpLinkMap linkMap_;
|
2012-05-23 14:56:01 +01:00
|
|
|
|
2012-12-27 15:44:40 +00:00
|
|
|
static double accuracy_;
|
|
|
|
static unsigned maxIter_;
|
|
|
|
static MsgSchedule schedule_;
|
|
|
|
|
2012-05-23 14:56:01 +01:00
|
|
|
private:
|
|
|
|
void initializeSolver (void);
|
|
|
|
|
|
|
|
bool converged (void);
|
|
|
|
|
|
|
|
virtual void printLinkInformation (void) const;
|
2012-12-27 22:25:45 +00:00
|
|
|
|
|
|
|
DISALLOW_COPY_AND_ASSIGN (BeliefProp);
|
2012-05-23 14:56:01 +01:00
|
|
|
};
|
|
|
|
|
2012-06-12 16:29:57 +01:00
|
|
|
#endif // HORUS_BELIEFPROP_H
|
2012-05-23 14:56:01 +01:00
|
|
|
|