From 004e6c0c5fc5119fa03063f43abc41377bd88d7a Mon Sep 17 00:00:00 2001 From: Tiago Gomes Date: Mon, 28 May 2012 18:26:15 +0100 Subject: [PATCH] minor improvements --- packages/CLPBN/horus/BpSolver.cpp | 12 +++++------- packages/CLPBN/horus/HorusCli.cpp | 2 +- packages/CLPBN/horus/Solver.cpp | 24 +++++++++++------------- 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/packages/CLPBN/horus/BpSolver.cpp b/packages/CLPBN/horus/BpSolver.cpp index 4bba67823..944fa16c0 100644 --- a/packages/CLPBN/horus/BpSolver.cpp +++ b/packages/CLPBN/horus/BpSolver.cpp @@ -39,11 +39,9 @@ Params BpSolver::solveQuery (VarIds queryVids) { assert (queryVids.empty() == false); - if (queryVids.size() == 1) { - return getPosterioriOf (queryVids[0]); - } else { - return getJointDistributionOf (queryVids); - } + return queryVids.size() == 1 + ? getPosterioriOf (queryVids[0]) + : getJointDistributionOf (queryVids); } @@ -61,8 +59,8 @@ BpSolver::printSolverFlags (void) const case Sch::PARALLEL: ss << "parallel"; break; case Sch::MAX_RESIDUAL: ss << "max_residual"; break; } - ss << ",max_iter=" << Util::toString (BpOptions::maxIter); - ss << ",accuracy=" << Util::toString (BpOptions::accuracy); + ss << ",max_iter=" << Util::toString (BpOptions::maxIter); + ss << ",accuracy=" << Util::toString (BpOptions::accuracy); ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << "]" ; cout << ss.str() << endl; diff --git a/packages/CLPBN/horus/HorusCli.cpp b/packages/CLPBN/horus/HorusCli.cpp index 00bbeefbd..7766d65f0 100644 --- a/packages/CLPBN/horus/HorusCli.cpp +++ b/packages/CLPBN/horus/HorusCli.cpp @@ -178,7 +178,7 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds) solver->printSolverFlags(); cout << endl; } - if (queryIds.size() == 0) { + if (queryIds.empty()) { solver->printAllPosterioris(); } else { solver->printAnswer (queryIds); diff --git a/packages/CLPBN/horus/Solver.cpp b/packages/CLPBN/horus/Solver.cpp index 20d503a02..4f1b52d5b 100644 --- a/packages/CLPBN/horus/Solver.cpp +++ b/packages/CLPBN/horus/Solver.cpp @@ -14,14 +14,16 @@ Solver::printAnswer (const VarIds& vids) unobservedVids.push_back (vids[i]); } } - Params res = solveQuery (unobservedVids); - vector stateLines = Util::getStateLines (unobservedVars); - for (size_t i = 0; i < res.size(); i++) { - cout << "P(" << stateLines[i] << ") = " ; - cout << std::setprecision (Constants::PRECISION) << res[i]; + if (unobservedVids.empty() == false) { + Params res = solveQuery (unobservedVids); + vector stateLines = Util::getStateLines (unobservedVars); + for (size_t i = 0; i < res.size(); i++) { + cout << "P(" << stateLines[i] << ") = " ; + cout << std::setprecision (Constants::PRECISION) << res[i]; + cout << endl; + } cout << endl; } - cout << endl; } @@ -29,14 +31,10 @@ Solver::printAnswer (const VarIds& vids) void Solver::printAllPosterioris (void) { - VarIds vids; - const VarNodes& vars = fg.varNodes(); + VarNodes vars = fg.varNodes(); + std::sort (vars.begin(), vars.end(), sortByVarId()); for (size_t i = 0; i < vars.size(); i++) { - vids.push_back (vars[i]->varId()); - } - std::sort (vids.begin(), vids.end()); - for (size_t i = 0; i < vids.size(); i++) { - printAnswer ({vids[i]}); + printAnswer ({vars[i]->varId()}); } }