Move BpLink to inside of BeliefProp
This commit is contained in:
parent
6b0e125e3b
commit
f0572e3cfb
@ -12,55 +12,6 @@
|
|||||||
|
|
||||||
namespace Horus {
|
namespace Horus {
|
||||||
|
|
||||||
BpLink::BpLink (FacNode* fn, VarNode* vn)
|
|
||||||
{
|
|
||||||
fac_ = fn;
|
|
||||||
var_ = vn;
|
|
||||||
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
|
||||||
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
|
||||||
currMsg_ = &v1_;
|
|
||||||
nextMsg_ = &v2_;
|
|
||||||
residual_ = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BpLink::clearResidual (void)
|
|
||||||
{
|
|
||||||
residual_ = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BpLink::updateResidual (void)
|
|
||||||
{
|
|
||||||
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BpLink::updateMessage (void)
|
|
||||||
{
|
|
||||||
swap (currMsg_, nextMsg_);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
std::string
|
|
||||||
BpLink::toString (void) const
|
|
||||||
{
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << fac_->getLabel();
|
|
||||||
ss << " -- " ;
|
|
||||||
ss << var_->label();
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double BeliefProp::accuracy_ = 0.0001;
|
double BeliefProp::accuracy_ = 0.0001;
|
||||||
unsigned BeliefProp::maxIter_ = 1000;
|
unsigned BeliefProp::maxIter_ = 1000;
|
||||||
|
|
||||||
@ -207,6 +158,55 @@ BeliefProp::getFactorJoint (
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
BeliefProp::BpLink::BpLink (FacNode* fn, VarNode* vn)
|
||||||
|
{
|
||||||
|
fac_ = fn;
|
||||||
|
var_ = vn;
|
||||||
|
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||||
|
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||||
|
currMsg_ = &v1_;
|
||||||
|
nextMsg_ = &v2_;
|
||||||
|
residual_ = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BeliefProp::BpLink::clearResidual (void)
|
||||||
|
{
|
||||||
|
residual_ = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BeliefProp::BpLink::updateResidual (void)
|
||||||
|
{
|
||||||
|
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BeliefProp::BpLink::updateMessage (void)
|
||||||
|
{
|
||||||
|
swap (currMsg_, nextMsg_);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::string
|
||||||
|
BeliefProp::BpLink::toString (void) const
|
||||||
|
{
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << fac_->getLabel();
|
||||||
|
ss << " -- " ;
|
||||||
|
ss << var_->label();
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
|
BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
|
||||||
{
|
{
|
||||||
|
@ -11,47 +11,6 @@
|
|||||||
|
|
||||||
namespace Horus {
|
namespace Horus {
|
||||||
|
|
||||||
|
|
||||||
class BpLink {
|
|
||||||
public:
|
|
||||||
BpLink (FacNode* fn, VarNode* vn);
|
|
||||||
|
|
||||||
virtual ~BpLink (void) { };
|
|
||||||
|
|
||||||
FacNode* facNode (void) const { return fac_; }
|
|
||||||
|
|
||||||
VarNode* varNode (void) const { return var_; }
|
|
||||||
|
|
||||||
const Params& message (void) const { return *currMsg_; }
|
|
||||||
|
|
||||||
Params& nextMessage (void) { return *nextMsg_; }
|
|
||||||
|
|
||||||
double residual (void) const { return residual_; }
|
|
||||||
|
|
||||||
void clearResidual (void);
|
|
||||||
|
|
||||||
void updateResidual (void);
|
|
||||||
|
|
||||||
virtual void updateMessage (void);
|
|
||||||
|
|
||||||
std::string toString (void) const;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
FacNode* fac_;
|
|
||||||
VarNode* var_;
|
|
||||||
Params v1_;
|
|
||||||
Params v2_;
|
|
||||||
Params* currMsg_;
|
|
||||||
Params* nextMsg_;
|
|
||||||
double residual_;
|
|
||||||
|
|
||||||
private:
|
|
||||||
DISALLOW_COPY_AND_ASSIGN (BpLink);
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef std::vector<BpLink*> BpLinks;
|
|
||||||
|
|
||||||
|
|
||||||
class BeliefProp : public GroundSolver {
|
class BeliefProp : public GroundSolver {
|
||||||
private:
|
private:
|
||||||
class SPNodeInfo;
|
class SPNodeInfo;
|
||||||
@ -91,13 +50,51 @@ class BeliefProp : public GroundSolver {
|
|||||||
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
|
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
class BpLink {
|
||||||
|
public:
|
||||||
|
BpLink (FacNode* fn, VarNode* vn);
|
||||||
|
|
||||||
|
virtual ~BpLink (void) { };
|
||||||
|
|
||||||
|
FacNode* facNode (void) const { return fac_; }
|
||||||
|
|
||||||
|
VarNode* varNode (void) const { return var_; }
|
||||||
|
|
||||||
|
const Params& message (void) const { return *currMsg_; }
|
||||||
|
|
||||||
|
Params& nextMessage (void) { return *nextMsg_; }
|
||||||
|
|
||||||
|
double residual (void) const { return residual_; }
|
||||||
|
|
||||||
|
void clearResidual (void);
|
||||||
|
|
||||||
|
void updateResidual (void);
|
||||||
|
|
||||||
|
virtual void updateMessage (void);
|
||||||
|
|
||||||
|
std::string toString (void) const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
FacNode* fac_;
|
||||||
|
VarNode* var_;
|
||||||
|
Params v1_;
|
||||||
|
Params v2_;
|
||||||
|
Params* currMsg_;
|
||||||
|
Params* nextMsg_;
|
||||||
|
double residual_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
DISALLOW_COPY_AND_ASSIGN (BpLink);
|
||||||
|
};
|
||||||
|
|
||||||
struct CmpResidual {
|
struct CmpResidual {
|
||||||
bool operator() (const BpLink* l1, const BpLink* l2) {
|
bool operator() (const BpLink* l1, const BpLink* l2) {
|
||||||
return l1->residual() > l2->residual();
|
return l1->residual() > l2->residual();
|
||||||
}};
|
}};
|
||||||
|
|
||||||
typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
|
typedef std::vector<BeliefProp::BpLink*> BpLinks;
|
||||||
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
|
||||||
|
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||||
|
|
||||||
SPNodeInfo* ninf (const VarNode* var) const;
|
SPNodeInfo* ninf (const VarNode* var) const;
|
||||||
|
|
||||||
@ -138,7 +135,7 @@ class BeliefProp : public GroundSolver {
|
|||||||
public:
|
public:
|
||||||
SPNodeInfo (void) { }
|
SPNodeInfo (void) { }
|
||||||
|
|
||||||
void addBpLink (BpLink* link) { links_.push_back (link); }
|
void addBpLink (BeliefProp::BpLink* link) { links_.push_back (link); }
|
||||||
|
|
||||||
const BpLinks& getLinks (void) { return links_; }
|
const BpLinks& getLinks (void) { return links_; }
|
||||||
|
|
||||||
|
@ -8,6 +8,16 @@
|
|||||||
|
|
||||||
namespace Horus {
|
namespace Horus {
|
||||||
|
|
||||||
|
WeightedBp::WeightedBp (
|
||||||
|
const FactorGraph& fg,
|
||||||
|
const std::vector<std::vector<unsigned>>& weights)
|
||||||
|
: BeliefProp (fg), weights_(weights)
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
WeightedBp::~WeightedBp (void)
|
WeightedBp::~WeightedBp (void)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
@ -53,6 +63,19 @@ WeightedBp::getPosterioriOf (VarId vid)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
WeightedBp::WeightedLink::WeightedLink (
|
||||||
|
FacNode* fn,
|
||||||
|
VarNode* vn,
|
||||||
|
size_t idx,
|
||||||
|
unsigned weight)
|
||||||
|
: BpLink (fn, vn), index_(idx), weight_(weight),
|
||||||
|
pwdMsg_(vn->range(), LogAware::one())
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
WeightedBp::createLinks (void)
|
WeightedBp::createLinks (void)
|
||||||
{
|
{
|
||||||
|
@ -6,51 +6,37 @@
|
|||||||
|
|
||||||
namespace Horus {
|
namespace Horus {
|
||||||
|
|
||||||
class WeightedLink : public BpLink {
|
|
||||||
public:
|
|
||||||
WeightedLink (FacNode* fn, VarNode* vn, size_t idx, unsigned weight)
|
|
||||||
: BpLink (fn, vn), index_(idx), weight_(weight),
|
|
||||||
pwdMsg_(vn->range(), LogAware::one()) { }
|
|
||||||
|
|
||||||
size_t index (void) const { return index_; }
|
|
||||||
|
|
||||||
unsigned weight (void) const { return weight_; }
|
|
||||||
|
|
||||||
const Params& powMessage (void) const { return pwdMsg_; }
|
|
||||||
|
|
||||||
void updateMessage (void);
|
|
||||||
|
|
||||||
private:
|
|
||||||
DISALLOW_COPY_AND_ASSIGN (WeightedLink);
|
|
||||||
|
|
||||||
size_t index_;
|
|
||||||
unsigned weight_;
|
|
||||||
Params pwdMsg_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
WeightedLink::updateMessage (void)
|
|
||||||
{
|
|
||||||
pwdMsg_ = *nextMsg_;
|
|
||||||
swap (currMsg_, nextMsg_);
|
|
||||||
LogAware::pow (pwdMsg_, weight_);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WeightedBp : public BeliefProp {
|
class WeightedBp : public BeliefProp {
|
||||||
public:
|
public:
|
||||||
WeightedBp (const FactorGraph& fg,
|
WeightedBp (const FactorGraph& fg,
|
||||||
const std::vector<std::vector<unsigned>>& weights)
|
const std::vector<std::vector<unsigned>>& weights);
|
||||||
: BeliefProp (fg), weights_(weights) { }
|
|
||||||
|
|
||||||
~WeightedBp (void);
|
~WeightedBp (void);
|
||||||
|
|
||||||
Params getPosterioriOf (VarId);
|
Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
class WeightedLink : public BeliefProp::BpLink {
|
||||||
|
public:
|
||||||
|
WeightedLink (FacNode* fn, VarNode* vn, size_t idx,
|
||||||
|
unsigned weight);
|
||||||
|
|
||||||
|
size_t index (void) const { return index_; }
|
||||||
|
|
||||||
|
unsigned weight (void) const { return weight_; }
|
||||||
|
|
||||||
|
const Params& powMessage (void) const { return pwdMsg_; }
|
||||||
|
|
||||||
|
void updateMessage (void);
|
||||||
|
|
||||||
|
private:
|
||||||
|
DISALLOW_COPY_AND_ASSIGN (WeightedLink);
|
||||||
|
|
||||||
|
size_t index_;
|
||||||
|
unsigned weight_;
|
||||||
|
Params pwdMsg_;
|
||||||
|
};
|
||||||
|
|
||||||
void createLinks (void);
|
void createLinks (void);
|
||||||
|
|
||||||
void maxResidualSchedule (void);
|
void maxResidualSchedule (void);
|
||||||
@ -66,6 +52,17 @@ class WeightedBp : public BeliefProp {
|
|||||||
DISALLOW_COPY_AND_ASSIGN (WeightedBp);
|
DISALLOW_COPY_AND_ASSIGN (WeightedBp);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
inline void
|
||||||
|
WeightedBp::WeightedLink::updateMessage (void)
|
||||||
|
{
|
||||||
|
pwdMsg_ = *nextMsg_;
|
||||||
|
swap (currMsg_, nextMsg_);
|
||||||
|
LogAware::pow (pwdMsg_, weight_);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace Horus
|
} // namespace Horus
|
||||||
|
|
||||||
#endif // YAP_PACKAGES_CLPBN_HORUS_WEIGHTEDBP_H_
|
#endif // YAP_PACKAGES_CLPBN_HORUS_WEIGHTEDBP_H_
|
||||||
|
Reference in New Issue
Block a user