Move belief propagation stuff out from Horus.h

This commit is contained in:
Tiago Gomes 2012-12-27 15:00:30 +00:00
parent cbea630fbf
commit de0a118ae5
7 changed files with 57 additions and 67 deletions

View File

@ -9,6 +9,11 @@
#include "Horus.h"
MsgSchedule BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
double BeliefProp::accuracy = 0.0001;
unsigned BeliefProp::maxIter = 1000;
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
{
runned_ = false;
@ -48,15 +53,14 @@ BeliefProp::printSolverFlags (void) const
stringstream ss;
ss << "belief propagation [" ;
ss << "schedule=" ;
typedef BpOptions::Schedule Sch;
switch (BpOptions::schedule) {
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
case Sch::PARALLEL: ss << "parallel"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
switch (schedule) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
}
ss << ",max_iter=" << Util::toString (BpOptions::maxIter);
ss << ",accuracy=" << Util::toString (BpOptions::accuracy);
ss << ",max_iter=" << Util::toString (maxIter);
ss << ",accuracy=" << Util::toString (accuracy);
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
@ -153,21 +157,21 @@ BeliefProp::runSolver (void)
{
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
while (!converged() && nIters_ < maxIter) {
nIters_ ++;
if (Globals::verbosity > 1) {
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
switch (schedule) {
case MsgSchedule::SEQ_RANDOM:
std::random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
case MsgSchedule::SEQ_FIXED:
for (size_t i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
case MsgSchedule::PARALLEL:
for (size_t i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
@ -175,13 +179,13 @@ BeliefProp::runSolver (void)
updateMessage(links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
case MsgSchedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
}
if (Globals::verbosity > 0) {
if (nIters_ < BpOptions::maxIter) {
if (nIters_ < maxIter) {
cout << "Belief propagation converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
@ -233,7 +237,7 @@ BeliefProp::maxResidualSchedule (void)
SortedOrder::iterator it = sortedOrder_.begin();
BpLink* link = *it;
if (link->residual() < BpOptions::accuracy) {
if (link->residual() < accuracy) {
return;
}
updateMessage (link);
@ -423,9 +427,9 @@ BeliefProp::converged (void)
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
if (schedule == MsgSchedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->residual();
if (maxResidual > BpOptions::accuracy) {
if (maxResidual > accuracy) {
converged = false;
} else {
converged = true;
@ -436,7 +440,7 @@ BeliefProp::converged (void)
if (Globals::verbosity > 1) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
if (residual > accuracy) {
converged = false;
if (Globals::verbosity < 2) {
break;

View File

@ -13,6 +13,14 @@
using namespace std;
enum MsgSchedule {
SEQ_FIXED,
SEQ_RANDOM,
PARALLEL,
MAX_RESIDUAL
};
class BpLink
{
public:
@ -98,6 +106,10 @@ class BeliefProp : public GroundSolver
virtual Params getJointDistributionOf (const VarIds&);
static MsgSchedule schedule;
static double accuracy;
static unsigned maxIter;
protected:
void runSolver (void);

View File

@ -37,15 +37,14 @@ CountingBp::printSolverFlags (void) const
stringstream ss;
ss << "counting bp [" ;
ss << "schedule=" ;
typedef BpOptions::Schedule Sch;
switch (BpOptions::schedule) {
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
case Sch::PARALLEL: ss << "parallel"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
switch (WeightedBp::schedule) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
}
ss << ",max_iter=" << BpOptions::maxIter;
ss << ",accuracy=" << BpOptions::accuracy;
ss << ",max_iter=" << WeightedBp::maxIter;
ss << ",accuracy=" << WeightedBp::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors);
ss << "]" ;

View File

@ -67,19 +67,5 @@ const unsigned PRECISION = 6;
};
namespace BpOptions
{
enum Schedule {
SEQ_FIXED,
SEQ_RANDOM,
PARALLEL,
MAX_RESIDUAL
};
extern Schedule schedule;
extern double accuracy;
extern unsigned maxIter;
}
#endif // HORUS_HORUS_H

View File

@ -63,15 +63,14 @@ LiftedBp::printSolverFlags (void) const
stringstream ss;
ss << "lifted bp [" ;
ss << "schedule=" ;
typedef BpOptions::Schedule Sch;
switch (BpOptions::schedule) {
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
case Sch::PARALLEL: ss << "parallel"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
switch (WeightedBp::schedule) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
}
ss << ",max_iter=" << BpOptions::maxIter;
ss << ",accuracy=" << BpOptions::accuracy;
ss << ",max_iter=" << WeightedBp::maxIter;
ss << ",accuracy=" << WeightedBp::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;

View File

@ -3,6 +3,7 @@
#include "Util.h"
#include "Indexer.h"
#include "ElimGraph.h"
#include "BeliefProp.h"
namespace Globals {
@ -18,17 +19,6 @@ GroundSolverType groundSolver = GroundSolverType::VE;
namespace BpOptions {
Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
//Schedule schedule = BpOptions::Schedule::PARALLEL;
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
double accuracy = 0.0001;
unsigned maxIter = 1000;
}
namespace Util {
@ -248,13 +238,13 @@ setHorusFlag (string key, string value)
}
} else if (key == "schedule") {
if ( value == "seq_fixed") {
BpOptions::schedule = BpOptions::Schedule::SEQ_FIXED;
BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
} else if (value == "seq_random") {
BpOptions::schedule = BpOptions::Schedule::SEQ_RANDOM;
BeliefProp::schedule = MsgSchedule::SEQ_RANDOM;
} else if (value == "parallel") {
BpOptions::schedule = BpOptions::Schedule::PARALLEL;
BeliefProp::schedule = MsgSchedule::PARALLEL;
} else if (value == "max_residual") {
BpOptions::schedule = BpOptions::Schedule::MAX_RESIDUAL;
BeliefProp::schedule = MsgSchedule::MAX_RESIDUAL;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
@ -263,11 +253,11 @@ setHorusFlag (string key, string value)
} else if (key == "accuracy") {
stringstream ss;
ss << value;
ss >> BpOptions::accuracy;
ss >> BeliefProp::accuracy;
} else if (key == "max_iter") {
stringstream ss;
ss << value;
ss >> BpOptions::maxIter;
ss >> BeliefProp::maxIter;
} else if (key == "use_logarithms") {
if ( value == "true") {
Globals::logDomain = true;

View File

@ -107,7 +107,7 @@ WeightedBp::maxResidualSchedule (void)
if (Globals::verbosity >= 1) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
}
if (link->residual() < BpOptions::accuracy) {
if (link->residual() < accuracy) {
return;
}
link->updateMessage();