improve calculation of joint dist in counting bp
This commit is contained in:
parent
b43e3316b3
commit
0e83a75b60
@ -117,24 +117,36 @@ BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
}
|
}
|
||||||
if (idx == facNodes.size()) {
|
if (idx == facNodes.size()) {
|
||||||
return getJointByConditioning (jointVarIds);
|
return getJointByConditioning (jointVarIds);
|
||||||
} else {
|
|
||||||
Factor res (facNodes[idx]->factor());
|
|
||||||
const BpLinks& links = ninf(facNodes[idx])->getLinks();
|
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
|
||||||
Factor msg ({links[i]->varNode()->varId()},
|
|
||||||
{links[i]->varNode()->range()},
|
|
||||||
getVarToFactorMsg (links[i]));
|
|
||||||
res.multiply (msg);
|
|
||||||
}
|
|
||||||
res.sumOutAllExcept (jointVarIds);
|
|
||||||
res.reorderArguments (jointVarIds);
|
|
||||||
res.normalize();
|
|
||||||
Params jointDist = res.params();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
Util::exp (jointDist);
|
|
||||||
}
|
|
||||||
return jointDist;
|
|
||||||
}
|
}
|
||||||
|
return getFactorJoint (facNodes[idx], jointVarIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
BeliefProp::getFactorJoint (
|
||||||
|
FacNode* fn,
|
||||||
|
const VarIds& jointVarIds)
|
||||||
|
{
|
||||||
|
if (runned_ == false) {
|
||||||
|
runSolver();
|
||||||
|
}
|
||||||
|
Factor res (fn->factor());
|
||||||
|
const BpLinks& links = ninf(fn)->getLinks();
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
Factor msg ({links[i]->varNode()->varId()},
|
||||||
|
{links[i]->varNode()->range()},
|
||||||
|
getVarToFactorMsg (links[i]));
|
||||||
|
res.multiply (msg);
|
||||||
|
}
|
||||||
|
res.sumOutAllExcept (jointVarIds);
|
||||||
|
res.reorderArguments (jointVarIds);
|
||||||
|
res.normalize();
|
||||||
|
Params jointDist = res.params();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::exp (jointDist);
|
||||||
|
}
|
||||||
|
return jointDist;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -363,53 +375,7 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
|||||||
Params
|
Params
|
||||||
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
|
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
|
||||||
{
|
{
|
||||||
VarNodes jointVars;
|
return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds);
|
||||||
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
|
||||||
assert (fg.getVarNode (jointVarIds[i]));
|
|
||||||
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
FactorGraph* tempFg = new FactorGraph (fg);
|
|
||||||
BeliefProp solver (*tempFg);
|
|
||||||
solver.runSolver();
|
|
||||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
|
||||||
|
|
||||||
VarIds observedVids = {jointVars[0]->varId()};
|
|
||||||
|
|
||||||
for (size_t i = 1; i < jointVarIds.size(); i++) {
|
|
||||||
assert (jointVars[i]->hasEvidence() == false);
|
|
||||||
Params newBeliefs;
|
|
||||||
Vars observedVars;
|
|
||||||
Ranges observedRanges;
|
|
||||||
for (size_t j = 0; j < observedVids.size(); j++) {
|
|
||||||
observedVars.push_back (tempFg->getVarNode (observedVids[j]));
|
|
||||||
observedRanges.push_back (observedVars.back()->range());
|
|
||||||
}
|
|
||||||
Indexer indexer (observedRanges, false);
|
|
||||||
while (indexer.valid()) {
|
|
||||||
for (size_t j = 0; j < observedVars.size(); j++) {
|
|
||||||
observedVars[j]->setEvidence (indexer[j]);
|
|
||||||
}
|
|
||||||
BeliefProp solver (*tempFg);
|
|
||||||
solver.runSolver();
|
|
||||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
|
||||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
|
||||||
newBeliefs.push_back (beliefs[k]);
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
|
|
||||||
int count = -1;
|
|
||||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
|
||||||
if (j % jointVars[i]->range() == 0) {
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
newBeliefs[j] *= prevBeliefs[count];
|
|
||||||
}
|
|
||||||
prevBeliefs = newBeliefs;
|
|
||||||
observedVids.push_back (jointVars[i]->varId());
|
|
||||||
}
|
|
||||||
return prevBeliefs;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -111,6 +111,10 @@ class BeliefProp : public Solver
|
|||||||
|
|
||||||
virtual Params getJointByConditioning (const VarIds&) const;
|
virtual Params getJointByConditioning (const VarIds&) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Params getFactorJoint (FacNode*, const VarIds&);
|
||||||
|
|
||||||
|
protected:
|
||||||
SPNodeInfo* ninf (const VarNode* var) const
|
SPNodeInfo* ninf (const VarNode* var) const
|
||||||
{
|
{
|
||||||
return varsI_[var->getIndex()];
|
return varsI_[var->getIndex()];
|
||||||
|
@ -74,16 +74,17 @@ CountingBp::solveQuery (VarIds queryVids)
|
|||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
if (idx == facNodes.size()) {
|
if (idx == facNodes.size()) {
|
||||||
cerr << "error: only joint distributions on variables of some " ;
|
res = Solver::getJointByConditioning (
|
||||||
cerr << "clique are supported with the current solver" ;
|
GroundSolver::CBP, fg, queryVids);
|
||||||
cerr << endl;
|
} else {
|
||||||
exit (1);
|
FacNode* reprFn = getRepresentative (facNodes[idx]);
|
||||||
|
assert (reprFn != 0);
|
||||||
|
VarIds reprArgs;
|
||||||
|
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||||
|
reprArgs.push_back (getRepresentative (queryVids[i]));
|
||||||
|
}
|
||||||
|
res = solver_->getFactorJoint (reprFn, reprArgs);
|
||||||
}
|
}
|
||||||
VarIds representatives;
|
|
||||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
|
||||||
representatives.push_back (getRepresentative (queryVids[i]));
|
|
||||||
}
|
|
||||||
res = solver_->getJointDistributionOf (representatives);
|
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@ -292,6 +293,29 @@ CountingBp::getSignature (const FacNode* facNode)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
VarId
|
||||||
|
CountingBp::getRepresentative (VarId vid)
|
||||||
|
{
|
||||||
|
assert (Util::contains (vid2VarCluster_, vid));
|
||||||
|
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||||
|
return vc->representative()->varId();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FacNode*
|
||||||
|
CountingBp::getRepresentative (FacNode* fn)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||||
|
if (Util::contains (facClusters_[i]->members(), fn)) {
|
||||||
|
return facClusters_[i]->representative();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph*
|
FactorGraph*
|
||||||
CountingBp::getCompressedFactorGraph (void)
|
CountingBp::getCompressedFactorGraph (void)
|
||||||
{
|
{
|
||||||
|
@ -154,12 +154,9 @@ class CountingBp : public Solver
|
|||||||
|
|
||||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||||
|
|
||||||
VarId getRepresentative (VarId vid)
|
VarId getRepresentative (VarId vid);
|
||||||
{
|
|
||||||
assert (Util::contains (vid2VarCluster_, vid));
|
FacNode* getRepresentative (FacNode*);
|
||||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
|
||||||
return vc->representative()->varId();
|
|
||||||
}
|
|
||||||
|
|
||||||
FactorGraph* getCompressedFactorGraph (void);
|
FactorGraph* getCompressedFactorGraph (void);
|
||||||
|
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
#include "Solver.h"
|
#include "Solver.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
#include "BeliefProp.h"
|
||||||
|
#include "CountingBp.h"
|
||||||
|
#include "VarElim.h"
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
@ -38,3 +41,67 @@ Solver::printAllPosterioris (void)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
Solver::getJointByConditioning (
|
||||||
|
GroundSolver solverType,
|
||||||
|
FactorGraph fg,
|
||||||
|
const VarIds& jointVarIds) const
|
||||||
|
{
|
||||||
|
VarNodes jointVars;
|
||||||
|
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
||||||
|
assert (fg.getVarNode (jointVarIds[i]));
|
||||||
|
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
Solver* solver = 0;
|
||||||
|
switch (solverType) {
|
||||||
|
case GroundSolver::BP: solver = new BeliefProp (fg); break;
|
||||||
|
case GroundSolver::CBP: solver = new CountingBp (fg); break;
|
||||||
|
case GroundSolver::VE: solver = new VarElim (fg); break;
|
||||||
|
}
|
||||||
|
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
||||||
|
VarIds observedVids = {jointVars[0]->varId()};
|
||||||
|
|
||||||
|
for (size_t i = 1; i < jointVarIds.size(); i++) {
|
||||||
|
assert (jointVars[i]->hasEvidence() == false);
|
||||||
|
Params newBeliefs;
|
||||||
|
Vars observedVars;
|
||||||
|
Ranges observedRanges;
|
||||||
|
for (size_t j = 0; j < observedVids.size(); j++) {
|
||||||
|
observedVars.push_back (fg.getVarNode (observedVids[j]));
|
||||||
|
observedRanges.push_back (observedVars.back()->range());
|
||||||
|
}
|
||||||
|
Indexer indexer (observedRanges, false);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
for (size_t j = 0; j < observedVars.size(); j++) {
|
||||||
|
observedVars[j]->setEvidence (indexer[j]);
|
||||||
|
}
|
||||||
|
delete solver;
|
||||||
|
switch (solverType) {
|
||||||
|
case GroundSolver::BP: solver = new BeliefProp (fg); break;
|
||||||
|
case GroundSolver::CBP: solver = new CountingBp (fg); break;
|
||||||
|
case GroundSolver::VE: solver = new VarElim (fg); break;
|
||||||
|
}
|
||||||
|
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
||||||
|
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||||
|
newBeliefs.push_back (beliefs[k]);
|
||||||
|
}
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
|
||||||
|
int count = -1;
|
||||||
|
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||||
|
if (j % jointVars[i]->range() == 0) {
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
newBeliefs[j] *= prevBeliefs[count];
|
||||||
|
}
|
||||||
|
prevBeliefs = newBeliefs;
|
||||||
|
observedVids.push_back (jointVars[i]->varId());
|
||||||
|
}
|
||||||
|
delete solver;
|
||||||
|
return prevBeliefs;
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -3,8 +3,9 @@
|
|||||||
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
||||||
#include "Var.h"
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
|
#include "Var.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -23,6 +24,9 @@ class Solver
|
|||||||
void printAnswer (const VarIds& vids);
|
void printAnswer (const VarIds& vids);
|
||||||
|
|
||||||
void printAllPosterioris (void);
|
void printAllPosterioris (void);
|
||||||
|
|
||||||
|
Params getJointByConditioning (GroundSolver,
|
||||||
|
FactorGraph, const VarIds& jointVarIds) const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const FactorGraph& fg;
|
const FactorGraph& fg;
|
||||||
|
Reference in New Issue
Block a user