refactor ground solver interface

This commit is contained in:
Tiago Gomes 2012-04-10 15:00:18 +01:00
parent 46e6a10625
commit 78e86a6330
15 changed files with 191 additions and 242 deletions

View File

@ -15,6 +15,7 @@
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
{
factorGraph_ = &fg;
runned_ = false;
}
@ -34,31 +35,14 @@ BpSolver::~BpSolver (void)
void
BpSolver::runSolver (void)
Params
BpSolver::solveQuery (VarIds queryVids)
{
clock_t start;
if (Constants::COLLECT_STATS) {
start = clock();
}
runLoopySolver();
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = factorGraph_->varNodes().size();
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = factorGraph_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
assert (queryVids.empty() == false);
if (queryVids.size() == 1) {
return getPosterioriOf (queryVids[0]);
} else {
return getJointDistributionOf (queryVids);
}
}
@ -67,6 +51,9 @@ BpSolver::runSolver (void)
Params
BpSolver::getPosterioriOf (VarId vid)
{
if (runned_ == false) {
runSolver();
}
assert (factorGraph_->getVarNode (vid));
VarNode* var = factorGraph_->getVarNode (vid);
Params probs;
@ -97,6 +84,9 @@ BpSolver::getPosterioriOf (VarId vid)
Params
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
{
if (runned_ == false) {
runSolver();
}
int idx = -1;
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
const FacNodes& facNodes = vn->neighbors();
@ -130,52 +120,6 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
void
BpSolver::runLoopySolver (void)
{
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (Constants::DEBUG >= 2) {
Util::printHeader (" Iteration " + nIters_);
cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
for (unsigned i = 0; i < links_.size(); i++) {
updateMessage(links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
}
void
BpSolver::initializeSolver (void)
{
@ -226,40 +170,6 @@ BpSolver::createLinks (void)
bool
BpSolver::converged (void)
{
if (links_.size() == 0) {
return true;
}
if (nIters_ == 0 || nIters_ == 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual > BpOptions::accuracy) {
converged = false;
} else {
converged = true;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (Constants::DEBUG >= 2) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
if (Constants::DEBUG == 0) break;
}
}
}
return converged;
}
void
BpSolver::maxResidualSchedule (void)
{
@ -493,3 +403,114 @@ BpSolver::printLinkInformation (void) const
}
}
void
BpSolver::runSolver (void)
{
clock_t start;
if (Constants::COLLECT_STATS) {
start = clock();
}
runLoopySolver();
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = factorGraph_->varNodes().size();
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = factorGraph_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
runned_ = true;
}
void
BpSolver::runLoopySolver (void)
{
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (Constants::DEBUG >= 2) {
Util::printHeader (" Iteration " + nIters_);
cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
for (unsigned i = 0; i < links_.size(); i++) {
updateMessage(links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
}
bool
BpSolver::converged (void)
{
if (links_.size() == 0) {
return true;
}
if (nIters_ == 0 || nIters_ == 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual > BpOptions::accuracy) {
converged = false;
} else {
converged = true;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (Constants::DEBUG >= 2) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
if (Constants::DEBUG == 0) break;
}
}
}
return converged;
}

View File

@ -95,7 +95,7 @@ class BpSolver : public Solver
virtual ~BpSolver (void);
void runSolver (void);
Params solveQuery (VarIds);
virtual Params getPosterioriOf (VarId);
@ -169,6 +169,7 @@ class BpSolver : public Solver
unsigned nIters_;
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_;
bool runned_;
const FactorGraph* factorGraph_;
typedef multiset<SpLink*, CompareResidual> SortedOrder;
@ -178,6 +179,7 @@ class BpSolver : public Solver
SpLinkMap linkMap_;
private:
void runSolver (void);
void runLoopySolver (void);
bool converged (void);
};

View File

@ -85,9 +85,8 @@ CbpSolver::initializeSolver (void)
if (Constants::COLLECT_STATS) {
unsigned nClusterVars = factorGraph_->varNodes().size();
unsigned nClusterFacs = factorGraph_->facNodes().size();
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
nClusterVars, nClusterFacs,
nWithoutNeighs);
Statistics::updateCompressingStatistics (nGroundVars,
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
}
// cout << "Compressed Factor Graph" << endl;

View File

@ -37,7 +37,7 @@ class CbpSolverLink : public SpLink
class CbpSolver : public BpSolver
{
public:
CbpSolver (FactorGraph& fg) : BpSolver (fg) { }
CbpSolver (const FactorGraph& fg) : BpSolver (fg) { }
~CbpSolver (void);
@ -47,13 +47,17 @@ class CbpSolver : public BpSolver
private:
void initializeSolver (void);
void createLinks (void);
void maxResidualSchedule (void);
Params getVar2FactorMsg (const SpLink*) const;
void printLinkInformation (void) const;
CFactorGraph* lfg_;
FactorGraph* factorGraph_;
};
#endif // HORUS_CBP_H

View File

@ -34,7 +34,6 @@ ElimGraph::ElimGraph (const vector<Factor*>& factors)
}
}
}
setIndexes();
}
@ -148,6 +147,7 @@ void
ElimGraph::addNode (EgNode* n)
{
nodes_.push_back (n);
n->setIndex (nodes_.size() - 1);
varMap_.insert (make_pair (n->varId(), n));
}
@ -301,13 +301,3 @@ ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
return false;
}
void
ElimGraph::setIndexes (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i);
}
}

