This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/bp/SPSolver.h
2011-07-22 21:33:30 +01:00

131 lines
3.0 KiB
C++
Executable File

#ifndef BP_SP_SOLVER_H
#define BP_SP_SOLVER_H
#include <vector>
#include <set>
#include "Solver.h"
#include "FgVarNode.h"
#include "Factor.h"
using namespace std;
class FactorGraph;
class SPSolver;
class Link
{
public:
Link (Factor* f, FgVarNode* v)
{
factor_ = f;
var_ = v;
currMsg_.resize (v->getDomainSize(), 1);
nextMsg_.resize (v->getDomainSize(), 1);
msgSended_ = false;
residual_ = 0.0;
}
void setMessage (ParamSet msg)
{
Util::normalize (msg);
residual_ = Util::getMaxNorm (currMsg_, msg);
currMsg_ = msg;
}
void setNextMessage (CParamSet msg)
{
nextMsg_ = msg;
Util::normalize (nextMsg_);
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
}
void updateMessage (void)
{
currMsg_ = nextMsg_;
msgSended_ = true;
}
string toString (void) const
{
stringstream ss;
ss << factor_->getLabel();
ss << " -- " ;
ss << var_->getLabel();
return ss.str();
}
Factor* getFactor (void) const { return factor_; }
FgVarNode* getVariable (void) const { return var_; }
CParamSet getMessage (void) const { return currMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; }
private:
Factor* factor_;
FgVarNode* var_;
ParamSet currMsg_;
ParamSet nextMsg_;
bool msgSended_;
double residual_;
};
class SPNodeInfo
{
public:
void addLink (Link* link) { links_.push_back (link); }
CLinkSet getLinks (void) { return links_; }
private:
LinkSet links_;
};
class SPSolver : public Solver
{
public:
SPSolver (FactorGraph&);
virtual ~SPSolver (void);
void runSolver (void);
virtual ParamSet getPosterioriOf (Vid) const;
ParamSet getJointDistributionOf (CVidSet);
protected:
virtual void initializeSolver (void);
void runTreeSolver (void);
bool readyToSendMessage (const Link*) const;
virtual void createLinks (void);
virtual void deleteJunction (Factor*, FgVarNode*);
bool converged (void);
virtual void maxResidualSchedule (void);
virtual ParamSet getFactor2VarMsg (const Link*) const;
virtual ParamSet getVar2FactorMsg (const Link*) const;
struct CompareResidual {
inline bool operator() (const Link* link1, const Link* link2)
{
return link1->getResidual() > link2->getResidual();
}
};
FactorGraph* fg_;
LinkSet links_;
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> factorsI_;
unsigned nIter_;
typedef multiset<Link*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_;
typedef map<Link*, SortedOrder::iterator> LinkMap;
LinkMap linkMap_;
};
#endif // BP_SP_SOLVER_H