minor improvements
This commit is contained in:
parent
0cee50496e
commit
004e6c0c5f
@ -39,11 +39,9 @@ Params
|
|||||||
BpSolver::solveQuery (VarIds queryVids)
|
BpSolver::solveQuery (VarIds queryVids)
|
||||||
{
|
{
|
||||||
assert (queryVids.empty() == false);
|
assert (queryVids.empty() == false);
|
||||||
if (queryVids.size() == 1) {
|
return queryVids.size() == 1
|
||||||
return getPosterioriOf (queryVids[0]);
|
? getPosterioriOf (queryVids[0])
|
||||||
} else {
|
: getJointDistributionOf (queryVids);
|
||||||
return getJointDistributionOf (queryVids);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -61,8 +59,8 @@ BpSolver::printSolverFlags (void) const
|
|||||||
case Sch::PARALLEL: ss << "parallel"; break;
|
case Sch::PARALLEL: ss << "parallel"; break;
|
||||||
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
}
|
}
|
||||||
ss << ",max_iter=" << Util::toString (BpOptions::maxIter);
|
ss << ",max_iter=" << Util::toString (BpOptions::maxIter);
|
||||||
ss << ",accuracy=" << Util::toString (BpOptions::accuracy);
|
ss << ",accuracy=" << Util::toString (BpOptions::accuracy);
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
cout << ss.str() << endl;
|
cout << ss.str() << endl;
|
||||||
|
@ -178,7 +178,7 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
|||||||
solver->printSolverFlags();
|
solver->printSolverFlags();
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
if (queryIds.size() == 0) {
|
if (queryIds.empty()) {
|
||||||
solver->printAllPosterioris();
|
solver->printAllPosterioris();
|
||||||
} else {
|
} else {
|
||||||
solver->printAnswer (queryIds);
|
solver->printAnswer (queryIds);
|
||||||
|
@ -14,14 +14,16 @@ Solver::printAnswer (const VarIds& vids)
|
|||||||
unobservedVids.push_back (vids[i]);
|
unobservedVids.push_back (vids[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Params res = solveQuery (unobservedVids);
|
if (unobservedVids.empty() == false) {
|
||||||
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
Params res = solveQuery (unobservedVids);
|
||||||
for (size_t i = 0; i < res.size(); i++) {
|
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
||||||
cout << "P(" << stateLines[i] << ") = " ;
|
for (size_t i = 0; i < res.size(); i++) {
|
||||||
cout << std::setprecision (Constants::PRECISION) << res[i];
|
cout << "P(" << stateLines[i] << ") = " ;
|
||||||
|
cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -29,14 +31,10 @@ Solver::printAnswer (const VarIds& vids)
|
|||||||
void
|
void
|
||||||
Solver::printAllPosterioris (void)
|
Solver::printAllPosterioris (void)
|
||||||
{
|
{
|
||||||
VarIds vids;
|
VarNodes vars = fg.varNodes();
|
||||||
const VarNodes& vars = fg.varNodes();
|
std::sort (vars.begin(), vars.end(), sortByVarId());
|
||||||
for (size_t i = 0; i < vars.size(); i++) {
|
for (size_t i = 0; i < vars.size(); i++) {
|
||||||
vids.push_back (vars[i]->varId());
|
printAnswer ({vars[i]->varId()});
|
||||||
}
|
|
||||||
std::sort (vids.begin(), vids.end());
|
|
||||||
for (size_t i = 0; i < vids.size(); i++) {
|
|
||||||
printAnswer ({vids[i]});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user