View File

@ -78,8 +78,6 @@ class ElimGraph
bool neighbors (const EgNode*, const EgNode*) const;
void setIndexes (void);
vector<EgNode*> nodes_;
vector<bool> marked_;
unordered_map<VarId, EgNode*> varMap_;

View File

@ -241,7 +241,7 @@ Factor::print (void) const
for (unsigned i = 0; i < args_.size(); i++) {
vars.push_back (new Var (args_[i], ranges_[i]));
}
vector<string> jointStrings = Util::getJointStateStrings (vars);
vector<string> jointStrings = Util::getStateLines (vars);
for (unsigned i = 0; i < params_.size(); i++) {
cout << "f(" << jointStrings[i] << ")" ;
cout << " = " << params_[i] << endl;

View File

@ -11,7 +11,7 @@
using namespace std;
void processArguments (FactorGraph&, int, const char* []);
void runSolver (Solver*, const VarIds&);
void runSolver (const FactorGraph&, const VarIds&);
const string USAGE = "usage: \
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
@ -25,8 +25,8 @@ main (int argc, const char* argv[])
cerr << USAGE << endl;
exit (0);
}
const string& fileName = argv[1];
const string& extension = fileName.substr (
string fileName = argv[1];
string extension = fileName.substr (
fileName.find_last_of ('.') + 1);
FactorGraph fg;
if (extension == "uai") {
@ -38,8 +38,6 @@ main (int argc, const char* argv[])
cerr << "in a UAI or libDAI file" << endl;
exit (0);
}
fg.print();
assert (false);
processArguments (fg, argc, argv);
return 0;
}
@ -123,6 +121,14 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
}
}
}
runSolver (fg, queryIds);
}
void
runSolver (const FactorGraph& fg, const VarIds& queryIds)
{
Solver* solver = 0;
switch (Globals::infAlgorithm) {
case InfAlgorithms::VE:
@ -137,23 +143,10 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
default:
assert (false);
}
runSolver (solver, queryIds);
}
void
runSolver (Solver* solver, const VarIds& queryIds)
{
if (queryIds.size() == 0) {
solver->runSolver();
solver->printAllPosterioris();
} else if (queryIds.size() == 1) {
solver->runSolver();
solver->printPosterioriOf (queryIds[0]);
} else {
solver->runSolver();
solver->printJointDistributionOf (queryIds);
solver->printAnswer (queryIds);
}
delete solver;
}

