improve calculation of joint dist in counting bp

This commit is contained in:
Tiago Gomes
2012-06-13 12:47:41 +01:00
parent b43e3316b3
commit 0e83a75b60
6 changed files with 142 additions and 80 deletions

View File

@@ -117,24 +117,36 @@ BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
}
if (idx == facNodes.size()) {
return getJointByConditioning (jointVarIds);
} else {
Factor res (facNodes[idx]->factor());
const BpLinks& links = ninf(facNodes[idx])->getLinks();
for (size_t i = 0; i < links.size(); i++) {
Factor msg ({links[i]->varNode()->varId()},
{links[i]->varNode()->range()},
getVarToFactorMsg (links[i]));
res.multiply (msg);
}
res.sumOutAllExcept (jointVarIds);
res.reorderArguments (jointVarIds);
res.normalize();
Params jointDist = res.params();
if (Globals::logDomain) {
Util::exp (jointDist);
}
return jointDist;
}
return getFactorJoint (facNodes[idx], jointVarIds);
}
Params
BeliefProp::getFactorJoint (
FacNode* fn,
const VarIds& jointVarIds)
{
if (runned_ == false) {
runSolver();
}
Factor res (fn->factor());
const BpLinks& links = ninf(fn)->getLinks();
for (size_t i = 0; i < links.size(); i++) {
Factor msg ({links[i]->varNode()->varId()},
{links[i]->varNode()->range()},
getVarToFactorMsg (links[i]));
res.multiply (msg);
}
res.sumOutAllExcept (jointVarIds);
res.reorderArguments (jointVarIds);
res.normalize();
Params jointDist = res.params();
if (Globals::logDomain) {
Util::exp (jointDist);
}
return jointDist;
}
@@ -363,53 +375,7 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
Params
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
{
VarNodes jointVars;
for (size_t i = 0; i < jointVarIds.size(); i++) {
assert (fg.getVarNode (jointVarIds[i]));
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
}
FactorGraph* tempFg = new FactorGraph (fg);
BeliefProp solver (*tempFg);
solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
VarIds observedVids = {jointVars[0]->varId()};
for (size_t i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs;
Vars observedVars;
Ranges observedRanges;
for (size_t j = 0; j < observedVids.size(); j++) {
observedVars.push_back (tempFg->getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range());
}
Indexer indexer (observedRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]);
}
BeliefProp solver (*tempFg);
solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->range() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
}
return prevBeliefs;
return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds);
}