diff --git a/packages/CLPBN/horus/BeliefProp.cpp b/packages/CLPBN/horus/BeliefProp.cpp index 4ce70fef8..3d2237c49 100644 --- a/packages/CLPBN/horus/BeliefProp.cpp +++ b/packages/CLPBN/horus/BeliefProp.cpp @@ -117,24 +117,36 @@ BeliefProp::getJointDistributionOf (const VarIds& jointVarIds) } if (idx == facNodes.size()) { return getJointByConditioning (jointVarIds); - } else { - Factor res (facNodes[idx]->factor()); - const BpLinks& links = ninf(facNodes[idx])->getLinks(); - for (size_t i = 0; i < links.size(); i++) { - Factor msg ({links[i]->varNode()->varId()}, - {links[i]->varNode()->range()}, - getVarToFactorMsg (links[i])); - res.multiply (msg); - } - res.sumOutAllExcept (jointVarIds); - res.reorderArguments (jointVarIds); - res.normalize(); - Params jointDist = res.params(); - if (Globals::logDomain) { - Util::exp (jointDist); - } - return jointDist; } + return getFactorJoint (facNodes[idx], jointVarIds); +} + + + +Params +BeliefProp::getFactorJoint ( + FacNode* fn, + const VarIds& jointVarIds) +{ + if (runned_ == false) { + runSolver(); + } + Factor res (fn->factor()); + const BpLinks& links = ninf(fn)->getLinks(); + for (size_t i = 0; i < links.size(); i++) { + Factor msg ({links[i]->varNode()->varId()}, + {links[i]->varNode()->range()}, + getVarToFactorMsg (links[i])); + res.multiply (msg); + } + res.sumOutAllExcept (jointVarIds); + res.reorderArguments (jointVarIds); + res.normalize(); + Params jointDist = res.params(); + if (Globals::logDomain) { + Util::exp (jointDist); + } + return jointDist; } @@ -363,53 +375,7 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const Params BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const { - VarNodes jointVars; - for (size_t i = 0; i < jointVarIds.size(); i++) { - assert (fg.getVarNode (jointVarIds[i])); - jointVars.push_back (fg.getVarNode (jointVarIds[i])); - } - - FactorGraph* tempFg = new FactorGraph (fg); - BeliefProp solver (*tempFg); - solver.runSolver(); - Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]); - - VarIds observedVids = {jointVars[0]->varId()}; - - for (size_t i = 1; i < jointVarIds.size(); i++) { - assert (jointVars[i]->hasEvidence() == false); - Params newBeliefs; - Vars observedVars; - Ranges observedRanges; - for (size_t j = 0; j < observedVids.size(); j++) { - observedVars.push_back (tempFg->getVarNode (observedVids[j])); - observedRanges.push_back (observedVars.back()->range()); - } - Indexer indexer (observedRanges, false); - while (indexer.valid()) { - for (size_t j = 0; j < observedVars.size(); j++) { - observedVars[j]->setEvidence (indexer[j]); - } - BeliefProp solver (*tempFg); - solver.runSolver(); - Params beliefs = solver.getPosterioriOf (jointVarIds[i]); - for (size_t k = 0; k < beliefs.size(); k++) { - newBeliefs.push_back (beliefs[k]); - } - ++ indexer; - } - - int count = -1; - for (size_t j = 0; j < newBeliefs.size(); j++) { - if (j % jointVars[i]->range() == 0) { - count ++; - } - newBeliefs[j] *= prevBeliefs[count]; - } - prevBeliefs = newBeliefs; - observedVids.push_back (jointVars[i]->varId()); - } - return prevBeliefs; + return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds); } diff --git a/packages/CLPBN/horus/BeliefProp.h b/packages/CLPBN/horus/BeliefProp.h index 767767de9..af8da9a23 100644 --- a/packages/CLPBN/horus/BeliefProp.h +++ b/packages/CLPBN/horus/BeliefProp.h @@ -111,6 +111,10 @@ class BeliefProp : public Solver virtual Params getJointByConditioning (const VarIds&) const; + public: + Params getFactorJoint (FacNode*, const VarIds&); + + protected: SPNodeInfo* ninf (const VarNode* var) const { return varsI_[var->getIndex()]; diff --git a/packages/CLPBN/horus/CountingBp.cpp b/packages/CLPBN/horus/CountingBp.cpp index f09faef2c..ffd1abc64 100644 --- a/packages/CLPBN/horus/CountingBp.cpp +++ b/packages/CLPBN/horus/CountingBp.cpp @@ -74,16 +74,17 @@ CountingBp::solveQuery (VarIds queryVids) cout << endl; } if (idx == facNodes.size()) { - cerr << "error: only joint distributions on variables of some " ; - cerr << "clique are supported with the current solver" ; - cerr << endl; - exit (1); + res = Solver::getJointByConditioning ( + GroundSolver::CBP, fg, queryVids); + } else { + FacNode* reprFn = getRepresentative (facNodes[idx]); + assert (reprFn != 0); + VarIds reprArgs; + for (size_t i = 0; i < queryVids.size(); i++) { + reprArgs.push_back (getRepresentative (queryVids[i])); + } + res = solver_->getFactorJoint (reprFn, reprArgs); } - VarIds representatives; - for (size_t i = 0; i < queryVids.size(); i++) { - representatives.push_back (getRepresentative (queryVids[i])); - } - res = solver_->getJointDistributionOf (representatives); } return res; } @@ -292,6 +293,29 @@ CountingBp::getSignature (const FacNode* facNode) +VarId +CountingBp::getRepresentative (VarId vid) +{ + assert (Util::contains (vid2VarCluster_, vid)); + VarCluster* vc = vid2VarCluster_.find (vid)->second; + return vc->representative()->varId(); +} + + + +FacNode* +CountingBp::getRepresentative (FacNode* fn) +{ + for (size_t i = 0; i < facClusters_.size(); i++) { + if (Util::contains (facClusters_[i]->members(), fn)) { + return facClusters_[i]->representative(); + } + } + return 0; +} + + + FactorGraph* CountingBp::getCompressedFactorGraph (void) { diff --git a/packages/CLPBN/horus/CountingBp.h b/packages/CLPBN/horus/CountingBp.h index 894f4ceeb..d54a47fee 100644 --- a/packages/CLPBN/horus/CountingBp.h +++ b/packages/CLPBN/horus/CountingBp.h @@ -154,12 +154,9 @@ class CountingBp : public Solver void printGroups (const VarSignMap&, const FacSignMap&) const; - VarId getRepresentative (VarId vid) - { - assert (Util::contains (vid2VarCluster_, vid)); - VarCluster* vc = vid2VarCluster_.find (vid)->second; - return vc->representative()->varId(); - } + VarId getRepresentative (VarId vid); + + FacNode* getRepresentative (FacNode*); FactorGraph* getCompressedFactorGraph (void); diff --git a/packages/CLPBN/horus/Solver.cpp b/packages/CLPBN/horus/Solver.cpp index 4f1b52d5b..4cb3b6768 100644 --- a/packages/CLPBN/horus/Solver.cpp +++ b/packages/CLPBN/horus/Solver.cpp @@ -1,5 +1,8 @@ #include "Solver.h" #include "Util.h" +#include "BeliefProp.h" +#include "CountingBp.h" +#include "VarElim.h" void @@ -38,3 +41,67 @@ Solver::printAllPosterioris (void) } } + + +Params +Solver::getJointByConditioning ( + GroundSolver solverType, + FactorGraph fg, + const VarIds& jointVarIds) const +{ + VarNodes jointVars; + for (size_t i = 0; i < jointVarIds.size(); i++) { + assert (fg.getVarNode (jointVarIds[i])); + jointVars.push_back (fg.getVarNode (jointVarIds[i])); + } + + Solver* solver = 0; + switch (solverType) { + case GroundSolver::BP: solver = new BeliefProp (fg); break; + case GroundSolver::CBP: solver = new CountingBp (fg); break; + case GroundSolver::VE: solver = new VarElim (fg); break; + } + Params prevBeliefs = solver->solveQuery ({jointVarIds[0]}); + VarIds observedVids = {jointVars[0]->varId()}; + + for (size_t i = 1; i < jointVarIds.size(); i++) { + assert (jointVars[i]->hasEvidence() == false); + Params newBeliefs; + Vars observedVars; + Ranges observedRanges; + for (size_t j = 0; j < observedVids.size(); j++) { + observedVars.push_back (fg.getVarNode (observedVids[j])); + observedRanges.push_back (observedVars.back()->range()); + } + Indexer indexer (observedRanges, false); + while (indexer.valid()) { + for (size_t j = 0; j < observedVars.size(); j++) { + observedVars[j]->setEvidence (indexer[j]); + } + delete solver; + switch (solverType) { + case GroundSolver::BP: solver = new BeliefProp (fg); break; + case GroundSolver::CBP: solver = new CountingBp (fg); break; + case GroundSolver::VE: solver = new VarElim (fg); break; + } + Params beliefs = solver->solveQuery ({jointVarIds[i]}); + for (size_t k = 0; k < beliefs.size(); k++) { + newBeliefs.push_back (beliefs[k]); + } + ++ indexer; + } + + int count = -1; + for (size_t j = 0; j < newBeliefs.size(); j++) { + if (j % jointVars[i]->range() == 0) { + count ++; + } + newBeliefs[j] *= prevBeliefs[count]; + } + prevBeliefs = newBeliefs; + observedVids.push_back (jointVars[i]->varId()); + } + delete solver; + return prevBeliefs; +} + diff --git a/packages/CLPBN/horus/Solver.h b/packages/CLPBN/horus/Solver.h index cc22795d1..a378b2419 100644 --- a/packages/CLPBN/horus/Solver.h +++ b/packages/CLPBN/horus/Solver.h @@ -3,8 +3,9 @@ #include -#include "Var.h" #include "FactorGraph.h" +#include "Var.h" +#include "Horus.h" using namespace std; @@ -23,6 +24,9 @@ class Solver void printAnswer (const VarIds& vids); void printAllPosterioris (void); + + Params getJointByConditioning (GroundSolver, + FactorGraph, const VarIds& jointVarIds) const; protected: const FactorGraph& fg;