Merge branch 'master' of ssh://yap.git.sourceforge.net/gitroot/yap/yap-6.3

This commit is contained in:
Vítor Santos Costa 2012-10-03 21:54:13 +01:00
commit 2d373a28a6
6 changed files with 23 additions and 18 deletions

View File

@ -118,20 +118,19 @@ BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
if (idx == facNodes.size()) {
return getJointByConditioning (jointVarIds);
}
return getFactorJoint (idx, jointVarIds);
return getFactorJoint (facNodes[idx], jointVarIds);
}
Params
BeliefProp::getFactorJoint (
size_t fnIdx,
FacNode* fn,
const VarIds& jointVarIds)
{
if (runned_ == false) {
runSolver();
}
FacNode* fn = fg.facNodes()[fnIdx];
Factor res (fn->factor());
const BpLinks& links = ninf(fn)->getLinks();
for (size_t i = 0; i < links.size(); i++) {
@ -350,20 +349,22 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
const BpLinks& links = ninf (src)->getLinks();
if (Globals::logDomain) {
for (it = links.begin(); it != links.end(); ++it) {
msg += (*it)->message();
if (*it != link) {
msg += (*it)->message();
}
if (Constants::SHOW_BP_CALCS) {
cout << " x " << (*it)->message();
}
}
msg -= link->message();
} else {
for (it = links.begin(); it != links.end(); ++it) {
msg *= (*it)->message();
if (*it != link) {
msg *= (*it)->message();
}
if (Constants::SHOW_BP_CALCS) {
cout << " x " << (*it)->message();
}
}
msg /= link->message();
}
if (Constants::SHOW_BP_CALCS) {
cout << " = " << msg;

View File

@ -112,7 +112,7 @@ class BeliefProp : public Solver
virtual Params getJointByConditioning (const VarIds&) const;
public:
Params getFactorJoint (size_t fnIdx, const VarIds&);
Params getFactorJoint (FacNode* fn, const VarIds&);
protected:
SPNodeInfo* ninf (const VarNode* var) const

View File

@ -81,7 +81,9 @@ CountingBp::solveQuery (VarIds queryVids)
for (size_t i = 0; i < queryVids.size(); i++) {
reprArgs.push_back (getRepresentative (queryVids[i]));
}
res = solver_->getFactorJoint (idx, reprArgs);
FacNode* reprFac = getRepresentative (facNodes[idx]);
assert (reprFac != 0);
res = solver_->getFactorJoint (reprFac, reprArgs);
}
}
return res;

View File

@ -8,7 +8,8 @@ LiftedBp::LiftedBp (const ParfactorList& pfList)
: pfList_(pfList)
{
refineParfactors();
solver_ = new WeightedBp (*getFactorGraph(), getWeights());
createFactorGraph();
solver_ = new WeightedBp (*fg_, getWeights());
}
@ -16,6 +17,7 @@ LiftedBp::LiftedBp (const ParfactorList& pfList)
LiftedBp::~LiftedBp (void)
{
delete solver_;
delete fg_;
}
@ -47,7 +49,7 @@ LiftedBp::solveQuery (const Grounds& query)
for (unsigned i = 0; i < groups.size(); i++) {
queryVids.push_back (groups[i]);
}
res = solver_->getFactorJoint (idx, queryVids);
res = solver_->getFactorJoint (fg_->facNodes()[idx], queryVids);
}
}
return res;
@ -131,10 +133,10 @@ LiftedBp::getQueryGroups (const Grounds& query)
FactorGraph*
LiftedBp::getFactorGraph (void)
void
LiftedBp::createFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
fg_ = new FactorGraph();
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
vector<PrvGroup> groups = (*it)->getAllGroups();
@ -142,9 +144,8 @@ LiftedBp::getFactorGraph (void)
for (size_t i = 0; i < groups.size(); i++) {
varIds.push_back (groups[i]);
}
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
fg_->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
}
return fg;
}

View File

@ -24,7 +24,7 @@ class LiftedBp
vector<PrvGroup> getQueryGroups (const Grounds&);
FactorGraph* getFactorGraph (void);
void createFactorGraph (void);
vector<vector<unsigned>> getWeights (void) const;
@ -34,6 +34,7 @@ class LiftedBp
ParfactorList pfList_;
WeightedBp* solver_;
FactorGraph* fg_;
};