Move belief propagation stuff out from Horus.h
This commit is contained in:
parent
cbea630fbf
commit
de0a118ae5
@ -9,6 +9,11 @@
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
MsgSchedule BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
|
||||
double BeliefProp::accuracy = 0.0001;
|
||||
unsigned BeliefProp::maxIter = 1000;
|
||||
|
||||
|
||||
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
|
||||
{
|
||||
runned_ = false;
|
||||
@ -48,15 +53,14 @@ BeliefProp::printSolverFlags (void) const
|
||||
stringstream ss;
|
||||
ss << "belief propagation [" ;
|
||||
ss << "schedule=" ;
|
||||
typedef BpOptions::Schedule Sch;
|
||||
switch (BpOptions::schedule) {
|
||||
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case Sch::PARALLEL: ss << "parallel"; break;
|
||||
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
switch (schedule) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",max_iter=" << Util::toString (BpOptions::maxIter);
|
||||
ss << ",accuracy=" << Util::toString (BpOptions::accuracy);
|
||||
ss << ",max_iter=" << Util::toString (maxIter);
|
||||
ss << ",accuracy=" << Util::toString (accuracy);
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
@ -153,21 +157,21 @@ BeliefProp::runSolver (void)
|
||||
{
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
while (!converged() && nIters_ < maxIter) {
|
||||
nIters_ ++;
|
||||
if (Globals::verbosity > 1) {
|
||||
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||
}
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
switch (schedule) {
|
||||
case MsgSchedule::SEQ_RANDOM:
|
||||
std::random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
case MsgSchedule::SEQ_FIXED:
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
}
|
||||
break;
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
case MsgSchedule::PARALLEL:
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
@ -175,13 +179,13 @@ BeliefProp::runSolver (void)
|
||||
updateMessage(links_[i]);
|
||||
}
|
||||
break;
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
case MsgSchedule::MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
if (nIters_ < maxIter) {
|
||||
cout << "Belief propagation converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
@ -233,7 +237,7 @@ BeliefProp::maxResidualSchedule (void)
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
BpLink* link = *it;
|
||||
if (link->residual() < BpOptions::accuracy) {
|
||||
if (link->residual() < accuracy) {
|
||||
return;
|
||||
}
|
||||
updateMessage (link);
|
||||
@ -423,9 +427,9 @@ BeliefProp::converged (void)
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||
if (schedule == MsgSchedule::MAX_RESIDUAL) {
|
||||
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
||||
if (maxResidual > BpOptions::accuracy) {
|
||||
if (maxResidual > accuracy) {
|
||||
converged = false;
|
||||
} else {
|
||||
converged = true;
|
||||
@ -436,7 +440,7 @@ BeliefProp::converged (void)
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > BpOptions::accuracy) {
|
||||
if (residual > accuracy) {
|
||||
converged = false;
|
||||
if (Globals::verbosity < 2) {
|
||||
break;
|
||||
|
@ -13,6 +13,14 @@
|
||||
using namespace std;
|
||||
|
||||
|
||||
enum MsgSchedule {
|
||||
SEQ_FIXED,
|
||||
SEQ_RANDOM,
|
||||
PARALLEL,
|
||||
MAX_RESIDUAL
|
||||
};
|
||||
|
||||
|
||||
class BpLink
|
||||
{
|
||||
public:
|
||||
@ -98,6 +106,10 @@ class BeliefProp : public GroundSolver
|
||||
|
||||
virtual Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
static MsgSchedule schedule;
|
||||
static double accuracy;
|
||||
static unsigned maxIter;
|
||||
|
||||
protected:
|
||||
void runSolver (void);
|
||||
|
||||
|
@ -37,15 +37,14 @@ CountingBp::printSolverFlags (void) const
|
||||
stringstream ss;
|
||||
ss << "counting bp [" ;
|
||||
ss << "schedule=" ;
|
||||
typedef BpOptions::Schedule Sch;
|
||||
switch (BpOptions::schedule) {
|
||||
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case Sch::PARALLEL: ss << "parallel"; break;
|
||||
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
switch (WeightedBp::schedule) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",max_iter=" << BpOptions::maxIter;
|
||||
ss << ",accuracy=" << BpOptions::accuracy;
|
||||
ss << ",max_iter=" << WeightedBp::maxIter;
|
||||
ss << ",accuracy=" << WeightedBp::accuracy;
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors);
|
||||
ss << "]" ;
|
||||
|
@ -67,19 +67,5 @@ const unsigned PRECISION = 6;
|
||||
|
||||
};
|
||||
|
||||
|
||||
namespace BpOptions
|
||||
{
|
||||
enum Schedule {
|
||||
SEQ_FIXED,
|
||||
SEQ_RANDOM,
|
||||
PARALLEL,
|
||||
MAX_RESIDUAL
|
||||
};
|
||||
extern Schedule schedule;
|
||||
extern double accuracy;
|
||||
extern unsigned maxIter;
|
||||
}
|
||||
|
||||
#endif // HORUS_HORUS_H
|
||||
|
||||
|
@ -63,15 +63,14 @@ LiftedBp::printSolverFlags (void) const
|
||||
stringstream ss;
|
||||
ss << "lifted bp [" ;
|
||||
ss << "schedule=" ;
|
||||
typedef BpOptions::Schedule Sch;
|
||||
switch (BpOptions::schedule) {
|
||||
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case Sch::PARALLEL: ss << "parallel"; break;
|
||||
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
switch (WeightedBp::schedule) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",max_iter=" << BpOptions::maxIter;
|
||||
ss << ",accuracy=" << BpOptions::accuracy;
|
||||
ss << ",max_iter=" << WeightedBp::maxIter;
|
||||
ss << ",accuracy=" << WeightedBp::accuracy;
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include "Util.h"
|
||||
#include "Indexer.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "BeliefProp.h"
|
||||
|
||||
|
||||
namespace Globals {
|
||||
@ -18,17 +19,6 @@ GroundSolverType groundSolver = GroundSolverType::VE;
|
||||
|
||||
|
||||
|
||||
namespace BpOptions {
|
||||
Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
|
||||
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
|
||||
//Schedule schedule = BpOptions::Schedule::PARALLEL;
|
||||
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
||||
double accuracy = 0.0001;
|
||||
unsigned maxIter = 1000;
|
||||
}
|
||||
|
||||
|
||||
|
||||
namespace Util {
|
||||
|
||||
|
||||
@ -248,13 +238,13 @@ setHorusFlag (string key, string value)
|
||||
}
|
||||
} else if (key == "schedule") {
|
||||
if ( value == "seq_fixed") {
|
||||
BpOptions::schedule = BpOptions::Schedule::SEQ_FIXED;
|
||||
BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
|
||||
} else if (value == "seq_random") {
|
||||
BpOptions::schedule = BpOptions::Schedule::SEQ_RANDOM;
|
||||
BeliefProp::schedule = MsgSchedule::SEQ_RANDOM;
|
||||
} else if (value == "parallel") {
|
||||
BpOptions::schedule = BpOptions::Schedule::PARALLEL;
|
||||
BeliefProp::schedule = MsgSchedule::PARALLEL;
|
||||
} else if (value == "max_residual") {
|
||||
BpOptions::schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
||||
BeliefProp::schedule = MsgSchedule::MAX_RESIDUAL;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
@ -263,11 +253,11 @@ setHorusFlag (string key, string value)
|
||||
} else if (key == "accuracy") {
|
||||
stringstream ss;
|
||||
ss << value;
|
||||
ss >> BpOptions::accuracy;
|
||||
ss >> BeliefProp::accuracy;
|
||||
} else if (key == "max_iter") {
|
||||
stringstream ss;
|
||||
ss << value;
|
||||
ss >> BpOptions::maxIter;
|
||||
ss >> BeliefProp::maxIter;
|
||||
} else if (key == "use_logarithms") {
|
||||
if ( value == "true") {
|
||||
Globals::logDomain = true;
|
||||
|
@ -107,7 +107,7 @@ WeightedBp::maxResidualSchedule (void)
|
||||
if (Globals::verbosity >= 1) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
}
|
||||
if (link->residual() < BpOptions::accuracy) {
|
||||
if (link->residual() < accuracy) {
|
||||
return;
|
||||
}
|
||||
link->updateMessage();
|
||||
|
Reference in New Issue
Block a user