Don't use public members for solver flags
This commit is contained in:
parent
b996436b24
commit
7b7f663ac6
@ -9,9 +9,9 @@
|
|||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
MsgSchedule BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
|
double BeliefProp::accuracy_ = 0.0001;
|
||||||
double BeliefProp::accuracy = 0.0001;
|
unsigned BeliefProp::maxIter_ = 1000;
|
||||||
unsigned BeliefProp::maxIter = 1000;
|
MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED;
|
||||||
|
|
||||||
|
|
||||||
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
|
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
|
||||||
@ -53,14 +53,14 @@ BeliefProp::printSolverFlags (void) const
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "belief propagation [" ;
|
ss << "belief propagation [" ;
|
||||||
ss << "schedule=" ;
|
ss << "schedule=" ;
|
||||||
switch (schedule) {
|
switch (schedule_) {
|
||||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
}
|
}
|
||||||
ss << ",max_iter=" << Util::toString (maxIter);
|
ss << ",max_iter=" << Util::toString (maxIter_);
|
||||||
ss << ",accuracy=" << Util::toString (accuracy);
|
ss << ",accuracy=" << Util::toString (accuracy_);
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
cout << ss.str() << endl;
|
cout << ss.str() << endl;
|
||||||
@ -157,12 +157,12 @@ BeliefProp::runSolver (void)
|
|||||||
{
|
{
|
||||||
initializeSolver();
|
initializeSolver();
|
||||||
nIters_ = 0;
|
nIters_ = 0;
|
||||||
while (!converged() && nIters_ < maxIter) {
|
while (!converged() && nIters_ < maxIter_) {
|
||||||
nIters_ ++;
|
nIters_ ++;
|
||||||
if (Globals::verbosity > 1) {
|
if (Globals::verbosity > 1) {
|
||||||
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||||
}
|
}
|
||||||
switch (schedule) {
|
switch (schedule_) {
|
||||||
case MsgSchedule::SEQ_RANDOM:
|
case MsgSchedule::SEQ_RANDOM:
|
||||||
std::random_shuffle (links_.begin(), links_.end());
|
std::random_shuffle (links_.begin(), links_.end());
|
||||||
// no break
|
// no break
|
||||||
@ -185,7 +185,7 @@ BeliefProp::runSolver (void)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 0) {
|
if (Globals::verbosity > 0) {
|
||||||
if (nIters_ < maxIter) {
|
if (nIters_ < maxIter_) {
|
||||||
cout << "Belief propagation converged in " ;
|
cout << "Belief propagation converged in " ;
|
||||||
cout << nIters_ << " iterations" << endl;
|
cout << nIters_ << " iterations" << endl;
|
||||||
} else {
|
} else {
|
||||||
@ -237,7 +237,7 @@ BeliefProp::maxResidualSchedule (void)
|
|||||||
|
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
BpLink* link = *it;
|
BpLink* link = *it;
|
||||||
if (link->residual() < accuracy) {
|
if (link->residual() < accuracy_) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
updateMessage (link);
|
updateMessage (link);
|
||||||
@ -427,9 +427,9 @@ BeliefProp::converged (void)
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool converged = true;
|
bool converged = true;
|
||||||
if (schedule == MsgSchedule::MAX_RESIDUAL) {
|
if (schedule_ == MsgSchedule::MAX_RESIDUAL) {
|
||||||
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
||||||
if (maxResidual > accuracy) {
|
if (maxResidual > accuracy_) {
|
||||||
converged = false;
|
converged = false;
|
||||||
} else {
|
} else {
|
||||||
converged = true;
|
converged = true;
|
||||||
@ -440,7 +440,7 @@ BeliefProp::converged (void)
|
|||||||
if (Globals::verbosity > 1) {
|
if (Globals::verbosity > 1) {
|
||||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||||
}
|
}
|
||||||
if (residual > accuracy) {
|
if (residual > accuracy_) {
|
||||||
converged = false;
|
converged = false;
|
||||||
if (Globals::verbosity < 2) {
|
if (Globals::verbosity < 2) {
|
||||||
break;
|
break;
|
||||||
|
@ -108,9 +108,17 @@ class BeliefProp : public GroundSolver
|
|||||||
|
|
||||||
Params getFactorJoint (FacNode* fn, const VarIds&);
|
Params getFactorJoint (FacNode* fn, const VarIds&);
|
||||||
|
|
||||||
static MsgSchedule schedule;
|
static double accuracy (void) { return accuracy_; }
|
||||||
static double accuracy;
|
|
||||||
static unsigned maxIter;
|
static void setAccuracy (double acc) { accuracy_ = acc; }
|
||||||
|
|
||||||
|
static unsigned maxIterations (void) { return maxIter_; }
|
||||||
|
|
||||||
|
static void setMaxIterations (unsigned mi) { maxIter_ = mi; }
|
||||||
|
|
||||||
|
static MsgSchedule msgSchedule (void) { return schedule_; }
|
||||||
|
|
||||||
|
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
SPNodeInfo* ninf (const VarNode* var) const
|
SPNodeInfo* ninf (const VarNode* var) const
|
||||||
@ -186,6 +194,10 @@ class BeliefProp : public GroundSolver
|
|||||||
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||||
BpLinkMap linkMap_;
|
BpLinkMap linkMap_;
|
||||||
|
|
||||||
|
static double accuracy_;
|
||||||
|
static unsigned maxIter_;
|
||||||
|
static MsgSchedule schedule_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void initializeSolver (void);
|
void initializeSolver (void);
|
||||||
|
|
||||||
|
@ -37,14 +37,14 @@ CountingBp::printSolverFlags (void) const
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "counting bp [" ;
|
ss << "counting bp [" ;
|
||||||
ss << "schedule=" ;
|
ss << "schedule=" ;
|
||||||
switch (WeightedBp::schedule) {
|
switch (WeightedBp::msgSchedule()) {
|
||||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
}
|
}
|
||||||
ss << ",max_iter=" << WeightedBp::maxIter;
|
ss << ",max_iter=" << WeightedBp::maxIterations();
|
||||||
ss << ",accuracy=" << WeightedBp::accuracy;
|
ss << ",accuracy=" << WeightedBp::accuracy();
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors);
|
ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include "ElimGraph.h"
|
#include "ElimGraph.h"
|
||||||
|
|
||||||
ElimHeuristic ElimGraph::elimHeuristic = MIN_NEIGHBORS;
|
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
||||||
|
|
||||||
|
|
||||||
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||||
@ -132,7 +132,7 @@ ElimGraph::getEliminationOrder (
|
|||||||
const Factors& factors,
|
const Factors& factors,
|
||||||
VarIds excludedVids)
|
VarIds excludedVids)
|
||||||
{
|
{
|
||||||
if (elimHeuristic == ElimHeuristic::SEQUENTIAL) {
|
if (elimHeuristic_ == ElimHeuristic::SEQUENTIAL) {
|
||||||
VarIds allVids;
|
VarIds allVids;
|
||||||
Factors::const_iterator first = factors.begin();
|
Factors::const_iterator first = factors.begin();
|
||||||
Factors::const_iterator end = factors.end();
|
Factors::const_iterator end = factors.end();
|
||||||
@ -175,7 +175,7 @@ ElimGraph::getLowestCostNode (void) const
|
|||||||
EgNode* bestNode = 0;
|
EgNode* bestNode = 0;
|
||||||
unsigned minCost = Util::maxUnsigned();
|
unsigned minCost = Util::maxUnsigned();
|
||||||
EGNeighs::const_iterator it;
|
EGNeighs::const_iterator it;
|
||||||
switch (elimHeuristic) {
|
switch (elimHeuristic_) {
|
||||||
case MIN_NEIGHBORS: {
|
case MIN_NEIGHBORS: {
|
||||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||||
unsigned cost = getNeighborsCost (*it);
|
unsigned cost = getNeighborsCost (*it);
|
||||||
|
@ -58,7 +58,9 @@ class ElimGraph
|
|||||||
|
|
||||||
static VarIds getEliminationOrder (const Factors&, VarIds);
|
static VarIds getEliminationOrder (const Factors&, VarIds);
|
||||||
|
|
||||||
static ElimHeuristic elimHeuristic;
|
static ElimHeuristic elimHeuristic (void) { return elimHeuristic_; }
|
||||||
|
|
||||||
|
static void setElimHeuristic (ElimHeuristic eh) { elimHeuristic_ = eh; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
@ -132,6 +134,8 @@ class ElimGraph
|
|||||||
vector<EgNode*> nodes_;
|
vector<EgNode*> nodes_;
|
||||||
TinySet<EgNode*> unmarked_;
|
TinySet<EgNode*> unmarked_;
|
||||||
unordered_map<VarId, EgNode*> varMap_;
|
unordered_map<VarId, EgNode*> varMap_;
|
||||||
|
|
||||||
|
static ElimHeuristic elimHeuristic_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_ELIMGRAPH_H
|
#endif // HORUS_ELIMGRAPH_H
|
||||||
|
@ -63,14 +63,14 @@ LiftedBp::printSolverFlags (void) const
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "lifted bp [" ;
|
ss << "lifted bp [" ;
|
||||||
ss << "schedule=" ;
|
ss << "schedule=" ;
|
||||||
switch (WeightedBp::schedule) {
|
switch (WeightedBp::msgSchedule()) {
|
||||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
}
|
}
|
||||||
ss << ",max_iter=" << WeightedBp::maxIter;
|
ss << ",max_iter=" << WeightedBp::maxIterations();
|
||||||
ss << ",accuracy=" << WeightedBp::accuracy;
|
ss << ",accuracy=" << WeightedBp::accuracy();
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
cout << ss.str() << endl;
|
cout << ss.str() << endl;
|
||||||
|
@ -222,15 +222,15 @@ setHorusFlag (string key, string value)
|
|||||||
}
|
}
|
||||||
} else if (key == "elim_heuristic") {
|
} else if (key == "elim_heuristic") {
|
||||||
if ( value == "sequential") {
|
if ( value == "sequential") {
|
||||||
ElimGraph::elimHeuristic = ElimHeuristic::SEQUENTIAL;
|
ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL);
|
||||||
} else if (value == "min_neighbors") {
|
} else if (value == "min_neighbors") {
|
||||||
ElimGraph::elimHeuristic = ElimHeuristic::MIN_NEIGHBORS;
|
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_NEIGHBORS);
|
||||||
} else if (value == "min_weight") {
|
} else if (value == "min_weight") {
|
||||||
ElimGraph::elimHeuristic = ElimHeuristic::MIN_WEIGHT;
|
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_WEIGHT);
|
||||||
} else if (value == "min_fill") {
|
} else if (value == "min_fill") {
|
||||||
ElimGraph::elimHeuristic = ElimHeuristic::MIN_FILL;
|
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_FILL);
|
||||||
} else if (value == "weighted_min_fill") {
|
} else if (value == "weighted_min_fill") {
|
||||||
ElimGraph::elimHeuristic = ElimHeuristic::WEIGHTED_MIN_FILL;
|
ElimGraph::setElimHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL);
|
||||||
} else {
|
} else {
|
||||||
cerr << "warning: invalid value `" << value << "' " ;
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
cerr << "for `" << key << "'" << endl;
|
cerr << "for `" << key << "'" << endl;
|
||||||
@ -238,13 +238,13 @@ setHorusFlag (string key, string value)
|
|||||||
}
|
}
|
||||||
} else if (key == "schedule") {
|
} else if (key == "schedule") {
|
||||||
if ( value == "seq_fixed") {
|
if ( value == "seq_fixed") {
|
||||||
BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
|
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED);
|
||||||
} else if (value == "seq_random") {
|
} else if (value == "seq_random") {
|
||||||
BeliefProp::schedule = MsgSchedule::SEQ_RANDOM;
|
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_RANDOM);
|
||||||
} else if (value == "parallel") {
|
} else if (value == "parallel") {
|
||||||
BeliefProp::schedule = MsgSchedule::PARALLEL;
|
BeliefProp::setMsgSchedule (MsgSchedule::PARALLEL);
|
||||||
} else if (value == "max_residual") {
|
} else if (value == "max_residual") {
|
||||||
BeliefProp::schedule = MsgSchedule::MAX_RESIDUAL;
|
BeliefProp::setMsgSchedule (MsgSchedule::MAX_RESIDUAL);
|
||||||
} else {
|
} else {
|
||||||
cerr << "warning: invalid value `" << value << "' " ;
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
cerr << "for `" << key << "'" << endl;
|
cerr << "for `" << key << "'" << endl;
|
||||||
@ -252,12 +252,16 @@ setHorusFlag (string key, string value)
|
|||||||
}
|
}
|
||||||
} else if (key == "accuracy") {
|
} else if (key == "accuracy") {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
|
double acc;
|
||||||
ss << value;
|
ss << value;
|
||||||
ss >> BeliefProp::accuracy;
|
ss >> acc;
|
||||||
|
BeliefProp::setAccuracy (acc);
|
||||||
} else if (key == "max_iter") {
|
} else if (key == "max_iter") {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
|
unsigned mi;
|
||||||
ss << value;
|
ss << value;
|
||||||
ss >> BeliefProp::maxIter;
|
ss >> mi;
|
||||||
|
BeliefProp::setMaxIterations (mi);
|
||||||
} else if (key == "use_logarithms") {
|
} else if (key == "use_logarithms") {
|
||||||
if ( value == "true") {
|
if ( value == "true") {
|
||||||
Globals::logDomain = true;
|
Globals::logDomain = true;
|
||||||
|
@ -38,13 +38,12 @@ VarElim::printSolverFlags (void) const
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "variable elimination [" ;
|
ss << "variable elimination [" ;
|
||||||
ss << "elim_heuristic=" ;
|
ss << "elim_heuristic=" ;
|
||||||
ElimHeuristic eh = ElimGraph::elimHeuristic;
|
switch (ElimGraph::elimHeuristic()) {
|
||||||
switch (eh) {
|
case ElimHeuristic::SEQUENTIAL: ss << "sequential"; break;
|
||||||
case SEQUENTIAL: ss << "sequential"; break;
|
case ElimHeuristic::MIN_NEIGHBORS: ss << "min_neighbors"; break;
|
||||||
case MIN_NEIGHBORS: ss << "min_neighbors"; break;
|
case ElimHeuristic::MIN_WEIGHT: ss << "min_weight"; break;
|
||||||
case MIN_WEIGHT: ss << "min_weight"; break;
|
case ElimHeuristic::MIN_FILL: ss << "min_fill"; break;
|
||||||
case MIN_FILL: ss << "min_fill"; break;
|
case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
|
||||||
case WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
|
|
||||||
}
|
}
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
|
@ -107,7 +107,7 @@ WeightedBp::maxResidualSchedule (void)
|
|||||||
if (Globals::verbosity >= 1) {
|
if (Globals::verbosity >= 1) {
|
||||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||||
}
|
}
|
||||||
if (link->residual() < accuracy) {
|
if (link->residual() < accuracy_) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
link->updateMessage();
|
link->updateMessage();
|
||||||
|
Reference in New Issue
Block a user