diff --git a/packages/CLPBN/horus/BeliefProp.cpp b/packages/CLPBN/horus/BeliefProp.cpp index e8d5244ad..bf8f30a79 100644 --- a/packages/CLPBN/horus/BeliefProp.cpp +++ b/packages/CLPBN/horus/BeliefProp.cpp @@ -9,9 +9,9 @@ #include "Horus.h" -MsgSchedule BeliefProp::schedule = MsgSchedule::SEQ_FIXED; -double BeliefProp::accuracy = 0.0001; -unsigned BeliefProp::maxIter = 1000; +double BeliefProp::accuracy_ = 0.0001; +unsigned BeliefProp::maxIter_ = 1000; +MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED; BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg) @@ -53,14 +53,14 @@ BeliefProp::printSolverFlags (void) const stringstream ss; ss << "belief propagation [" ; ss << "schedule=" ; - switch (schedule) { + switch (schedule_) { case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::PARALLEL: ss << "parallel"; break; case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; } - ss << ",max_iter=" << Util::toString (maxIter); - ss << ",accuracy=" << Util::toString (accuracy); + ss << ",max_iter=" << Util::toString (maxIter_); + ss << ",accuracy=" << Util::toString (accuracy_); ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << "]" ; cout << ss.str() << endl; @@ -157,12 +157,12 @@ BeliefProp::runSolver (void) { initializeSolver(); nIters_ = 0; - while (!converged() && nIters_ < maxIter) { + while (!converged() && nIters_ < maxIter_) { nIters_ ++; if (Globals::verbosity > 1) { Util::printHeader (string ("Iteration ") + Util::toString (nIters_)); } - switch (schedule) { + switch (schedule_) { case MsgSchedule::SEQ_RANDOM: std::random_shuffle (links_.begin(), links_.end()); // no break @@ -185,7 +185,7 @@ BeliefProp::runSolver (void) } } if (Globals::verbosity > 0) { - if (nIters_ < maxIter) { + if (nIters_ < maxIter_) { cout << "Belief propagation converged in " ; cout << nIters_ << " iterations" << endl; } else { @@ -237,7 +237,7 @@ BeliefProp::maxResidualSchedule (void) SortedOrder::iterator it = sortedOrder_.begin(); BpLink* link = *it; - if (link->residual() < accuracy) { + if (link->residual() < accuracy_) { return; } updateMessage (link); @@ -427,9 +427,9 @@ BeliefProp::converged (void) return false; } bool converged = true; - if (schedule == MsgSchedule::MAX_RESIDUAL) { + if (schedule_ == MsgSchedule::MAX_RESIDUAL) { double maxResidual = (*(sortedOrder_.begin()))->residual(); - if (maxResidual > accuracy) { + if (maxResidual > accuracy_) { converged = false; } else { converged = true; @@ -440,7 +440,7 @@ BeliefProp::converged (void) if (Globals::verbosity > 1) { cout << links_[i]->toString() + " residual = " << residual << endl; } - if (residual > accuracy) { + if (residual > accuracy_) { converged = false; if (Globals::verbosity < 2) { break; diff --git a/packages/CLPBN/horus/BeliefProp.h b/packages/CLPBN/horus/BeliefProp.h index beaf73b1e..6399c65e7 100644 --- a/packages/CLPBN/horus/BeliefProp.h +++ b/packages/CLPBN/horus/BeliefProp.h @@ -108,9 +108,17 @@ class BeliefProp : public GroundSolver Params getFactorJoint (FacNode* fn, const VarIds&); - static MsgSchedule schedule; - static double accuracy; - static unsigned maxIter; + static double accuracy (void) { return accuracy_; } + + static void setAccuracy (double acc) { accuracy_ = acc; } + + static unsigned maxIterations (void) { return maxIter_; } + + static void setMaxIterations (unsigned mi) { maxIter_ = mi; } + + static MsgSchedule msgSchedule (void) { return schedule_; } + + static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; } protected: SPNodeInfo* ninf (const VarNode* var) const @@ -186,6 +194,10 @@ class BeliefProp : public GroundSolver typedef unordered_map BpLinkMap; BpLinkMap linkMap_; + static double accuracy_; + static unsigned maxIter_; + static MsgSchedule schedule_; + private: void initializeSolver (void); diff --git a/packages/CLPBN/horus/CountingBp.cpp b/packages/CLPBN/horus/CountingBp.cpp index 876104f2a..006bf99fd 100644 --- a/packages/CLPBN/horus/CountingBp.cpp +++ b/packages/CLPBN/horus/CountingBp.cpp @@ -37,14 +37,14 @@ CountingBp::printSolverFlags (void) const stringstream ss; ss << "counting bp [" ; ss << "schedule=" ; - switch (WeightedBp::schedule) { + switch (WeightedBp::msgSchedule()) { case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::PARALLEL: ss << "parallel"; break; case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; } - ss << ",max_iter=" << WeightedBp::maxIter; - ss << ",accuracy=" << WeightedBp::accuracy; + ss << ",max_iter=" << WeightedBp::maxIterations(); + ss << ",accuracy=" << WeightedBp::accuracy(); ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors); ss << "]" ; diff --git a/packages/CLPBN/horus/ElimGraph.cpp b/packages/CLPBN/horus/ElimGraph.cpp index 0292c775f..3a808a8c2 100644 --- a/packages/CLPBN/horus/ElimGraph.cpp +++ b/packages/CLPBN/horus/ElimGraph.cpp @@ -2,7 +2,7 @@ #include "ElimGraph.h" -ElimHeuristic ElimGraph::elimHeuristic = MIN_NEIGHBORS; +ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS; ElimGraph::ElimGraph (const vector& factors) @@ -132,7 +132,7 @@ ElimGraph::getEliminationOrder ( const Factors& factors, VarIds excludedVids) { - if (elimHeuristic == ElimHeuristic::SEQUENTIAL) { + if (elimHeuristic_ == ElimHeuristic::SEQUENTIAL) { VarIds allVids; Factors::const_iterator first = factors.begin(); Factors::const_iterator end = factors.end(); @@ -175,7 +175,7 @@ ElimGraph::getLowestCostNode (void) const EgNode* bestNode = 0; unsigned minCost = Util::maxUnsigned(); EGNeighs::const_iterator it; - switch (elimHeuristic) { + switch (elimHeuristic_) { case MIN_NEIGHBORS: { for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) { unsigned cost = getNeighborsCost (*it); diff --git a/packages/CLPBN/horus/ElimGraph.h b/packages/CLPBN/horus/ElimGraph.h index 575258829..881f59759 100644 --- a/packages/CLPBN/horus/ElimGraph.h +++ b/packages/CLPBN/horus/ElimGraph.h @@ -58,7 +58,9 @@ class ElimGraph static VarIds getEliminationOrder (const Factors&, VarIds); - static ElimHeuristic elimHeuristic; + static ElimHeuristic elimHeuristic (void) { return elimHeuristic_; } + + static void setElimHeuristic (ElimHeuristic eh) { elimHeuristic_ = eh; } private: @@ -132,6 +134,8 @@ class ElimGraph vector nodes_; TinySet unmarked_; unordered_map varMap_; + + static ElimHeuristic elimHeuristic_; }; #endif // HORUS_ELIMGRAPH_H diff --git a/packages/CLPBN/horus/LiftedBp.cpp b/packages/CLPBN/horus/LiftedBp.cpp index 7cfb49c23..18f056f8a 100644 --- a/packages/CLPBN/horus/LiftedBp.cpp +++ b/packages/CLPBN/horus/LiftedBp.cpp @@ -63,14 +63,14 @@ LiftedBp::printSolverFlags (void) const stringstream ss; ss << "lifted bp [" ; ss << "schedule=" ; - switch (WeightedBp::schedule) { + switch (WeightedBp::msgSchedule()) { case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::PARALLEL: ss << "parallel"; break; case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; } - ss << ",max_iter=" << WeightedBp::maxIter; - ss << ",accuracy=" << WeightedBp::accuracy; + ss << ",max_iter=" << WeightedBp::maxIterations(); + ss << ",accuracy=" << WeightedBp::accuracy(); ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << "]" ; cout << ss.str() << endl; diff --git a/packages/CLPBN/horus/Util.cpp b/packages/CLPBN/horus/Util.cpp index 810be63c8..9fad10705 100644 --- a/packages/CLPBN/horus/Util.cpp +++ b/packages/CLPBN/horus/Util.cpp @@ -222,15 +222,15 @@ setHorusFlag (string key, string value) } } else if (key == "elim_heuristic") { if ( value == "sequential") { - ElimGraph::elimHeuristic = ElimHeuristic::SEQUENTIAL; + ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL); } else if (value == "min_neighbors") { - ElimGraph::elimHeuristic = ElimHeuristic::MIN_NEIGHBORS; + ElimGraph::setElimHeuristic (ElimHeuristic::MIN_NEIGHBORS); } else if (value == "min_weight") { - ElimGraph::elimHeuristic = ElimHeuristic::MIN_WEIGHT; + ElimGraph::setElimHeuristic (ElimHeuristic::MIN_WEIGHT); } else if (value == "min_fill") { - ElimGraph::elimHeuristic = ElimHeuristic::MIN_FILL; + ElimGraph::setElimHeuristic (ElimHeuristic::MIN_FILL); } else if (value == "weighted_min_fill") { - ElimGraph::elimHeuristic = ElimHeuristic::WEIGHTED_MIN_FILL; + ElimGraph::setElimHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL); } else { cerr << "warning: invalid value `" << value << "' " ; cerr << "for `" << key << "'" << endl; @@ -238,13 +238,13 @@ setHorusFlag (string key, string value) } } else if (key == "schedule") { if ( value == "seq_fixed") { - BeliefProp::schedule = MsgSchedule::SEQ_FIXED; + BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED); } else if (value == "seq_random") { - BeliefProp::schedule = MsgSchedule::SEQ_RANDOM; + BeliefProp::setMsgSchedule (MsgSchedule::SEQ_RANDOM); } else if (value == "parallel") { - BeliefProp::schedule = MsgSchedule::PARALLEL; + BeliefProp::setMsgSchedule (MsgSchedule::PARALLEL); } else if (value == "max_residual") { - BeliefProp::schedule = MsgSchedule::MAX_RESIDUAL; + BeliefProp::setMsgSchedule (MsgSchedule::MAX_RESIDUAL); } else { cerr << "warning: invalid value `" << value << "' " ; cerr << "for `" << key << "'" << endl; @@ -252,12 +252,16 @@ setHorusFlag (string key, string value) } } else if (key == "accuracy") { stringstream ss; + double acc; ss << value; - ss >> BeliefProp::accuracy; + ss >> acc; + BeliefProp::setAccuracy (acc); } else if (key == "max_iter") { stringstream ss; + unsigned mi; ss << value; - ss >> BeliefProp::maxIter; + ss >> mi; + BeliefProp::setMaxIterations (mi); } else if (key == "use_logarithms") { if ( value == "true") { Globals::logDomain = true; diff --git a/packages/CLPBN/horus/VarElim.cpp b/packages/CLPBN/horus/VarElim.cpp index d31f6ce51..e1b11edf8 100644 --- a/packages/CLPBN/horus/VarElim.cpp +++ b/packages/CLPBN/horus/VarElim.cpp @@ -38,13 +38,12 @@ VarElim::printSolverFlags (void) const stringstream ss; ss << "variable elimination [" ; ss << "elim_heuristic=" ; - ElimHeuristic eh = ElimGraph::elimHeuristic; - switch (eh) { - case SEQUENTIAL: ss << "sequential"; break; - case MIN_NEIGHBORS: ss << "min_neighbors"; break; - case MIN_WEIGHT: ss << "min_weight"; break; - case MIN_FILL: ss << "min_fill"; break; - case WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break; + switch (ElimGraph::elimHeuristic()) { + case ElimHeuristic::SEQUENTIAL: ss << "sequential"; break; + case ElimHeuristic::MIN_NEIGHBORS: ss << "min_neighbors"; break; + case ElimHeuristic::MIN_WEIGHT: ss << "min_weight"; break; + case ElimHeuristic::MIN_FILL: ss << "min_fill"; break; + case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break; } ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << "]" ; diff --git a/packages/CLPBN/horus/WeightedBp.cpp b/packages/CLPBN/horus/WeightedBp.cpp index 28a31bb60..269891f78 100644 --- a/packages/CLPBN/horus/WeightedBp.cpp +++ b/packages/CLPBN/horus/WeightedBp.cpp @@ -107,7 +107,7 @@ WeightedBp::maxResidualSchedule (void) if (Globals::verbosity >= 1) { cout << "updating " << (*sortedOrder_.begin())->toString() << endl; } - if (link->residual() < accuracy) { + if (link->residual() < accuracy_) { return; } link->updateMessage();