Don't use public members for solver flags

This commit is contained in:
Tiago Gomes 2012-12-27 15:44:40 +00:00
parent b996436b24
commit 7b7f663ac6
9 changed files with 64 additions and 45 deletions

View File

@ -9,9 +9,9 @@
#include "Horus.h" #include "Horus.h"
MsgSchedule BeliefProp::schedule = MsgSchedule::SEQ_FIXED; double BeliefProp::accuracy_ = 0.0001;
double BeliefProp::accuracy = 0.0001; unsigned BeliefProp::maxIter_ = 1000;
unsigned BeliefProp::maxIter = 1000; MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED;
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg) BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
@ -53,14 +53,14 @@ BeliefProp::printSolverFlags (void) const
stringstream ss; stringstream ss;
ss << "belief propagation [" ; ss << "belief propagation [" ;
ss << "schedule=" ; ss << "schedule=" ;
switch (schedule) { switch (schedule_) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break; case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
} }
ss << ",max_iter=" << Util::toString (maxIter); ss << ",max_iter=" << Util::toString (maxIter_);
ss << ",accuracy=" << Util::toString (accuracy); ss << ",accuracy=" << Util::toString (accuracy_);
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; cout << ss.str() << endl;
@ -157,12 +157,12 @@ BeliefProp::runSolver (void)
{ {
initializeSolver(); initializeSolver();
nIters_ = 0; nIters_ = 0;
while (!converged() && nIters_ < maxIter) { while (!converged() && nIters_ < maxIter_) {
nIters_ ++; nIters_ ++;
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
Util::printHeader (string ("Iteration ") + Util::toString (nIters_)); Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
} }
switch (schedule) { switch (schedule_) {
case MsgSchedule::SEQ_RANDOM: case MsgSchedule::SEQ_RANDOM:
std::random_shuffle (links_.begin(), links_.end()); std::random_shuffle (links_.begin(), links_.end());
// no break // no break
@ -185,7 +185,7 @@ BeliefProp::runSolver (void)
} }
} }
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
if (nIters_ < maxIter) { if (nIters_ < maxIter_) {
cout << "Belief propagation converged in " ; cout << "Belief propagation converged in " ;
cout << nIters_ << " iterations" << endl; cout << nIters_ << " iterations" << endl;
} else { } else {
@ -237,7 +237,7 @@ BeliefProp::maxResidualSchedule (void)
SortedOrder::iterator it = sortedOrder_.begin(); SortedOrder::iterator it = sortedOrder_.begin();
BpLink* link = *it; BpLink* link = *it;
if (link->residual() < accuracy) { if (link->residual() < accuracy_) {
return; return;
} }
updateMessage (link); updateMessage (link);
@ -427,9 +427,9 @@ BeliefProp::converged (void)
return false; return false;
} }
bool converged = true; bool converged = true;
if (schedule == MsgSchedule::MAX_RESIDUAL) { if (schedule_ == MsgSchedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->residual(); double maxResidual = (*(sortedOrder_.begin()))->residual();
if (maxResidual > accuracy) { if (maxResidual > accuracy_) {
converged = false; converged = false;
} else { } else {
converged = true; converged = true;
@ -440,7 +440,7 @@ BeliefProp::converged (void)
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << links_[i]->toString() + " residual = " << residual << endl; cout << links_[i]->toString() + " residual = " << residual << endl;
} }
if (residual > accuracy) { if (residual > accuracy_) {
converged = false; converged = false;
if (Globals::verbosity < 2) { if (Globals::verbosity < 2) {
break; break;

View File

@ -108,9 +108,17 @@ class BeliefProp : public GroundSolver
Params getFactorJoint (FacNode* fn, const VarIds&); Params getFactorJoint (FacNode* fn, const VarIds&);
static MsgSchedule schedule; static double accuracy (void) { return accuracy_; }
static double accuracy;
static unsigned maxIter; static void setAccuracy (double acc) { accuracy_ = acc; }
static unsigned maxIterations (void) { return maxIter_; }
static void setMaxIterations (unsigned mi) { maxIter_ = mi; }
static MsgSchedule msgSchedule (void) { return schedule_; }
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
protected: protected:
SPNodeInfo* ninf (const VarNode* var) const SPNodeInfo* ninf (const VarNode* var) const
@ -186,6 +194,10 @@ class BeliefProp : public GroundSolver
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap; typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
BpLinkMap linkMap_; BpLinkMap linkMap_;
static double accuracy_;
static unsigned maxIter_;
static MsgSchedule schedule_;
private: private:
void initializeSolver (void); void initializeSolver (void);

View File

@ -37,14 +37,14 @@ CountingBp::printSolverFlags (void) const
stringstream ss; stringstream ss;
ss << "counting bp [" ; ss << "counting bp [" ;
ss << "schedule=" ; ss << "schedule=" ;
switch (WeightedBp::schedule) { switch (WeightedBp::msgSchedule()) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break; case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
} }
ss << ",max_iter=" << WeightedBp::maxIter; ss << ",max_iter=" << WeightedBp::maxIterations();
ss << ",accuracy=" << WeightedBp::accuracy; ss << ",accuracy=" << WeightedBp::accuracy();
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors); ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors);
ss << "]" ; ss << "]" ;

View File

@ -2,7 +2,7 @@
#include "ElimGraph.h" #include "ElimGraph.h"
ElimHeuristic ElimGraph::elimHeuristic = MIN_NEIGHBORS; ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
ElimGraph::ElimGraph (const vector<Factor*>& factors) ElimGraph::ElimGraph (const vector<Factor*>& factors)
@ -132,7 +132,7 @@ ElimGraph::getEliminationOrder (
const Factors& factors, const Factors& factors,
VarIds excludedVids) VarIds excludedVids)
{ {
if (elimHeuristic == ElimHeuristic::SEQUENTIAL) { if (elimHeuristic_ == ElimHeuristic::SEQUENTIAL) {
VarIds allVids; VarIds allVids;
Factors::const_iterator first = factors.begin(); Factors::const_iterator first = factors.begin();
Factors::const_iterator end = factors.end(); Factors::const_iterator end = factors.end();
@ -175,7 +175,7 @@ ElimGraph::getLowestCostNode (void) const
EgNode* bestNode = 0; EgNode* bestNode = 0;
unsigned minCost = Util::maxUnsigned(); unsigned minCost = Util::maxUnsigned();
EGNeighs::const_iterator it; EGNeighs::const_iterator it;
switch (elimHeuristic) { switch (elimHeuristic_) {
case MIN_NEIGHBORS: { case MIN_NEIGHBORS: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) { for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getNeighborsCost (*it); unsigned cost = getNeighborsCost (*it);

View File

@ -58,7 +58,9 @@ class ElimGraph
static VarIds getEliminationOrder (const Factors&, VarIds); static VarIds getEliminationOrder (const Factors&, VarIds);
static ElimHeuristic elimHeuristic; static ElimHeuristic elimHeuristic (void) { return elimHeuristic_; }
static void setElimHeuristic (ElimHeuristic eh) { elimHeuristic_ = eh; }
private: private:
@ -132,6 +134,8 @@ class ElimGraph
vector<EgNode*> nodes_; vector<EgNode*> nodes_;
TinySet<EgNode*> unmarked_; TinySet<EgNode*> unmarked_;
unordered_map<VarId, EgNode*> varMap_; unordered_map<VarId, EgNode*> varMap_;
static ElimHeuristic elimHeuristic_;
}; };
#endif // HORUS_ELIMGRAPH_H #endif // HORUS_ELIMGRAPH_H

View File

@ -63,14 +63,14 @@ LiftedBp::printSolverFlags (void) const
stringstream ss; stringstream ss;
ss << "lifted bp [" ; ss << "lifted bp [" ;
ss << "schedule=" ; ss << "schedule=" ;
switch (WeightedBp::schedule) { switch (WeightedBp::msgSchedule()) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break; case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
} }
ss << ",max_iter=" << WeightedBp::maxIter; ss << ",max_iter=" << WeightedBp::maxIterations();
ss << ",accuracy=" << WeightedBp::accuracy; ss << ",accuracy=" << WeightedBp::accuracy();
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; cout << ss.str() << endl;

View File

@ -222,15 +222,15 @@ setHorusFlag (string key, string value)
} }
} else if (key == "elim_heuristic") { } else if (key == "elim_heuristic") {
if ( value == "sequential") { if ( value == "sequential") {
ElimGraph::elimHeuristic = ElimHeuristic::SEQUENTIAL; ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL);
} else if (value == "min_neighbors") { } else if (value == "min_neighbors") {
ElimGraph::elimHeuristic = ElimHeuristic::MIN_NEIGHBORS; ElimGraph::setElimHeuristic (ElimHeuristic::MIN_NEIGHBORS);
} else if (value == "min_weight") { } else if (value == "min_weight") {
ElimGraph::elimHeuristic = ElimHeuristic::MIN_WEIGHT; ElimGraph::setElimHeuristic (ElimHeuristic::MIN_WEIGHT);
} else if (value == "min_fill") { } else if (value == "min_fill") {
ElimGraph::elimHeuristic = ElimHeuristic::MIN_FILL; ElimGraph::setElimHeuristic (ElimHeuristic::MIN_FILL);
} else if (value == "weighted_min_fill") { } else if (value == "weighted_min_fill") {
ElimGraph::elimHeuristic = ElimHeuristic::WEIGHTED_MIN_FILL; ElimGraph::setElimHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL);
} else { } else {
cerr << "warning: invalid value `" << value << "' " ; cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;
@ -238,13 +238,13 @@ setHorusFlag (string key, string value)
} }
} else if (key == "schedule") { } else if (key == "schedule") {
if ( value == "seq_fixed") { if ( value == "seq_fixed") {
BeliefProp::schedule = MsgSchedule::SEQ_FIXED; BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED);
} else if (value == "seq_random") { } else if (value == "seq_random") {
BeliefProp::schedule = MsgSchedule::SEQ_RANDOM; BeliefProp::setMsgSchedule (MsgSchedule::SEQ_RANDOM);
} else if (value == "parallel") { } else if (value == "parallel") {
BeliefProp::schedule = MsgSchedule::PARALLEL; BeliefProp::setMsgSchedule (MsgSchedule::PARALLEL);
} else if (value == "max_residual") { } else if (value == "max_residual") {
BeliefProp::schedule = MsgSchedule::MAX_RESIDUAL; BeliefProp::setMsgSchedule (MsgSchedule::MAX_RESIDUAL);
} else { } else {
cerr << "warning: invalid value `" << value << "' " ; cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;
@ -252,12 +252,16 @@ setHorusFlag (string key, string value)
} }
} else if (key == "accuracy") { } else if (key == "accuracy") {
stringstream ss; stringstream ss;
double acc;
ss << value; ss << value;
ss >> BeliefProp::accuracy; ss >> acc;
BeliefProp::setAccuracy (acc);
} else if (key == "max_iter") { } else if (key == "max_iter") {
stringstream ss; stringstream ss;
unsigned mi;
ss << value; ss << value;
ss >> BeliefProp::maxIter; ss >> mi;
BeliefProp::setMaxIterations (mi);
} else if (key == "use_logarithms") { } else if (key == "use_logarithms") {
if ( value == "true") { if ( value == "true") {
Globals::logDomain = true; Globals::logDomain = true;

View File

@ -38,13 +38,12 @@ VarElim::printSolverFlags (void) const
stringstream ss; stringstream ss;
ss << "variable elimination [" ; ss << "variable elimination [" ;
ss << "elim_heuristic=" ; ss << "elim_heuristic=" ;
ElimHeuristic eh = ElimGraph::elimHeuristic; switch (ElimGraph::elimHeuristic()) {
switch (eh) { case ElimHeuristic::SEQUENTIAL: ss << "sequential"; break;
case SEQUENTIAL: ss << "sequential"; break; case ElimHeuristic::MIN_NEIGHBORS: ss << "min_neighbors"; break;
case MIN_NEIGHBORS: ss << "min_neighbors"; break; case ElimHeuristic::MIN_WEIGHT: ss << "min_weight"; break;
case MIN_WEIGHT: ss << "min_weight"; break; case ElimHeuristic::MIN_FILL: ss << "min_fill"; break;
case MIN_FILL: ss << "min_fill"; break; case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
case WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
} }
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;

View File

@ -107,7 +107,7 @@ WeightedBp::maxResidualSchedule (void)
if (Globals::verbosity >= 1) { if (Globals::verbosity >= 1) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl; cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
} }
if (link->residual() < accuracy) { if (link->residual() < accuracy_) {
return; return;
} }
link->updateMessage(); link->updateMessage();