improve calculation of joint dist in counting bp
This commit is contained in:
parent
b43e3316b3
commit
0e83a75b60
@ -117,24 +117,36 @@ BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
}
|
||||
if (idx == facNodes.size()) {
|
||||
return getJointByConditioning (jointVarIds);
|
||||
} else {
|
||||
Factor res (facNodes[idx]->factor());
|
||||
const BpLinks& links = ninf(facNodes[idx])->getLinks();
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
Factor msg ({links[i]->varNode()->varId()},
|
||||
{links[i]->varNode()->range()},
|
||||
getVarToFactorMsg (links[i]));
|
||||
res.multiply (msg);
|
||||
}
|
||||
res.sumOutAllExcept (jointVarIds);
|
||||
res.reorderArguments (jointVarIds);
|
||||
res.normalize();
|
||||
Params jointDist = res.params();
|
||||
if (Globals::logDomain) {
|
||||
Util::exp (jointDist);
|
||||
}
|
||||
return jointDist;
|
||||
}
|
||||
return getFactorJoint (facNodes[idx], jointVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BeliefProp::getFactorJoint (
|
||||
FacNode* fn,
|
||||
const VarIds& jointVarIds)
|
||||
{
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
Factor res (fn->factor());
|
||||
const BpLinks& links = ninf(fn)->getLinks();
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
Factor msg ({links[i]->varNode()->varId()},
|
||||
{links[i]->varNode()->range()},
|
||||
getVarToFactorMsg (links[i]));
|
||||
res.multiply (msg);
|
||||
}
|
||||
res.sumOutAllExcept (jointVarIds);
|
||||
res.reorderArguments (jointVarIds);
|
||||
res.normalize();
|
||||
Params jointDist = res.params();
|
||||
if (Globals::logDomain) {
|
||||
Util::exp (jointDist);
|
||||
}
|
||||
return jointDist;
|
||||
}
|
||||
|
||||
|
||||
@ -363,53 +375,7 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||
Params
|
||||
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
VarNodes jointVars;
|
||||
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
||||
assert (fg.getVarNode (jointVarIds[i]));
|
||||
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
FactorGraph* tempFg = new FactorGraph (fg);
|
||||
BeliefProp solver (*tempFg);
|
||||
solver.runSolver();
|
||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
|
||||
for (size_t i = 1; i < jointVarIds.size(); i++) {
|
||||
assert (jointVars[i]->hasEvidence() == false);
|
||||
Params newBeliefs;
|
||||
Vars observedVars;
|
||||
Ranges observedRanges;
|
||||
for (size_t j = 0; j < observedVids.size(); j++) {
|
||||
observedVars.push_back (tempFg->getVarNode (observedVids[j]));
|
||||
observedRanges.push_back (observedVars.back()->range());
|
||||
}
|
||||
Indexer indexer (observedRanges, false);
|
||||
while (indexer.valid()) {
|
||||
for (size_t j = 0; j < observedVars.size(); j++) {
|
||||
observedVars[j]->setEvidence (indexer[j]);
|
||||
}
|
||||
BeliefProp solver (*tempFg);
|
||||
solver.runSolver();
|
||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
|
||||
int count = -1;
|
||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->range() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
observedVids.push_back (jointVars[i]->varId());
|
||||
}
|
||||
return prevBeliefs;
|
||||
return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
@ -111,6 +111,10 @@ class BeliefProp : public Solver
|
||||
|
||||
virtual Params getJointByConditioning (const VarIds&) const;
|
||||
|
||||
public:
|
||||
Params getFactorJoint (FacNode*, const VarIds&);
|
||||
|
||||
protected:
|
||||
SPNodeInfo* ninf (const VarNode* var) const
|
||||
{
|
||||
return varsI_[var->getIndex()];
|
||||
|
@ -74,16 +74,17 @@ CountingBp::solveQuery (VarIds queryVids)
|
||||
cout << endl;
|
||||
}
|
||||
if (idx == facNodes.size()) {
|
||||
cerr << "error: only joint distributions on variables of some " ;
|
||||
cerr << "clique are supported with the current solver" ;
|
||||
cerr << endl;
|
||||
exit (1);
|
||||
res = Solver::getJointByConditioning (
|
||||
GroundSolver::CBP, fg, queryVids);
|
||||
} else {
|
||||
FacNode* reprFn = getRepresentative (facNodes[idx]);
|
||||
assert (reprFn != 0);
|
||||
VarIds reprArgs;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
reprArgs.push_back (getRepresentative (queryVids[i]));
|
||||
}
|
||||
res = solver_->getFactorJoint (reprFn, reprArgs);
|
||||
}
|
||||
VarIds representatives;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
representatives.push_back (getRepresentative (queryVids[i]));
|
||||
}
|
||||
res = solver_->getJointDistributionOf (representatives);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@ -292,6 +293,29 @@ CountingBp::getSignature (const FacNode* facNode)
|
||||
|
||||
|
||||
|
||||
VarId
|
||||
CountingBp::getRepresentative (VarId vid)
|
||||
{
|
||||
assert (Util::contains (vid2VarCluster_, vid));
|
||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||
return vc->representative()->varId();
|
||||
}
|
||||
|
||||
|
||||
|
||||
FacNode*
|
||||
CountingBp::getRepresentative (FacNode* fn)
|
||||
{
|
||||
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||
if (Util::contains (facClusters_[i]->members(), fn)) {
|
||||
return facClusters_[i]->representative();
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
CountingBp::getCompressedFactorGraph (void)
|
||||
{
|
||||
|
@ -154,12 +154,9 @@ class CountingBp : public Solver
|
||||
|
||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||
|
||||
VarId getRepresentative (VarId vid)
|
||||
{
|
||||
assert (Util::contains (vid2VarCluster_, vid));
|
||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||
return vc->representative()->varId();
|
||||
}
|
||||
VarId getRepresentative (VarId vid);
|
||||
|
||||
FacNode* getRepresentative (FacNode*);
|
||||
|
||||
FactorGraph* getCompressedFactorGraph (void);
|
||||
|
||||
|
@ -1,5 +1,8 @@
|
||||
#include "Solver.h"
|
||||
#include "Util.h"
|
||||
#include "BeliefProp.h"
|
||||
#include "CountingBp.h"
|
||||
#include "VarElim.h"
|
||||
|
||||
|
||||
void
|
||||
@ -38,3 +41,67 @@ Solver::printAllPosterioris (void)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
Solver::getJointByConditioning (
|
||||
GroundSolver solverType,
|
||||
FactorGraph fg,
|
||||
const VarIds& jointVarIds) const
|
||||
{
|
||||
VarNodes jointVars;
|
||||
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
||||
assert (fg.getVarNode (jointVarIds[i]));
|
||||
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
Solver* solver = 0;
|
||||
switch (solverType) {
|
||||
case GroundSolver::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolver::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolver::VE: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
|
||||
for (size_t i = 1; i < jointVarIds.size(); i++) {
|
||||
assert (jointVars[i]->hasEvidence() == false);
|
||||
Params newBeliefs;
|
||||
Vars observedVars;
|
||||
Ranges observedRanges;
|
||||
for (size_t j = 0; j < observedVids.size(); j++) {
|
||||
observedVars.push_back (fg.getVarNode (observedVids[j]));
|
||||
observedRanges.push_back (observedVars.back()->range());
|
||||
}
|
||||
Indexer indexer (observedRanges, false);
|
||||
while (indexer.valid()) {
|
||||
for (size_t j = 0; j < observedVars.size(); j++) {
|
||||
observedVars[j]->setEvidence (indexer[j]);
|
||||
}
|
||||
delete solver;
|
||||
switch (solverType) {
|
||||
case GroundSolver::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolver::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolver::VE: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
|
||||
int count = -1;
|
||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->range() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
observedVids.push_back (jointVars[i]->varId());
|
||||
}
|
||||
delete solver;
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
|
@ -3,8 +3,9 @@
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
#include "Var.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Var.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
@ -23,6 +24,9 @@ class Solver
|
||||
void printAnswer (const VarIds& vids);
|
||||
|
||||
void printAllPosterioris (void);
|
||||
|
||||
Params getJointByConditioning (GroundSolver,
|
||||
FactorGraph, const VarIds& jointVarIds) const;
|
||||
|
||||
protected:
|
||||
const FactorGraph& fg;
|
||||
|
Reference in New Issue
Block a user