Move belief propagation stuff out from Horus.h
This commit is contained in:
parent
cbea630fbf
commit
de0a118ae5
@ -9,6 +9,11 @@
|
|||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
|
MsgSchedule BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
|
||||||
|
double BeliefProp::accuracy = 0.0001;
|
||||||
|
unsigned BeliefProp::maxIter = 1000;
|
||||||
|
|
||||||
|
|
||||||
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
|
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
|
||||||
{
|
{
|
||||||
runned_ = false;
|
runned_ = false;
|
||||||
@ -48,15 +53,14 @@ BeliefProp::printSolverFlags (void) const
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "belief propagation [" ;
|
ss << "belief propagation [" ;
|
||||||
ss << "schedule=" ;
|
ss << "schedule=" ;
|
||||||
typedef BpOptions::Schedule Sch;
|
switch (schedule) {
|
||||||
switch (BpOptions::schedule) {
|
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||||
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||||
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||||
case Sch::PARALLEL: ss << "parallel"; break;
|
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
|
||||||
}
|
}
|
||||||
ss << ",max_iter=" << Util::toString (BpOptions::maxIter);
|
ss << ",max_iter=" << Util::toString (maxIter);
|
||||||
ss << ",accuracy=" << Util::toString (BpOptions::accuracy);
|
ss << ",accuracy=" << Util::toString (accuracy);
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
cout << ss.str() << endl;
|
cout << ss.str() << endl;
|
||||||
@ -153,21 +157,21 @@ BeliefProp::runSolver (void)
|
|||||||
{
|
{
|
||||||
initializeSolver();
|
initializeSolver();
|
||||||
nIters_ = 0;
|
nIters_ = 0;
|
||||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
while (!converged() && nIters_ < maxIter) {
|
||||||
nIters_ ++;
|
nIters_ ++;
|
||||||
if (Globals::verbosity > 1) {
|
if (Globals::verbosity > 1) {
|
||||||
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||||
}
|
}
|
||||||
switch (BpOptions::schedule) {
|
switch (schedule) {
|
||||||
case BpOptions::Schedule::SEQ_RANDOM:
|
case MsgSchedule::SEQ_RANDOM:
|
||||||
std::random_shuffle (links_.begin(), links_.end());
|
std::random_shuffle (links_.begin(), links_.end());
|
||||||
// no break
|
// no break
|
||||||
case BpOptions::Schedule::SEQ_FIXED:
|
case MsgSchedule::SEQ_FIXED:
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
calculateAndUpdateMessage (links_[i]);
|
calculateAndUpdateMessage (links_[i]);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case BpOptions::Schedule::PARALLEL:
|
case MsgSchedule::PARALLEL:
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
calculateMessage (links_[i]);
|
calculateMessage (links_[i]);
|
||||||
}
|
}
|
||||||
@ -175,13 +179,13 @@ BeliefProp::runSolver (void)
|
|||||||
updateMessage(links_[i]);
|
updateMessage(links_[i]);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
case MsgSchedule::MAX_RESIDUAL:
|
||||||
maxResidualSchedule();
|
maxResidualSchedule();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 0) {
|
if (Globals::verbosity > 0) {
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
if (nIters_ < maxIter) {
|
||||||
cout << "Belief propagation converged in " ;
|
cout << "Belief propagation converged in " ;
|
||||||
cout << nIters_ << " iterations" << endl;
|
cout << nIters_ << " iterations" << endl;
|
||||||
} else {
|
} else {
|
||||||
@ -233,7 +237,7 @@ BeliefProp::maxResidualSchedule (void)
|
|||||||
|
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
BpLink* link = *it;
|
BpLink* link = *it;
|
||||||
if (link->residual() < BpOptions::accuracy) {
|
if (link->residual() < accuracy) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
updateMessage (link);
|
updateMessage (link);
|
||||||
@ -423,9 +427,9 @@ BeliefProp::converged (void)
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool converged = true;
|
bool converged = true;
|
||||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
if (schedule == MsgSchedule::MAX_RESIDUAL) {
|
||||||
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
||||||
if (maxResidual > BpOptions::accuracy) {
|
if (maxResidual > accuracy) {
|
||||||
converged = false;
|
converged = false;
|
||||||
} else {
|
} else {
|
||||||
converged = true;
|
converged = true;
|
||||||
@ -436,7 +440,7 @@ BeliefProp::converged (void)
|
|||||||
if (Globals::verbosity > 1) {
|
if (Globals::verbosity > 1) {
|
||||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||||
}
|
}
|
||||||
if (residual > BpOptions::accuracy) {
|
if (residual > accuracy) {
|
||||||
converged = false;
|
converged = false;
|
||||||
if (Globals::verbosity < 2) {
|
if (Globals::verbosity < 2) {
|
||||||
break;
|
break;
|
||||||
|
@ -13,6 +13,14 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
enum MsgSchedule {
|
||||||
|
SEQ_FIXED,
|
||||||
|
SEQ_RANDOM,
|
||||||
|
PARALLEL,
|
||||||
|
MAX_RESIDUAL
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class BpLink
|
class BpLink
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -98,6 +106,10 @@ class BeliefProp : public GroundSolver
|
|||||||
|
|
||||||
virtual Params getJointDistributionOf (const VarIds&);
|
virtual Params getJointDistributionOf (const VarIds&);
|
||||||
|
|
||||||
|
static MsgSchedule schedule;
|
||||||
|
static double accuracy;
|
||||||
|
static unsigned maxIter;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void runSolver (void);
|
void runSolver (void);
|
||||||
|
|
||||||
|
@ -37,15 +37,14 @@ CountingBp::printSolverFlags (void) const
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "counting bp [" ;
|
ss << "counting bp [" ;
|
||||||
ss << "schedule=" ;
|
ss << "schedule=" ;
|
||||||
typedef BpOptions::Schedule Sch;
|
switch (WeightedBp::schedule) {
|
||||||
switch (BpOptions::schedule) {
|
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||||
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||||
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||||
case Sch::PARALLEL: ss << "parallel"; break;
|
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
|
||||||
}
|
}
|
||||||
ss << ",max_iter=" << BpOptions::maxIter;
|
ss << ",max_iter=" << WeightedBp::maxIter;
|
||||||
ss << ",accuracy=" << BpOptions::accuracy;
|
ss << ",accuracy=" << WeightedBp::accuracy;
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors);
|
ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
|
@ -67,19 +67,5 @@ const unsigned PRECISION = 6;
|
|||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
namespace BpOptions
|
|
||||||
{
|
|
||||||
enum Schedule {
|
|
||||||
SEQ_FIXED,
|
|
||||||
SEQ_RANDOM,
|
|
||||||
PARALLEL,
|
|
||||||
MAX_RESIDUAL
|
|
||||||
};
|
|
||||||
extern Schedule schedule;
|
|
||||||
extern double accuracy;
|
|
||||||
extern unsigned maxIter;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // HORUS_HORUS_H
|
#endif // HORUS_HORUS_H
|
||||||
|
|
||||||
|
@ -63,15 +63,14 @@ LiftedBp::printSolverFlags (void) const
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "lifted bp [" ;
|
ss << "lifted bp [" ;
|
||||||
ss << "schedule=" ;
|
ss << "schedule=" ;
|
||||||
typedef BpOptions::Schedule Sch;
|
switch (WeightedBp::schedule) {
|
||||||
switch (BpOptions::schedule) {
|
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||||
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||||
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||||
case Sch::PARALLEL: ss << "parallel"; break;
|
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
|
||||||
}
|
}
|
||||||
ss << ",max_iter=" << BpOptions::maxIter;
|
ss << ",max_iter=" << WeightedBp::maxIter;
|
||||||
ss << ",accuracy=" << BpOptions::accuracy;
|
ss << ",accuracy=" << WeightedBp::accuracy;
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
cout << ss.str() << endl;
|
cout << ss.str() << endl;
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
#include "ElimGraph.h"
|
#include "ElimGraph.h"
|
||||||
|
#include "BeliefProp.h"
|
||||||
|
|
||||||
|
|
||||||
namespace Globals {
|
namespace Globals {
|
||||||
@ -18,17 +19,6 @@ GroundSolverType groundSolver = GroundSolverType::VE;
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace BpOptions {
|
|
||||||
Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
|
|
||||||
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
|
|
||||||
//Schedule schedule = BpOptions::Schedule::PARALLEL;
|
|
||||||
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
|
||||||
double accuracy = 0.0001;
|
|
||||||
unsigned maxIter = 1000;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace Util {
|
namespace Util {
|
||||||
|
|
||||||
|
|
||||||
@ -248,13 +238,13 @@ setHorusFlag (string key, string value)
|
|||||||
}
|
}
|
||||||
} else if (key == "schedule") {
|
} else if (key == "schedule") {
|
||||||
if ( value == "seq_fixed") {
|
if ( value == "seq_fixed") {
|
||||||
BpOptions::schedule = BpOptions::Schedule::SEQ_FIXED;
|
BeliefProp::schedule = MsgSchedule::SEQ_FIXED;
|
||||||
} else if (value == "seq_random") {
|
} else if (value == "seq_random") {
|
||||||
BpOptions::schedule = BpOptions::Schedule::SEQ_RANDOM;
|
BeliefProp::schedule = MsgSchedule::SEQ_RANDOM;
|
||||||
} else if (value == "parallel") {
|
} else if (value == "parallel") {
|
||||||
BpOptions::schedule = BpOptions::Schedule::PARALLEL;
|
BeliefProp::schedule = MsgSchedule::PARALLEL;
|
||||||
} else if (value == "max_residual") {
|
} else if (value == "max_residual") {
|
||||||
BpOptions::schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
BeliefProp::schedule = MsgSchedule::MAX_RESIDUAL;
|
||||||
} else {
|
} else {
|
||||||
cerr << "warning: invalid value `" << value << "' " ;
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
cerr << "for `" << key << "'" << endl;
|
cerr << "for `" << key << "'" << endl;
|
||||||
@ -263,11 +253,11 @@ setHorusFlag (string key, string value)
|
|||||||
} else if (key == "accuracy") {
|
} else if (key == "accuracy") {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << value;
|
ss << value;
|
||||||
ss >> BpOptions::accuracy;
|
ss >> BeliefProp::accuracy;
|
||||||
} else if (key == "max_iter") {
|
} else if (key == "max_iter") {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << value;
|
ss << value;
|
||||||
ss >> BpOptions::maxIter;
|
ss >> BeliefProp::maxIter;
|
||||||
} else if (key == "use_logarithms") {
|
} else if (key == "use_logarithms") {
|
||||||
if ( value == "true") {
|
if ( value == "true") {
|
||||||
Globals::logDomain = true;
|
Globals::logDomain = true;
|
||||||
|
@ -107,7 +107,7 @@ WeightedBp::maxResidualSchedule (void)
|
|||||||
if (Globals::verbosity >= 1) {
|
if (Globals::verbosity >= 1) {
|
||||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||||
}
|
}
|
||||||
if (link->residual() < BpOptions::accuracy) {
|
if (link->residual() < accuracy) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
link->updateMessage();
|
link->updateMessage();
|
||||||
|
Reference in New Issue
Block a user