improve calculation of joint dist in counting bp
This commit is contained in:
@@ -74,16 +74,17 @@ CountingBp::solveQuery (VarIds queryVids)
|
||||
cout << endl;
|
||||
}
|
||||
if (idx == facNodes.size()) {
|
||||
cerr << "error: only joint distributions on variables of some " ;
|
||||
cerr << "clique are supported with the current solver" ;
|
||||
cerr << endl;
|
||||
exit (1);
|
||||
res = Solver::getJointByConditioning (
|
||||
GroundSolver::CBP, fg, queryVids);
|
||||
} else {
|
||||
FacNode* reprFn = getRepresentative (facNodes[idx]);
|
||||
assert (reprFn != 0);
|
||||
VarIds reprArgs;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
reprArgs.push_back (getRepresentative (queryVids[i]));
|
||||
}
|
||||
res = solver_->getFactorJoint (reprFn, reprArgs);
|
||||
}
|
||||
VarIds representatives;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
representatives.push_back (getRepresentative (queryVids[i]));
|
||||
}
|
||||
res = solver_->getJointDistributionOf (representatives);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@@ -292,6 +293,29 @@ CountingBp::getSignature (const FacNode* facNode)
|
||||
|
||||
|
||||
|
||||
VarId
|
||||
CountingBp::getRepresentative (VarId vid)
|
||||
{
|
||||
assert (Util::contains (vid2VarCluster_, vid));
|
||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||
return vc->representative()->varId();
|
||||
}
|
||||
|
||||
|
||||
|
||||
FacNode*
|
||||
CountingBp::getRepresentative (FacNode* fn)
|
||||
{
|
||||
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||
if (Util::contains (facClusters_[i]->members(), fn)) {
|
||||
return facClusters_[i]->representative();
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
CountingBp::getCompressedFactorGraph (void)
|
||||
{
|
||||
|
Reference in New Issue
Block a user