131 lines
3.0 KiB
C++
131 lines
3.0 KiB
C++
#ifndef BP_SP_SOLVER_H
|
|
#define BP_SP_SOLVER_H
|
|
|
|
#include <vector>
|
|
#include <set>
|
|
|
|
#include "Solver.h"
|
|
#include "FgVarNode.h"
|
|
#include "Factor.h"
|
|
|
|
using namespace std;
|
|
|
|
class FactorGraph;
|
|
class SPSolver;
|
|
|
|
|
|
class Link
|
|
{
|
|
public:
|
|
Link (Factor* f, FgVarNode* v)
|
|
{
|
|
factor_ = f;
|
|
var_ = v;
|
|
currMsg_.resize (v->getDomainSize(), 1);
|
|
nextMsg_.resize (v->getDomainSize(), 1);
|
|
msgSended_ = false;
|
|
residual_ = 0.0;
|
|
}
|
|
|
|
void setMessage (ParamSet msg)
|
|
{
|
|
Util::normalize (msg);
|
|
residual_ = Util::getMaxNorm (currMsg_, msg);
|
|
currMsg_ = msg;
|
|
}
|
|
|
|
void setNextMessage (CParamSet msg)
|
|
{
|
|
nextMsg_ = msg;
|
|
Util::normalize (nextMsg_);
|
|
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
|
|
}
|
|
|
|
void updateMessage (void)
|
|
{
|
|
currMsg_ = nextMsg_;
|
|
msgSended_ = true;
|
|
}
|
|
|
|
string toString (void) const
|
|
{
|
|
stringstream ss;
|
|
ss << factor_->getLabel();
|
|
ss << " -- " ;
|
|
ss << var_->getLabel();
|
|
return ss.str();
|
|
}
|
|
|
|
Factor* getFactor (void) const { return factor_; }
|
|
FgVarNode* getVariable (void) const { return var_; }
|
|
CParamSet getMessage (void) const { return currMsg_; }
|
|
bool messageWasSended (void) const { return msgSended_; }
|
|
double getResidual (void) const { return residual_; }
|
|
void clearResidual (void) { residual_ = 0.0; }
|
|
|
|
private:
|
|
Factor* factor_;
|
|
FgVarNode* var_;
|
|
ParamSet currMsg_;
|
|
ParamSet nextMsg_;
|
|
bool msgSended_;
|
|
double residual_;
|
|
};
|
|
|
|
|
|
class SPNodeInfo
|
|
{
|
|
public:
|
|
void addLink (Link* link) { links_.push_back (link); }
|
|
CLinkSet getLinks (void) { return links_; }
|
|
|
|
private:
|
|
LinkSet links_;
|
|
};
|
|
|
|
|
|
class SPSolver : public Solver
|
|
{
|
|
public:
|
|
SPSolver (FactorGraph&);
|
|
virtual ~SPSolver (void);
|
|
|
|
void runSolver (void);
|
|
virtual ParamSet getPosterioriOf (Vid) const;
|
|
ParamSet getJointDistributionOf (CVidSet);
|
|
|
|
protected:
|
|
virtual void initializeSolver (void);
|
|
void runTreeSolver (void);
|
|
bool readyToSendMessage (const Link*) const;
|
|
virtual void createLinks (void);
|
|
virtual void deleteJunction (Factor*, FgVarNode*);
|
|
bool converged (void);
|
|
virtual void maxResidualSchedule (void);
|
|
virtual ParamSet getFactor2VarMsg (const Link*) const;
|
|
virtual ParamSet getVar2FactorMsg (const Link*) const;
|
|
|
|
struct CompareResidual {
|
|
inline bool operator() (const Link* link1, const Link* link2)
|
|
{
|
|
return link1->getResidual() > link2->getResidual();
|
|
}
|
|
};
|
|
|
|
FactorGraph* fg_;
|
|
LinkSet links_;
|
|
vector<SPNodeInfo*> varsI_;
|
|
vector<SPNodeInfo*> factorsI_;
|
|
unsigned nIter_;
|
|
|
|
typedef multiset<Link*, CompareResidual> SortedOrder;
|
|
SortedOrder sortedOrder_;
|
|
|
|
typedef map<Link*, SortedOrder::iterator> LinkMap;
|
|
LinkMap linkMap_;
|
|
|
|
};
|
|
|
|
#endif // BP_SP_SOLVER_H
|
|
|