diff --git a/packages/CLPBN/horus/BeliefProp.cpp b/packages/CLPBN/horus/BeliefProp.cpp index d65c36a01..314f4a6c5 100644 --- a/packages/CLPBN/horus/BeliefProp.cpp +++ b/packages/CLPBN/horus/BeliefProp.cpp @@ -118,20 +118,19 @@ BeliefProp::getJointDistributionOf (const VarIds& jointVarIds) if (idx == facNodes.size()) { return getJointByConditioning (jointVarIds); } - return getFactorJoint (idx, jointVarIds); + return getFactorJoint (facNodes[idx], jointVarIds); } Params BeliefProp::getFactorJoint ( - size_t fnIdx, + FacNode* fn, const VarIds& jointVarIds) { if (runned_ == false) { runSolver(); } - FacNode* fn = fg.facNodes()[fnIdx]; Factor res (fn->factor()); const BpLinks& links = ninf(fn)->getLinks(); for (size_t i = 0; i < links.size(); i++) { @@ -350,20 +349,22 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const const BpLinks& links = ninf (src)->getLinks(); if (Globals::logDomain) { for (it = links.begin(); it != links.end(); ++it) { - msg += (*it)->message(); + if (*it != link) { + msg += (*it)->message(); + } if (Constants::SHOW_BP_CALCS) { cout << " x " << (*it)->message(); } } - msg -= link->message(); } else { for (it = links.begin(); it != links.end(); ++it) { - msg *= (*it)->message(); + if (*it != link) { + msg *= (*it)->message(); + } if (Constants::SHOW_BP_CALCS) { cout << " x " << (*it)->message(); } } - msg /= link->message(); } if (Constants::SHOW_BP_CALCS) { cout << " = " << msg; diff --git a/packages/CLPBN/horus/BeliefProp.h b/packages/CLPBN/horus/BeliefProp.h index 44c867dbc..1545abfc8 100644 --- a/packages/CLPBN/horus/BeliefProp.h +++ b/packages/CLPBN/horus/BeliefProp.h @@ -112,7 +112,7 @@ class BeliefProp : public Solver virtual Params getJointByConditioning (const VarIds&) const; public: - Params getFactorJoint (size_t fnIdx, const VarIds&); + Params getFactorJoint (FacNode* fn, const VarIds&); protected: SPNodeInfo* ninf (const VarNode* var) const diff --git a/packages/CLPBN/horus/CountingBp.cpp b/packages/CLPBN/horus/CountingBp.cpp index faf9f4d18..365ff7098 100644 --- a/packages/CLPBN/horus/CountingBp.cpp +++ b/packages/CLPBN/horus/CountingBp.cpp @@ -81,7 +81,9 @@ CountingBp::solveQuery (VarIds queryVids) for (size_t i = 0; i < queryVids.size(); i++) { reprArgs.push_back (getRepresentative (queryVids[i])); } - res = solver_->getFactorJoint (idx, reprArgs); + FacNode* reprFac = getRepresentative (facNodes[idx]); + assert (reprFac != 0); + res = solver_->getFactorJoint (reprFac, reprArgs); } } return res; diff --git a/packages/CLPBN/horus/LiftedBp.cpp b/packages/CLPBN/horus/LiftedBp.cpp index 8fc8573b6..05a5ea6af 100644 --- a/packages/CLPBN/horus/LiftedBp.cpp +++ b/packages/CLPBN/horus/LiftedBp.cpp @@ -8,7 +8,8 @@ LiftedBp::LiftedBp (const ParfactorList& pfList) : pfList_(pfList) { refineParfactors(); - solver_ = new WeightedBp (*getFactorGraph(), getWeights()); + createFactorGraph(); + solver_ = new WeightedBp (*fg_, getWeights()); } @@ -16,6 +17,7 @@ LiftedBp::LiftedBp (const ParfactorList& pfList) LiftedBp::~LiftedBp (void) { delete solver_; + delete fg_; } @@ -47,7 +49,7 @@ LiftedBp::solveQuery (const Grounds& query) for (unsigned i = 0; i < groups.size(); i++) { queryVids.push_back (groups[i]); } - res = solver_->getFactorJoint (idx, queryVids); + res = solver_->getFactorJoint (fg_->facNodes()[idx], queryVids); } } return res; @@ -131,10 +133,10 @@ LiftedBp::getQueryGroups (const Grounds& query) -FactorGraph* -LiftedBp::getFactorGraph (void) +void +LiftedBp::createFactorGraph (void) { - FactorGraph* fg = new FactorGraph(); + fg_ = new FactorGraph(); ParfactorList::const_iterator it = pfList_.begin(); for (; it != pfList_.end(); ++it) { vector groups = (*it)->getAllGroups(); @@ -142,9 +144,8 @@ LiftedBp::getFactorGraph (void) for (size_t i = 0; i < groups.size(); i++) { varIds.push_back (groups[i]); } - fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params())); + fg_->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params())); } - return fg; } diff --git a/packages/CLPBN/horus/LiftedBp.h b/packages/CLPBN/horus/LiftedBp.h index 01e18310e..c34956320 100644 --- a/packages/CLPBN/horus/LiftedBp.h +++ b/packages/CLPBN/horus/LiftedBp.h @@ -24,7 +24,7 @@ class LiftedBp vector getQueryGroups (const Grounds&); - FactorGraph* getFactorGraph (void); + void createFactorGraph (void); vector> getWeights (void) const; @@ -34,6 +34,7 @@ class LiftedBp ParfactorList pfList_; WeightedBp* solver_; + FactorGraph* fg_; }; diff --git a/packages/CLPBN/horus/WeightedBp.cpp b/packages/CLPBN/horus/WeightedBp.cpp index c4e308d4f..d8a32a246 100644 --- a/packages/CLPBN/horus/WeightedBp.cpp +++ b/packages/CLPBN/horus/WeightedBp.cpp @@ -275,7 +275,7 @@ void WeightedBp::printLinkInformation (void) const { for (size_t i = 0; i < links_.size(); i++) { - WeightedLink* l = static_cast (links_[i]); + WeightedLink* l = static_cast (links_[i]); cout << l->toString() << ":" << endl; cout << " curr msg = " << l->message() << endl; cout << " next msg = " << l->nextMessage() << endl;