Don't use public members for solver flags

This commit is contained in:
Tiago Gomes 2012-12-27 15:44:40 +00:00
parent b996436b24
commit 7b7f663ac6
9 changed files with 64 additions and 45 deletions

View File

@ -9,9 +9,9 @@
#include "Horus.h"
MsgSchedule BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
double BeliefProp::accuracy = 0.0001;
unsigned BeliefProp::maxIter = 1000;
double BeliefProp::accuracy_ = 0.0001;
unsigned BeliefProp::maxIter_ = 1000;
MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED;
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
@ -53,14 +53,14 @@ BeliefProp::printSolverFlags (void) const
stringstream ss;
ss << "belief propagation [" ;
ss << "schedule=" ;
switch (schedule) {
switch (schedule_) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
}
ss << ",max_iter=" << Util::toString (maxIter);
ss << ",accuracy=" << Util::toString (accuracy);
ss << ",max_iter=" << Util::toString (maxIter_);
ss << ",accuracy=" << Util::toString (accuracy_);
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
@ -157,12 +157,12 @@ BeliefProp::runSolver (void)
{
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < maxIter) {
while (!converged() && nIters_ < maxIter_) {
nIters_ ++;
if (Globals::verbosity > 1) {
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
}
switch (schedule) {
switch (schedule_) {
case MsgSchedule::SEQ_RANDOM:
std::random_shuffle (links_.begin(), links_.end());
// no break
@ -185,7 +185,7 @@ BeliefProp::runSolver (void)
}
}
if (Globals::verbosity > 0) {
if (nIters_ < maxIter) {
if (nIters_ < maxIter_) {
cout << "Belief propagation converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
@ -237,7 +237,7 @@ BeliefProp::maxResidualSchedule (void)
SortedOrder::iterator it = sortedOrder_.begin();
BpLink* link = *it;
if (link->residual() < accuracy) {
if (link->residual() < accuracy_) {
return;
}
updateMessage (link);
@ -427,9 +427,9 @@ BeliefProp::converged (void)
return false;
}
bool converged = true;
if (schedule == MsgSchedule::MAX_RESIDUAL) {
if (schedule_ == MsgSchedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->residual();
if (maxResidual > accuracy) {
if (maxResidual > accuracy_) {
converged = false;
} else {
converged = true;
@ -440,7 +440,7 @@ BeliefProp::converged (void)
if (Globals::verbosity > 1) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > accuracy) {
if (residual > accuracy_) {
converged = false;
if (Globals::verbosity < 2) {
break;

View File

@ -108,9 +108,17 @@ class BeliefProp : public GroundSolver
Params getFactorJoint (FacNode* fn, const VarIds&);
static MsgSchedule schedule;
static double accuracy;
static unsigned maxIter;
static double accuracy (void) { return accuracy_; }
static void setAccuracy (double acc) { accuracy_ = acc; }
static unsigned maxIterations (void) { return maxIter_; }
static void setMaxIterations (unsigned mi) { maxIter_ = mi; }
static MsgSchedule msgSchedule (void) { return schedule_; }
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
protected:
SPNodeInfo* ninf (const VarNode* var) const
@ -186,6 +194,10 @@ class BeliefProp : public GroundSolver
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
BpLinkMap linkMap_;
static double accuracy_;
static unsigned maxIter_;
static MsgSchedule schedule_;
private:
void initializeSolver (void);

View File

@ -37,14 +37,14 @@ CountingBp::printSolverFlags (void) const
stringstream ss;
ss << "counting bp [" ;
ss << "schedule=" ;
switch (WeightedBp::schedule) {
switch (WeightedBp::msgSchedule()) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
}
ss << ",max_iter=" << WeightedBp::maxIter;
ss << ",accuracy=" << WeightedBp::accuracy;
ss << ",max_iter=" << WeightedBp::maxIterations();
ss << ",accuracy=" << WeightedBp::accuracy();
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors);
ss << "]" ;

View File

@ -2,7 +2,7 @@
#include "ElimGraph.h"
ElimHeuristic ElimGraph::elimHeuristic = MIN_NEIGHBORS;
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
ElimGraph::ElimGraph (const vector<Factor*>& factors)
@ -132,7 +132,7 @@ ElimGraph::getEliminationOrder (
const Factors& factors,
VarIds excludedVids)
{
if (elimHeuristic == ElimHeuristic::SEQUENTIAL) {
if (elimHeuristic_ == ElimHeuristic::SEQUENTIAL) {
VarIds allVids;
Factors::const_iterator first = factors.begin();
Factors::const_iterator end = factors.end();
@ -175,7 +175,7 @@ ElimGraph::getLowestCostNode (void) const
EgNode* bestNode = 0;
unsigned minCost = Util::maxUnsigned();
EGNeighs::const_iterator it;
switch (elimHeuristic) {
switch (elimHeuristic_) {
case MIN_NEIGHBORS: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getNeighborsCost (*it);

View File

@ -58,7 +58,9 @@ class ElimGraph
static VarIds getEliminationOrder (const Factors&, VarIds);
static ElimHeuristic elimHeuristic;
static ElimHeuristic elimHeuristic (void) { return elimHeuristic_; }
static void setElimHeuristic (ElimHeuristic eh) { elimHeuristic_ = eh; }
private:
@ -132,6 +134,8 @@ class ElimGraph
vector<EgNode*> nodes_;
TinySet<EgNode*> unmarked_;
unordered_map<VarId, EgNode*> varMap_;
static ElimHeuristic elimHeuristic_;
};
#endif // HORUS_ELIMGRAPH_H

View File

@ -63,14 +63,14 @@ LiftedBp::printSolverFlags (void) const
stringstream ss;
ss << "lifted bp [" ;
ss << "schedule=" ;
switch (WeightedBp::schedule) {
switch (WeightedBp::msgSchedule()) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
}
ss << ",max_iter=" << WeightedBp::maxIter;
ss << ",accuracy=" << WeightedBp::accuracy;
ss << ",max_iter=" << WeightedBp::maxIterations();
ss << ",accuracy=" << WeightedBp::accuracy();
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;

View File

@ -222,15 +222,15 @@ setHorusFlag (string key, string value)
}
} else if (key == "elim_heuristic") {
if ( value == "sequential") {
ElimGraph::elimHeuristic = ElimHeuristic::SEQUENTIAL;
ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL);
} else if (value == "min_neighbors") {
ElimGraph::elimHeuristic = ElimHeuristic::MIN_NEIGHBORS;
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_NEIGHBORS);
} else if (value == "min_weight") {
ElimGraph::elimHeuristic = ElimHeuristic::MIN_WEIGHT;
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_WEIGHT);
} else if (value == "min_fill") {
ElimGraph::elimHeuristic = ElimHeuristic::MIN_FILL;
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_FILL);
} else if (value == "weighted_min_fill") {
ElimGraph::elimHeuristic = ElimHeuristic::WEIGHTED_MIN_FILL;
ElimGraph::setElimHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL);
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
@ -238,13 +238,13 @@ setHorusFlag (string key, string value)
}
} else if (key == "schedule") {
if ( value == "seq_fixed") {
BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED);
} else if (value == "seq_random") {
BeliefProp::schedule = MsgSchedule::SEQ_RANDOM;
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_RANDOM);
} else if (value == "parallel") {
BeliefProp::schedule = MsgSchedule::PARALLEL;
BeliefProp::setMsgSchedule (MsgSchedule::PARALLEL);
} else if (value == "max_residual") {
BeliefProp::schedule = MsgSchedule::MAX_RESIDUAL;
BeliefProp::setMsgSchedule (MsgSchedule::MAX_RESIDUAL);
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
@ -252,12 +252,16 @@ setHorusFlag (string key, string value)
}
} else if (key == "accuracy") {
stringstream ss;
double acc;
ss << value;
ss >> BeliefProp::accuracy;
ss >> acc;
BeliefProp::setAccuracy (acc);
} else if (key == "max_iter") {
stringstream ss;
unsigned mi;
ss << value;
ss >> BeliefProp::maxIter;
ss >> mi;
BeliefProp::setMaxIterations (mi);
} else if (key == "use_logarithms") {
if ( value == "true") {
Globals::logDomain = true;

View File

@ -38,13 +38,12 @@ VarElim::printSolverFlags (void) const
stringstream ss;
ss << "variable elimination [" ;
ss << "elim_heuristic=" ;
ElimHeuristic eh = ElimGraph::elimHeuristic;
switch (eh) {
case SEQUENTIAL: ss << "sequential"; break;
case MIN_NEIGHBORS: ss << "min_neighbors"; break;
case MIN_WEIGHT: ss << "min_weight"; break;
case MIN_FILL: ss << "min_fill"; break;
case WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
switch (ElimGraph::elimHeuristic()) {
case ElimHeuristic::SEQUENTIAL: ss << "sequential"; break;
case ElimHeuristic::MIN_NEIGHBORS: ss << "min_neighbors"; break;
case ElimHeuristic::MIN_WEIGHT: ss << "min_weight"; break;
case ElimHeuristic::MIN_FILL: ss << "min_fill"; break;
case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
}
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;

View File

@ -107,7 +107,7 @@ WeightedBp::maxResidualSchedule (void)
if (Globals::verbosity >= 1) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
}
if (link->residual() < accuracy) {
if (link->residual() < accuracy_) {
return;
}
link->updateMessage();