This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/BeliefProp.h

178 lines
3.7 KiB
C
Raw Normal View History

2013-02-07 17:50:02 +00:00
#ifndef YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_
#define YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_
2012-05-23 14:56:01 +01:00
#include <vector>
2013-02-07 20:09:10 +00:00
#include <set>
#include <string>
2012-05-23 14:56:01 +01:00
2012-11-14 21:55:51 +00:00
#include "GroundSolver.h"
2012-05-23 14:56:01 +01:00
#include "FactorGraph.h"
2012-12-27 12:54:58 +00:00
2012-05-23 14:56:01 +01:00
namespace Horus {
2013-02-07 23:53:13 +00:00
enum MsgSchedule {
seqFixedSch,
seqRandomSch,
parallelSch,
maxResidualSch
};
class BpLink {
2012-05-23 14:56:01 +01:00
public:
BpLink (FacNode* fn, VarNode* vn);
2012-05-23 14:56:01 +01:00
virtual ~BpLink (void) { };
2012-05-23 14:56:01 +01:00
2012-05-28 19:41:24 +01:00
FacNode* facNode (void) const { return fac_; }
2012-05-23 14:56:01 +01:00
2012-05-28 19:41:24 +01:00
VarNode* varNode (void) const { return var_; }
2012-05-23 14:56:01 +01:00
2012-05-28 19:41:24 +01:00
const Params& message (void) const { return *currMsg_; }
2012-05-23 14:56:01 +01:00
2012-05-28 19:41:24 +01:00
Params& nextMessage (void) { return *nextMsg_; }
2012-05-23 14:56:01 +01:00
2012-05-28 19:41:24 +01:00
double residual (void) const { return residual_; }
2012-05-23 14:56:01 +01:00
void clearResidual (void);
2012-05-23 14:56:01 +01:00
void updateResidual (void);
2012-05-23 14:56:01 +01:00
virtual void updateMessage (void);
2012-05-23 14:56:01 +01:00
2013-02-07 13:37:15 +00:00
std::string toString (void) const;
2012-12-17 18:39:42 +00:00
2012-05-23 14:56:01 +01:00
protected:
FacNode* fac_;
VarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
double residual_;
2012-12-27 22:25:45 +00:00
private:
DISALLOW_COPY_AND_ASSIGN (BpLink);
2012-05-23 14:56:01 +01:00
};
2013-02-07 13:37:15 +00:00
typedef std::vector<BpLink*> BpLinks;
2012-05-23 14:56:01 +01:00
class SPNodeInfo {
2012-05-23 14:56:01 +01:00
public:
2012-12-27 22:25:45 +00:00
SPNodeInfo (void) { }
void addBpLink (BpLink* link) { links_.push_back (link); }
const BpLinks& getLinks (void) { return links_; }
2012-05-23 14:56:01 +01:00
private:
BpLinks links_;
2012-12-27 22:25:45 +00:00
DISALLOW_COPY_AND_ASSIGN (SPNodeInfo);
2012-05-23 14:56:01 +01:00
};
class BeliefProp : public GroundSolver {
2012-05-23 14:56:01 +01:00
public:
BeliefProp (const FactorGraph&);
2012-05-23 14:56:01 +01:00
virtual ~BeliefProp (void);
2012-05-23 14:56:01 +01:00
Params solveQuery (VarIds);
virtual void printSolverFlags (void) const;
virtual Params getPosterioriOf (VarId);
virtual Params getJointDistributionOf (const VarIds&);
2012-12-17 18:39:42 +00:00
2012-12-27 15:05:40 +00:00
Params getFactorJoint (FacNode* fn, const VarIds&);
static double accuracy (void) { return accuracy_; }
static void setAccuracy (double acc) { accuracy_ = acc; }
static unsigned maxIterations (void) { return maxIter_; }
static void setMaxIterations (unsigned mi) { maxIter_ = mi; }
static MsgSchedule msgSchedule (void) { return schedule_; }
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
protected:
struct CmpResidual {
bool operator() (const BpLink* l1, const BpLink* l2) {
return l1->residual() > l2->residual();
}};
typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
SPNodeInfo* ninf (const VarNode* var) const;
2012-05-23 14:56:01 +01:00
SPNodeInfo* ninf (const FacNode* fac) const;
2012-05-23 14:56:01 +01:00
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true);
2012-05-23 14:56:01 +01:00
void calculateMessage (BpLink* link, bool calcResidual = true);
2012-05-23 14:56:01 +01:00
void updateMessage (BpLink* link);
2012-05-23 14:56:01 +01:00
2012-12-27 15:05:40 +00:00
void runSolver (void);
virtual void createLinks (void);
virtual void maxResidualSchedule (void);
virtual void calcFactorToVarMsg (BpLink*);
virtual Params getVarToFactorMsg (const BpLink*) const;
virtual Params getJointByConditioning (const VarIds&) const;
2013-02-07 13:37:15 +00:00
BpLinks links_;
unsigned nIters_;
std::vector<SPNodeInfo*> varsI_;
std::vector<SPNodeInfo*> facsI_;
bool runned_;
SortedOrder sortedOrder_;
BpLinkMap linkMap_;
2012-05-23 14:56:01 +01:00
static double accuracy_;
static unsigned maxIter_;
static MsgSchedule schedule_;
2012-05-23 14:56:01 +01:00
private:
void initializeSolver (void);
bool converged (void);
virtual void printLinkInformation (void) const;
2012-12-27 22:25:45 +00:00
DISALLOW_COPY_AND_ASSIGN (BeliefProp);
2012-05-23 14:56:01 +01:00
};
inline SPNodeInfo*
BeliefProp::ninf (const VarNode* var) const
{
return varsI_[var->getIndex()];
}
inline SPNodeInfo*
BeliefProp::ninf (const FacNode* fac) const
{
return facsI_[fac->getIndex()];
}
} // namespace Horus
2013-02-07 23:53:13 +00:00
2013-02-08 00:20:01 +00:00
#endif // YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_
2012-05-23 14:56:01 +01:00