Kill SPNodeInfo class
This commit is contained in:
parent
d563fce952
commit
95539226ea
@ -29,11 +29,17 @@ BeliefProp::BeliefProp (const FactorGraph& fg)
|
|||||||
|
|
||||||
BeliefProp::~BeliefProp()
|
BeliefProp::~BeliefProp()
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < varsI_.size(); i++) {
|
for (size_t i = 0; i < varsLinks_.size(); i++) {
|
||||||
delete varsI_[i];
|
BpLinks& links = varsLinks_[i];
|
||||||
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
|
delete links[j];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < facsI_.size(); i++) {
|
for (size_t i = 0; i < facsLinks_.size(); i++) {
|
||||||
delete facsI_[i];
|
BpLinks& links = facsLinks_[i];
|
||||||
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
|
delete links[j];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
delete links_[i];
|
delete links_[i];
|
||||||
@ -88,7 +94,7 @@ BeliefProp::getPosterioriOf (VarId vid)
|
|||||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||||
} else {
|
} else {
|
||||||
probs.resize (var->range(), LogAware::multIdenty());
|
probs.resize (var->range(), LogAware::multIdenty());
|
||||||
const BpLinks& links = ninf(var)->getLinks();
|
const BpLinks& links = getLinks (var);
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
probs += links[i]->message();
|
probs += links[i]->message();
|
||||||
@ -139,7 +145,7 @@ BeliefProp::getFactorJoint (
|
|||||||
runSolver();
|
runSolver();
|
||||||
}
|
}
|
||||||
Factor res (fn->factor());
|
Factor res (fn->factor());
|
||||||
const BpLinks& links = ninf(fn)->getLinks();
|
const BpLinks& links = getLinks( fn);
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
Factor msg ({links[i]->varNode()->varId()},
|
Factor msg ({links[i]->varNode()->varId()},
|
||||||
{links[i]->varNode()->range()},
|
{links[i]->varNode()->range()},
|
||||||
@ -350,7 +356,7 @@ BeliefProp::maxResidualSchedule()
|
|||||||
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
||||||
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
||||||
if (factorNeighbors[i] != link->facNode()) {
|
if (factorNeighbors[i] != link->facNode()) {
|
||||||
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
|
const BpLinks& links = getLinks (factorNeighbors[i]);
|
||||||
for (size_t j = 0; j < links.size(); j++) {
|
for (size_t j = 0; j < links.size(); j++) {
|
||||||
if (links[j]->varNode() != link->varNode()) {
|
if (links[j]->varNode() != link->varNode()) {
|
||||||
calculateMessage (links[j]);
|
calculateMessage (links[j]);
|
||||||
@ -374,7 +380,7 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
|||||||
{
|
{
|
||||||
FacNode* src = link->facNode();
|
FacNode* src = link->facNode();
|
||||||
const VarNode* dst = link->varNode();
|
const VarNode* dst = link->varNode();
|
||||||
const BpLinks& links = ninf(src)->getLinks();
|
const BpLinks& links = getLinks (src);
|
||||||
// calculate the product of messages that were sent
|
// calculate the product of messages that were sent
|
||||||
// to factor `src', except from var `dst'
|
// to factor `src', except from var `dst'
|
||||||
unsigned reps = 1;
|
unsigned reps = 1;
|
||||||
@ -435,7 +441,7 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
BeliefProp::getVarToFactorMsg (const BpLink* link)
|
||||||
{
|
{
|
||||||
const VarNode* src = link->varNode();
|
const VarNode* src = link->varNode();
|
||||||
Params msg;
|
Params msg;
|
||||||
@ -449,7 +455,7 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
|||||||
std::cout << msg;
|
std::cout << msg;
|
||||||
}
|
}
|
||||||
BpLinks::const_iterator it;
|
BpLinks::const_iterator it;
|
||||||
const BpLinks& links = ninf (src)->getLinks();
|
const BpLinks& links = getLinks (src);
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (it = links.begin(); it != links.end(); ++it) {
|
for (it = links.begin(); it != links.end(); ++it) {
|
||||||
if (*it != link) {
|
if (*it != link) {
|
||||||
@ -490,21 +496,21 @@ void
|
|||||||
BeliefProp::initializeSolver()
|
BeliefProp::initializeSolver()
|
||||||
{
|
{
|
||||||
const VarNodes& varNodes = fg.varNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
varsI_.reserve (varNodes.size());
|
varsLinks_.reserve (varNodes.size());
|
||||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
varsI_.push_back (new SPNodeInfo());
|
varsLinks_.push_back (BpLinks());
|
||||||
}
|
}
|
||||||
const FacNodes& facNodes = fg.facNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
facsI_.reserve (facNodes.size());
|
facsLinks_.reserve (facNodes.size());
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
facsI_.push_back (new SPNodeInfo());
|
facsLinks_.push_back (BpLinks());
|
||||||
}
|
}
|
||||||
createLinks();
|
createLinks();
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
FacNode* src = links_[i]->facNode();
|
FacNode* src = links_[i]->facNode();
|
||||||
VarNode* dst = links_[i]->varNode();
|
VarNode* dst = links_[i]->varNode();
|
||||||
ninf (dst)->addBpLink (links_[i]);
|
getLinks (dst).push_back (links_[i]);
|
||||||
ninf (src)->addBpLink (links_[i]);
|
getLinks (src).push_back (links_[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,9 +96,9 @@ class BeliefProp : public GroundSolver {
|
|||||||
typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
|
typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
|
||||||
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||||
|
|
||||||
SPNodeInfo* ninf (const VarNode* var) const;
|
BpLinks& getLinks (const VarNode* var);
|
||||||
|
|
||||||
SPNodeInfo* ninf (const FacNode* fac) const;
|
BpLinks& getLinks (const FacNode* fac);
|
||||||
|
|
||||||
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true);
|
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true);
|
||||||
|
|
||||||
@ -114,14 +114,14 @@ class BeliefProp : public GroundSolver {
|
|||||||
|
|
||||||
virtual void calcFactorToVarMsg (BpLink*);
|
virtual void calcFactorToVarMsg (BpLink*);
|
||||||
|
|
||||||
virtual Params getVarToFactorMsg (const BpLink*) const;
|
virtual Params getVarToFactorMsg (const BpLink*);
|
||||||
|
|
||||||
virtual Params getJointByConditioning (const VarIds&) const;
|
virtual Params getJointByConditioning (const VarIds&) const;
|
||||||
|
|
||||||
BpLinks links_;
|
BpLinks links_;
|
||||||
unsigned nIters_;
|
unsigned nIters_;
|
||||||
std::vector<SPNodeInfo*> varsI_;
|
std::vector<BpLinks> varsLinks_;
|
||||||
std::vector<SPNodeInfo*> facsI_;
|
std::vector<BpLinks> facsLinks_;
|
||||||
bool runned_;
|
bool runned_;
|
||||||
SortedOrder sortedOrder_;
|
SortedOrder sortedOrder_;
|
||||||
BpLinkMap linkMap_;
|
BpLinkMap linkMap_;
|
||||||
@ -131,20 +131,6 @@ class BeliefProp : public GroundSolver {
|
|||||||
static MsgSchedule schedule_;
|
static MsgSchedule schedule_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class SPNodeInfo {
|
|
||||||
public:
|
|
||||||
SPNodeInfo() { }
|
|
||||||
|
|
||||||
void addBpLink (BeliefProp::BpLink* link) { links_.push_back (link); }
|
|
||||||
|
|
||||||
const BpLinks& getLinks() { return links_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
BpLinks links_;
|
|
||||||
|
|
||||||
DISALLOW_COPY_AND_ASSIGN (SPNodeInfo);
|
|
||||||
};
|
|
||||||
|
|
||||||
void initializeSolver();
|
void initializeSolver();
|
||||||
|
|
||||||
bool converged();
|
bool converged();
|
||||||
@ -156,18 +142,18 @@ class BeliefProp : public GroundSolver {
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline BeliefProp::SPNodeInfo*
|
inline BeliefProp::BpLinks&
|
||||||
BeliefProp::ninf (const VarNode* var) const
|
BeliefProp::getLinks (const VarNode* var)
|
||||||
{
|
{
|
||||||
return varsI_[var->getIndex()];
|
return varsLinks_[var->getIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline BeliefProp::SPNodeInfo*
|
inline BeliefProp::BpLinks&
|
||||||
BeliefProp::ninf (const FacNode* fac) const
|
BeliefProp::getLinks (const FacNode* fac)
|
||||||
{
|
{
|
||||||
return facsI_[fac->getIndex()];
|
return facsLinks_[fac->getIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace Horus
|
} // namespace Horus
|
||||||
|
@ -42,7 +42,7 @@ WeightedBp::getPosterioriOf (VarId vid)
|
|||||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||||
} else {
|
} else {
|
||||||
probs.resize (var->range(), LogAware::multIdenty());
|
probs.resize (var->range(), LogAware::multIdenty());
|
||||||
const BpLinks& links = ninf(var)->getLinks();
|
const BpLinks& links = getLinks (var);
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
@ -152,7 +152,7 @@ WeightedBp::maxResidualSchedule()
|
|||||||
// update the messages that depend on message source --> destin
|
// update the messages that depend on message source --> destin
|
||||||
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
||||||
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
||||||
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
|
const BpLinks& links = getLinks (factorNeighbors[i]);
|
||||||
for (size_t j = 0; j < links.size(); j++) {
|
for (size_t j = 0; j < links.size(); j++) {
|
||||||
if (links[j]->varNode() != link->varNode()) {
|
if (links[j]->varNode() != link->varNode()) {
|
||||||
if (Globals::verbosity > 1) {
|
if (Globals::verbosity > 1) {
|
||||||
@ -168,7 +168,7 @@ WeightedBp::maxResidualSchedule()
|
|||||||
}
|
}
|
||||||
// in counting bp, the message that a variable X sends to
|
// in counting bp, the message that a variable X sends to
|
||||||
// to a factor F depends on the message that F sent to the X
|
// to a factor F depends on the message that F sent to the X
|
||||||
const BpLinks& links = ninf(link->facNode())->getLinks();
|
const BpLinks& links = getLinks (link->facNode());
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->varNode() != link->varNode()) {
|
if (links[i]->varNode() != link->varNode()) {
|
||||||
if (Globals::verbosity > 1) {
|
if (Globals::verbosity > 1) {
|
||||||
@ -192,7 +192,7 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
|||||||
WeightedLink* link = static_cast<WeightedLink*> (_link);
|
WeightedLink* link = static_cast<WeightedLink*> (_link);
|
||||||
FacNode* src = link->facNode();
|
FacNode* src = link->facNode();
|
||||||
const VarNode* dst = link->varNode();
|
const VarNode* dst = link->varNode();
|
||||||
const BpLinks& links = ninf(src)->getLinks();
|
const BpLinks& links = getLinks (src);
|
||||||
// calculate the product of messages that were sent
|
// calculate the product of messages that were sent
|
||||||
// to factor `src', except from var `dst'
|
// to factor `src', except from var `dst'
|
||||||
unsigned reps = 1;
|
unsigned reps = 1;
|
||||||
@ -265,7 +265,7 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
WeightedBp::getVarToFactorMsg (const BpLink* _link)
|
||||||
{
|
{
|
||||||
const WeightedLink* link = static_cast<const WeightedLink*> (_link);
|
const WeightedLink* link = static_cast<const WeightedLink*> (_link);
|
||||||
const VarNode* src = link->varNode();
|
const VarNode* src = link->varNode();
|
||||||
@ -286,7 +286,7 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
|||||||
}
|
}
|
||||||
LogAware::pow (msg, link->weight() - 1);
|
LogAware::pow (msg, link->weight() - 1);
|
||||||
}
|
}
|
||||||
const BpLinks& links = ninf(src)->getLinks();
|
const BpLinks& links = getLinks (src);
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
|
@ -43,7 +43,7 @@ class WeightedBp : public BeliefProp {
|
|||||||
|
|
||||||
void calcFactorToVarMsg (BpLink*);
|
void calcFactorToVarMsg (BpLink*);
|
||||||
|
|
||||||
Params getVarToFactorMsg (const BpLink*) const;
|
Params getVarToFactorMsg (const BpLink*);
|
||||||
|
|
||||||
void printLinkInformation() const;
|
void printLinkInformation() const;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user