Move BpLink to inside of BeliefProp

This commit is contained in:
Tiago Gomes 2013-02-20 23:59:03 +00:00
parent 6b0e125e3b
commit f0572e3cfb
4 changed files with 146 additions and 129 deletions

View File

@ -12,55 +12,6 @@
namespace Horus {
BpLink::BpLink (FacNode* fn, VarNode* vn)
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
residual_ = 0.0;
}
void
BpLink::clearResidual (void)
{
residual_ = 0.0;
}
void
BpLink::updateResidual (void)
{
residual_ = LogAware::getMaxNorm (v1_, v2_);
}
void
BpLink::updateMessage (void)
{
swap (currMsg_, nextMsg_);
}
std::string
BpLink::toString (void) const
{
std::stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
double BeliefProp::accuracy_ = 0.0001;
unsigned BeliefProp::maxIter_ = 1000;
@ -207,6 +158,55 @@ BeliefProp::getFactorJoint (
BeliefProp::BpLink::BpLink (FacNode* fn, VarNode* vn)
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
residual_ = 0.0;
}
void
BeliefProp::BpLink::clearResidual (void)
{
residual_ = 0.0;
}
void
BeliefProp::BpLink::updateResidual (void)
{
residual_ = LogAware::getMaxNorm (v1_, v2_);
}
void
BeliefProp::BpLink::updateMessage (void)
{
swap (currMsg_, nextMsg_);
}
std::string
BeliefProp::BpLink::toString (void) const
{
std::stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
void
BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
{

View File

@ -11,47 +11,6 @@
namespace Horus {
class BpLink {
public:
BpLink (FacNode* fn, VarNode* vn);
virtual ~BpLink (void) { };
FacNode* facNode (void) const { return fac_; }
VarNode* varNode (void) const { return var_; }
const Params& message (void) const { return *currMsg_; }
Params& nextMessage (void) { return *nextMsg_; }
double residual (void) const { return residual_; }
void clearResidual (void);
void updateResidual (void);
virtual void updateMessage (void);
std::string toString (void) const;
protected:
FacNode* fac_;
VarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
double residual_;
private:
DISALLOW_COPY_AND_ASSIGN (BpLink);
};
typedef std::vector<BpLink*> BpLinks;
class BeliefProp : public GroundSolver {
private:
class SPNodeInfo;
@ -91,13 +50,51 @@ class BeliefProp : public GroundSolver {
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
protected:
class BpLink {
public:
BpLink (FacNode* fn, VarNode* vn);
virtual ~BpLink (void) { };
FacNode* facNode (void) const { return fac_; }
VarNode* varNode (void) const { return var_; }
const Params& message (void) const { return *currMsg_; }
Params& nextMessage (void) { return *nextMsg_; }
double residual (void) const { return residual_; }
void clearResidual (void);
void updateResidual (void);
virtual void updateMessage (void);
std::string toString (void) const;
protected:
FacNode* fac_;
VarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
double residual_;
private:
DISALLOW_COPY_AND_ASSIGN (BpLink);
};
struct CmpResidual {
bool operator() (const BpLink* l1, const BpLink* l2) {
return l1->residual() > l2->residual();
}};
typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
typedef std::vector<BeliefProp::BpLink*> BpLinks;
typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
SPNodeInfo* ninf (const VarNode* var) const;
@ -138,7 +135,7 @@ class BeliefProp : public GroundSolver {
public:
SPNodeInfo (void) { }
void addBpLink (BpLink* link) { links_.push_back (link); }
void addBpLink (BeliefProp::BpLink* link) { links_.push_back (link); }
const BpLinks& getLinks (void) { return links_; }

View File

@ -8,6 +8,16 @@
namespace Horus {
WeightedBp::WeightedBp (
const FactorGraph& fg,
const std::vector<std::vector<unsigned>>& weights)
: BeliefProp (fg), weights_(weights)
{
}
WeightedBp::~WeightedBp (void)
{
for (size_t i = 0; i < links_.size(); i++) {
@ -53,6 +63,19 @@ WeightedBp::getPosterioriOf (VarId vid)
WeightedBp::WeightedLink::WeightedLink (
FacNode* fn,
VarNode* vn,
size_t idx,
unsigned weight)
: BpLink (fn, vn), index_(idx), weight_(weight),
pwdMsg_(vn->range(), LogAware::one())
{
}
void
WeightedBp::createLinks (void)
{

View File

@ -6,51 +6,37 @@
namespace Horus {
class WeightedLink : public BpLink {
public:
WeightedLink (FacNode* fn, VarNode* vn, size_t idx, unsigned weight)
: BpLink (fn, vn), index_(idx), weight_(weight),
pwdMsg_(vn->range(), LogAware::one()) { }
size_t index (void) const { return index_; }
unsigned weight (void) const { return weight_; }
const Params& powMessage (void) const { return pwdMsg_; }
void updateMessage (void);
private:
DISALLOW_COPY_AND_ASSIGN (WeightedLink);
size_t index_;
unsigned weight_;
Params pwdMsg_;
};
inline void
WeightedLink::updateMessage (void)
{
pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_);
LogAware::pow (pwdMsg_, weight_);
}
class WeightedBp : public BeliefProp {
public:
WeightedBp (const FactorGraph& fg,
const std::vector<std::vector<unsigned>>& weights)
: BeliefProp (fg), weights_(weights) { }
const std::vector<std::vector<unsigned>>& weights);
~WeightedBp (void);
Params getPosterioriOf (VarId);
private:
class WeightedLink : public BeliefProp::BpLink {
public:
WeightedLink (FacNode* fn, VarNode* vn, size_t idx,
unsigned weight);
size_t index (void) const { return index_; }
unsigned weight (void) const { return weight_; }
const Params& powMessage (void) const { return pwdMsg_; }
void updateMessage (void);
private:
DISALLOW_COPY_AND_ASSIGN (WeightedLink);
size_t index_;
unsigned weight_;
Params pwdMsg_;
};
void createLinks (void);
void maxResidualSchedule (void);
@ -66,6 +52,17 @@ class WeightedBp : public BeliefProp {
DISALLOW_COPY_AND_ASSIGN (WeightedBp);
};
inline void
WeightedBp::WeightedLink::updateMessage (void)
{
pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_);
LogAware::pow (pwdMsg_, weight_);
}
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_WEIGHTEDBP_H_