445 lines
9.2 KiB
C++
445 lines
9.2 KiB
C++
#include "Util.h"
|
|
#include "Indexer.h"
|
|
#include "ElimGraph.h"
|
|
#include "BeliefProp.h"
|
|
|
|
|
|
namespace Horus {
|
|
|
|
namespace Globals {
|
|
|
|
bool logDomain = false;
|
|
|
|
unsigned verbosity = 0;
|
|
|
|
LiftedSolverType liftedSolver = LiftedSolverType::lveSolver;
|
|
|
|
GroundSolverType groundSolver = GroundSolverType::veSolver;
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace Util {
|
|
|
|
template <> std::string
|
|
toString (const bool& b)
|
|
{
|
|
std::stringstream ss;
|
|
ss << std::boolalpha << b;
|
|
return ss.str();
|
|
}
|
|
|
|
|
|
|
|
unsigned
|
|
stringToUnsigned (std::string str)
|
|
{
|
|
int val;
|
|
std::stringstream ss;
|
|
ss << str;
|
|
ss >> val;
|
|
if (val < 0) {
|
|
std::cerr << "Error: the number readed is negative." << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
return static_cast<unsigned> (val);
|
|
}
|
|
|
|
|
|
|
|
double
|
|
stringToDouble (std::string str)
|
|
{
|
|
double val;
|
|
std::stringstream ss;
|
|
ss << str;
|
|
ss >> val;
|
|
return val;
|
|
}
|
|
|
|
|
|
|
|
double
|
|
factorial (unsigned num)
|
|
{
|
|
double result = 1.0;
|
|
for (unsigned i = 1; i <= num; i++) {
|
|
result *= i;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
|
|
double
|
|
logFactorial (unsigned num)
|
|
{
|
|
double result = 0.0;
|
|
if (num < 150) {
|
|
result = std::log (factorial (num));
|
|
} else {
|
|
for (unsigned i = 1; i <= num; i++) {
|
|
result += std::log (i);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
|
|
unsigned
|
|
nrCombinations (unsigned n, unsigned k)
|
|
{
|
|
assert (n >= k);
|
|
int diff = n - k;
|
|
unsigned result = 0;
|
|
if (n < 150) {
|
|
unsigned prod = 1;
|
|
for (int i = n; i > diff; i--) {
|
|
prod *= i;
|
|
}
|
|
result = prod / factorial (k);
|
|
} else {
|
|
double prod = 0.0;
|
|
for (int i = n; i > diff; i--) {
|
|
prod += std::log (i);
|
|
}
|
|
prod -= logFactorial (k);
|
|
result = static_cast<unsigned> (std::exp (prod));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
|
|
size_t
|
|
sizeExpected (const Ranges& ranges)
|
|
{
|
|
return std::accumulate (ranges.begin(),
|
|
ranges.end(), 1, std::multiplies<unsigned>());
|
|
}
|
|
|
|
|
|
|
|
unsigned
|
|
nrDigits (int num)
|
|
{
|
|
unsigned count = 1;
|
|
while (num >= 10) {
|
|
num /= 10;
|
|
count ++;
|
|
}
|
|
return count;
|
|
}
|
|
|
|
|
|
|
|
bool
|
|
isInteger (const std::string& s)
|
|
{
|
|
std::stringstream ss1 (s);
|
|
std::stringstream ss2;
|
|
int integer;
|
|
ss1 >> integer;
|
|
ss2 << integer;
|
|
return (ss1.str() == ss2.str());
|
|
}
|
|
|
|
|
|
|
|
std::string
|
|
parametersToString (const Params& v, unsigned precision)
|
|
{
|
|
std::stringstream ss;
|
|
ss.precision (precision);
|
|
ss << "[" ;
|
|
for (size_t i = 0; i < v.size(); i++) {
|
|
if (i != 0) ss << ", " ;
|
|
ss << v[i];
|
|
}
|
|
ss << "]" ;
|
|
return ss.str();
|
|
}
|
|
|
|
|
|
|
|
std::vector<std::string>
|
|
getStateLines (const Vars& vars)
|
|
{
|
|
Ranges ranges;
|
|
for (size_t i = 0; i < vars.size(); i++) {
|
|
ranges.push_back (vars[i]->range());
|
|
}
|
|
Indexer indexer (ranges);
|
|
std::vector<std::string> jointStrings;
|
|
while (indexer.valid()) {
|
|
std::stringstream ss;
|
|
for (size_t i = 0; i < vars.size(); i++) {
|
|
if (i != 0) ss << ", " ;
|
|
ss << vars[i]->label() << "=" ;
|
|
ss << vars[i]->states()[(indexer[i])];
|
|
}
|
|
jointStrings.push_back (ss.str());
|
|
++ indexer;
|
|
}
|
|
return jointStrings;
|
|
}
|
|
|
|
|
|
|
|
bool invalidValue (std::string option, std::string value)
|
|
{
|
|
std::cerr << "Warning: invalid value `" << value << "' " ;
|
|
std::cerr << "for `" << option << "'." ;
|
|
std::cerr << std::endl;
|
|
return false;
|
|
}
|
|
|
|
|
|
|
|
bool
|
|
setHorusFlag (std::string option, std::string value)
|
|
{
|
|
bool returnVal = true;
|
|
if (option == "lifted_solver") {
|
|
if (value == "lve")
|
|
Globals::liftedSolver = LiftedSolverType::lveSolver;
|
|
else if (value == "lbp")
|
|
Globals::liftedSolver = LiftedSolverType::lbpSolver;
|
|
else if (value == "lkc")
|
|
Globals::liftedSolver = LiftedSolverType::lkcSolver;
|
|
else
|
|
returnVal = invalidValue (option, value);
|
|
|
|
} else if (option == "ground_solver" || option == "solver") {
|
|
if (value == "hve")
|
|
Globals::groundSolver = GroundSolverType::veSolver;
|
|
else if (value == "bp")
|
|
Globals::groundSolver = GroundSolverType::bpSolver;
|
|
else if (value == "cbp")
|
|
Globals::groundSolver = GroundSolverType::CbpSolver;
|
|
else
|
|
returnVal = invalidValue (option, value);
|
|
|
|
} else if (option == "verbosity") {
|
|
std::stringstream ss;
|
|
ss << value;
|
|
ss >> Globals::verbosity;
|
|
|
|
} else if (option == "use_logarithms") {
|
|
if (value == "true") Globals::logDomain = true;
|
|
else if (value == "false") Globals::logDomain = false;
|
|
else returnVal = invalidValue (option, value);
|
|
|
|
} else if (option == "hve_elim_heuristic") {
|
|
typedef ElimGraph::ElimHeuristic ElimHeuristic;
|
|
if (value == "sequential")
|
|
ElimGraph::setElimHeuristic (ElimHeuristic::sequentialEh);
|
|
else if (value == "min_neighbors")
|
|
ElimGraph::setElimHeuristic (ElimHeuristic::minNeighborsEh);
|
|
else if (value == "min_weight")
|
|
ElimGraph::setElimHeuristic (ElimHeuristic::minWeightEh);
|
|
else if (value == "min_fill")
|
|
ElimGraph::setElimHeuristic (ElimHeuristic::minFillEh);
|
|
else if (value == "weighted_min_fill")
|
|
ElimGraph::setElimHeuristic (ElimHeuristic::weightedMinFillEh);
|
|
else
|
|
returnVal = invalidValue (option, value);
|
|
|
|
} else if (option == "bp_msg_schedule") {
|
|
typedef BeliefProp::MsgSchedule MsgSchedule;
|
|
if (value == "seq_fixed")
|
|
BeliefProp::setMsgSchedule (MsgSchedule::seqFixedSch);
|
|
else if (value == "seq_random")
|
|
BeliefProp::setMsgSchedule (MsgSchedule::seqRandomSch);
|
|
else if (value == "parallel")
|
|
BeliefProp::setMsgSchedule (MsgSchedule::parallelSch);
|
|
else if (value == "max_residual")
|
|
BeliefProp::setMsgSchedule (MsgSchedule::maxResidualSch);
|
|
else
|
|
returnVal = invalidValue (option, value);
|
|
|
|
} else if (option == "bp_accuracy") {
|
|
std::stringstream ss;
|
|
double acc;
|
|
ss << value;
|
|
ss >> acc;
|
|
BeliefProp::setAccuracy (acc);
|
|
|
|
} else if (option == "bp_max_iter") {
|
|
std::stringstream ss;
|
|
unsigned mi;
|
|
ss << value;
|
|
ss >> mi;
|
|
BeliefProp::setMaxIterations (mi);
|
|
|
|
} else if (option == "export_libdai") {
|
|
if (value == "true") FactorGraph::enableExportToLibDai();
|
|
else if (value == "false") FactorGraph::disableExportToLibDai();
|
|
else returnVal = invalidValue (option, value);
|
|
|
|
} else if (option == "export_uai") {
|
|
if (value == "true") FactorGraph::enableExportToUai();
|
|
else if (value == "false") FactorGraph::disableExportToUai();
|
|
else returnVal = invalidValue (option, value);
|
|
|
|
} else if (option == "export_graphviz") {
|
|
if (value == "true") FactorGraph::enableExportToGraphViz();
|
|
else if (value == "false") FactorGraph::disableExportToGraphViz();
|
|
else returnVal = invalidValue (option, value);
|
|
|
|
} else if (option == "print_fg") {
|
|
if (value == "true") FactorGraph::enablePrintFactorGraph();
|
|
else if (value == "false") FactorGraph::disablePrintFactorGraph();
|
|
else returnVal = invalidValue (option, value);
|
|
|
|
} else {
|
|
std::cerr << "Warning: invalid option `" << option << "'" << std::endl;
|
|
returnVal = false;
|
|
}
|
|
return returnVal;
|
|
}
|
|
|
|
|
|
|
|
void
|
|
printHeader (std::string header, std::ostream& os)
|
|
{
|
|
printAsteriskLine (os);
|
|
os << header << std::endl;
|
|
printAsteriskLine (os);
|
|
}
|
|
|
|
|
|
|
|
void
|
|
printSubHeader (std::string header, std::ostream& os)
|
|
{
|
|
printDashedLine (os);
|
|
os << header << std::endl;
|
|
printDashedLine (os);
|
|
}
|
|
|
|
|
|
|
|
void
|
|
printAsteriskLine (std::ostream& os)
|
|
{
|
|
os << "********************************" ;
|
|
os << "********************************" ;
|
|
os << std::endl;
|
|
}
|
|
|
|
|
|
|
|
void
|
|
printDashedLine (std::ostream& os)
|
|
{
|
|
os << "--------------------------------" ;
|
|
os << "--------------------------------" ;
|
|
os << std::endl;
|
|
}
|
|
|
|
} // namespace Util
|
|
|
|
|
|
|
|
namespace LogAware {
|
|
|
|
void
|
|
normalize (Params& v)
|
|
{
|
|
if (Globals::logDomain) {
|
|
double sum = std::accumulate (v.begin(), v.end(),
|
|
LogAware::addIdenty(), Util::logSum);
|
|
assert (sum != -std::numeric_limits<double>::infinity());
|
|
v -= sum;
|
|
} else {
|
|
double sum = std::accumulate (v.begin(), v.end(), 0.0);
|
|
assert (sum != 0.0);
|
|
v /= sum;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
double
|
|
getL1Distance (const Params& v1, const Params& v2)
|
|
{
|
|
assert (v1.size() == v2.size());
|
|
double dist = 0.0;
|
|
if (Globals::logDomain) {
|
|
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
|
std::plus<double>(), FuncObj::abs_diff_exp<double>());
|
|
} else {
|
|
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
|
std::plus<double>(), FuncObj::abs_diff<double>());
|
|
}
|
|
return dist;
|
|
}
|
|
|
|
|
|
|
|
double
|
|
getMaxNorm (const Params& v1, const Params& v2)
|
|
{
|
|
assert (v1.size() == v2.size());
|
|
double max = 0.0;
|
|
if (Globals::logDomain) {
|
|
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
|
FuncObj::max<double>(), FuncObj::abs_diff_exp<double>());
|
|
} else {
|
|
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
|
FuncObj::max<double>(), FuncObj::abs_diff<double>());
|
|
}
|
|
return max;
|
|
}
|
|
|
|
|
|
|
|
double
|
|
pow (double base, unsigned iexp)
|
|
{
|
|
return Globals::logDomain
|
|
? base * iexp
|
|
: std::pow (base, iexp);
|
|
}
|
|
|
|
|
|
|
|
double
|
|
pow (double base, double exp)
|
|
{
|
|
// `expoent' should not be in log domain
|
|
return Globals::logDomain
|
|
? base * exp
|
|
: std::pow (base, exp);
|
|
}
|
|
|
|
|
|
|
|
void
|
|
pow (Params& v, unsigned iexp)
|
|
{
|
|
if (iexp == 1) {
|
|
return;
|
|
}
|
|
Globals::logDomain ? v *= iexp : v ^= (int)iexp;
|
|
}
|
|
|
|
|
|
|
|
void
|
|
pow (Params& v, double exp)
|
|
{
|
|
// `expoent' should not be in log domain
|
|
Globals::logDomain ? v *= exp : v ^= exp;
|
|
}
|
|
|
|
} // namespace LogAware
|
|
|
|
} // namespace Horus
|
|
|
|
|