minor improvements

This commit is contained in:
Tiago Gomes 2012-05-28 18:26:15 +01:00
parent 0cee50496e
commit 004e6c0c5f
3 changed files with 17 additions and 21 deletions

View File

@ -39,11 +39,9 @@ Params
BpSolver::solveQuery (VarIds queryVids) BpSolver::solveQuery (VarIds queryVids)
{ {
assert (queryVids.empty() == false); assert (queryVids.empty() == false);
if (queryVids.size() == 1) { return queryVids.size() == 1
return getPosterioriOf (queryVids[0]); ? getPosterioriOf (queryVids[0])
} else { : getJointDistributionOf (queryVids);
return getJointDistributionOf (queryVids);
}
} }
@ -61,8 +59,8 @@ BpSolver::printSolverFlags (void) const
case Sch::PARALLEL: ss << "parallel"; break; case Sch::PARALLEL: ss << "parallel"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break; case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
} }
ss << ",max_iter=" << Util::toString (BpOptions::maxIter); ss << ",max_iter=" << Util::toString (BpOptions::maxIter);
ss << ",accuracy=" << Util::toString (BpOptions::accuracy); ss << ",accuracy=" << Util::toString (BpOptions::accuracy);
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; cout << ss.str() << endl;

View File

@ -178,7 +178,7 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
solver->printSolverFlags(); solver->printSolverFlags();
cout << endl; cout << endl;
} }
if (queryIds.size() == 0) { if (queryIds.empty()) {
solver->printAllPosterioris(); solver->printAllPosterioris();
} else { } else {
solver->printAnswer (queryIds); solver->printAnswer (queryIds);

View File

@ -14,14 +14,16 @@ Solver::printAnswer (const VarIds& vids)
unobservedVids.push_back (vids[i]); unobservedVids.push_back (vids[i]);
} }
} }
Params res = solveQuery (unobservedVids); if (unobservedVids.empty() == false) {
vector<string> stateLines = Util::getStateLines (unobservedVars); Params res = solveQuery (unobservedVids);
for (size_t i = 0; i < res.size(); i++) { vector<string> stateLines = Util::getStateLines (unobservedVars);
cout << "P(" << stateLines[i] << ") = " ; for (size_t i = 0; i < res.size(); i++) {
cout << std::setprecision (Constants::PRECISION) << res[i]; cout << "P(" << stateLines[i] << ") = " ;
cout << std::setprecision (Constants::PRECISION) << res[i];
cout << endl;
}
cout << endl; cout << endl;
} }
cout << endl;
} }
@ -29,14 +31,10 @@ Solver::printAnswer (const VarIds& vids)
void void
Solver::printAllPosterioris (void) Solver::printAllPosterioris (void)
{ {
VarIds vids; VarNodes vars = fg.varNodes();
const VarNodes& vars = fg.varNodes(); std::sort (vars.begin(), vars.end(), sortByVarId());
for (size_t i = 0; i < vars.size(); i++) { for (size_t i = 0; i < vars.size(); i++) {
vids.push_back (vars[i]->varId()); printAnswer ({vars[i]->varId()});
}
std::sort (vids.begin(), vids.end());
for (size_t i = 0; i < vids.size(); i++) {
printAnswer ({vids[i]});
} }
} }