View File

@ -379,11 +379,7 @@ void runVeSolver (
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
}
VarElimSolver solver (*mfg);
if (tasks[i].size() == 1) {
results.push_back (solver.getPosterioriOf (tasks[i][0]));
} else {
results.push_back (solver.getJointDistributionOf (tasks[i]));
}
results.push_back (solver.solveQuery (tasks[i]));
if (fg->isFromBayesNetwork()) {
delete mfg;
}
@ -416,14 +412,9 @@ void runBpSolver (
cerr << "error: unknow solver" << endl;
abort();
}
solver->runSolver();
results.reserve (tasks.size());
for (unsigned i = 0; i < tasks.size(); i++) {
if (tasks[i].size() == 1) {
results.push_back (solver->getPosterioriOf (tasks[i][0]));
} else {
results.push_back (solver->getJointDistributionOf (tasks[i]));
}
results.push_back (solver->solveQuery (tasks[i]));
}
if (fg->isFromBayesNetwork()) {
delete mfg;

View File

@ -2,52 +2,36 @@
#include "Util.h"
void
Solver::printAnswer (const VarIds& vids)
{
Vars unobservedVars;
VarIds unobservedVids;
for (unsigned i = 0; i < vids.size(); i++) {
VarNode* vn = fg_.getVarNode (vids[i]);
if (vn->hasEvidence() == false) {
unobservedVars.push_back (vn);
unobservedVids.push_back (vids[i]);
}
}
Params res = solveQuery (unobservedVids);
vector<string> stateLines = Util::getStateLines (unobservedVars);
for (unsigned i = 0; i < res.size(); i++) {
cout << "P(" << stateLines[i] << ") = " ;
cout << std::setprecision (Constants::PRECISION) << res[i];
cout << endl;
}
cout << endl;
}
void
Solver::printAllPosterioris (void)
{
const VarNodes& vars = fg_.varNodes();
for (unsigned i = 0; i < vars.size(); i++) {
printPosterioriOf (vars[i]->varId());
printAnswer ({vars[i]->varId()});
}
}
void
Solver::printPosterioriOf (VarId vid)
{
VarNode* vn = fg_.getVarNode (vid);
const Params& posterioriDist = getPosterioriOf (vid);
const States& states = vn->states();
for (unsigned i = 0; i < states.size(); i++) {
cout << "P(" << vn->label() << "=" << states[i] << ") = " ;
cout << setprecision (Constants::PRECISION) << posterioriDist[i];
cout << endl;
}
cout << endl;
}
void
Solver::printJointDistributionOf (const VarIds& vids)
{
Vars vars;
VarIds vidsWithoutEvidence;
for (unsigned i = 0; i < vids.size(); i++) {
VarNode* vn = fg_.getVarNode (vids[i]);
if (vn->hasEvidence() == false) {
vars.push_back (vn);
vidsWithoutEvidence.push_back (vids[i]);
}
}
const Params& jointDist = getJointDistributionOf (vidsWithoutEvidence);
vector<string> jointStrings = Util::getJointStateStrings (vars);
for (unsigned i = 0; i < jointDist.size(); i++) {
cout << "P(" << jointStrings[i] << ") = " ;
cout << setprecision (Constants::PRECISION) << jointDist[i];
cout << endl;
}
cout << endl;
}

View File

@ -16,19 +16,13 @@ class Solver
virtual ~Solver() { } // ensure that subclass destructor is called
virtual void runSolver (void) = 0;
virtual Params solveQuery (VarIds queryVids) = 0;
virtual Params getPosterioriOf (VarId) = 0;
virtual Params getJointDistributionOf (const VarIds&) = 0;
void printAnswer (const VarIds& vids);
void printAllPosterioris (void);
void printPosterioriOf (VarId vid);
void printJointDistributionOf (const VarIds& vids);
private:
protected:
const FactorGraph& fg_;
};

