improve calculation of joint dist in counting bp

This commit is contained in:
Tiago Gomes
2012-06-13 12:47:41 +01:00
parent b43e3316b3
commit 0e83a75b60
6 changed files with 142 additions and 80 deletions

View File

@@ -74,16 +74,17 @@ CountingBp::solveQuery (VarIds queryVids)
cout << endl;
}
if (idx == facNodes.size()) {
cerr << "error: only joint distributions on variables of some " ;
cerr << "clique are supported with the current solver" ;
cerr << endl;
exit (1);
res = Solver::getJointByConditioning (
GroundSolver::CBP, fg, queryVids);
} else {
FacNode* reprFn = getRepresentative (facNodes[idx]);
assert (reprFn != 0);
VarIds reprArgs;
for (size_t i = 0; i < queryVids.size(); i++) {
reprArgs.push_back (getRepresentative (queryVids[i]));
}
res = solver_->getFactorJoint (reprFn, reprArgs);
}
VarIds representatives;
for (size_t i = 0; i < queryVids.size(); i++) {
representatives.push_back (getRepresentative (queryVids[i]));
}
res = solver_->getJointDistributionOf (representatives);
}
return res;
}
@@ -292,6 +293,29 @@ CountingBp::getSignature (const FacNode* facNode)
VarId
CountingBp::getRepresentative (VarId vid)
{
assert (Util::contains (vid2VarCluster_, vid));
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->representative()->varId();
}
FacNode*
CountingBp::getRepresentative (FacNode* fn)
{
for (size_t i = 0; i < facClusters_.size(); i++) {
if (Util::contains (facClusters_[i]->members(), fn)) {
return facClusters_[i]->representative();
}
}
return 0;
}
FactorGraph*
CountingBp::getCompressedFactorGraph (void)
{