Kill SPNodeInfo class

This commit is contained in:
Tiago Gomes 2013-03-09 16:41:53 +00:00
parent d563fce952
commit 95539226ea
4 changed files with 40 additions and 48 deletions

View File

@ -29,11 +29,17 @@ BeliefProp::BeliefProp (const FactorGraph& fg)
BeliefProp::~BeliefProp() BeliefProp::~BeliefProp()
{ {
for (size_t i = 0; i < varsI_.size(); i++) { for (size_t i = 0; i < varsLinks_.size(); i++) {
delete varsI_[i]; BpLinks& links = varsLinks_[i];
for (unsigned j = 0; j < links.size(); j++) {
delete links[j];
}
} }
for (size_t i = 0; i < facsI_.size(); i++) { for (size_t i = 0; i < facsLinks_.size(); i++) {
delete facsI_[i]; BpLinks& links = facsLinks_[i];
for (unsigned j = 0; j < links.size(); j++) {
delete links[j];
}
} }
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
delete links_[i]; delete links_[i];
@ -88,7 +94,7 @@ BeliefProp::getPosterioriOf (VarId vid)
probs[var->getEvidence()] = LogAware::withEvidence(); probs[var->getEvidence()] = LogAware::withEvidence();
} else { } else {
probs.resize (var->range(), LogAware::multIdenty()); probs.resize (var->range(), LogAware::multIdenty());
const BpLinks& links = ninf(var)->getLinks(); const BpLinks& links = getLinks (var);
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
probs += links[i]->message(); probs += links[i]->message();
@ -139,7 +145,7 @@ BeliefProp::getFactorJoint (
runSolver(); runSolver();
} }
Factor res (fn->factor()); Factor res (fn->factor());
const BpLinks& links = ninf(fn)->getLinks(); const BpLinks& links = getLinks( fn);
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
Factor msg ({links[i]->varNode()->varId()}, Factor msg ({links[i]->varNode()->varId()},
{links[i]->varNode()->range()}, {links[i]->varNode()->range()},
@ -350,7 +356,7 @@ BeliefProp::maxResidualSchedule()
const FacNodes& factorNeighbors = link->varNode()->neighbors(); const FacNodes& factorNeighbors = link->varNode()->neighbors();
for (size_t i = 0; i < factorNeighbors.size(); i++) { for (size_t i = 0; i < factorNeighbors.size(); i++) {
if (factorNeighbors[i] != link->facNode()) { if (factorNeighbors[i] != link->facNode()) {
const BpLinks& links = ninf(factorNeighbors[i])->getLinks(); const BpLinks& links = getLinks (factorNeighbors[i]);
for (size_t j = 0; j < links.size(); j++) { for (size_t j = 0; j < links.size(); j++) {
if (links[j]->varNode() != link->varNode()) { if (links[j]->varNode() != link->varNode()) {
calculateMessage (links[j]); calculateMessage (links[j]);
@ -374,7 +380,7 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
{ {
FacNode* src = link->facNode(); FacNode* src = link->facNode();
const VarNode* dst = link->varNode(); const VarNode* dst = link->varNode();
const BpLinks& links = ninf(src)->getLinks(); const BpLinks& links = getLinks (src);
// calculate the product of messages that were sent // calculate the product of messages that were sent
// to factor `src', except from var `dst' // to factor `src', except from var `dst'
unsigned reps = 1; unsigned reps = 1;
@ -435,7 +441,7 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
Params Params
BeliefProp::getVarToFactorMsg (const BpLink* link) const BeliefProp::getVarToFactorMsg (const BpLink* link)
{ {
const VarNode* src = link->varNode(); const VarNode* src = link->varNode();
Params msg; Params msg;
@ -449,7 +455,7 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
std::cout << msg; std::cout << msg;
} }
BpLinks::const_iterator it; BpLinks::const_iterator it;
const BpLinks& links = ninf (src)->getLinks(); const BpLinks& links = getLinks (src);
if (Globals::logDomain) { if (Globals::logDomain) {
for (it = links.begin(); it != links.end(); ++it) { for (it = links.begin(); it != links.end(); ++it) {
if (*it != link) { if (*it != link) {
@ -490,21 +496,21 @@ void
BeliefProp::initializeSolver() BeliefProp::initializeSolver()
{ {
const VarNodes& varNodes = fg.varNodes(); const VarNodes& varNodes = fg.varNodes();
varsI_.reserve (varNodes.size()); varsLinks_.reserve (varNodes.size());
for (size_t i = 0; i < varNodes.size(); i++) { for (size_t i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo()); varsLinks_.push_back (BpLinks());
} }
const FacNodes& facNodes = fg.facNodes(); const FacNodes& facNodes = fg.facNodes();
facsI_.reserve (facNodes.size()); facsLinks_.reserve (facNodes.size());
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo()); facsLinks_.push_back (BpLinks());
} }
createLinks(); createLinks();
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
FacNode* src = links_[i]->facNode(); FacNode* src = links_[i]->facNode();
VarNode* dst = links_[i]->varNode(); VarNode* dst = links_[i]->varNode();
ninf (dst)->addBpLink (links_[i]); getLinks (dst).push_back (links_[i]);
ninf (src)->addBpLink (links_[i]); getLinks (src).push_back (links_[i]);
} }
} }

View File

@ -96,9 +96,9 @@ class BeliefProp : public GroundSolver {
typedef std::multiset<BpLink*, CmpResidual> SortedOrder; typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap; typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
SPNodeInfo* ninf (const VarNode* var) const; BpLinks& getLinks (const VarNode* var);
SPNodeInfo* ninf (const FacNode* fac) const; BpLinks& getLinks (const FacNode* fac);
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true); void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true);
@ -114,14 +114,14 @@ class BeliefProp : public GroundSolver {
virtual void calcFactorToVarMsg (BpLink*); virtual void calcFactorToVarMsg (BpLink*);
virtual Params getVarToFactorMsg (const BpLink*) const; virtual Params getVarToFactorMsg (const BpLink*);
virtual Params getJointByConditioning (const VarIds&) const; virtual Params getJointByConditioning (const VarIds&) const;
BpLinks links_; BpLinks links_;
unsigned nIters_; unsigned nIters_;
std::vector<SPNodeInfo*> varsI_; std::vector<BpLinks> varsLinks_;
std::vector<SPNodeInfo*> facsI_; std::vector<BpLinks> facsLinks_;
bool runned_; bool runned_;
SortedOrder sortedOrder_; SortedOrder sortedOrder_;
BpLinkMap linkMap_; BpLinkMap linkMap_;
@ -131,20 +131,6 @@ class BeliefProp : public GroundSolver {
static MsgSchedule schedule_; static MsgSchedule schedule_;
private: private:
class SPNodeInfo {
public:
SPNodeInfo() { }
void addBpLink (BeliefProp::BpLink* link) { links_.push_back (link); }
const BpLinks& getLinks() { return links_; }
private:
BpLinks links_;
DISALLOW_COPY_AND_ASSIGN (SPNodeInfo);
};
void initializeSolver(); void initializeSolver();
bool converged(); bool converged();
@ -156,18 +142,18 @@ class BeliefProp : public GroundSolver {
inline BeliefProp::SPNodeInfo* inline BeliefProp::BpLinks&
BeliefProp::ninf (const VarNode* var) const BeliefProp::getLinks (const VarNode* var)
{ {
return varsI_[var->getIndex()]; return varsLinks_[var->getIndex()];
} }
inline BeliefProp::SPNodeInfo* inline BeliefProp::BpLinks&
BeliefProp::ninf (const FacNode* fac) const BeliefProp::getLinks (const FacNode* fac)
{ {
return facsI_[fac->getIndex()]; return facsLinks_[fac->getIndex()];
} }
} // namespace Horus } // namespace Horus

View File

@ -42,7 +42,7 @@ WeightedBp::getPosterioriOf (VarId vid)
probs[var->getEvidence()] = LogAware::withEvidence(); probs[var->getEvidence()] = LogAware::withEvidence();
} else { } else {
probs.resize (var->range(), LogAware::multIdenty()); probs.resize (var->range(), LogAware::multIdenty());
const BpLinks& links = ninf(var)->getLinks(); const BpLinks& links = getLinks (var);
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links[i]); WeightedLink* l = static_cast<WeightedLink*> (links[i]);
@ -152,7 +152,7 @@ WeightedBp::maxResidualSchedule()
// update the messages that depend on message source --> destin // update the messages that depend on message source --> destin
const FacNodes& factorNeighbors = link->varNode()->neighbors(); const FacNodes& factorNeighbors = link->varNode()->neighbors();
for (size_t i = 0; i < factorNeighbors.size(); i++) { for (size_t i = 0; i < factorNeighbors.size(); i++) {
const BpLinks& links = ninf(factorNeighbors[i])->getLinks(); const BpLinks& links = getLinks (factorNeighbors[i]);
for (size_t j = 0; j < links.size(); j++) { for (size_t j = 0; j < links.size(); j++) {
if (links[j]->varNode() != link->varNode()) { if (links[j]->varNode() != link->varNode()) {
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
@ -168,7 +168,7 @@ WeightedBp::maxResidualSchedule()
} }
// in counting bp, the message that a variable X sends to // in counting bp, the message that a variable X sends to
// to a factor F depends on the message that F sent to the X // to a factor F depends on the message that F sent to the X
const BpLinks& links = ninf(link->facNode())->getLinks(); const BpLinks& links = getLinks (link->facNode());
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
if (links[i]->varNode() != link->varNode()) { if (links[i]->varNode() != link->varNode()) {
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
@ -192,7 +192,7 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
WeightedLink* link = static_cast<WeightedLink*> (_link); WeightedLink* link = static_cast<WeightedLink*> (_link);
FacNode* src = link->facNode(); FacNode* src = link->facNode();
const VarNode* dst = link->varNode(); const VarNode* dst = link->varNode();
const BpLinks& links = ninf(src)->getLinks(); const BpLinks& links = getLinks (src);
// calculate the product of messages that were sent // calculate the product of messages that were sent
// to factor `src', except from var `dst' // to factor `src', except from var `dst'
unsigned reps = 1; unsigned reps = 1;
@ -265,7 +265,7 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
Params Params
WeightedBp::getVarToFactorMsg (const BpLink* _link) const WeightedBp::getVarToFactorMsg (const BpLink* _link)
{ {
const WeightedLink* link = static_cast<const WeightedLink*> (_link); const WeightedLink* link = static_cast<const WeightedLink*> (_link);
const VarNode* src = link->varNode(); const VarNode* src = link->varNode();
@ -286,7 +286,7 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
} }
LogAware::pow (msg, link->weight() - 1); LogAware::pow (msg, link->weight() - 1);
} }
const BpLinks& links = ninf(src)->getLinks(); const BpLinks& links = getLinks (src);
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links[i]); WeightedLink* l = static_cast<WeightedLink*> (links[i]);

View File

@ -43,7 +43,7 @@ class WeightedBp : public BeliefProp {
void calcFactorToVarMsg (BpLink*); void calcFactorToVarMsg (BpLink*);
Params getVarToFactorMsg (const BpLink*) const; Params getVarToFactorMsg (const BpLink*);
void printLinkInformation() const; void printLinkInformation() const;