Don't use public members for solver flags
This commit is contained in:
parent
b996436b24
commit
7b7f663ac6
@ -9,9 +9,9 @@
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
MsgSchedule BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
|
||||
double BeliefProp::accuracy = 0.0001;
|
||||
unsigned BeliefProp::maxIter = 1000;
|
||||
double BeliefProp::accuracy_ = 0.0001;
|
||||
unsigned BeliefProp::maxIter_ = 1000;
|
||||
MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED;
|
||||
|
||||
|
||||
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
|
||||
@ -53,14 +53,14 @@ BeliefProp::printSolverFlags (void) const
|
||||
stringstream ss;
|
||||
ss << "belief propagation [" ;
|
||||
ss << "schedule=" ;
|
||||
switch (schedule) {
|
||||
switch (schedule_) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",max_iter=" << Util::toString (maxIter);
|
||||
ss << ",accuracy=" << Util::toString (accuracy);
|
||||
ss << ",max_iter=" << Util::toString (maxIter_);
|
||||
ss << ",accuracy=" << Util::toString (accuracy_);
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
@ -157,12 +157,12 @@ BeliefProp::runSolver (void)
|
||||
{
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
while (!converged() && nIters_ < maxIter) {
|
||||
while (!converged() && nIters_ < maxIter_) {
|
||||
nIters_ ++;
|
||||
if (Globals::verbosity > 1) {
|
||||
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||
}
|
||||
switch (schedule) {
|
||||
switch (schedule_) {
|
||||
case MsgSchedule::SEQ_RANDOM:
|
||||
std::random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
@ -185,7 +185,7 @@ BeliefProp::runSolver (void)
|
||||
}
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
if (nIters_ < maxIter) {
|
||||
if (nIters_ < maxIter_) {
|
||||
cout << "Belief propagation converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
@ -237,7 +237,7 @@ BeliefProp::maxResidualSchedule (void)
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
BpLink* link = *it;
|
||||
if (link->residual() < accuracy) {
|
||||
if (link->residual() < accuracy_) {
|
||||
return;
|
||||
}
|
||||
updateMessage (link);
|
||||
@ -427,9 +427,9 @@ BeliefProp::converged (void)
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (schedule == MsgSchedule::MAX_RESIDUAL) {
|
||||
if (schedule_ == MsgSchedule::MAX_RESIDUAL) {
|
||||
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
||||
if (maxResidual > accuracy) {
|
||||
if (maxResidual > accuracy_) {
|
||||
converged = false;
|
||||
} else {
|
||||
converged = true;
|
||||
@ -440,7 +440,7 @@ BeliefProp::converged (void)
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > accuracy) {
|
||||
if (residual > accuracy_) {
|
||||
converged = false;
|
||||
if (Globals::verbosity < 2) {
|
||||
break;
|
||||
|
@ -108,9 +108,17 @@ class BeliefProp : public GroundSolver
|
||||
|
||||
Params getFactorJoint (FacNode* fn, const VarIds&);
|
||||
|
||||
static MsgSchedule schedule;
|
||||
static double accuracy;
|
||||
static unsigned maxIter;
|
||||
static double accuracy (void) { return accuracy_; }
|
||||
|
||||
static void setAccuracy (double acc) { accuracy_ = acc; }
|
||||
|
||||
static unsigned maxIterations (void) { return maxIter_; }
|
||||
|
||||
static void setMaxIterations (unsigned mi) { maxIter_ = mi; }
|
||||
|
||||
static MsgSchedule msgSchedule (void) { return schedule_; }
|
||||
|
||||
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
|
||||
|
||||
protected:
|
||||
SPNodeInfo* ninf (const VarNode* var) const
|
||||
@ -186,6 +194,10 @@ class BeliefProp : public GroundSolver
|
||||
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||
BpLinkMap linkMap_;
|
||||
|
||||
static double accuracy_;
|
||||
static unsigned maxIter_;
|
||||
static MsgSchedule schedule_;
|
||||
|
||||
private:
|
||||
void initializeSolver (void);
|
||||
|
||||
|
@ -37,14 +37,14 @@ CountingBp::printSolverFlags (void) const
|
||||
stringstream ss;
|
||||
ss << "counting bp [" ;
|
||||
ss << "schedule=" ;
|
||||
switch (WeightedBp::schedule) {
|
||||
switch (WeightedBp::msgSchedule()) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",max_iter=" << WeightedBp::maxIter;
|
||||
ss << ",accuracy=" << WeightedBp::accuracy;
|
||||
ss << ",max_iter=" << WeightedBp::maxIterations();
|
||||
ss << ",accuracy=" << WeightedBp::accuracy();
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors);
|
||||
ss << "]" ;
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include "ElimGraph.h"
|
||||
|
||||
ElimHeuristic ElimGraph::elimHeuristic = MIN_NEIGHBORS;
|
||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
||||
|
||||
|
||||
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||
@ -132,7 +132,7 @@ ElimGraph::getEliminationOrder (
|
||||
const Factors& factors,
|
||||
VarIds excludedVids)
|
||||
{
|
||||
if (elimHeuristic == ElimHeuristic::SEQUENTIAL) {
|
||||
if (elimHeuristic_ == ElimHeuristic::SEQUENTIAL) {
|
||||
VarIds allVids;
|
||||
Factors::const_iterator first = factors.begin();
|
||||
Factors::const_iterator end = factors.end();
|
||||
@ -175,7 +175,7 @@ ElimGraph::getLowestCostNode (void) const
|
||||
EgNode* bestNode = 0;
|
||||
unsigned minCost = Util::maxUnsigned();
|
||||
EGNeighs::const_iterator it;
|
||||
switch (elimHeuristic) {
|
||||
switch (elimHeuristic_) {
|
||||
case MIN_NEIGHBORS: {
|
||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||
unsigned cost = getNeighborsCost (*it);
|
||||
|
@ -58,7 +58,9 @@ class ElimGraph
|
||||
|
||||
static VarIds getEliminationOrder (const Factors&, VarIds);
|
||||
|
||||
static ElimHeuristic elimHeuristic;
|
||||
static ElimHeuristic elimHeuristic (void) { return elimHeuristic_; }
|
||||
|
||||
static void setElimHeuristic (ElimHeuristic eh) { elimHeuristic_ = eh; }
|
||||
|
||||
private:
|
||||
|
||||
@ -132,6 +134,8 @@ class ElimGraph
|
||||
vector<EgNode*> nodes_;
|
||||
TinySet<EgNode*> unmarked_;
|
||||
unordered_map<VarId, EgNode*> varMap_;
|
||||
|
||||
static ElimHeuristic elimHeuristic_;
|
||||
};
|
||||
|
||||
#endif // HORUS_ELIMGRAPH_H
|
||||
|
@ -63,14 +63,14 @@ LiftedBp::printSolverFlags (void) const
|
||||
stringstream ss;
|
||||
ss << "lifted bp [" ;
|
||||
ss << "schedule=" ;
|
||||
switch (WeightedBp::schedule) {
|
||||
switch (WeightedBp::msgSchedule()) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",max_iter=" << WeightedBp::maxIter;
|
||||
ss << ",accuracy=" << WeightedBp::accuracy;
|
||||
ss << ",max_iter=" << WeightedBp::maxIterations();
|
||||
ss << ",accuracy=" << WeightedBp::accuracy();
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
|
@ -222,15 +222,15 @@ setHorusFlag (string key, string value)
|
||||
}
|
||||
} else if (key == "elim_heuristic") {
|
||||
if ( value == "sequential") {
|
||||
ElimGraph::elimHeuristic = ElimHeuristic::SEQUENTIAL;
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL);
|
||||
} else if (value == "min_neighbors") {
|
||||
ElimGraph::elimHeuristic = ElimHeuristic::MIN_NEIGHBORS;
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_NEIGHBORS);
|
||||
} else if (value == "min_weight") {
|
||||
ElimGraph::elimHeuristic = ElimHeuristic::MIN_WEIGHT;
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_WEIGHT);
|
||||
} else if (value == "min_fill") {
|
||||
ElimGraph::elimHeuristic = ElimHeuristic::MIN_FILL;
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_FILL);
|
||||
} else if (value == "weighted_min_fill") {
|
||||
ElimGraph::elimHeuristic = ElimHeuristic::WEIGHTED_MIN_FILL;
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL);
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
@ -238,13 +238,13 @@ setHorusFlag (string key, string value)
|
||||
}
|
||||
} else if (key == "schedule") {
|
||||
if ( value == "seq_fixed") {
|
||||
BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED);
|
||||
} else if (value == "seq_random") {
|
||||
BeliefProp::schedule = MsgSchedule::SEQ_RANDOM;
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_RANDOM);
|
||||
} else if (value == "parallel") {
|
||||
BeliefProp::schedule = MsgSchedule::PARALLEL;
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::PARALLEL);
|
||||
} else if (value == "max_residual") {
|
||||
BeliefProp::schedule = MsgSchedule::MAX_RESIDUAL;
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::MAX_RESIDUAL);
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
@ -252,12 +252,16 @@ setHorusFlag (string key, string value)
|
||||
}
|
||||
} else if (key == "accuracy") {
|
||||
stringstream ss;
|
||||
double acc;
|
||||
ss << value;
|
||||
ss >> BeliefProp::accuracy;
|
||||
ss >> acc;
|
||||
BeliefProp::setAccuracy (acc);
|
||||
} else if (key == "max_iter") {
|
||||
stringstream ss;
|
||||
unsigned mi;
|
||||
ss << value;
|
||||
ss >> BeliefProp::maxIter;
|
||||
ss >> mi;
|
||||
BeliefProp::setMaxIterations (mi);
|
||||
} else if (key == "use_logarithms") {
|
||||
if ( value == "true") {
|
||||
Globals::logDomain = true;
|
||||
|
@ -38,13 +38,12 @@ VarElim::printSolverFlags (void) const
|
||||
stringstream ss;
|
||||
ss << "variable elimination [" ;
|
||||
ss << "elim_heuristic=" ;
|
||||
ElimHeuristic eh = ElimGraph::elimHeuristic;
|
||||
switch (eh) {
|
||||
case SEQUENTIAL: ss << "sequential"; break;
|
||||
case MIN_NEIGHBORS: ss << "min_neighbors"; break;
|
||||
case MIN_WEIGHT: ss << "min_weight"; break;
|
||||
case MIN_FILL: ss << "min_fill"; break;
|
||||
case WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
|
||||
switch (ElimGraph::elimHeuristic()) {
|
||||
case ElimHeuristic::SEQUENTIAL: ss << "sequential"; break;
|
||||
case ElimHeuristic::MIN_NEIGHBORS: ss << "min_neighbors"; break;
|
||||
case ElimHeuristic::MIN_WEIGHT: ss << "min_weight"; break;
|
||||
case ElimHeuristic::MIN_FILL: ss << "min_fill"; break;
|
||||
case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
|
||||
}
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
|
@ -107,7 +107,7 @@ WeightedBp::maxResidualSchedule (void)
|
||||
if (Globals::verbosity >= 1) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
}
|
||||
if (link->residual() < accuracy) {
|
||||
if (link->residual() < accuracy_) {
|
||||
return;
|
||||
}
|
||||
link->updateMessage();
|
||||
|
Reference in New Issue
Block a user