Move BpLink to inside of BeliefProp
This commit is contained in:
parent
6b0e125e3b
commit
f0572e3cfb
@ -12,55 +12,6 @@
|
||||
|
||||
namespace Horus {
|
||||
|
||||
BpLink::BpLink (FacNode* fn, VarNode* vn)
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpLink::clearResidual (void)
|
||||
{
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpLink::updateResidual (void)
|
||||
{
|
||||
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpLink::updateMessage (void)
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::string
|
||||
BpLink::toString (void) const
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << fac_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->label();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
double BeliefProp::accuracy_ = 0.0001;
|
||||
unsigned BeliefProp::maxIter_ = 1000;
|
||||
|
||||
@ -207,6 +158,55 @@ BeliefProp::getFactorJoint (
|
||||
|
||||
|
||||
|
||||
BeliefProp::BpLink::BpLink (FacNode* fn, VarNode* vn)
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::BpLink::clearResidual (void)
|
||||
{
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::BpLink::updateResidual (void)
|
||||
{
|
||||
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::BpLink::updateMessage (void)
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::string
|
||||
BeliefProp::BpLink::toString (void) const
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << fac_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->label();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
|
||||
{
|
||||
|
@ -11,47 +11,6 @@
|
||||
|
||||
namespace Horus {
|
||||
|
||||
|
||||
class BpLink {
|
||||
public:
|
||||
BpLink (FacNode* fn, VarNode* vn);
|
||||
|
||||
virtual ~BpLink (void) { };
|
||||
|
||||
FacNode* facNode (void) const { return fac_; }
|
||||
|
||||
VarNode* varNode (void) const { return var_; }
|
||||
|
||||
const Params& message (void) const { return *currMsg_; }
|
||||
|
||||
Params& nextMessage (void) { return *nextMsg_; }
|
||||
|
||||
double residual (void) const { return residual_; }
|
||||
|
||||
void clearResidual (void);
|
||||
|
||||
void updateResidual (void);
|
||||
|
||||
virtual void updateMessage (void);
|
||||
|
||||
std::string toString (void) const;
|
||||
|
||||
protected:
|
||||
FacNode* fac_;
|
||||
VarNode* var_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
double residual_;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BpLink);
|
||||
};
|
||||
|
||||
typedef std::vector<BpLink*> BpLinks;
|
||||
|
||||
|
||||
class BeliefProp : public GroundSolver {
|
||||
private:
|
||||
class SPNodeInfo;
|
||||
@ -91,13 +50,51 @@ class BeliefProp : public GroundSolver {
|
||||
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
|
||||
|
||||
protected:
|
||||
class BpLink {
|
||||
public:
|
||||
BpLink (FacNode* fn, VarNode* vn);
|
||||
|
||||
virtual ~BpLink (void) { };
|
||||
|
||||
FacNode* facNode (void) const { return fac_; }
|
||||
|
||||
VarNode* varNode (void) const { return var_; }
|
||||
|
||||
const Params& message (void) const { return *currMsg_; }
|
||||
|
||||
Params& nextMessage (void) { return *nextMsg_; }
|
||||
|
||||
double residual (void) const { return residual_; }
|
||||
|
||||
void clearResidual (void);
|
||||
|
||||
void updateResidual (void);
|
||||
|
||||
virtual void updateMessage (void);
|
||||
|
||||
std::string toString (void) const;
|
||||
|
||||
protected:
|
||||
FacNode* fac_;
|
||||
VarNode* var_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
double residual_;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BpLink);
|
||||
};
|
||||
|
||||
struct CmpResidual {
|
||||
bool operator() (const BpLink* l1, const BpLink* l2) {
|
||||
return l1->residual() > l2->residual();
|
||||
}};
|
||||
|
||||
typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
|
||||
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||
typedef std::vector<BeliefProp::BpLink*> BpLinks;
|
||||
typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
|
||||
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||
|
||||
SPNodeInfo* ninf (const VarNode* var) const;
|
||||
|
||||
@ -138,7 +135,7 @@ class BeliefProp : public GroundSolver {
|
||||
public:
|
||||
SPNodeInfo (void) { }
|
||||
|
||||
void addBpLink (BpLink* link) { links_.push_back (link); }
|
||||
void addBpLink (BeliefProp::BpLink* link) { links_.push_back (link); }
|
||||
|
||||
const BpLinks& getLinks (void) { return links_; }
|
||||
|
||||
|
@ -8,6 +8,16 @@
|
||||
|
||||
namespace Horus {
|
||||
|
||||
WeightedBp::WeightedBp (
|
||||
const FactorGraph& fg,
|
||||
const std::vector<std::vector<unsigned>>& weights)
|
||||
: BeliefProp (fg), weights_(weights)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
WeightedBp::~WeightedBp (void)
|
||||
{
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
@ -53,6 +63,19 @@ WeightedBp::getPosterioriOf (VarId vid)
|
||||
|
||||
|
||||
|
||||
WeightedBp::WeightedLink::WeightedLink (
|
||||
FacNode* fn,
|
||||
VarNode* vn,
|
||||
size_t idx,
|
||||
unsigned weight)
|
||||
: BpLink (fn, vn), index_(idx), weight_(weight),
|
||||
pwdMsg_(vn->range(), LogAware::one())
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
WeightedBp::createLinks (void)
|
||||
{
|
||||
|
@ -6,51 +6,37 @@
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class WeightedLink : public BpLink {
|
||||
public:
|
||||
WeightedLink (FacNode* fn, VarNode* vn, size_t idx, unsigned weight)
|
||||
: BpLink (fn, vn), index_(idx), weight_(weight),
|
||||
pwdMsg_(vn->range(), LogAware::one()) { }
|
||||
|
||||
size_t index (void) const { return index_; }
|
||||
|
||||
unsigned weight (void) const { return weight_; }
|
||||
|
||||
const Params& powMessage (void) const { return pwdMsg_; }
|
||||
|
||||
void updateMessage (void);
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (WeightedLink);
|
||||
|
||||
size_t index_;
|
||||
unsigned weight_;
|
||||
Params pwdMsg_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
WeightedLink::updateMessage (void)
|
||||
{
|
||||
pwdMsg_ = *nextMsg_;
|
||||
swap (currMsg_, nextMsg_);
|
||||
LogAware::pow (pwdMsg_, weight_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
class WeightedBp : public BeliefProp {
|
||||
public:
|
||||
WeightedBp (const FactorGraph& fg,
|
||||
const std::vector<std::vector<unsigned>>& weights)
|
||||
: BeliefProp (fg), weights_(weights) { }
|
||||
const std::vector<std::vector<unsigned>>& weights);
|
||||
|
||||
~WeightedBp (void);
|
||||
|
||||
Params getPosterioriOf (VarId);
|
||||
|
||||
private:
|
||||
class WeightedLink : public BeliefProp::BpLink {
|
||||
public:
|
||||
WeightedLink (FacNode* fn, VarNode* vn, size_t idx,
|
||||
unsigned weight);
|
||||
|
||||
size_t index (void) const { return index_; }
|
||||
|
||||
unsigned weight (void) const { return weight_; }
|
||||
|
||||
const Params& powMessage (void) const { return pwdMsg_; }
|
||||
|
||||
void updateMessage (void);
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (WeightedLink);
|
||||
|
||||
size_t index_;
|
||||
unsigned weight_;
|
||||
Params pwdMsg_;
|
||||
};
|
||||
|
||||
void createLinks (void);
|
||||
|
||||
void maxResidualSchedule (void);
|
||||
@ -66,6 +52,17 @@ class WeightedBp : public BeliefProp {
|
||||
DISALLOW_COPY_AND_ASSIGN (WeightedBp);
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
WeightedBp::WeightedLink::updateMessage (void)
|
||||
{
|
||||
pwdMsg_ = *nextMsg_;
|
||||
swap (currMsg_, nextMsg_);
|
||||
LogAware::pow (pwdMsg_, weight_);
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_WEIGHTEDBP_H_
|
||||
|
Reference in New Issue
Block a user