diff --git a/packages/CLPBN/horus/CbpSolver.cpp b/packages/CLPBN/horus/CbpSolver.cpp index 619c73263..e8fe349f0 100644 --- a/packages/CLPBN/horus/CbpSolver.cpp +++ b/packages/CLPBN/horus/CbpSolver.cpp @@ -59,29 +59,17 @@ Params CbpSolver::solveQuery (VarIds queryVids) { assert (queryVids.empty() == false); - return queryVids.size() == 1 - ? getPosterioriOf (queryVids[0]) - : getJointDistributionOf (queryVids); -} - - - -Params -CbpSolver::getPosterioriOf (VarId vid) -{ - return solver_->getPosterioriOf (getRepresentative (vid)); -} - - - -Params -CbpSolver::getJointDistributionOf (const VarIds& jointVids) -{ - VarIds representatives; - for (size_t i = 0; i < jointVids.size(); i++) { - representatives.push_back (getRepresentative (jointVids[i])); + Params res; + if (queryVids.size() == 1) { + res = solver_->getPosterioriOf (getRepresentative (queryVids[0])); + } else { + VarIds representatives; + for (size_t i = 0; i < queryVids.size(); i++) { + representatives.push_back (getRepresentative (queryVids[i])); + } + res = solver_->getJointDistributionOf (representatives); } - return solver_->getJointDistributionOf (representatives); + return res; } diff --git a/packages/CLPBN/horus/CbpSolver.h b/packages/CLPBN/horus/CbpSolver.h index 38078712f..b2d789a58 100644 --- a/packages/CLPBN/horus/CbpSolver.h +++ b/packages/CLPBN/horus/CbpSolver.h @@ -58,7 +58,6 @@ struct FacSignHash }; - class VarCluster { public: @@ -78,7 +77,6 @@ class VarCluster }; - class FacCluster { public: @@ -102,7 +100,6 @@ class FacCluster }; - class CbpSolver : public Solver { public: @@ -113,15 +110,10 @@ class CbpSolver : public Solver void printSolverFlags (void) const; Params solveQuery (VarIds); - - Params getPosterioriOf (VarId); - - Params getJointDistributionOf (const VarIds&); - + static bool checkForIdenticalFactors; private: - Color getNewColor (void) { ++ freeColor_; diff --git a/packages/CLPBN/horus/FoveSolver.cpp b/packages/CLPBN/horus/FoveSolver.cpp index ddc42b4b0..65f348e7d 100644 --- a/packages/CLPBN/horus/FoveSolver.cpp +++ b/packages/CLPBN/horus/FoveSolver.cpp @@ -630,16 +630,9 @@ GroundOperator::getAffectedFormulas (void) Params -FoveSolver::getPosterioriOf (const Ground& query) -{ - return getJointDistributionOf ({query}); -} - - - -Params -FoveSolver::getJointDistributionOf (const Grounds& query) +FoveSolver::solveQuery (const Grounds& query) { + assert (query.empty() == false); runSolver (query); (*pfList_.begin())->normalize(); Params params = (*pfList_.begin())->params(); diff --git a/packages/CLPBN/horus/FoveSolver.h b/packages/CLPBN/horus/FoveSolver.h index 4fab77d82..ed5ff4cf0 100644 --- a/packages/CLPBN/horus/FoveSolver.h +++ b/packages/CLPBN/horus/FoveSolver.h @@ -135,9 +135,7 @@ class FoveSolver public: FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { } - Params getPosterioriOf (const Ground&); - - Params getJointDistributionOf (const Grounds&); + Params solveQuery (const Grounds&); void printSolverFlags (void) const; diff --git a/packages/CLPBN/horus/HorusYap.cpp b/packages/CLPBN/horus/HorusYap.cpp index 082d2c884..d7ba5a803 100644 --- a/packages/CLPBN/horus/HorusYap.cpp +++ b/packages/CLPBN/horus/HorusYap.cpp @@ -317,22 +317,14 @@ runLiftedSolver (void) solver.printSolverFlags(); cout << endl; } - if (queryVars.size() == 1) { - results.push_back (solver.getPosterioriOf (queryVars[0])); - } else { - results.push_back (solver.getJointDistributionOf (queryVars)); - } + results.push_back (solver.solveQuery (queryVars)); } else if (Globals::liftedSolver == LiftedSolvers::LBP) { LiftedBpSolver solver (pfListCopy); if (Globals::verbosity > 0 && taskList == YAP_ARG2) { solver.printSolverFlags(); cout << endl; } - if (queryVars.size() == 1) { - results.push_back (solver.getPosterioriOf (queryVars[0])); - } else { - results.push_back (solver.getJointDistributionOf (queryVars)); - } + results.push_back (solver.solveQuery (queryVars)); } else { assert (false); } diff --git a/packages/CLPBN/horus/LiftedBpSolver.cpp b/packages/CLPBN/horus/LiftedBpSolver.cpp index a53205e1f..6132f6d64 100644 --- a/packages/CLPBN/horus/LiftedBpSolver.cpp +++ b/packages/CLPBN/horus/LiftedBpSolver.cpp @@ -15,23 +15,21 @@ LiftedBpSolver::LiftedBpSolver (const ParfactorList& pfList) Params -LiftedBpSolver::getPosterioriOf (const Ground& query) -{ - vector groups = getQueryGroups ({query}); - return solver_->getPosterioriOf (groups[0]); -} - - - -Params -LiftedBpSolver::getJointDistributionOf (const Grounds& query) +LiftedBpSolver::solveQuery (const Grounds& query) { + assert (query.empty() == false); + Params res; vector groups = getQueryGroups (query); - VarIds queryVids; - for (unsigned i = 0; i < groups.size(); i++) { - queryVids.push_back (groups[i]); + if (query.size() == 1) { + res = solver_->getPosterioriOf (groups[0]); + } else { + VarIds queryVids; + for (unsigned i = 0; i < groups.size(); i++) { + queryVids.push_back (groups[i]); + } + res = solver_->getJointDistributionOf (queryVids); } - return solver_->getJointDistributionOf (queryVids); + return res; } diff --git a/packages/CLPBN/horus/LiftedBpSolver.h b/packages/CLPBN/horus/LiftedBpSolver.h index 869db05b3..17695157b 100644 --- a/packages/CLPBN/horus/LiftedBpSolver.h +++ b/packages/CLPBN/horus/LiftedBpSolver.h @@ -12,9 +12,7 @@ class LiftedBpSolver public: LiftedBpSolver (const ParfactorList& pfList); - Params getPosterioriOf (const Ground&); - - Params getJointDistributionOf (const Grounds&); + Params solveQuery (const Grounds&); void printSolverFlags (void) const; diff --git a/packages/CLPBN/horus/WeightedBpSolver.cpp b/packages/CLPBN/horus/WeightedBpSolver.cpp index 0cc07f98f..a56d92373 100644 --- a/packages/CLPBN/horus/WeightedBpSolver.cpp +++ b/packages/CLPBN/horus/WeightedBpSolver.cpp @@ -1,14 +1,6 @@ #include "WeightedBpSolver.h" -WeightedBpSolver::WeightedBpSolver ( - const FactorGraph& fg, const vector>& weights) - : BpSolver (fg), weights_(weights) -{ -} - - - WeightedBpSolver::~WeightedBpSolver (void) { for (size_t i = 0; i < links_.size(); i++) { @@ -172,8 +164,8 @@ WeightedBpSolver::calcFactorToVarMsg (BpLink* _link) Params msgProduct (msgSize, LogAware::multIdenty()); if (Globals::logDomain) { for (size_t i = links.size(); i-- > 0; ) { - const WeightedLink* cl = static_cast (links[i]); - if ( ! (cl->varNode() == dst && cl->index() == link->index())) { + const WeightedLink* l = static_cast (links[i]); + if ( ! (l->varNode() == dst && l->index() == link->index())) { if (Constants::SHOW_BP_CALCS) { cout << " message from " << links[i]->varNode()->label(); cout << ": " ; @@ -188,8 +180,8 @@ WeightedBpSolver::calcFactorToVarMsg (BpLink* _link) } } else { for (size_t i = links.size(); i-- > 0; ) { - const WeightedLink* cl = static_cast (links[i]); - if ( ! (cl->varNode() == dst && cl->index() == link->index())) { + const WeightedLink* l = static_cast (links[i]); + if ( ! (l->varNode() == dst && l->index() == link->index())) { if (Constants::SHOW_BP_CALCS) { cout << " message from " << links[i]->varNode()->label(); cout << ": " ; @@ -255,19 +247,18 @@ WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const const BpLinks& links = ninf(src)->getLinks(); if (Globals::logDomain) { for (size_t i = 0; i < links.size(); i++) { - WeightedLink* cl = static_cast (links[i]); - if ( ! (cl->facNode() == dst && cl->index() == link->index())) { - WeightedLink* cl = static_cast (links[i]); - msg += cl->powMessage(); + WeightedLink* l = static_cast (links[i]); + if ( ! (l->facNode() == dst && l->index() == link->index())) { + msg += l->powMessage(); } } } else { for (size_t i = 0; i < links.size(); i++) { - WeightedLink* cl = static_cast (links[i]); - if ( ! (cl->facNode() == dst && cl->index() == link->index())) { - msg *= cl->powMessage(); + WeightedLink* l = static_cast (links[i]); + if ( ! (l->facNode() == dst && l->index() == link->index())) { + msg *= l->powMessage(); if (Constants::SHOW_BP_CALCS) { - cout << " x " << cl->nextMessage() << "^" << link->weight(); + cout << " x " << l->nextMessage() << "^" << link->weight(); } } } @@ -284,14 +275,14 @@ void WeightedBpSolver::printLinkInformation (void) const { for (size_t i = 0; i < links_.size(); i++) { - WeightedLink* cl = static_cast (links_[i]); - cout << cl->toString() << ":" << endl; - cout << " curr msg = " << cl->message() << endl; - cout << " next msg = " << cl->nextMessage() << endl; - cout << " index = " << cl->index() << endl; - cout << " weight = " << cl->weight() << endl; - cout << " powered = " << cl->powMessage() << endl; - cout << " residual = " << cl->residual() << endl; + WeightedLink* l = static_cast (links_[i]); + cout << l->toString() << ":" << endl; + cout << " curr msg = " << l->message() << endl; + cout << " next msg = " << l->nextMessage() << endl; + cout << " pow msg = " << l->powMessage() << endl; + cout << " index = " << l->index() << endl; + cout << " weight = " << l->weight() << endl; + cout << " residual = " << l->residual() << endl; } } diff --git a/packages/CLPBN/horus/WeightedBpSolver.h b/packages/CLPBN/horus/WeightedBpSolver.h index 6416eb48b..53bce1bc7 100644 --- a/packages/CLPBN/horus/WeightedBpSolver.h +++ b/packages/CLPBN/horus/WeightedBpSolver.h @@ -35,7 +35,8 @@ class WeightedBpSolver : public BpSolver { public: WeightedBpSolver (const FactorGraph& fg, - const vector>&); + const vector>& weights) + : BpSolver (fg), weights_(weights) { } ~WeightedBpSolver (void);