This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/Util.cpp

445 lines
9.2 KiB
C++
Raw Permalink Normal View History

2012-05-23 14:56:01 +01:00
#include "Util.h"
#include "Indexer.h"
#include "ElimGraph.h"
#include "BeliefProp.h"
2012-05-23 14:56:01 +01:00
namespace Horus {
2013-02-07 23:53:13 +00:00
namespace Globals {
2013-02-07 13:37:15 +00:00
2012-05-23 14:56:01 +01:00
bool logDomain = false;
unsigned verbosity = 0;
LiftedSolverType liftedSolver = LiftedSolverType::lveSolver;
GroundSolverType groundSolver = GroundSolverType::veSolver;
2013-02-07 23:53:13 +00:00
}
2012-05-23 14:56:01 +01:00
namespace Util {
2012-05-23 14:56:01 +01:00
template <> std::string
toString (const bool& b)
{
std::stringstream ss;
ss << std::boolalpha << b;
return ss.str();
}
unsigned
2013-02-07 13:37:15 +00:00
stringToUnsigned (std::string str)
2012-05-23 14:56:01 +01:00
{
int val;
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2012-05-23 14:56:01 +01:00
ss << str;
ss >> val;
if (val < 0) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: the number readed is negative." << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
return static_cast<unsigned> (val);
}
double
2013-02-07 13:37:15 +00:00
stringToDouble (std::string str)
2012-05-23 14:56:01 +01:00
{
double val;
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2012-05-23 14:56:01 +01:00
ss << str;
ss >> val;
return val;
}
double
factorial (unsigned num)
{
double result = 1.0;
for (unsigned i = 1; i <= num; i++) {
result *= i;
}
return result;
}
double
logFactorial (unsigned num)
{
double result = 0.0;
if (num < 150) {
result = std::log (factorial (num));
} else {
for (unsigned i = 1; i <= num; i++) {
result += std::log (i);
}
}
return result;
}
unsigned
nrCombinations (unsigned n, unsigned k)
{
assert (n >= k);
int diff = n - k;
unsigned result = 0;
if (n < 150) {
unsigned prod = 1;
for (int i = n; i > diff; i--) {
prod *= i;
}
result = prod / factorial (k);
} else {
double prod = 0.0;
for (int i = n; i > diff; i--) {
prod += std::log (i);
}
prod -= logFactorial (k);
result = static_cast<unsigned> (std::exp (prod));
}
return result;
}
2012-05-24 22:55:20 +01:00
size_t
2012-05-24 16:14:13 +01:00
sizeExpected (const Ranges& ranges)
2012-05-23 14:56:01 +01:00
{
2012-05-28 21:27:52 +01:00
return std::accumulate (ranges.begin(),
2013-02-07 13:37:15 +00:00
ranges.end(), 1, std::multiplies<unsigned>());
2012-05-23 14:56:01 +01:00
}
unsigned
2012-05-24 16:14:13 +01:00
nrDigits (int num)
2012-05-23 14:56:01 +01:00
{
unsigned count = 1;
while (num >= 10) {
2012-12-20 23:19:10 +00:00
num /= 10;
2012-05-23 14:56:01 +01:00
count ++;
}
return count;
}
bool
2013-02-07 13:37:15 +00:00
isInteger (const std::string& s)
2012-05-23 14:56:01 +01:00
{
2013-02-07 13:37:15 +00:00
std::stringstream ss1 (s);
std::stringstream ss2;
2012-05-23 14:56:01 +01:00
int integer;
ss1 >> integer;
ss2 << integer;
return (ss1.str() == ss2.str());
}
2013-02-07 13:37:15 +00:00
std::string
2012-05-23 14:56:01 +01:00
parametersToString (const Params& v, unsigned precision)
{
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2012-05-23 14:56:01 +01:00
ss.precision (precision);
2012-12-20 23:19:10 +00:00
ss << "[" ;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < v.size(); i++) {
2012-05-23 14:56:01 +01:00
if (i != 0) ss << ", " ;
ss << v[i];
}
ss << "]" ;
return ss.str();
}
2013-02-07 13:37:15 +00:00
std::vector<std::string>
2012-05-23 14:56:01 +01:00
getStateLines (const Vars& vars)
{
2012-05-24 22:55:20 +01:00
Ranges ranges;
for (size_t i = 0; i < vars.size(); i++) {
ranges.push_back (vars[i]->range());
}
Indexer indexer (ranges);
2013-02-07 13:37:15 +00:00
std::vector<std::string> jointStrings;
2012-05-24 22:55:20 +01:00
while (indexer.valid()) {
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < vars.size(); i++) {
2012-05-23 14:56:01 +01:00
if (i != 0) ss << ", " ;
2012-05-24 22:55:20 +01:00
ss << vars[i]->label() << "=" ;
ss << vars[i]->states()[(indexer[i])];
2012-05-23 14:56:01 +01:00
}
jointStrings.push_back (ss.str());
2012-05-24 22:55:20 +01:00
++ indexer;
2012-05-23 14:56:01 +01:00
}
return jointStrings;
}
2013-02-07 13:37:15 +00:00
bool invalidValue (std::string option, std::string value)
2013-01-08 21:13:58 +00:00
{
2013-02-07 13:37:15 +00:00
std::cerr << "Warning: invalid value `" << value << "' " ;
std::cerr << "for `" << option << "'." ;
std::cerr << std::endl;
2013-01-08 21:13:58 +00:00
return false;
}
2012-05-23 14:56:01 +01:00
bool
2013-02-07 13:37:15 +00:00
setHorusFlag (std::string option, std::string value)
2012-05-23 14:56:01 +01:00
{
bool returnVal = true;
if (option == "lifted_solver") {
2013-03-09 17:14:00 +00:00
if (value == "lve")
Globals::liftedSolver = LiftedSolverType::lveSolver;
else if (value == "lbp")
Globals::liftedSolver = LiftedSolverType::lbpSolver;
else if (value == "lkc")
Globals::liftedSolver = LiftedSolverType::lkcSolver;
else
returnVal = invalidValue (option, value);
2013-01-08 21:13:58 +00:00
} else if (option == "ground_solver" || option == "solver") {
2013-03-09 17:14:00 +00:00
if (value == "hve")
Globals::groundSolver = GroundSolverType::veSolver;
else if (value == "bp")
Globals::groundSolver = GroundSolverType::bpSolver;
else if (value == "cbp")
Globals::groundSolver = GroundSolverType::CbpSolver;
else
returnVal = invalidValue (option, value);
2013-01-08 21:13:58 +00:00
} else if (option == "verbosity") {
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2012-12-27 23:21:32 +00:00
ss << value;
ss >> Globals::verbosity;
2013-01-08 21:13:58 +00:00
} else if (option == "use_logarithms") {
if (value == "true") Globals::logDomain = true;
else if (value == "false") Globals::logDomain = false;
else returnVal = invalidValue (option, value);
2013-01-08 21:13:58 +00:00
} else if (option == "hve_elim_heuristic") {
typedef ElimGraph::ElimHeuristic ElimHeuristic;
2013-01-08 21:13:58 +00:00
if (value == "sequential")
ElimGraph::setElimHeuristic (ElimHeuristic::sequentialEh);
2013-01-08 21:13:58 +00:00
else if (value == "min_neighbors")
ElimGraph::setElimHeuristic (ElimHeuristic::minNeighborsEh);
2013-01-08 21:13:58 +00:00
else if (value == "min_weight")
ElimGraph::setElimHeuristic (ElimHeuristic::minWeightEh);
2013-01-08 21:13:58 +00:00
else if (value == "min_fill")
ElimGraph::setElimHeuristic (ElimHeuristic::minFillEh);
2013-01-08 21:13:58 +00:00
else if (value == "weighted_min_fill")
ElimGraph::setElimHeuristic (ElimHeuristic::weightedMinFillEh);
2013-01-08 21:13:58 +00:00
else
returnVal = invalidValue (option, value);
2013-01-08 21:13:58 +00:00
} else if (option == "bp_msg_schedule") {
typedef BeliefProp::MsgSchedule MsgSchedule;
2013-01-08 21:13:58 +00:00
if (value == "seq_fixed")
BeliefProp::setMsgSchedule (MsgSchedule::seqFixedSch);
2013-01-08 21:13:58 +00:00
else if (value == "seq_random")
BeliefProp::setMsgSchedule (MsgSchedule::seqRandomSch);
2013-01-08 21:13:58 +00:00
else if (value == "parallel")
BeliefProp::setMsgSchedule (MsgSchedule::parallelSch);
2013-01-08 21:13:58 +00:00
else if (value == "max_residual")
BeliefProp::setMsgSchedule (MsgSchedule::maxResidualSch);
2013-01-08 21:13:58 +00:00
else
returnVal = invalidValue (option, value);
2013-01-08 21:13:58 +00:00
} else if (option == "bp_accuracy") {
2013-02-07 13:37:15 +00:00
std::stringstream ss;
double acc;
2012-05-23 14:56:01 +01:00
ss << value;
ss >> acc;
BeliefProp::setAccuracy (acc);
2013-01-08 21:13:58 +00:00
} else if (option == "bp_max_iter") {
2013-02-07 13:37:15 +00:00
std::stringstream ss;
unsigned mi;
2012-05-23 14:56:01 +01:00
ss << value;
ss >> mi;
BeliefProp::setMaxIterations (mi);
2013-01-08 21:13:58 +00:00
} else if (option == "export_libdai") {
2013-01-08 21:13:58 +00:00
if (value == "true") FactorGraph::enableExportToLibDai();
else if (value == "false") FactorGraph::disableExportToLibDai();
else returnVal = invalidValue (option, value);
2013-01-08 21:13:58 +00:00
} else if (option == "export_uai") {
2013-01-08 21:13:58 +00:00
if (value == "true") FactorGraph::enableExportToUai();
else if (value == "false") FactorGraph::disableExportToUai();
else returnVal = invalidValue (option, value);
2013-01-08 21:13:58 +00:00
} else if (option == "export_graphviz") {
2013-01-08 21:13:58 +00:00
if (value == "true") FactorGraph::enableExportToGraphViz();
else if (value == "false") FactorGraph::disableExportToGraphViz();
else returnVal = invalidValue (option, value);
2013-01-08 21:13:58 +00:00
} else if (option == "print_fg") {
2013-01-08 21:13:58 +00:00
if (value == "true") FactorGraph::enablePrintFactorGraph();
else if (value == "false") FactorGraph::disablePrintFactorGraph();
else returnVal = invalidValue (option, value);
2013-01-08 21:13:58 +00:00
2012-05-23 14:56:01 +01:00
} else {
2013-02-07 13:37:15 +00:00
std::cerr << "Warning: invalid option `" << option << "'" << std::endl;
2012-05-23 14:56:01 +01:00
returnVal = false;
}
return returnVal;
}
void
2013-02-07 13:37:15 +00:00
printHeader (std::string header, std::ostream& os)
2012-05-23 14:56:01 +01:00
{
printAsteriskLine (os);
2013-02-07 13:37:15 +00:00
os << header << std::endl;
2012-05-23 14:56:01 +01:00
printAsteriskLine (os);
}
void
2013-02-07 13:37:15 +00:00
printSubHeader (std::string header, std::ostream& os)
2012-05-23 14:56:01 +01:00
{
printDashedLine (os);
2013-02-07 13:37:15 +00:00
os << header << std::endl;
2012-05-23 14:56:01 +01:00
printDashedLine (os);
}
void
printAsteriskLine (std::ostream& os)
{
os << "********************************" ;
os << "********************************" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
2012-05-23 14:56:01 +01:00
}
void
printDashedLine (std::ostream& os)
{
os << "--------------------------------" ;
os << "--------------------------------" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
2012-05-23 14:56:01 +01:00
}
2013-02-07 23:53:13 +00:00
} // namespace Util
2012-05-23 14:56:01 +01:00
namespace LogAware {
2012-05-23 14:56:01 +01:00
void
normalize (Params& v)
{
if (Globals::logDomain) {
double sum = std::accumulate (v.begin(), v.end(),
LogAware::addIdenty(), Util::logSum);
2012-12-27 12:54:58 +00:00
assert (sum != -std::numeric_limits<double>::infinity());
v -= sum;
2012-05-23 14:56:01 +01:00
} else {
2012-05-31 12:19:13 +01:00
double sum = std::accumulate (v.begin(), v.end(), 0.0);
2012-05-23 14:56:01 +01:00
assert (sum != 0.0);
v /= sum;
2012-05-23 14:56:01 +01:00
}
}
double
getL1Distance (const Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
double dist = 0.0;
if (Globals::logDomain) {
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
2013-02-13 19:09:11 +00:00
std::plus<double>(), FuncObj::abs_diff_exp<double>());
2012-05-23 14:56:01 +01:00
} else {
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
2013-02-13 19:09:11 +00:00
std::plus<double>(), FuncObj::abs_diff<double>());
2012-05-23 14:56:01 +01:00
}
return dist;
}
double
getMaxNorm (const Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
double max = 0.0;
if (Globals::logDomain) {
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
2013-02-13 19:09:11 +00:00
FuncObj::max<double>(), FuncObj::abs_diff_exp<double>());
2012-05-23 14:56:01 +01:00
} else {
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
2013-02-13 19:09:11 +00:00
FuncObj::max<double>(), FuncObj::abs_diff<double>());
2012-05-23 14:56:01 +01:00
}
return max;
}
2012-05-23 14:56:01 +01:00
double
pow (double base, unsigned iexp)
2012-05-23 14:56:01 +01:00
{
return Globals::logDomain
2012-05-28 21:27:52 +01:00
? base * iexp
: std::pow (base, iexp);
2012-05-23 14:56:01 +01:00
}
double
pow (double base, double exp)
2012-05-23 14:56:01 +01:00
{
2012-05-28 21:27:52 +01:00
// `expoent' should not be in log domain
return Globals::logDomain
2012-05-28 21:27:52 +01:00
? base * exp
: std::pow (base, exp);
2012-05-23 14:56:01 +01:00
}
void
pow (Params& v, unsigned iexp)
2012-05-23 14:56:01 +01:00
{
if (iexp == 1) {
2012-05-23 14:56:01 +01:00
return;
}
Globals::logDomain ? v *= iexp : v ^= (int)iexp;
2012-05-23 14:56:01 +01:00
}
void
pow (Params& v, double exp)
2012-05-23 14:56:01 +01:00
{
// `expoent' should not be in log domain
Globals::logDomain ? v *= exp : v ^= exp;
2012-05-23 14:56:01 +01:00
}
} // namespace LogAware
2013-02-07 23:53:13 +00:00
} // namespace Horus
2013-02-07 23:53:13 +00:00
2012-05-23 14:56:01 +01:00