#ifndef HORUS_FGBPSOLVER_H #define HORUS_FGBPSOLVER_H #include #include #include #include "Solver.h" #include "Factor.h" #include "FactorGraph.h" using namespace std; class SpLink { public: SpLink (FgFacNode* fn, FgVarNode* vn) { fac_ = fn; var_ = vn; v1_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates())); v2_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates())); currMsg_ = &v1_; nextMsg_ = &v2_; msgSended_ = false; residual_ = 0.0; } virtual ~SpLink (void) {}; virtual void updateMessage (void) { swap (currMsg_, nextMsg_); msgSended_ = true; } void updateResidual (void) { residual_ = Util::getMaxNorm (v1_, v2_); } string toString (void) const { stringstream ss; ss << fac_->getLabel(); ss << " -- " ; ss << var_->label(); return ss.str(); } FgFacNode* getFactor (void) const { return fac_; } FgVarNode* getVariable (void) const { return var_; } const ParamSet& getMessage (void) const { return *currMsg_; } ParamSet& getNextMessage (void) { return *nextMsg_; } bool messageWasSended (void) const { return msgSended_; } double getResidual (void) const { return residual_; } void clearResidual (void) { residual_ = 0.0; } protected: FgFacNode* fac_; FgVarNode* var_; ParamSet v1_; ParamSet v2_; ParamSet* currMsg_; ParamSet* nextMsg_; bool msgSended_; double residual_; }; typedef vector SpLinkSet; class SPNodeInfo { public: void addSpLink (SpLink* link) { links_.push_back (link); } const SpLinkSet& getLinks (void) { return links_; } private: SpLinkSet links_; }; class FgBpSolver : public Solver { public: FgBpSolver (const FactorGraph&); virtual ~FgBpSolver (void); void runSolver (void); virtual ParamSet getPosterioriOf (VarId); virtual ParamSet getJointDistributionOf (const VarIdSet&); protected: virtual void initializeSolver (void); virtual void createLinks (void); virtual void maxResidualSchedule (void); virtual void calculateFactor2VariableMsg (SpLink*) const; virtual ParamSet getVar2FactorMsg (const SpLink*) const; virtual void printLinkInformation (void) const; void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true) { if (DL >= 3) { cout << "calculating & updating " << link->toString() << endl; } calculateFactor2VariableMsg (link); if (calcResidual) { link->updateResidual(); } link->updateMessage(); } void calculateMessage (SpLink* link, bool calcResidual = true) { if (DL >= 3) { cout << "calculating " << link->toString() << endl; } calculateFactor2VariableMsg (link); if (calcResidual) { link->updateResidual(); } } void updateMessage (SpLink* link) { link->updateMessage(); if (DL >= 3) { cout << "updating " << link->toString() << endl; } } SPNodeInfo* ninf (const FgVarNode* var) const { return varsI_[var->getIndex()]; } SPNodeInfo* ninf (const FgFacNode* fac) const { return facsI_[fac->getIndex()]; } struct CompareResidual { inline bool operator() (const SpLink* link1, const SpLink* link2) { return link1->getResidual() > link2->getResidual(); } }; SpLinkSet links_; unsigned nIters_; vector varsI_; vector facsI_; const FactorGraph* factorGraph_; typedef multiset SortedOrder; SortedOrder sortedOrder_; typedef unordered_map SpLinkMap; SpLinkMap linkMap_; private: void runTreeSolver (void); bool readyToSendMessage (const SpLink*) const; void runLoopySolver (void); bool converged (void); }; #endif // HORUS_FGBPSOLVER_H