Move belief propagation stuff out from Horus.h

This commit is contained in:
Tiago Gomes 2012-12-27 15:00:30 +00:00
parent cbea630fbf
commit de0a118ae5
7 changed files with 57 additions and 67 deletions

View File

@ -9,6 +9,11 @@
#include "Horus.h" #include "Horus.h"
MsgSchedule BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
double BeliefProp::accuracy = 0.0001;
unsigned BeliefProp::maxIter = 1000;
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg) BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
{ {
runned_ = false; runned_ = false;
@ -48,15 +53,14 @@ BeliefProp::printSolverFlags (void) const
stringstream ss; stringstream ss;
ss << "belief propagation [" ; ss << "belief propagation [" ;
ss << "schedule=" ; ss << "schedule=" ;
typedef BpOptions::Schedule Sch; switch (schedule) {
switch (BpOptions::schedule) { case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case Sch::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case Sch::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::PARALLEL: ss << "parallel"; break;
case Sch::PARALLEL: ss << "parallel"; break; case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
} }
ss << ",max_iter=" << Util::toString (BpOptions::maxIter); ss << ",max_iter=" << Util::toString (maxIter);
ss << ",accuracy=" << Util::toString (BpOptions::accuracy); ss << ",accuracy=" << Util::toString (accuracy);
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; cout << ss.str() << endl;
@ -153,21 +157,21 @@ BeliefProp::runSolver (void)
{ {
initializeSolver(); initializeSolver();
nIters_ = 0; nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) { while (!converged() && nIters_ < maxIter) {
nIters_ ++; nIters_ ++;
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
Util::printHeader (string ("Iteration ") + Util::toString (nIters_)); Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
} }
switch (BpOptions::schedule) { switch (schedule) {
case BpOptions::Schedule::SEQ_RANDOM: case MsgSchedule::SEQ_RANDOM:
std::random_shuffle (links_.begin(), links_.end()); std::random_shuffle (links_.begin(), links_.end());
// no break // no break
case BpOptions::Schedule::SEQ_FIXED: case MsgSchedule::SEQ_FIXED:
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]); calculateAndUpdateMessage (links_[i]);
} }
break; break;
case BpOptions::Schedule::PARALLEL: case MsgSchedule::PARALLEL:
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]); calculateMessage (links_[i]);
} }
@ -175,13 +179,13 @@ BeliefProp::runSolver (void)
updateMessage(links_[i]); updateMessage(links_[i]);
} }
break; break;
case BpOptions::Schedule::MAX_RESIDUAL: case MsgSchedule::MAX_RESIDUAL:
maxResidualSchedule(); maxResidualSchedule();
break; break;
} }
} }
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
if (nIters_ < BpOptions::maxIter) { if (nIters_ < maxIter) {
cout << "Belief propagation converged in " ; cout << "Belief propagation converged in " ;
cout << nIters_ << " iterations" << endl; cout << nIters_ << " iterations" << endl;
} else { } else {
@ -233,7 +237,7 @@ BeliefProp::maxResidualSchedule (void)
SortedOrder::iterator it = sortedOrder_.begin(); SortedOrder::iterator it = sortedOrder_.begin();
BpLink* link = *it; BpLink* link = *it;
if (link->residual() < BpOptions::accuracy) { if (link->residual() < accuracy) {
return; return;
} }
updateMessage (link); updateMessage (link);
@ -423,9 +427,9 @@ BeliefProp::converged (void)
return false; return false;
} }
bool converged = true; bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) { if (schedule == MsgSchedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->residual(); double maxResidual = (*(sortedOrder_.begin()))->residual();
if (maxResidual > BpOptions::accuracy) { if (maxResidual > accuracy) {
converged = false; converged = false;
} else { } else {
converged = true; converged = true;
@ -436,7 +440,7 @@ BeliefProp::converged (void)
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << links_[i]->toString() + " residual = " << residual << endl; cout << links_[i]->toString() + " residual = " << residual << endl;
} }
if (residual > BpOptions::accuracy) { if (residual > accuracy) {
converged = false; converged = false;
if (Globals::verbosity < 2) { if (Globals::verbosity < 2) {
break; break;

View File

@ -13,6 +13,14 @@
using namespace std; using namespace std;
enum MsgSchedule {
SEQ_FIXED,
SEQ_RANDOM,
PARALLEL,
MAX_RESIDUAL
};
class BpLink class BpLink
{ {
public: public:
@ -98,6 +106,10 @@ class BeliefProp : public GroundSolver
virtual Params getJointDistributionOf (const VarIds&); virtual Params getJointDistributionOf (const VarIds&);
static MsgSchedule schedule;
static double accuracy;
static unsigned maxIter;
protected: protected:
void runSolver (void); void runSolver (void);

View File

@ -37,15 +37,14 @@ CountingBp::printSolverFlags (void) const
stringstream ss; stringstream ss;
ss << "counting bp [" ; ss << "counting bp [" ;
ss << "schedule=" ; ss << "schedule=" ;
typedef BpOptions::Schedule Sch; switch (WeightedBp::schedule) {
switch (BpOptions::schedule) { case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case Sch::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case Sch::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::PARALLEL: ss << "parallel"; break;
case Sch::PARALLEL: ss << "parallel"; break; case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
} }
ss << ",max_iter=" << BpOptions::maxIter; ss << ",max_iter=" << WeightedBp::maxIter;
ss << ",accuracy=" << BpOptions::accuracy; ss << ",accuracy=" << WeightedBp::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors); ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors);
ss << "]" ; ss << "]" ;

View File

@ -67,19 +67,5 @@ const unsigned PRECISION = 6;
}; };
namespace BpOptions
{
enum Schedule {
SEQ_FIXED,
SEQ_RANDOM,
PARALLEL,
MAX_RESIDUAL
};
extern Schedule schedule;
extern double accuracy;
extern unsigned maxIter;
}
#endif // HORUS_HORUS_H #endif // HORUS_HORUS_H

View File

@ -63,15 +63,14 @@ LiftedBp::printSolverFlags (void) const
stringstream ss; stringstream ss;
ss << "lifted bp [" ; ss << "lifted bp [" ;
ss << "schedule=" ; ss << "schedule=" ;
typedef BpOptions::Schedule Sch; switch (WeightedBp::schedule) {
switch (BpOptions::schedule) { case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case Sch::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case Sch::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::PARALLEL: ss << "parallel"; break;
case Sch::PARALLEL: ss << "parallel"; break; case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
} }
ss << ",max_iter=" << BpOptions::maxIter; ss << ",max_iter=" << WeightedBp::maxIter;
ss << ",accuracy=" << BpOptions::accuracy; ss << ",accuracy=" << WeightedBp::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; cout << ss.str() << endl;

View File

@ -3,6 +3,7 @@
#include "Util.h" #include "Util.h"
#include "Indexer.h" #include "Indexer.h"
#include "ElimGraph.h" #include "ElimGraph.h"
#include "BeliefProp.h"
namespace Globals { namespace Globals {
@ -18,17 +19,6 @@ GroundSolverType groundSolver = GroundSolverType::VE;
namespace BpOptions {
Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
//Schedule schedule = BpOptions::Schedule::PARALLEL;
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
double accuracy = 0.0001;
unsigned maxIter = 1000;
}
namespace Util { namespace Util {
@ -248,13 +238,13 @@ setHorusFlag (string key, string value)
} }
} else if (key == "schedule") { } else if (key == "schedule") {
if ( value == "seq_fixed") { if ( value == "seq_fixed") {
BpOptions::schedule = BpOptions::Schedule::SEQ_FIXED; BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
} else if (value == "seq_random") { } else if (value == "seq_random") {
BpOptions::schedule = BpOptions::Schedule::SEQ_RANDOM; BeliefProp::schedule = MsgSchedule::SEQ_RANDOM;
} else if (value == "parallel") { } else if (value == "parallel") {
BpOptions::schedule = BpOptions::Schedule::PARALLEL; BeliefProp::schedule = MsgSchedule::PARALLEL;
} else if (value == "max_residual") { } else if (value == "max_residual") {
BpOptions::schedule = BpOptions::Schedule::MAX_RESIDUAL; BeliefProp::schedule = MsgSchedule::MAX_RESIDUAL;
} else { } else {
cerr << "warning: invalid value `" << value << "' " ; cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;
@ -263,11 +253,11 @@ setHorusFlag (string key, string value)
} else if (key == "accuracy") { } else if (key == "accuracy") {
stringstream ss; stringstream ss;
ss << value; ss << value;
ss >> BpOptions::accuracy; ss >> BeliefProp::accuracy;
} else if (key == "max_iter") { } else if (key == "max_iter") {
stringstream ss; stringstream ss;
ss << value; ss << value;
ss >> BpOptions::maxIter; ss >> BeliefProp::maxIter;
} else if (key == "use_logarithms") { } else if (key == "use_logarithms") {
if ( value == "true") { if ( value == "true") {
Globals::logDomain = true; Globals::logDomain = true;

View File

@ -107,7 +107,7 @@ WeightedBp::maxResidualSchedule (void)
if (Globals::verbosity >= 1) { if (Globals::verbosity >= 1) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl; cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
} }
if (link->residual() < BpOptions::accuracy) { if (link->residual() < accuracy) {
return; return;
} }
link->updateMessage(); link->updateMessage();