diff --git a/packages/CLPBN/horus/BeliefProp.cpp b/packages/CLPBN/horus/BeliefProp.cpp index 64195c91b..e8d5244ad 100644 --- a/packages/CLPBN/horus/BeliefProp.cpp +++ b/packages/CLPBN/horus/BeliefProp.cpp @@ -9,6 +9,11 @@ #include "Horus.h" +MsgSchedule BeliefProp::schedule = MsgSchedule::SEQ_FIXED; +double BeliefProp::accuracy = 0.0001; +unsigned BeliefProp::maxIter = 1000; + + BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg) { runned_ = false; @@ -48,15 +53,14 @@ BeliefProp::printSolverFlags (void) const stringstream ss; ss << "belief propagation [" ; ss << "schedule=" ; - typedef BpOptions::Schedule Sch; - switch (BpOptions::schedule) { - case Sch::SEQ_FIXED: ss << "seq_fixed"; break; - case Sch::SEQ_RANDOM: ss << "seq_random"; break; - case Sch::PARALLEL: ss << "parallel"; break; - case Sch::MAX_RESIDUAL: ss << "max_residual"; break; + switch (schedule) { + case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; + case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; + case MsgSchedule::PARALLEL: ss << "parallel"; break; + case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; } - ss << ",max_iter=" << Util::toString (BpOptions::maxIter); - ss << ",accuracy=" << Util::toString (BpOptions::accuracy); + ss << ",max_iter=" << Util::toString (maxIter); + ss << ",accuracy=" << Util::toString (accuracy); ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << "]" ; cout << ss.str() << endl; @@ -153,21 +157,21 @@ BeliefProp::runSolver (void) { initializeSolver(); nIters_ = 0; - while (!converged() && nIters_ < BpOptions::maxIter) { + while (!converged() && nIters_ < maxIter) { nIters_ ++; if (Globals::verbosity > 1) { Util::printHeader (string ("Iteration ") + Util::toString (nIters_)); } - switch (BpOptions::schedule) { - case BpOptions::Schedule::SEQ_RANDOM: + switch (schedule) { + case MsgSchedule::SEQ_RANDOM: std::random_shuffle (links_.begin(), links_.end()); // no break - case BpOptions::Schedule::SEQ_FIXED: + case MsgSchedule::SEQ_FIXED: for (size_t i = 0; i < links_.size(); i++) { calculateAndUpdateMessage (links_[i]); } break; - case BpOptions::Schedule::PARALLEL: + case MsgSchedule::PARALLEL: for (size_t i = 0; i < links_.size(); i++) { calculateMessage (links_[i]); } @@ -175,13 +179,13 @@ BeliefProp::runSolver (void) updateMessage(links_[i]); } break; - case BpOptions::Schedule::MAX_RESIDUAL: + case MsgSchedule::MAX_RESIDUAL: maxResidualSchedule(); break; } } if (Globals::verbosity > 0) { - if (nIters_ < BpOptions::maxIter) { + if (nIters_ < maxIter) { cout << "Belief propagation converged in " ; cout << nIters_ << " iterations" << endl; } else { @@ -233,7 +237,7 @@ BeliefProp::maxResidualSchedule (void) SortedOrder::iterator it = sortedOrder_.begin(); BpLink* link = *it; - if (link->residual() < BpOptions::accuracy) { + if (link->residual() < accuracy) { return; } updateMessage (link); @@ -423,9 +427,9 @@ BeliefProp::converged (void) return false; } bool converged = true; - if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) { + if (schedule == MsgSchedule::MAX_RESIDUAL) { double maxResidual = (*(sortedOrder_.begin()))->residual(); - if (maxResidual > BpOptions::accuracy) { + if (maxResidual > accuracy) { converged = false; } else { converged = true; @@ -436,7 +440,7 @@ BeliefProp::converged (void) if (Globals::verbosity > 1) { cout << links_[i]->toString() + " residual = " << residual << endl; } - if (residual > BpOptions::accuracy) { + if (residual > accuracy) { converged = false; if (Globals::verbosity < 2) { break; diff --git a/packages/CLPBN/horus/BeliefProp.h b/packages/CLPBN/horus/BeliefProp.h index cfdf98cbb..a7f9a3961 100644 --- a/packages/CLPBN/horus/BeliefProp.h +++ b/packages/CLPBN/horus/BeliefProp.h @@ -13,6 +13,14 @@ using namespace std; +enum MsgSchedule { + SEQ_FIXED, + SEQ_RANDOM, + PARALLEL, + MAX_RESIDUAL +}; + + class BpLink { public: @@ -98,6 +106,10 @@ class BeliefProp : public GroundSolver virtual Params getJointDistributionOf (const VarIds&); + static MsgSchedule schedule; + static double accuracy; + static unsigned maxIter; + protected: void runSolver (void); diff --git a/packages/CLPBN/horus/CountingBp.cpp b/packages/CLPBN/horus/CountingBp.cpp index a0836332f..876104f2a 100644 --- a/packages/CLPBN/horus/CountingBp.cpp +++ b/packages/CLPBN/horus/CountingBp.cpp @@ -37,15 +37,14 @@ CountingBp::printSolverFlags (void) const stringstream ss; ss << "counting bp [" ; ss << "schedule=" ; - typedef BpOptions::Schedule Sch; - switch (BpOptions::schedule) { - case Sch::SEQ_FIXED: ss << "seq_fixed"; break; - case Sch::SEQ_RANDOM: ss << "seq_random"; break; - case Sch::PARALLEL: ss << "parallel"; break; - case Sch::MAX_RESIDUAL: ss << "max_residual"; break; + switch (WeightedBp::schedule) { + case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; + case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; + case MsgSchedule::PARALLEL: ss << "parallel"; break; + case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; } - ss << ",max_iter=" << BpOptions::maxIter; - ss << ",accuracy=" << BpOptions::accuracy; + ss << ",max_iter=" << WeightedBp::maxIter; + ss << ",accuracy=" << WeightedBp::accuracy; ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors); ss << "]" ; diff --git a/packages/CLPBN/horus/Horus.h b/packages/CLPBN/horus/Horus.h index 960e7bb6a..17141d63e 100644 --- a/packages/CLPBN/horus/Horus.h +++ b/packages/CLPBN/horus/Horus.h @@ -67,19 +67,5 @@ const unsigned PRECISION = 6; }; - -namespace BpOptions -{ - enum Schedule { - SEQ_FIXED, - SEQ_RANDOM, - PARALLEL, - MAX_RESIDUAL - }; - extern Schedule schedule; - extern double accuracy; - extern unsigned maxIter; -} - #endif // HORUS_HORUS_H diff --git a/packages/CLPBN/horus/LiftedBp.cpp b/packages/CLPBN/horus/LiftedBp.cpp index bdf761e4f..7cfb49c23 100644 --- a/packages/CLPBN/horus/LiftedBp.cpp +++ b/packages/CLPBN/horus/LiftedBp.cpp @@ -63,15 +63,14 @@ LiftedBp::printSolverFlags (void) const stringstream ss; ss << "lifted bp [" ; ss << "schedule=" ; - typedef BpOptions::Schedule Sch; - switch (BpOptions::schedule) { - case Sch::SEQ_FIXED: ss << "seq_fixed"; break; - case Sch::SEQ_RANDOM: ss << "seq_random"; break; - case Sch::PARALLEL: ss << "parallel"; break; - case Sch::MAX_RESIDUAL: ss << "max_residual"; break; + switch (WeightedBp::schedule) { + case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; + case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; + case MsgSchedule::PARALLEL: ss << "parallel"; break; + case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; } - ss << ",max_iter=" << BpOptions::maxIter; - ss << ",accuracy=" << BpOptions::accuracy; + ss << ",max_iter=" << WeightedBp::maxIter; + ss << ",accuracy=" << WeightedBp::accuracy; ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << "]" ; cout << ss.str() << endl; diff --git a/packages/CLPBN/horus/Util.cpp b/packages/CLPBN/horus/Util.cpp index 4258908d0..810be63c8 100644 --- a/packages/CLPBN/horus/Util.cpp +++ b/packages/CLPBN/horus/Util.cpp @@ -3,6 +3,7 @@ #include "Util.h" #include "Indexer.h" #include "ElimGraph.h" +#include "BeliefProp.h" namespace Globals { @@ -18,17 +19,6 @@ GroundSolverType groundSolver = GroundSolverType::VE; -namespace BpOptions { -Schedule schedule = BpOptions::Schedule::SEQ_FIXED; -//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM; -//Schedule schedule = BpOptions::Schedule::PARALLEL; -//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL; -double accuracy = 0.0001; -unsigned maxIter = 1000; -} - - - namespace Util { @@ -248,13 +238,13 @@ setHorusFlag (string key, string value) } } else if (key == "schedule") { if ( value == "seq_fixed") { - BpOptions::schedule = BpOptions::Schedule::SEQ_FIXED; + BeliefProp::schedule = MsgSchedule::SEQ_FIXED; } else if (value == "seq_random") { - BpOptions::schedule = BpOptions::Schedule::SEQ_RANDOM; + BeliefProp::schedule = MsgSchedule::SEQ_RANDOM; } else if (value == "parallel") { - BpOptions::schedule = BpOptions::Schedule::PARALLEL; + BeliefProp::schedule = MsgSchedule::PARALLEL; } else if (value == "max_residual") { - BpOptions::schedule = BpOptions::Schedule::MAX_RESIDUAL; + BeliefProp::schedule = MsgSchedule::MAX_RESIDUAL; } else { cerr << "warning: invalid value `" << value << "' " ; cerr << "for `" << key << "'" << endl; @@ -263,11 +253,11 @@ setHorusFlag (string key, string value) } else if (key == "accuracy") { stringstream ss; ss << value; - ss >> BpOptions::accuracy; + ss >> BeliefProp::accuracy; } else if (key == "max_iter") { stringstream ss; ss << value; - ss >> BpOptions::maxIter; + ss >> BeliefProp::maxIter; } else if (key == "use_logarithms") { if ( value == "true") { Globals::logDomain = true; diff --git a/packages/CLPBN/horus/WeightedBp.cpp b/packages/CLPBN/horus/WeightedBp.cpp index 9f6fca8df..28a31bb60 100644 --- a/packages/CLPBN/horus/WeightedBp.cpp +++ b/packages/CLPBN/horus/WeightedBp.cpp @@ -107,7 +107,7 @@ WeightedBp::maxResidualSchedule (void) if (Globals::verbosity >= 1) { cout << "updating " << (*sortedOrder_.begin())->toString() << endl; } - if (link->residual() < BpOptions::accuracy) { + if (link->residual() < accuracy) { return; } link->updateMessage();