#ifndef BP_SP_SOLVER_H #define BP_SP_SOLVER_H #include #include #include "Solver.h" #include "FgVarNode.h" #include "Factor.h" using namespace std; class FactorGraph; class SPSolver; class Link { public: Link (Factor* f, FgVarNode* v) { factor_ = f; var_ = v; currMsg_.resize (v->getDomainSize(), 1); nextMsg_.resize (v->getDomainSize(), 1); msgSended_ = false; residual_ = 0.0; } void setMessage (ParamSet msg) { Util::normalize (msg); residual_ = Util::getMaxNorm (currMsg_, msg); currMsg_ = msg; } void setNextMessage (CParamSet msg) { nextMsg_ = msg; Util::normalize (nextMsg_); residual_ = Util::getMaxNorm (currMsg_, nextMsg_); } void updateMessage (void) { currMsg_ = nextMsg_; msgSended_ = true; } string toString (void) const { stringstream ss; ss << factor_->getLabel(); ss << " -- " ; ss << var_->getLabel(); return ss.str(); } Factor* getFactor (void) const { return factor_; } FgVarNode* getVariable (void) const { return var_; } CParamSet getMessage (void) const { return currMsg_; } bool messageWasSended (void) const { return msgSended_; } double getResidual (void) const { return residual_; } void clearResidual (void) { residual_ = 0.0; } private: Factor* factor_; FgVarNode* var_; ParamSet currMsg_; ParamSet nextMsg_; bool msgSended_; double residual_; }; class SPNodeInfo { public: void addLink (Link* link) { links_.push_back (link); } CLinkSet getLinks (void) { return links_; } private: LinkSet links_; }; class SPSolver : public Solver { public: SPSolver (FactorGraph&); virtual ~SPSolver (void); void runSolver (void); virtual ParamSet getPosterioriOf (Vid) const; ParamSet getJointDistributionOf (CVidSet); protected: virtual void initializeSolver (void); void runTreeSolver (void); bool readyToSendMessage (const Link*) const; virtual void createLinks (void); virtual void deleteJunction (Factor*, FgVarNode*); bool converged (void); virtual void maxResidualSchedule (void); virtual ParamSet getFactor2VarMsg (const Link*) const; virtual ParamSet getVar2FactorMsg (const Link*) const; struct CompareResidual { inline bool operator() (const Link* link1, const Link* link2) { return link1->getResidual() > link2->getResidual(); } }; FactorGraph* fg_; LinkSet links_; vector varsI_; vector factorsI_; unsigned nIter_; typedef multiset SortedOrder; SortedOrder sortedOrder_; typedef map LinkMap; LinkMap linkMap_; }; #endif // BP_SP_SOLVER_H