This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/bp/FgBpSolver.h

187 lines
4.0 KiB
C
Raw Normal View History

2011-12-12 15:29:51 +00:00
#ifndef HORUS_FGBPSOLVER_H
#define HORUS_FGBPSOLVER_H
#include <set>
#include <vector>
#include <sstream>
#include "Solver.h"
#include "Factor.h"
#include "FactorGraph.h"
2012-03-22 11:33:24 +00:00
#include "Util.h"
2011-12-12 15:29:51 +00:00
using namespace std;
class SpLink
{
public:
SpLink (FgFacNode* fn, FgVarNode* vn)
{
fac_ = fn;
var_ = vn;
2012-04-05 18:38:56 +01:00
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
2011-12-12 15:29:51 +00:00
currMsg_ = &v1_;
nextMsg_ = &v2_;
msgSended_ = false;
residual_ = 0.0;
}
2012-03-31 23:27:37 +01:00
virtual ~SpLink (void) { };
2011-12-12 15:29:51 +00:00
2012-03-31 23:27:37 +01:00
FgFacNode* getFactor (void) const { return fac_; }
FgVarNode* getVariable (void) const { return var_; }
const Params& getMessage (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; }
void updateResidual (void)
2011-12-12 15:29:51 +00:00
{
2012-03-31 23:27:37 +01:00
residual_ = LogAware::getMaxNorm (v1_,v2_);
2011-12-12 15:29:51 +00:00
}
2012-03-31 23:27:37 +01:00
virtual void updateMessage (void)
2011-12-12 15:29:51 +00:00
{
2012-03-31 23:27:37 +01:00
swap (currMsg_, nextMsg_);
msgSended_ = true;
2011-12-12 15:29:51 +00:00
}
string toString (void) const
{
stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
protected:
2012-03-31 23:27:37 +01:00
FgFacNode* fac_;
FgVarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
bool msgSended_;
double residual_;
2011-12-12 15:29:51 +00:00
};
typedef vector<SpLink*> SpLinkSet;
class SPNodeInfo
{
public:
2012-03-31 23:27:37 +01:00
void addSpLink (SpLink* link) { links_.push_back (link); }
const SpLinkSet& getLinks (void) { return links_; }
2011-12-12 15:29:51 +00:00
private:
2012-03-31 23:27:37 +01:00
SpLinkSet links_;
2011-12-12 15:29:51 +00:00
};
class FgBpSolver : public Solver
{
public:
FgBpSolver (const FactorGraph&);
2012-03-31 23:27:37 +01:00
2011-12-12 15:29:51 +00:00
virtual ~FgBpSolver (void);
2012-03-31 23:27:37 +01:00
void runSolver (void);
virtual Params getPosterioriOf (VarId);
virtual Params getJointDistributionOf (const VarIds&);
2011-12-12 15:29:51 +00:00
protected:
2012-03-31 23:27:37 +01:00
virtual void initializeSolver (void);
virtual void createLinks (void);
virtual void maxResidualSchedule (void);
virtual void calculateFactor2VariableMsg (SpLink*) const;
virtual Params getVar2FactorMsg (const SpLink*) const;
virtual Params getJointByConditioning (const VarIds&) const;
virtual void printLinkInformation (void) const;
SPNodeInfo* ninf (const FgVarNode* var) const
{
return varsI_[var->getIndex()];
}
SPNodeInfo* ninf (const FgFacNode* fac) const
{
return facsI_[fac->getIndex()];
}
2011-12-12 15:29:51 +00:00
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
{
2012-03-31 23:27:37 +01:00
if (Constants::DEBUG >= 3) {
2011-12-12 15:29:51 +00:00
cout << "calculating & updating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void calculateMessage (SpLink* link, bool calcResidual = true)
{
2012-03-31 23:27:37 +01:00
if (Constants::DEBUG >= 3) {
2011-12-12 15:29:51 +00:00
cout << "calculating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
}
void updateMessage (SpLink* link)
{
link->updateMessage();
2012-03-31 23:27:37 +01:00
if (Constants::DEBUG >= 3) {
2011-12-12 15:29:51 +00:00
cout << "updating " << link->toString() << endl;
}
}
2012-03-31 23:27:37 +01:00
struct CompareResidual
2011-12-12 15:29:51 +00:00
{
inline bool operator() (const SpLink* link1, const SpLink* link2)
{
return link1->getResidual() > link2->getResidual();
}
};
SpLinkSet links_;
unsigned nIters_;
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_;
const FactorGraph* factorGraph_;
typedef multiset<SpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_;
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
SpLinkMap linkMap_;
private:
2012-03-31 23:27:37 +01:00
void runLoopySolver (void);
bool converged (void);
2011-12-12 15:29:51 +00:00
};
#endif // HORUS_FGBPSOLVER_H