2013-02-07 20:09:10 +00:00
|
|
|
#include <cassert>
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
#include <string>
|
|
|
|
#include <iostream>
|
|
|
|
|
2012-11-14 23:17:39 +00:00
|
|
|
#include "GroundSolver.h"
|
2012-12-27 22:28:19 +00:00
|
|
|
#include "VarElim.h"
|
2012-11-14 23:17:39 +00:00
|
|
|
#include "BeliefProp.h"
|
|
|
|
#include "CountingBp.h"
|
2012-12-27 12:54:58 +00:00
|
|
|
#include "Util.h"
|
2012-11-14 23:17:39 +00:00
|
|
|
|
|
|
|
|
2013-02-07 23:53:13 +00:00
|
|
|
namespace horus {
|
|
|
|
|
2012-11-14 23:17:39 +00:00
|
|
|
void
|
|
|
|
GroundSolver::printAnswer (const VarIds& vids)
|
|
|
|
{
|
|
|
|
Vars unobservedVars;
|
|
|
|
VarIds unobservedVids;
|
|
|
|
for (size_t i = 0; i < vids.size(); i++) {
|
|
|
|
VarNode* vn = fg.getVarNode (vids[i]);
|
|
|
|
if (vn->hasEvidence() == false) {
|
|
|
|
unobservedVars.push_back (vn);
|
|
|
|
unobservedVids.push_back (vids[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (unobservedVids.empty() == false) {
|
|
|
|
Params res = solveQuery (unobservedVids);
|
2013-02-07 13:37:15 +00:00
|
|
|
std::vector<std::string> stateLines =
|
2013-02-08 00:15:41 +00:00
|
|
|
util::getStateLines (unobservedVars);
|
2012-11-14 23:17:39 +00:00
|
|
|
for (size_t i = 0; i < res.size(); i++) {
|
2013-02-07 13:37:15 +00:00
|
|
|
std::cout << "P(" << stateLines[i] << ") = " ;
|
2013-02-08 00:15:41 +00:00
|
|
|
std::cout << std::setprecision (constants::PRECISION) << res[i];
|
2013-02-07 13:37:15 +00:00
|
|
|
std::cout << std::endl;
|
2012-11-14 23:17:39 +00:00
|
|
|
}
|
2013-02-07 13:37:15 +00:00
|
|
|
std::cout << std::endl;
|
2012-11-14 23:17:39 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
GroundSolver::printAllPosterioris (void)
|
|
|
|
{
|
|
|
|
VarNodes vars = fg.varNodes();
|
|
|
|
std::sort (vars.begin(), vars.end(), sortByVarId());
|
|
|
|
for (size_t i = 0; i < vars.size(); i++) {
|
|
|
|
printAnswer ({vars[i]->varId()});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Params
|
|
|
|
GroundSolver::getJointByConditioning (
|
|
|
|
GroundSolverType solverType,
|
|
|
|
FactorGraph fg,
|
2012-12-27 12:54:58 +00:00
|
|
|
const VarIds& jointVarIds)
|
2012-11-14 23:17:39 +00:00
|
|
|
{
|
|
|
|
VarNodes jointVars;
|
|
|
|
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
|
|
|
assert (fg.getVarNode (jointVarIds[i]));
|
|
|
|
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
|
|
|
}
|
|
|
|
|
|
|
|
GroundSolver* solver = 0;
|
|
|
|
switch (solverType) {
|
|
|
|
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
|
|
|
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
|
|
|
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
|
|
|
}
|
|
|
|
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
|
|
|
VarIds observedVids = {jointVars[0]->varId()};
|
|
|
|
|
|
|
|
for (size_t i = 1; i < jointVarIds.size(); i++) {
|
|
|
|
assert (jointVars[i]->hasEvidence() == false);
|
|
|
|
Params newBeliefs;
|
|
|
|
Vars observedVars;
|
|
|
|
Ranges observedRanges;
|
|
|
|
for (size_t j = 0; j < observedVids.size(); j++) {
|
|
|
|
observedVars.push_back (fg.getVarNode (observedVids[j]));
|
|
|
|
observedRanges.push_back (observedVars.back()->range());
|
|
|
|
}
|
|
|
|
Indexer indexer (observedRanges, false);
|
|
|
|
while (indexer.valid()) {
|
|
|
|
for (size_t j = 0; j < observedVars.size(); j++) {
|
|
|
|
observedVars[j]->setEvidence (indexer[j]);
|
|
|
|
}
|
|
|
|
delete solver;
|
|
|
|
switch (solverType) {
|
|
|
|
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
|
|
|
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
|
|
|
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
|
|
|
}
|
|
|
|
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
|
|
|
for (size_t k = 0; k < beliefs.size(); k++) {
|
|
|
|
newBeliefs.push_back (beliefs[k]);
|
|
|
|
}
|
|
|
|
++ indexer;
|
|
|
|
}
|
|
|
|
|
|
|
|
int count = -1;
|
|
|
|
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
|
|
|
if (j % jointVars[i]->range() == 0) {
|
|
|
|
count ++;
|
|
|
|
}
|
|
|
|
newBeliefs[j] *= prevBeliefs[count];
|
|
|
|
}
|
|
|
|
prevBeliefs = newBeliefs;
|
|
|
|
observedVids.push_back (jointVars[i]->varId());
|
|
|
|
}
|
|
|
|
delete solver;
|
|
|
|
return prevBeliefs;
|
|
|
|
}
|
|
|
|
|
2013-02-07 23:53:13 +00:00
|
|
|
} // namespace horus
|
|
|
|
|