refactor ground solver interface

This commit is contained in:
Tiago Gomes
2012-04-10 15:00:18 +01:00
parent 46e6a10625
commit 78e86a6330
15 changed files with 191 additions and 242 deletions

View File

@@ -15,6 +15,7 @@
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
{
factorGraph_ = &fg;
runned_ = false;
}
@@ -34,31 +35,14 @@ BpSolver::~BpSolver (void)
void
BpSolver::runSolver (void)
Params
BpSolver::solveQuery (VarIds queryVids)
{
clock_t start;
if (Constants::COLLECT_STATS) {
start = clock();
}
runLoopySolver();
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = factorGraph_->varNodes().size();
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = factorGraph_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
assert (queryVids.empty() == false);
if (queryVids.size() == 1) {
return getPosterioriOf (queryVids[0]);
} else {
return getJointDistributionOf (queryVids);
}
}
@@ -67,6 +51,9 @@ BpSolver::runSolver (void)
Params
BpSolver::getPosterioriOf (VarId vid)
{
if (runned_ == false) {
runSolver();
}
assert (factorGraph_->getVarNode (vid));
VarNode* var = factorGraph_->getVarNode (vid);
Params probs;
@@ -97,6 +84,9 @@ BpSolver::getPosterioriOf (VarId vid)
Params
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
{
if (runned_ == false) {
runSolver();
}
int idx = -1;
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
const FacNodes& facNodes = vn->neighbors();
@@ -130,52 +120,6 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
void
BpSolver::runLoopySolver (void)
{
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (Constants::DEBUG >= 2) {
Util::printHeader (" Iteration " + nIters_);
cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
for (unsigned i = 0; i < links_.size(); i++) {
updateMessage(links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
}
void
BpSolver::initializeSolver (void)
{
@@ -226,40 +170,6 @@ BpSolver::createLinks (void)
bool
BpSolver::converged (void)
{
if (links_.size() == 0) {
return true;
}
if (nIters_ == 0 || nIters_ == 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual > BpOptions::accuracy) {
converged = false;
} else {
converged = true;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (Constants::DEBUG >= 2) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
if (Constants::DEBUG == 0) break;
}
}
}
return converged;
}
void
BpSolver::maxResidualSchedule (void)
{
@@ -493,3 +403,114 @@ BpSolver::printLinkInformation (void) const
}
}
void
BpSolver::runSolver (void)
{
clock_t start;
if (Constants::COLLECT_STATS) {
start = clock();
}
runLoopySolver();
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = factorGraph_->varNodes().size();
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = factorGraph_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
runned_ = true;
}
void
BpSolver::runLoopySolver (void)
{
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (Constants::DEBUG >= 2) {
Util::printHeader (" Iteration " + nIters_);
cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
for (unsigned i = 0; i < links_.size(); i++) {
updateMessage(links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
}
bool
BpSolver::converged (void)
{
if (links_.size() == 0) {
return true;
}
if (nIters_ == 0 || nIters_ == 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual > BpOptions::accuracy) {
converged = false;
} else {
converged = true;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (Constants::DEBUG >= 2) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
if (Constants::DEBUG == 0) break;
}
}
}
return converged;
}