From 47768176031d5e1d301b0211f31f3257ddad3d89 Mon Sep 17 00:00:00 2001 From: Tiago Gomes Date: Sat, 10 Nov 2012 00:18:20 +0000 Subject: [PATCH] move more code around --- packages/CLPBN/horus/LiftedBp.cpp | 17 +++---- packages/CLPBN/horus/LiftedOperations.cpp | 61 +++++++++++++++++++++++ packages/CLPBN/horus/LiftedOperations.h | 4 +- packages/CLPBN/horus/LiftedVe.cpp | 61 +---------------------- packages/CLPBN/horus/LiftedVe.h | 26 +++------- 5 files changed, 81 insertions(+), 88 deletions(-) diff --git a/packages/CLPBN/horus/LiftedBp.cpp b/packages/CLPBN/horus/LiftedBp.cpp index 1c5a93023..468e79d72 100644 --- a/packages/CLPBN/horus/LiftedBp.cpp +++ b/packages/CLPBN/horus/LiftedBp.cpp @@ -2,7 +2,6 @@ #include "WeightedBp.h" #include "FactorGraph.h" #include "LiftedOperations.h" -#include "LiftedVe.h" LiftedBp::LiftedBp (const ParfactorList& pfList) @@ -190,12 +189,12 @@ LiftedBp::rangeOfGround (const Ground& gr) Params LiftedBp::getJointByConditioning ( const ParfactorList& pfList, - const Grounds& grounds) + const Grounds& query) { LiftedBp solver (pfList); - Params prevBeliefs = solver.solveQuery ({grounds[0]}); - Grounds obsGrounds = {grounds[0]}; - for (size_t i = 1; i < grounds.size(); i++) { + Params prevBeliefs = solver.solveQuery ({query[0]}); + Grounds obsGrounds = {query[0]}; + for (size_t i = 1; i < query.size(); i++) { Params newBeliefs; vector obsFs; Ranges obsRanges; @@ -210,16 +209,16 @@ LiftedBp::getJointByConditioning ( obsFs[j].setEvidence (indexer[j]); } ParfactorList tempPfList (pfList); - LiftedVe::absorveEvidence (tempPfList, obsFs); + LiftedOperations::absorveEvidence (tempPfList, obsFs); LiftedBp solver (tempPfList); - Params beliefs = solver.solveQuery ({grounds[i]}); + Params beliefs = solver.solveQuery ({query[i]}); for (size_t k = 0; k < beliefs.size(); k++) { newBeliefs.push_back (beliefs[k]); } ++ indexer; } int count = -1; - unsigned range = rangeOfGround (grounds[i]); + unsigned range = rangeOfGround (query[i]); for (size_t j = 0; j < newBeliefs.size(); j++) { if (j % range == 0) { count ++; @@ -227,7 +226,7 @@ LiftedBp::getJointByConditioning ( newBeliefs[j] *= prevBeliefs[count]; } prevBeliefs = newBeliefs; - obsGrounds.push_back (grounds[i]); + obsGrounds.push_back (query[i]); } return prevBeliefs; } diff --git a/packages/CLPBN/horus/LiftedOperations.cpp b/packages/CLPBN/horus/LiftedOperations.cpp index 1c9c8b413..ec40695a7 100644 --- a/packages/CLPBN/horus/LiftedOperations.cpp +++ b/packages/CLPBN/horus/LiftedOperations.cpp @@ -54,6 +54,67 @@ LiftedOperations::shatterAgainstQuery ( +void +LiftedOperations::runWeakBayesBall ( + ParfactorList& pfList, + const Grounds& query) +{ + queue todo; // groups to process + set done; // processed or in queue + for (size_t i = 0; i < query.size(); i++) { + ParfactorList::iterator it = pfList.begin(); + while (it != pfList.end()) { + PrvGroup group = (*it)->findGroup (query[i]); + if (group != numeric_limits::max()) { + todo.push (group); + done.insert (group); + break; + } + ++ it; + } + } + + set requiredPfs; + while (todo.empty() == false) { + PrvGroup group = todo.front(); + ParfactorList::iterator it = pfList.begin(); + while (it != pfList.end()) { + if (Util::contains (requiredPfs, *it) == false && + (*it)->containsGroup (group)) { + vector groups = (*it)->getAllGroups(); + for (size_t i = 0; i < groups.size(); i++) { + if (Util::contains (done, groups[i]) == false) { + todo.push (groups[i]); + done.insert (groups[i]); + } + } + requiredPfs.insert (*it); + } + ++ it; + } + todo.pop(); + } + + ParfactorList::iterator it = pfList.begin(); + bool foundNotRequired = false; + while (it != pfList.end()) { + if (Util::contains (requiredPfs, *it) == false) { + if (Globals::verbosity > 2) { + if (foundNotRequired == false) { + Util::printHeader ("PARFACTORS TO DISCARD"); + foundNotRequired = true; + } + (*it)->print(); + } + it = pfList.removeAndDelete (it); + } else { + ++ it; + } + } +} + + + void LiftedOperations::absorveEvidence ( ParfactorList& pfList, diff --git a/packages/CLPBN/horus/LiftedOperations.h b/packages/CLPBN/horus/LiftedOperations.h index ae6b6f0a9..1e21f317c 100644 --- a/packages/CLPBN/horus/LiftedOperations.h +++ b/packages/CLPBN/horus/LiftedOperations.h @@ -9,6 +9,9 @@ class LiftedOperations static void shatterAgainstQuery ( ParfactorList& pfList, const Grounds& query); + static void runWeakBayesBall ( + ParfactorList& pfList, const Grounds&); + static void absorveEvidence ( ParfactorList& pfList, ObservedFormulas& obsFormulas); @@ -17,7 +20,6 @@ class LiftedOperations static Parfactor calcGroundMultiplication (Parfactor pf); private: - static Parfactors absorve (ObservedFormula&, Parfactor*); }; diff --git a/packages/CLPBN/horus/LiftedVe.cpp b/packages/CLPBN/horus/LiftedVe.cpp index 0b5435295..add7a36a6 100644 --- a/packages/CLPBN/horus/LiftedVe.cpp +++ b/packages/CLPBN/horus/LiftedVe.cpp @@ -661,7 +661,7 @@ LiftedVe::runSolver (const Grounds& query) { largestCost_ = std::log (0); LiftedOperations::shatterAgainstQuery (pfList_, query); - runWeakBayesBall (query); + LiftedOperations::runWeakBayesBall (pfList_, query); while (true) { if (Globals::verbosity > 2) { Util::printDashedLine(); @@ -727,62 +727,3 @@ LiftedVe::getBestOperation (const Grounds& query) return bestOp; } - - -void -LiftedVe::runWeakBayesBall (const Grounds& query) -{ - queue todo; // groups to process - set done; // processed or in queue - for (size_t i = 0; i < query.size(); i++) { - ParfactorList::iterator it = pfList_.begin(); - while (it != pfList_.end()) { - PrvGroup group = (*it)->findGroup (query[i]); - if (group != numeric_limits::max()) { - todo.push (group); - done.insert (group); - break; - } - ++ it; - } - } - - set requiredPfs; - while (todo.empty() == false) { - PrvGroup group = todo.front(); - ParfactorList::iterator it = pfList_.begin(); - while (it != pfList_.end()) { - if (Util::contains (requiredPfs, *it) == false && - (*it)->containsGroup (group)) { - vector groups = (*it)->getAllGroups(); - for (size_t i = 0; i < groups.size(); i++) { - if (Util::contains (done, groups[i]) == false) { - todo.push (groups[i]); - done.insert (groups[i]); - } - } - requiredPfs.insert (*it); - } - ++ it; - } - todo.pop(); - } - - ParfactorList::iterator it = pfList_.begin(); - bool foundNotRequired = false; - while (it != pfList_.end()) { - if (Util::contains (requiredPfs, *it) == false) { - if (Globals::verbosity > 2) { - if (foundNotRequired == false) { - Util::printHeader ("PARFACTORS TO DISCARD"); - foundNotRequired = true; - } - (*it)->print(); - } - it = pfList_.removeAndDelete (it); - } else { - ++ it; - } - } -} - diff --git a/packages/CLPBN/horus/LiftedVe.h b/packages/CLPBN/horus/LiftedVe.h index 85adacc18..9a464a348 100644 --- a/packages/CLPBN/horus/LiftedVe.h +++ b/packages/CLPBN/horus/LiftedVe.h @@ -45,9 +45,9 @@ class ProductOperator : public LiftedOperator private: static bool validOp (Parfactor*, Parfactor*); - ParfactorList::iterator g1_; - ParfactorList::iterator g2_; - ParfactorList& pfList_; + ParfactorList::iterator g1_; + ParfactorList::iterator g2_; + ParfactorList& pfList_; }; @@ -125,9 +125,9 @@ class GroundOperator : public LiftedOperator private: vector> getAffectedFormulas (void); - PrvGroup group_; - unsigned lvIndex_; - ParfactorList& pfList_; + PrvGroup group_; + unsigned lvIndex_; + ParfactorList& pfList_; }; @@ -141,23 +141,13 @@ class LiftedVe void printSolverFlags (void) const; - static void absorveEvidence ( - ParfactorList& pfList, ObservedFormulas& obsFormulas); - private: void runSolver (const Grounds&); LiftedOperator* getBestOperation (const Grounds&); - void runWeakBayesBall (const Grounds&); - - void shatterAgainstQuery (const Grounds&); - - static Parfactors absorve (ObservedFormula&, Parfactor*); - - ParfactorList pfList_; - - double largestCost_; + ParfactorList pfList_; + double largestCost_; }; #endif // HORUS_LIFTEDVE_H