refactor ground solver interface
This commit is contained in:
parent
46e6a10625
commit
78e86a6330
@ -15,6 +15,7 @@
|
|||||||
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
||||||
{
|
{
|
||||||
factorGraph_ = &fg;
|
factorGraph_ = &fg;
|
||||||
|
runned_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -34,31 +35,14 @@ BpSolver::~BpSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
Params
|
||||||
BpSolver::runSolver (void)
|
BpSolver::solveQuery (VarIds queryVids)
|
||||||
{
|
{
|
||||||
clock_t start;
|
assert (queryVids.empty() == false);
|
||||||
if (Constants::COLLECT_STATS) {
|
if (queryVids.size() == 1) {
|
||||||
start = clock();
|
return getPosterioriOf (queryVids[0]);
|
||||||
}
|
} else {
|
||||||
runLoopySolver();
|
return getJointDistributionOf (queryVids);
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
|
||||||
cout << "Sum-Product converged in " ;
|
|
||||||
cout << nIters_ << " iterations" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
unsigned size = factorGraph_->varNodes().size();
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
unsigned nIters = 0;
|
|
||||||
bool loopy = factorGraph_->isTree() == false;
|
|
||||||
if (loopy) nIters = nIters_;
|
|
||||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
|
||||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,6 +51,9 @@ BpSolver::runSolver (void)
|
|||||||
Params
|
Params
|
||||||
BpSolver::getPosterioriOf (VarId vid)
|
BpSolver::getPosterioriOf (VarId vid)
|
||||||
{
|
{
|
||||||
|
if (runned_ == false) {
|
||||||
|
runSolver();
|
||||||
|
}
|
||||||
assert (factorGraph_->getVarNode (vid));
|
assert (factorGraph_->getVarNode (vid));
|
||||||
VarNode* var = factorGraph_->getVarNode (vid);
|
VarNode* var = factorGraph_->getVarNode (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
@ -97,6 +84,9 @@ BpSolver::getPosterioriOf (VarId vid)
|
|||||||
Params
|
Params
|
||||||
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||||
{
|
{
|
||||||
|
if (runned_ == false) {
|
||||||
|
runSolver();
|
||||||
|
}
|
||||||
int idx = -1;
|
int idx = -1;
|
||||||
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
|
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
|
||||||
const FacNodes& facNodes = vn->neighbors();
|
const FacNodes& facNodes = vn->neighbors();
|
||||||
@ -130,52 +120,6 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BpSolver::runLoopySolver (void)
|
|
||||||
{
|
|
||||||
initializeSolver();
|
|
||||||
nIters_ = 0;
|
|
||||||
|
|
||||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
|
||||||
|
|
||||||
nIters_ ++;
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
Util::printHeader (" Iteration " + nIters_);
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (BpOptions::schedule) {
|
|
||||||
case BpOptions::Schedule::SEQ_RANDOM:
|
|
||||||
random_shuffle (links_.begin(), links_.end());
|
|
||||||
// no break
|
|
||||||
|
|
||||||
case BpOptions::Schedule::SEQ_FIXED:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateAndUpdateMessage (links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case BpOptions::Schedule::PARALLEL:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
updateMessage(links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
|
||||||
maxResidualSchedule();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BpSolver::initializeSolver (void)
|
BpSolver::initializeSolver (void)
|
||||||
{
|
{
|
||||||
@ -226,40 +170,6 @@ BpSolver::createLinks (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BpSolver::converged (void)
|
|
||||||
{
|
|
||||||
if (links_.size() == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (nIters_ == 0 || nIters_ == 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool converged = true;
|
|
||||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
|
||||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
|
||||||
if (maxResidual > BpOptions::accuracy) {
|
|
||||||
converged = false;
|
|
||||||
} else {
|
|
||||||
converged = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
double residual = links_[i]->getResidual();
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
|
||||||
}
|
|
||||||
if (residual > BpOptions::accuracy) {
|
|
||||||
converged = false;
|
|
||||||
if (Constants::DEBUG == 0) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return converged;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BpSolver::maxResidualSchedule (void)
|
BpSolver::maxResidualSchedule (void)
|
||||||
{
|
{
|
||||||
@ -493,3 +403,114 @@ BpSolver::printLinkInformation (void) const
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BpSolver::runSolver (void)
|
||||||
|
{
|
||||||
|
clock_t start;
|
||||||
|
if (Constants::COLLECT_STATS) {
|
||||||
|
start = clock();
|
||||||
|
}
|
||||||
|
runLoopySolver();
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
|
cout << endl;
|
||||||
|
if (nIters_ < BpOptions::maxIter) {
|
||||||
|
cout << "Sum-Product converged in " ;
|
||||||
|
cout << nIters_ << " iterations" << endl;
|
||||||
|
} else {
|
||||||
|
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unsigned size = factorGraph_->varNodes().size();
|
||||||
|
if (Constants::COLLECT_STATS) {
|
||||||
|
unsigned nIters = 0;
|
||||||
|
bool loopy = factorGraph_->isTree() == false;
|
||||||
|
if (loopy) nIters = nIters_;
|
||||||
|
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||||
|
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||||
|
}
|
||||||
|
runned_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BpSolver::runLoopySolver (void)
|
||||||
|
{
|
||||||
|
initializeSolver();
|
||||||
|
nIters_ = 0;
|
||||||
|
|
||||||
|
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||||
|
|
||||||
|
nIters_ ++;
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
|
Util::printHeader (" Iteration " + nIters_);
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (BpOptions::schedule) {
|
||||||
|
case BpOptions::Schedule::SEQ_RANDOM:
|
||||||
|
random_shuffle (links_.begin(), links_.end());
|
||||||
|
// no break
|
||||||
|
|
||||||
|
case BpOptions::Schedule::SEQ_FIXED:
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
calculateAndUpdateMessage (links_[i]);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case BpOptions::Schedule::PARALLEL:
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
calculateMessage (links_[i]);
|
||||||
|
}
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
updateMessage(links_[i]);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||||
|
maxResidualSchedule();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
BpSolver::converged (void)
|
||||||
|
{
|
||||||
|
if (links_.size() == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (nIters_ == 0 || nIters_ == 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool converged = true;
|
||||||
|
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||||
|
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||||
|
if (maxResidual > BpOptions::accuracy) {
|
||||||
|
converged = false;
|
||||||
|
} else {
|
||||||
|
converged = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
double residual = links_[i]->getResidual();
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
|
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||||
|
}
|
||||||
|
if (residual > BpOptions::accuracy) {
|
||||||
|
converged = false;
|
||||||
|
if (Constants::DEBUG == 0) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return converged;
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ class BpSolver : public Solver
|
|||||||
|
|
||||||
virtual ~BpSolver (void);
|
virtual ~BpSolver (void);
|
||||||
|
|
||||||
void runSolver (void);
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
virtual Params getPosterioriOf (VarId);
|
virtual Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
@ -169,6 +169,7 @@ class BpSolver : public Solver
|
|||||||
unsigned nIters_;
|
unsigned nIters_;
|
||||||
vector<SPNodeInfo*> varsI_;
|
vector<SPNodeInfo*> varsI_;
|
||||||
vector<SPNodeInfo*> facsI_;
|
vector<SPNodeInfo*> facsI_;
|
||||||
|
bool runned_;
|
||||||
const FactorGraph* factorGraph_;
|
const FactorGraph* factorGraph_;
|
||||||
|
|
||||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
||||||
@ -178,6 +179,7 @@ class BpSolver : public Solver
|
|||||||
SpLinkMap linkMap_;
|
SpLinkMap linkMap_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void runSolver (void);
|
||||||
void runLoopySolver (void);
|
void runLoopySolver (void);
|
||||||
bool converged (void);
|
bool converged (void);
|
||||||
};
|
};
|
||||||
|
@ -85,9 +85,8 @@ CbpSolver::initializeSolver (void)
|
|||||||
if (Constants::COLLECT_STATS) {
|
if (Constants::COLLECT_STATS) {
|
||||||
unsigned nClusterVars = factorGraph_->varNodes().size();
|
unsigned nClusterVars = factorGraph_->varNodes().size();
|
||||||
unsigned nClusterFacs = factorGraph_->facNodes().size();
|
unsigned nClusterFacs = factorGraph_->facNodes().size();
|
||||||
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
|
Statistics::updateCompressingStatistics (nGroundVars,
|
||||||
nClusterVars, nClusterFacs,
|
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
|
||||||
nWithoutNeighs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cout << "Compressed Factor Graph" << endl;
|
// cout << "Compressed Factor Graph" << endl;
|
||||||
|
@ -37,7 +37,7 @@ class CbpSolverLink : public SpLink
|
|||||||
class CbpSolver : public BpSolver
|
class CbpSolver : public BpSolver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CbpSolver (FactorGraph& fg) : BpSolver (fg) { }
|
CbpSolver (const FactorGraph& fg) : BpSolver (fg) { }
|
||||||
|
|
||||||
~CbpSolver (void);
|
~CbpSolver (void);
|
||||||
|
|
||||||
@ -47,13 +47,17 @@ class CbpSolver : public BpSolver
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
void initializeSolver (void);
|
void initializeSolver (void);
|
||||||
|
|
||||||
void createLinks (void);
|
void createLinks (void);
|
||||||
|
|
||||||
void maxResidualSchedule (void);
|
void maxResidualSchedule (void);
|
||||||
|
|
||||||
Params getVar2FactorMsg (const SpLink*) const;
|
Params getVar2FactorMsg (const SpLink*) const;
|
||||||
|
|
||||||
void printLinkInformation (void) const;
|
void printLinkInformation (void) const;
|
||||||
|
|
||||||
CFactorGraph* lfg_;
|
CFactorGraph* lfg_;
|
||||||
|
FactorGraph* factorGraph_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_CBP_H
|
#endif // HORUS_CBP_H
|
||||||
|
@ -34,7 +34,6 @@ ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
setIndexes();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -148,6 +147,7 @@ void
|
|||||||
ElimGraph::addNode (EgNode* n)
|
ElimGraph::addNode (EgNode* n)
|
||||||
{
|
{
|
||||||
nodes_.push_back (n);
|
nodes_.push_back (n);
|
||||||
|
n->setIndex (nodes_.size() - 1);
|
||||||
varMap_.insert (make_pair (n->varId(), n));
|
varMap_.insert (make_pair (n->varId(), n));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -301,13 +301,3 @@ ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ElimGraph::setIndexes (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
nodes_[i]->setIndex (i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -78,8 +78,6 @@ class ElimGraph
|
|||||||
|
|
||||||
bool neighbors (const EgNode*, const EgNode*) const;
|
bool neighbors (const EgNode*, const EgNode*) const;
|
||||||
|
|
||||||
void setIndexes (void);
|
|
||||||
|
|
||||||
vector<EgNode*> nodes_;
|
vector<EgNode*> nodes_;
|
||||||
vector<bool> marked_;
|
vector<bool> marked_;
|
||||||
unordered_map<VarId, EgNode*> varMap_;
|
unordered_map<VarId, EgNode*> varMap_;
|
||||||
|
@ -241,7 +241,7 @@ Factor::print (void) const
|
|||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
vars.push_back (new Var (args_[i], ranges_[i]));
|
vars.push_back (new Var (args_[i], ranges_[i]));
|
||||||
}
|
}
|
||||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
vector<string> jointStrings = Util::getStateLines (vars);
|
||||||
for (unsigned i = 0; i < params_.size(); i++) {
|
for (unsigned i = 0; i < params_.size(); i++) {
|
||||||
cout << "f(" << jointStrings[i] << ")" ;
|
cout << "f(" << jointStrings[i] << ")" ;
|
||||||
cout << " = " << params_[i] << endl;
|
cout << " = " << params_[i] << endl;
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
void processArguments (FactorGraph&, int, const char* []);
|
void processArguments (FactorGraph&, int, const char* []);
|
||||||
void runSolver (Solver*, const VarIds&);
|
void runSolver (const FactorGraph&, const VarIds&);
|
||||||
|
|
||||||
const string USAGE = "usage: \
|
const string USAGE = "usage: \
|
||||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||||
@ -25,8 +25,8 @@ main (int argc, const char* argv[])
|
|||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
const string& fileName = argv[1];
|
string fileName = argv[1];
|
||||||
const string& extension = fileName.substr (
|
string extension = fileName.substr (
|
||||||
fileName.find_last_of ('.') + 1);
|
fileName.find_last_of ('.') + 1);
|
||||||
FactorGraph fg;
|
FactorGraph fg;
|
||||||
if (extension == "uai") {
|
if (extension == "uai") {
|
||||||
@ -38,8 +38,6 @@ main (int argc, const char* argv[])
|
|||||||
cerr << "in a UAI or libDAI file" << endl;
|
cerr << "in a UAI or libDAI file" << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
fg.print();
|
|
||||||
assert (false);
|
|
||||||
processArguments (fg, argc, argv);
|
processArguments (fg, argc, argv);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@ -123,6 +121,14 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
runSolver (fg, queryIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||||
|
{
|
||||||
Solver* solver = 0;
|
Solver* solver = 0;
|
||||||
switch (Globals::infAlgorithm) {
|
switch (Globals::infAlgorithm) {
|
||||||
case InfAlgorithms::VE:
|
case InfAlgorithms::VE:
|
||||||
@ -137,23 +143,10 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
default:
|
default:
|
||||||
assert (false);
|
assert (false);
|
||||||
}
|
}
|
||||||
runSolver (solver, queryIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
runSolver (Solver* solver, const VarIds& queryIds)
|
|
||||||
{
|
|
||||||
if (queryIds.size() == 0) {
|
if (queryIds.size() == 0) {
|
||||||
solver->runSolver();
|
|
||||||
solver->printAllPosterioris();
|
solver->printAllPosterioris();
|
||||||
} else if (queryIds.size() == 1) {
|
|
||||||
solver->runSolver();
|
|
||||||
solver->printPosterioriOf (queryIds[0]);
|
|
||||||
} else {
|
} else {
|
||||||
solver->runSolver();
|
solver->printAnswer (queryIds);
|
||||||
solver->printJointDistributionOf (queryIds);
|
|
||||||
}
|
}
|
||||||
delete solver;
|
delete solver;
|
||||||
}
|
}
|
||||||
|
@ -379,11 +379,7 @@ void runVeSolver (
|
|||||||
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
||||||
}
|
}
|
||||||
VarElimSolver solver (*mfg);
|
VarElimSolver solver (*mfg);
|
||||||
if (tasks[i].size() == 1) {
|
results.push_back (solver.solveQuery (tasks[i]));
|
||||||
results.push_back (solver.getPosterioriOf (tasks[i][0]));
|
|
||||||
} else {
|
|
||||||
results.push_back (solver.getJointDistributionOf (tasks[i]));
|
|
||||||
}
|
|
||||||
if (fg->isFromBayesNetwork()) {
|
if (fg->isFromBayesNetwork()) {
|
||||||
delete mfg;
|
delete mfg;
|
||||||
}
|
}
|
||||||
@ -416,14 +412,9 @@ void runBpSolver (
|
|||||||
cerr << "error: unknow solver" << endl;
|
cerr << "error: unknow solver" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
solver->runSolver();
|
|
||||||
results.reserve (tasks.size());
|
results.reserve (tasks.size());
|
||||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||||
if (tasks[i].size() == 1) {
|
results.push_back (solver->solveQuery (tasks[i]));
|
||||||
results.push_back (solver->getPosterioriOf (tasks[i][0]));
|
|
||||||
} else {
|
|
||||||
results.push_back (solver->getJointDistributionOf (tasks[i]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (fg->isFromBayesNetwork()) {
|
if (fg->isFromBayesNetwork()) {
|
||||||
delete mfg;
|
delete mfg;
|
||||||
|
@ -2,52 +2,36 @@
|
|||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Solver::printAnswer (const VarIds& vids)
|
||||||
|
{
|
||||||
|
Vars unobservedVars;
|
||||||
|
VarIds unobservedVids;
|
||||||
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
|
VarNode* vn = fg_.getVarNode (vids[i]);
|
||||||
|
if (vn->hasEvidence() == false) {
|
||||||
|
unobservedVars.push_back (vn);
|
||||||
|
unobservedVids.push_back (vids[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Params res = solveQuery (unobservedVids);
|
||||||
|
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
||||||
|
for (unsigned i = 0; i < res.size(); i++) {
|
||||||
|
cout << "P(" << stateLines[i] << ") = " ;
|
||||||
|
cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Solver::printAllPosterioris (void)
|
Solver::printAllPosterioris (void)
|
||||||
{
|
{
|
||||||
const VarNodes& vars = fg_.varNodes();
|
const VarNodes& vars = fg_.varNodes();
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
printPosterioriOf (vars[i]->varId());
|
printAnswer ({vars[i]->varId()});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Solver::printPosterioriOf (VarId vid)
|
|
||||||
{
|
|
||||||
VarNode* vn = fg_.getVarNode (vid);
|
|
||||||
const Params& posterioriDist = getPosterioriOf (vid);
|
|
||||||
const States& states = vn->states();
|
|
||||||
for (unsigned i = 0; i < states.size(); i++) {
|
|
||||||
cout << "P(" << vn->label() << "=" << states[i] << ") = " ;
|
|
||||||
cout << setprecision (Constants::PRECISION) << posterioriDist[i];
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Solver::printJointDistributionOf (const VarIds& vids)
|
|
||||||
{
|
|
||||||
Vars vars;
|
|
||||||
VarIds vidsWithoutEvidence;
|
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
|
||||||
VarNode* vn = fg_.getVarNode (vids[i]);
|
|
||||||
if (vn->hasEvidence() == false) {
|
|
||||||
vars.push_back (vn);
|
|
||||||
vidsWithoutEvidence.push_back (vids[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const Params& jointDist = getJointDistributionOf (vidsWithoutEvidence);
|
|
||||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
|
||||||
for (unsigned i = 0; i < jointDist.size(); i++) {
|
|
||||||
cout << "P(" << jointStrings[i] << ") = " ;
|
|
||||||
cout << setprecision (Constants::PRECISION) << jointDist[i];
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -16,19 +16,13 @@ class Solver
|
|||||||
|
|
||||||
virtual ~Solver() { } // ensure that subclass destructor is called
|
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||||
|
|
||||||
virtual void runSolver (void) = 0;
|
virtual Params solveQuery (VarIds queryVids) = 0;
|
||||||
|
|
||||||
virtual Params getPosterioriOf (VarId) = 0;
|
void printAnswer (const VarIds& vids);
|
||||||
|
|
||||||
virtual Params getJointDistributionOf (const VarIds&) = 0;
|
|
||||||
|
|
||||||
void printAllPosterioris (void);
|
void printAllPosterioris (void);
|
||||||
|
|
||||||
void printPosterioriOf (VarId vid);
|
protected:
|
||||||
|
|
||||||
void printJointDistributionOf (const VarIds& vids);
|
|
||||||
|
|
||||||
private:
|
|
||||||
const FactorGraph& fg_;
|
const FactorGraph& fg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -137,7 +137,7 @@ parametersToString (const Params& v, unsigned precision)
|
|||||||
|
|
||||||
|
|
||||||
vector<string>
|
vector<string>
|
||||||
getJointStateStrings (const Vars& vars)
|
getStateLines (const Vars& vars)
|
||||||
{
|
{
|
||||||
StatesIndexer idx (vars);
|
StatesIndexer idx (vars);
|
||||||
vector<string> jointStrings;
|
vector<string> jointStrings;
|
||||||
|
@ -62,7 +62,7 @@ bool isInteger (const string&);
|
|||||||
|
|
||||||
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
||||||
|
|
||||||
vector<string> getJointStateStrings (const Vars&);
|
vector<string> getStateLines (const Vars&);
|
||||||
|
|
||||||
void printHeader (string, std::ostream& os = std::cout);
|
void printHeader (string, std::ostream& os = std::cout);
|
||||||
|
|
||||||
|
@ -6,13 +6,6 @@
|
|||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (fg)
|
|
||||||
{
|
|
||||||
factorGraph_ = &fg;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarElimSolver::~VarElimSolver (void)
|
VarElimSolver::~VarElimSolver (void)
|
||||||
{
|
{
|
||||||
delete factorList_.back();
|
delete factorList_.back();
|
||||||
@ -21,30 +14,15 @@ VarElimSolver::~VarElimSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
VarElimSolver::getPosterioriOf (VarId vid)
|
VarElimSolver::solveQuery (VarIds queryVids)
|
||||||
{
|
|
||||||
assert (factorGraph_->getVarNode (vid));
|
|
||||||
VarNode* vn = factorGraph_->getVarNode (vid);
|
|
||||||
if (vn->hasEvidence()) {
|
|
||||||
Params params (vn->range(), 0.0);
|
|
||||||
params[vn->getEvidence()] = 1.0;
|
|
||||||
return params;
|
|
||||||
}
|
|
||||||
return getJointDistributionOf (VarIds() = {vid});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
|
||||||
{
|
{
|
||||||
factorList_.clear();
|
factorList_.clear();
|
||||||
varFactors_.clear();
|
varFactors_.clear();
|
||||||
elimOrder_.clear();
|
elimOrder_.clear();
|
||||||
createFactorList();
|
createFactorList();
|
||||||
absorveEvidence();
|
absorveEvidence();
|
||||||
findEliminationOrder (vids);
|
findEliminationOrder (queryVids);
|
||||||
processFactorList (vids);
|
processFactorList (queryVids);
|
||||||
Params params = factorList_.back()->params();
|
Params params = factorList_.back()->params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (params);
|
Util::fromLog (params);
|
||||||
@ -57,10 +35,10 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
|||||||
void
|
void
|
||||||
VarElimSolver::createFactorList (void)
|
VarElimSolver::createFactorList (void)
|
||||||
{
|
{
|
||||||
const FacNodes& facNodes = factorGraph_->facNodes();
|
const FacNodes& facNodes = fg_.facNodes();
|
||||||
factorList_.reserve (facNodes.size() * 2);
|
factorList_.reserve (facNodes.size() * 2);
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
factorList_.push_back (new Factor (facNodes[i]->factor())); // FIXME
|
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
||||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
unordered_map<VarId,vector<unsigned> >::iterator it
|
unordered_map<VarId,vector<unsigned> >::iterator it
|
||||||
@ -79,7 +57,7 @@ VarElimSolver::createFactorList (void)
|
|||||||
void
|
void
|
||||||
VarElimSolver::absorveEvidence (void)
|
VarElimSolver::absorveEvidence (void)
|
||||||
{
|
{
|
||||||
const VarNodes& varNodes = factorGraph_->varNodes();
|
const VarNodes& varNodes = fg_.varNodes();
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
if (varNodes[i]->hasEvidence()) {
|
if (varNodes[i]->hasEvidence()) {
|
||||||
const vector<unsigned>& idxs =
|
const vector<unsigned>& idxs =
|
||||||
@ -125,7 +103,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
|||||||
|
|
||||||
VarIds unobservedVids;
|
VarIds unobservedVids;
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
if (factorGraph_->getVarNode (vids[i])->hasEvidence() == false) {
|
if (fg_.getVarNode (vids[i])->hasEvidence() == false) {
|
||||||
unobservedVids.push_back (vids[i]);
|
unobservedVids.push_back (vids[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -146,7 +124,7 @@ VarElimSolver::eliminate (VarId elimVar)
|
|||||||
unsigned idx = idxs[i];
|
unsigned idx = idxs[i];
|
||||||
if (factorList_[idx]) {
|
if (factorList_[idx]) {
|
||||||
if (result == 0) {
|
if (result == 0) {
|
||||||
result = new Factor(*factorList_[idx]);
|
result = new Factor (*factorList_[idx]);
|
||||||
} else {
|
} else {
|
||||||
result->multiply (*factorList_[idx]);
|
result->multiply (*factorList_[idx]);
|
||||||
}
|
}
|
||||||
|
@ -14,15 +14,11 @@ using namespace std;
|
|||||||
class VarElimSolver : public Solver
|
class VarElimSolver : public Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
VarElimSolver (const FactorGraph&);
|
VarElimSolver (const FactorGraph& fg) : Solver (fg) { }
|
||||||
|
|
||||||
~VarElimSolver (void);
|
~VarElimSolver (void);
|
||||||
|
|
||||||
void runSolver (void) { }
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
Params getPosterioriOf (VarId);
|
|
||||||
|
|
||||||
Params getJointDistributionOf (const VarIds&);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void createFactorList (void);
|
void createFactorList (void);
|
||||||
@ -37,7 +33,6 @@ class VarElimSolver : public Solver
|
|||||||
|
|
||||||
void printActiveFactors (void);
|
void printActiveFactors (void);
|
||||||
|
|
||||||
const FactorGraph* factorGraph_;
|
|
||||||
vector<Factor*> factorList_;
|
vector<Factor*> factorList_;
|
||||||
VarIds elimOrder_;
|
VarIds elimOrder_;
|
||||||
unordered_map<VarId, vector<unsigned>> varFactors_;
|
unordered_map<VarId, vector<unsigned>> varFactors_;
|
||||||
|
Reference in New Issue
Block a user