View File

@ -137,7 +137,7 @@ parametersToString (const Params& v, unsigned precision)
vector<string>
getJointStateStrings (const Vars& vars)
getStateLines (const Vars& vars)
{
StatesIndexer idx (vars);
vector<string> jointStrings;

View File

@ -62,7 +62,7 @@ bool isInteger (const string&);
string parametersToString (const Params&, unsigned = Constants::PRECISION);
vector<string> getJointStateStrings (const Vars&);
vector<string> getStateLines (const Vars&);
void printHeader (string, std::ostream& os = std::cout);

View File

@ -6,13 +6,6 @@
#include "Util.h"
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (fg)
{
factorGraph_ = &fg;
}
VarElimSolver::~VarElimSolver (void)
{
delete factorList_.back();
@ -21,30 +14,15 @@ VarElimSolver::~VarElimSolver (void)
Params
VarElimSolver::getPosterioriOf (VarId vid)
{
assert (factorGraph_->getVarNode (vid));
VarNode* vn = factorGraph_->getVarNode (vid);
if (vn->hasEvidence()) {
Params params (vn->range(), 0.0);
params[vn->getEvidence()] = 1.0;
return params;
}
return getJointDistributionOf (VarIds() = {vid});
}
Params
VarElimSolver::getJointDistributionOf (const VarIds& vids)
VarElimSolver::solveQuery (VarIds queryVids)
{
factorList_.clear();
varFactors_.clear();
elimOrder_.clear();
createFactorList();
absorveEvidence();
findEliminationOrder (vids);
processFactorList (vids);
findEliminationOrder (queryVids);
processFactorList (queryVids);
Params params = factorList_.back()->params();
if (Globals::logDomain) {
Util::fromLog (params);
@ -57,10 +35,10 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
void
VarElimSolver::createFactorList (void)
{
const FacNodes& facNodes = factorGraph_->facNodes();
const FacNodes& facNodes = fg_.facNodes();
factorList_.reserve (facNodes.size() * 2);
for (unsigned i = 0; i < facNodes.size(); i++) {
factorList_.push_back (new Factor (facNodes[i]->factor())); // FIXME
factorList_.push_back (new Factor (facNodes[i]->factor()));
const VarNodes& neighs = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
unordered_map<VarId,vector<unsigned> >::iterator it
@ -79,7 +57,7 @@ VarElimSolver::createFactorList (void)
void
VarElimSolver::absorveEvidence (void)
{
const VarNodes& varNodes = factorGraph_->varNodes();
const VarNodes& varNodes = fg_.varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) {
const vector<unsigned>& idxs =
@ -125,7 +103,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
VarIds unobservedVids;
for (unsigned i = 0; i < vids.size(); i++) {
if (factorGraph_->getVarNode (vids[i])->hasEvidence() == false) {
if (fg_.getVarNode (vids[i])->hasEvidence() == false) {
unobservedVids.push_back (vids[i]);
}
}
@ -146,7 +124,7 @@ VarElimSolver::eliminate (VarId elimVar)
unsigned idx = idxs[i];
if (factorList_[idx]) {
if (result == 0) {
result = new Factor(*factorList_[idx]);
result = new Factor (*factorList_[idx]);
} else {
result->multiply (*factorList_[idx]);
}

View File

@ -14,15 +14,11 @@ using namespace std;
class VarElimSolver : public Solver
{
public:
VarElimSolver (const FactorGraph&);
VarElimSolver (const FactorGraph& fg) : Solver (fg) { }
~VarElimSolver (void);
void runSolver (void) { }
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
Params solveQuery (VarIds);
private:
void createFactorList (void);
@ -37,7 +33,6 @@ class VarElimSolver : public Solver
void printActiveFactors (void);
const FactorGraph* factorGraph_;
vector<Factor*> factorList_;
VarIds elimOrder_;
unordered_map<VarId, vector<unsigned>> varFactors_;