improve calculation of joint dist in counting bp

This commit is contained in:
Tiago Gomes 2012-06-13 12:47:41 +01:00
parent b43e3316b3
commit 0e83a75b60
6 changed files with 142 additions and 80 deletions

View File

@ -117,24 +117,36 @@ BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
}
if (idx == facNodes.size()) {
return getJointByConditioning (jointVarIds);
} else {
Factor res (facNodes[idx]->factor());
const BpLinks& links = ninf(facNodes[idx])->getLinks();
for (size_t i = 0; i < links.size(); i++) {
Factor msg ({links[i]->varNode()->varId()},
{links[i]->varNode()->range()},
getVarToFactorMsg (links[i]));
res.multiply (msg);
}
res.sumOutAllExcept (jointVarIds);
res.reorderArguments (jointVarIds);
res.normalize();
Params jointDist = res.params();
if (Globals::logDomain) {
Util::exp (jointDist);
}
return jointDist;
}
return getFactorJoint (facNodes[idx], jointVarIds);
}
Params
BeliefProp::getFactorJoint (
FacNode* fn,
const VarIds& jointVarIds)
{
if (runned_ == false) {
runSolver();
}
Factor res (fn->factor());
const BpLinks& links = ninf(fn)->getLinks();
for (size_t i = 0; i < links.size(); i++) {
Factor msg ({links[i]->varNode()->varId()},
{links[i]->varNode()->range()},
getVarToFactorMsg (links[i]));
res.multiply (msg);
}
res.sumOutAllExcept (jointVarIds);
res.reorderArguments (jointVarIds);
res.normalize();
Params jointDist = res.params();
if (Globals::logDomain) {
Util::exp (jointDist);
}
return jointDist;
}
@ -363,53 +375,7 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
Params
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
{
VarNodes jointVars;
for (size_t i = 0; i < jointVarIds.size(); i++) {
assert (fg.getVarNode (jointVarIds[i]));
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
}
FactorGraph* tempFg = new FactorGraph (fg);
BeliefProp solver (*tempFg);
solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
VarIds observedVids = {jointVars[0]->varId()};
for (size_t i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs;
Vars observedVars;
Ranges observedRanges;
for (size_t j = 0; j < observedVids.size(); j++) {
observedVars.push_back (tempFg->getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range());
}
Indexer indexer (observedRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]);
}
BeliefProp solver (*tempFg);
solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->range() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
}
return prevBeliefs;
return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds);
}

View File

@ -111,6 +111,10 @@ class BeliefProp : public Solver
virtual Params getJointByConditioning (const VarIds&) const;
public:
Params getFactorJoint (FacNode*, const VarIds&);
protected:
SPNodeInfo* ninf (const VarNode* var) const
{
return varsI_[var->getIndex()];

View File

@ -74,16 +74,17 @@ CountingBp::solveQuery (VarIds queryVids)
cout << endl;
}
if (idx == facNodes.size()) {
cerr << "error: only joint distributions on variables of some " ;
cerr << "clique are supported with the current solver" ;
cerr << endl;
exit (1);
res = Solver::getJointByConditioning (
GroundSolver::CBP, fg, queryVids);
} else {
FacNode* reprFn = getRepresentative (facNodes[idx]);
assert (reprFn != 0);
VarIds reprArgs;
for (size_t i = 0; i < queryVids.size(); i++) {
reprArgs.push_back (getRepresentative (queryVids[i]));
}
res = solver_->getFactorJoint (reprFn, reprArgs);
}
VarIds representatives;
for (size_t i = 0; i < queryVids.size(); i++) {
representatives.push_back (getRepresentative (queryVids[i]));
}
res = solver_->getJointDistributionOf (representatives);
}
return res;
}
@ -292,6 +293,29 @@ CountingBp::getSignature (const FacNode* facNode)
VarId
CountingBp::getRepresentative (VarId vid)
{
assert (Util::contains (vid2VarCluster_, vid));
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->representative()->varId();
}
FacNode*
CountingBp::getRepresentative (FacNode* fn)
{
for (size_t i = 0; i < facClusters_.size(); i++) {
if (Util::contains (facClusters_[i]->members(), fn)) {
return facClusters_[i]->representative();
}
}
return 0;
}
FactorGraph*
CountingBp::getCompressedFactorGraph (void)
{

View File

@ -154,12 +154,9 @@ class CountingBp : public Solver
void printGroups (const VarSignMap&, const FacSignMap&) const;
VarId getRepresentative (VarId vid)
{
assert (Util::contains (vid2VarCluster_, vid));
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->representative()->varId();
}
VarId getRepresentative (VarId vid);
FacNode* getRepresentative (FacNode*);
FactorGraph* getCompressedFactorGraph (void);

View File

@ -1,5 +1,8 @@
#include "Solver.h"
#include "Util.h"
#include "BeliefProp.h"
#include "CountingBp.h"
#include "VarElim.h"
void
@ -38,3 +41,67 @@ Solver::printAllPosterioris (void)
}
}
Params
Solver::getJointByConditioning (
GroundSolver solverType,
FactorGraph fg,
const VarIds& jointVarIds) const
{
VarNodes jointVars;
for (size_t i = 0; i < jointVarIds.size(); i++) {
assert (fg.getVarNode (jointVarIds[i]));
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
}
Solver* solver = 0;
switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break;
}
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
VarIds observedVids = {jointVars[0]->varId()};
for (size_t i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs;
Vars observedVars;
Ranges observedRanges;
for (size_t j = 0; j < observedVids.size(); j++) {
observedVars.push_back (fg.getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range());
}
Indexer indexer (observedRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]);
}
delete solver;
switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break;
}
Params beliefs = solver->solveQuery ({jointVarIds[i]});
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->range() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
}
delete solver;
return prevBeliefs;
}

View File

@ -3,8 +3,9 @@
#include <iomanip>
#include "Var.h"
#include "FactorGraph.h"
#include "Var.h"
#include "Horus.h"
using namespace std;
@ -23,6 +24,9 @@ class Solver
void printAnswer (const VarIds& vids);
void printAllPosterioris (void);
Params getJointByConditioning (GroundSolver,
FactorGraph, const VarIds& jointVarIds) const;
protected:
const FactorGraph& fg;