Merge branch 'master' of ssh://yap.git.sourceforge.net/gitroot/yap/yap-6.3

This commit is contained in:
Vítor Santos Costa 2012-10-03 21:54:13 +01:00
commit 2d373a28a6
6 changed files with 23 additions and 18 deletions

View File

@ -118,20 +118,19 @@ BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
if (idx == facNodes.size()) { if (idx == facNodes.size()) {
return getJointByConditioning (jointVarIds); return getJointByConditioning (jointVarIds);
} }
return getFactorJoint (idx, jointVarIds); return getFactorJoint (facNodes[idx], jointVarIds);
} }
Params Params
BeliefProp::getFactorJoint ( BeliefProp::getFactorJoint (
size_t fnIdx, FacNode* fn,
const VarIds& jointVarIds) const VarIds& jointVarIds)
{ {
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
} }
FacNode* fn = fg.facNodes()[fnIdx];
Factor res (fn->factor()); Factor res (fn->factor());
const BpLinks& links = ninf(fn)->getLinks(); const BpLinks& links = ninf(fn)->getLinks();
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
@ -350,20 +349,22 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
const BpLinks& links = ninf (src)->getLinks(); const BpLinks& links = ninf (src)->getLinks();
if (Globals::logDomain) { if (Globals::logDomain) {
for (it = links.begin(); it != links.end(); ++it) { for (it = links.begin(); it != links.end(); ++it) {
msg += (*it)->message(); if (*it != link) {
msg += (*it)->message();
}
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " x " << (*it)->message(); cout << " x " << (*it)->message();
} }
} }
msg -= link->message();
} else { } else {
for (it = links.begin(); it != links.end(); ++it) { for (it = links.begin(); it != links.end(); ++it) {
msg *= (*it)->message(); if (*it != link) {
msg *= (*it)->message();
}
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " x " << (*it)->message(); cout << " x " << (*it)->message();
} }
} }
msg /= link->message();
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " = " << msg; cout << " = " << msg;

View File

@ -112,7 +112,7 @@ class BeliefProp : public Solver
virtual Params getJointByConditioning (const VarIds&) const; virtual Params getJointByConditioning (const VarIds&) const;
public: public:
Params getFactorJoint (size_t fnIdx, const VarIds&); Params getFactorJoint (FacNode* fn, const VarIds&);
protected: protected:
SPNodeInfo* ninf (const VarNode* var) const SPNodeInfo* ninf (const VarNode* var) const

View File

@ -81,7 +81,9 @@ CountingBp::solveQuery (VarIds queryVids)
for (size_t i = 0; i < queryVids.size(); i++) { for (size_t i = 0; i < queryVids.size(); i++) {
reprArgs.push_back (getRepresentative (queryVids[i])); reprArgs.push_back (getRepresentative (queryVids[i]));
} }
res = solver_->getFactorJoint (idx, reprArgs); FacNode* reprFac = getRepresentative (facNodes[idx]);
assert (reprFac != 0);
res = solver_->getFactorJoint (reprFac, reprArgs);
} }
} }
return res; return res;

View File

@ -8,7 +8,8 @@ LiftedBp::LiftedBp (const ParfactorList& pfList)
: pfList_(pfList) : pfList_(pfList)
{ {
refineParfactors(); refineParfactors();
solver_ = new WeightedBp (*getFactorGraph(), getWeights()); createFactorGraph();
solver_ = new WeightedBp (*fg_, getWeights());
} }
@ -16,6 +17,7 @@ LiftedBp::LiftedBp (const ParfactorList& pfList)
LiftedBp::~LiftedBp (void) LiftedBp::~LiftedBp (void)
{ {
delete solver_; delete solver_;
delete fg_;
} }
@ -47,7 +49,7 @@ LiftedBp::solveQuery (const Grounds& query)
for (unsigned i = 0; i < groups.size(); i++) { for (unsigned i = 0; i < groups.size(); i++) {
queryVids.push_back (groups[i]); queryVids.push_back (groups[i]);
} }
res = solver_->getFactorJoint (idx, queryVids); res = solver_->getFactorJoint (fg_->facNodes()[idx], queryVids);
} }
} }
return res; return res;
@ -131,10 +133,10 @@ LiftedBp::getQueryGroups (const Grounds& query)
FactorGraph* void
LiftedBp::getFactorGraph (void) LiftedBp::createFactorGraph (void)
{ {
FactorGraph* fg = new FactorGraph(); fg_ = new FactorGraph();
ParfactorList::const_iterator it = pfList_.begin(); ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) { for (; it != pfList_.end(); ++it) {
vector<PrvGroup> groups = (*it)->getAllGroups(); vector<PrvGroup> groups = (*it)->getAllGroups();
@ -142,9 +144,8 @@ LiftedBp::getFactorGraph (void)
for (size_t i = 0; i < groups.size(); i++) { for (size_t i = 0; i < groups.size(); i++) {
varIds.push_back (groups[i]); varIds.push_back (groups[i]);
} }
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params())); fg_->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
} }
return fg;
} }

View File

@ -24,7 +24,7 @@ class LiftedBp
vector<PrvGroup> getQueryGroups (const Grounds&); vector<PrvGroup> getQueryGroups (const Grounds&);
FactorGraph* getFactorGraph (void); void createFactorGraph (void);
vector<vector<unsigned>> getWeights (void) const; vector<vector<unsigned>> getWeights (void) const;
@ -34,6 +34,7 @@ class LiftedBp
ParfactorList pfList_; ParfactorList pfList_;
WeightedBp* solver_; WeightedBp* solver_;
FactorGraph* fg_;
}; };

View File

@ -275,7 +275,7 @@ void
WeightedBp::printLinkInformation (void) const WeightedBp::printLinkInformation (void) const
{ {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links_[i]); WeightedLink* l = static_cast<WeightedLink*> (links_[i]);
cout << l->toString() << ":" << endl; cout << l->toString() << ":" << endl;
cout << " curr msg = " << l->message() << endl; cout << " curr msg = " << l->message() << endl;
cout << " next msg = " << l->nextMessage() << endl; cout << " next msg = " << l->nextMessage() << endl;