refactor ground solver interface
This commit is contained in:
parent
46e6a10625
commit
78e86a6330
@ -15,6 +15,7 @@
|
||||
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
||||
{
|
||||
factorGraph_ = &fg;
|
||||
runned_ = false;
|
||||
}
|
||||
|
||||
|
||||
@ -34,31 +35,14 @@ BpSolver::~BpSolver (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::runSolver (void)
|
||||
Params
|
||||
BpSolver::solveQuery (VarIds queryVids)
|
||||
{
|
||||
clock_t start;
|
||||
if (Constants::COLLECT_STATS) {
|
||||
start = clock();
|
||||
}
|
||||
runLoopySolver();
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
unsigned size = factorGraph_->varNodes().size();
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = factorGraph_->isTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
assert (queryVids.empty() == false);
|
||||
if (queryVids.size() == 1) {
|
||||
return getPosterioriOf (queryVids[0]);
|
||||
} else {
|
||||
return getJointDistributionOf (queryVids);
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,6 +51,9 @@ BpSolver::runSolver (void)
|
||||
Params
|
||||
BpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
assert (factorGraph_->getVarNode (vid));
|
||||
VarNode* var = factorGraph_->getVarNode (vid);
|
||||
Params probs;
|
||||
@ -97,6 +84,9 @@ BpSolver::getPosterioriOf (VarId vid)
|
||||
Params
|
||||
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
{
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
int idx = -1;
|
||||
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
|
||||
const FacNodes& facNodes = vn->neighbors();
|
||||
@ -130,52 +120,6 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::runLoopySolver (void)
|
||||
{
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
|
||||
nIters_ ++;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader (" Iteration " + nIters_);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
updateMessage(links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::initializeSolver (void)
|
||||
{
|
||||
@ -226,40 +170,6 @@ BpSolver::createLinks (void)
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BpSolver::converged (void)
|
||||
{
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIters_ == 0 || nIters_ == 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
} else {
|
||||
converged = true;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
if (Constants::DEBUG == 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::maxResidualSchedule (void)
|
||||
{
|
||||
@ -493,3 +403,114 @@ BpSolver::printLinkInformation (void) const
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::runSolver (void)
|
||||
{
|
||||
clock_t start;
|
||||
if (Constants::COLLECT_STATS) {
|
||||
start = clock();
|
||||
}
|
||||
runLoopySolver();
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
unsigned size = factorGraph_->varNodes().size();
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = factorGraph_->isTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
}
|
||||
runned_ = true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::runLoopySolver (void)
|
||||
{
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
|
||||
nIters_ ++;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader (" Iteration " + nIters_);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
updateMessage(links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BpSolver::converged (void)
|
||||
{
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIters_ == 0 || nIters_ == 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
} else {
|
||||
converged = true;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
if (Constants::DEBUG == 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
|
||||
|
@ -95,7 +95,7 @@ class BpSolver : public Solver
|
||||
|
||||
virtual ~BpSolver (void);
|
||||
|
||||
void runSolver (void);
|
||||
Params solveQuery (VarIds);
|
||||
|
||||
virtual Params getPosterioriOf (VarId);
|
||||
|
||||
@ -169,6 +169,7 @@ class BpSolver : public Solver
|
||||
unsigned nIters_;
|
||||
vector<SPNodeInfo*> varsI_;
|
||||
vector<SPNodeInfo*> facsI_;
|
||||
bool runned_;
|
||||
const FactorGraph* factorGraph_;
|
||||
|
||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
||||
@ -178,6 +179,7 @@ class BpSolver : public Solver
|
||||
SpLinkMap linkMap_;
|
||||
|
||||
private:
|
||||
void runSolver (void);
|
||||
void runLoopySolver (void);
|
||||
bool converged (void);
|
||||
};
|
||||
|
@ -85,9 +85,8 @@ CbpSolver::initializeSolver (void)
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nClusterVars = factorGraph_->varNodes().size();
|
||||
unsigned nClusterFacs = factorGraph_->facNodes().size();
|
||||
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
|
||||
nClusterVars, nClusterFacs,
|
||||
nWithoutNeighs);
|
||||
Statistics::updateCompressingStatistics (nGroundVars,
|
||||
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
|
||||
}
|
||||
|
||||
// cout << "Compressed Factor Graph" << endl;
|
||||
|
@ -37,7 +37,7 @@ class CbpSolverLink : public SpLink
|
||||
class CbpSolver : public BpSolver
|
||||
{
|
||||
public:
|
||||
CbpSolver (FactorGraph& fg) : BpSolver (fg) { }
|
||||
CbpSolver (const FactorGraph& fg) : BpSolver (fg) { }
|
||||
|
||||
~CbpSolver (void);
|
||||
|
||||
@ -47,13 +47,17 @@ class CbpSolver : public BpSolver
|
||||
|
||||
private:
|
||||
void initializeSolver (void);
|
||||
|
||||
void createLinks (void);
|
||||
|
||||
void maxResidualSchedule (void);
|
||||
|
||||
Params getVar2FactorMsg (const SpLink*) const;
|
||||
|
||||
void printLinkInformation (void) const;
|
||||
|
||||
CFactorGraph* lfg_;
|
||||
FactorGraph* factorGraph_;
|
||||
};
|
||||
|
||||
#endif // HORUS_CBP_H
|
||||
|
@ -34,7 +34,6 @@ ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||
}
|
||||
}
|
||||
}
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
@ -148,6 +147,7 @@ void
|
||||
ElimGraph::addNode (EgNode* n)
|
||||
{
|
||||
nodes_.push_back (n);
|
||||
n->setIndex (nodes_.size() - 1);
|
||||
varMap_.insert (make_pair (n->varId(), n));
|
||||
}
|
||||
|
||||
@ -301,13 +301,3 @@ ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::setIndexes (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->setIndex (i);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,8 +78,6 @@ class ElimGraph
|
||||
|
||||
bool neighbors (const EgNode*, const EgNode*) const;
|
||||
|
||||
void setIndexes (void);
|
||||
|
||||
vector<EgNode*> nodes_;
|
||||
vector<bool> marked_;
|
||||
unordered_map<VarId, EgNode*> varMap_;
|
||||
|
@ -241,7 +241,7 @@ Factor::print (void) const
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
vars.push_back (new Var (args_[i], ranges_[i]));
|
||||
}
|
||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
||||
vector<string> jointStrings = Util::getStateLines (vars);
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
cout << "f(" << jointStrings[i] << ")" ;
|
||||
cout << " = " << params_[i] << endl;
|
||||
|
@ -11,7 +11,7 @@
|
||||
using namespace std;
|
||||
|
||||
void processArguments (FactorGraph&, int, const char* []);
|
||||
void runSolver (Solver*, const VarIds&);
|
||||
void runSolver (const FactorGraph&, const VarIds&);
|
||||
|
||||
const string USAGE = "usage: \
|
||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||
@ -25,8 +25,8 @@ main (int argc, const char* argv[])
|
||||
cerr << USAGE << endl;
|
||||
exit (0);
|
||||
}
|
||||
const string& fileName = argv[1];
|
||||
const string& extension = fileName.substr (
|
||||
string fileName = argv[1];
|
||||
string extension = fileName.substr (
|
||||
fileName.find_last_of ('.') + 1);
|
||||
FactorGraph fg;
|
||||
if (extension == "uai") {
|
||||
@ -38,8 +38,6 @@ main (int argc, const char* argv[])
|
||||
cerr << "in a UAI or libDAI file" << endl;
|
||||
exit (0);
|
||||
}
|
||||
fg.print();
|
||||
assert (false);
|
||||
processArguments (fg, argc, argv);
|
||||
return 0;
|
||||
}
|
||||
@ -123,6 +121,14 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
}
|
||||
}
|
||||
}
|
||||
runSolver (fg, queryIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||
{
|
||||
Solver* solver = 0;
|
||||
switch (Globals::infAlgorithm) {
|
||||
case InfAlgorithms::VE:
|
||||
@ -137,23 +143,10 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
runSolver (solver, queryIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
runSolver (Solver* solver, const VarIds& queryIds)
|
||||
{
|
||||
if (queryIds.size() == 0) {
|
||||
solver->runSolver();
|
||||
solver->printAllPosterioris();
|
||||
} else if (queryIds.size() == 1) {
|
||||
solver->runSolver();
|
||||
solver->printPosterioriOf (queryIds[0]);
|
||||
} else {
|
||||
solver->runSolver();
|
||||
solver->printJointDistributionOf (queryIds);
|
||||
solver->printAnswer (queryIds);
|
||||
}
|
||||
delete solver;
|
||||
}
|
||||
|
@ -379,11 +379,7 @@ void runVeSolver (
|
||||
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
||||
}
|
||||
VarElimSolver solver (*mfg);
|
||||
if (tasks[i].size() == 1) {
|
||||
results.push_back (solver.getPosterioriOf (tasks[i][0]));
|
||||
} else {
|
||||
results.push_back (solver.getJointDistributionOf (tasks[i]));
|
||||
}
|
||||
results.push_back (solver.solveQuery (tasks[i]));
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
delete mfg;
|
||||
}
|
||||
@ -416,14 +412,9 @@ void runBpSolver (
|
||||
cerr << "error: unknow solver" << endl;
|
||||
abort();
|
||||
}
|
||||
solver->runSolver();
|
||||
results.reserve (tasks.size());
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
if (tasks[i].size() == 1) {
|
||||
results.push_back (solver->getPosterioriOf (tasks[i][0]));
|
||||
} else {
|
||||
results.push_back (solver->getJointDistributionOf (tasks[i]));
|
||||
}
|
||||
results.push_back (solver->solveQuery (tasks[i]));
|
||||
}
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
delete mfg;
|
||||
|
@ -2,52 +2,36 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
void
|
||||
Solver::printAnswer (const VarIds& vids)
|
||||
{
|
||||
Vars unobservedVars;
|
||||
VarIds unobservedVids;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
VarNode* vn = fg_.getVarNode (vids[i]);
|
||||
if (vn->hasEvidence() == false) {
|
||||
unobservedVars.push_back (vn);
|
||||
unobservedVids.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
Params res = solveQuery (unobservedVids);
|
||||
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
||||
for (unsigned i = 0; i < res.size(); i++) {
|
||||
cout << "P(" << stateLines[i] << ") = " ;
|
||||
cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Solver::printAllPosterioris (void)
|
||||
{
|
||||
const VarNodes& vars = fg_.varNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
printPosterioriOf (vars[i]->varId());
|
||||
printAnswer ({vars[i]->varId()});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Solver::printPosterioriOf (VarId vid)
|
||||
{
|
||||
VarNode* vn = fg_.getVarNode (vid);
|
||||
const Params& posterioriDist = getPosterioriOf (vid);
|
||||
const States& states = vn->states();
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
cout << "P(" << vn->label() << "=" << states[i] << ") = " ;
|
||||
cout << setprecision (Constants::PRECISION) << posterioriDist[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Solver::printJointDistributionOf (const VarIds& vids)
|
||||
{
|
||||
Vars vars;
|
||||
VarIds vidsWithoutEvidence;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
VarNode* vn = fg_.getVarNode (vids[i]);
|
||||
if (vn->hasEvidence() == false) {
|
||||
vars.push_back (vn);
|
||||
vidsWithoutEvidence.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
const Params& jointDist = getJointDistributionOf (vidsWithoutEvidence);
|
||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
||||
for (unsigned i = 0; i < jointDist.size(); i++) {
|
||||
cout << "P(" << jointStrings[i] << ") = " ;
|
||||
cout << setprecision (Constants::PRECISION) << jointDist[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
@ -16,19 +16,13 @@ class Solver
|
||||
|
||||
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||
|
||||
virtual void runSolver (void) = 0;
|
||||
virtual Params solveQuery (VarIds queryVids) = 0;
|
||||
|
||||
virtual Params getPosterioriOf (VarId) = 0;
|
||||
|
||||
virtual Params getJointDistributionOf (const VarIds&) = 0;
|
||||
void printAnswer (const VarIds& vids);
|
||||
|
||||
void printAllPosterioris (void);
|
||||
|
||||
void printPosterioriOf (VarId vid);
|
||||
|
||||
void printJointDistributionOf (const VarIds& vids);
|
||||
|
||||
private:
|
||||
protected:
|
||||
const FactorGraph& fg_;
|
||||
};
|
||||
|
||||
|
@ -137,7 +137,7 @@ parametersToString (const Params& v, unsigned precision)
|
||||
|
||||
|
||||
vector<string>
|
||||
getJointStateStrings (const Vars& vars)
|
||||
getStateLines (const Vars& vars)
|
||||
{
|
||||
StatesIndexer idx (vars);
|
||||
vector<string> jointStrings;
|
||||
|
@ -62,7 +62,7 @@ bool isInteger (const string&);
|
||||
|
||||
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
||||
|
||||
vector<string> getJointStateStrings (const Vars&);
|
||||
vector<string> getStateLines (const Vars&);
|
||||
|
||||
void printHeader (string, std::ostream& os = std::cout);
|
||||
|
||||
|
@ -6,13 +6,6 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (fg)
|
||||
{
|
||||
factorGraph_ = &fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarElimSolver::~VarElimSolver (void)
|
||||
{
|
||||
delete factorList_.back();
|
||||
@ -21,30 +14,15 @@ VarElimSolver::~VarElimSolver (void)
|
||||
|
||||
|
||||
Params
|
||||
VarElimSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
assert (factorGraph_->getVarNode (vid));
|
||||
VarNode* vn = factorGraph_->getVarNode (vid);
|
||||
if (vn->hasEvidence()) {
|
||||
Params params (vn->range(), 0.0);
|
||||
params[vn->getEvidence()] = 1.0;
|
||||
return params;
|
||||
}
|
||||
return getJointDistributionOf (VarIds() = {vid});
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
||||
VarElimSolver::solveQuery (VarIds queryVids)
|
||||
{
|
||||
factorList_.clear();
|
||||
varFactors_.clear();
|
||||
elimOrder_.clear();
|
||||
createFactorList();
|
||||
absorveEvidence();
|
||||
findEliminationOrder (vids);
|
||||
processFactorList (vids);
|
||||
findEliminationOrder (queryVids);
|
||||
processFactorList (queryVids);
|
||||
Params params = factorList_.back()->params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
@ -57,10 +35,10 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
||||
void
|
||||
VarElimSolver::createFactorList (void)
|
||||
{
|
||||
const FacNodes& facNodes = factorGraph_->facNodes();
|
||||
const FacNodes& facNodes = fg_.facNodes();
|
||||
factorList_.reserve (facNodes.size() * 2);
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
factorList_.push_back (new Factor (facNodes[i]->factor())); // FIXME
|
||||
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
unordered_map<VarId,vector<unsigned> >::iterator it
|
||||
@ -79,7 +57,7 @@ VarElimSolver::createFactorList (void)
|
||||
void
|
||||
VarElimSolver::absorveEvidence (void)
|
||||
{
|
||||
const VarNodes& varNodes = factorGraph_->varNodes();
|
||||
const VarNodes& varNodes = fg_.varNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
if (varNodes[i]->hasEvidence()) {
|
||||
const vector<unsigned>& idxs =
|
||||
@ -125,7 +103,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
||||
|
||||
VarIds unobservedVids;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
if (factorGraph_->getVarNode (vids[i])->hasEvidence() == false) {
|
||||
if (fg_.getVarNode (vids[i])->hasEvidence() == false) {
|
||||
unobservedVids.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
@ -146,7 +124,7 @@ VarElimSolver::eliminate (VarId elimVar)
|
||||
unsigned idx = idxs[i];
|
||||
if (factorList_[idx]) {
|
||||
if (result == 0) {
|
||||
result = new Factor(*factorList_[idx]);
|
||||
result = new Factor (*factorList_[idx]);
|
||||
} else {
|
||||
result->multiply (*factorList_[idx]);
|
||||
}
|
||||
|
@ -14,15 +14,11 @@ using namespace std;
|
||||
class VarElimSolver : public Solver
|
||||
{
|
||||
public:
|
||||
VarElimSolver (const FactorGraph&);
|
||||
VarElimSolver (const FactorGraph& fg) : Solver (fg) { }
|
||||
|
||||
~VarElimSolver (void);
|
||||
|
||||
void runSolver (void) { }
|
||||
|
||||
Params getPosterioriOf (VarId);
|
||||
|
||||
Params getJointDistributionOf (const VarIds&);
|
||||
Params solveQuery (VarIds);
|
||||
|
||||
private:
|
||||
void createFactorList (void);
|
||||
@ -37,7 +33,6 @@ class VarElimSolver : public Solver
|
||||
|
||||
void printActiveFactors (void);
|
||||
|
||||
const FactorGraph* factorGraph_;
|
||||
vector<Factor*> factorList_;
|
||||
VarIds elimOrder_;
|
||||
unordered_map<VarId, vector<unsigned>> varFactors_;
|
||||
|
Reference in New Issue
Block a user