445 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			445 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#include "Util.h"
 | 
						|
#include "Indexer.h"
 | 
						|
#include "ElimGraph.h"
 | 
						|
#include "BeliefProp.h"
 | 
						|
 | 
						|
 | 
						|
namespace Horus {
 | 
						|
 | 
						|
namespace Globals {
 | 
						|
 | 
						|
bool logDomain = false;
 | 
						|
 | 
						|
unsigned verbosity = 0;
 | 
						|
 | 
						|
LiftedSolverType liftedSolver = LiftedSolverType::lveSolver;
 | 
						|
 | 
						|
GroundSolverType groundSolver = GroundSolverType::veSolver;
 | 
						|
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
namespace Util {
 | 
						|
 | 
						|
template <> std::string
 | 
						|
toString (const bool& b)
 | 
						|
{
 | 
						|
  std::stringstream ss;
 | 
						|
  ss << std::boolalpha << b;
 | 
						|
  return ss.str();
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
unsigned
 | 
						|
stringToUnsigned (std::string str)
 | 
						|
{
 | 
						|
  int val;
 | 
						|
  std::stringstream ss;
 | 
						|
  ss << str;
 | 
						|
  ss >> val;
 | 
						|
  if (val < 0) {
 | 
						|
    std::cerr << "Error: the number readed is negative." << std::endl;
 | 
						|
    exit (EXIT_FAILURE);
 | 
						|
  }
 | 
						|
  return static_cast<unsigned> (val);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
double
 | 
						|
stringToDouble (std::string str)
 | 
						|
{
 | 
						|
  double val;
 | 
						|
  std::stringstream ss;
 | 
						|
  ss << str;
 | 
						|
  ss >> val;
 | 
						|
  return val;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
double
 | 
						|
factorial (unsigned num)
 | 
						|
{
 | 
						|
  double result = 1.0;
 | 
						|
  for (unsigned i = 1; i <= num; i++) {
 | 
						|
    result *= i;
 | 
						|
  }
 | 
						|
  return result;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
double
 | 
						|
logFactorial (unsigned num)
 | 
						|
{
 | 
						|
  double result = 0.0;
 | 
						|
  if (num < 150) {
 | 
						|
    result = std::log (factorial (num));
 | 
						|
  } else {
 | 
						|
    for (unsigned i = 1; i <= num; i++) {
 | 
						|
      result += std::log (i);
 | 
						|
    }
 | 
						|
  }
 | 
						|
  return result;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
unsigned
 | 
						|
nrCombinations (unsigned n, unsigned k)
 | 
						|
{
 | 
						|
  assert (n >= k);
 | 
						|
  int diff = n - k;
 | 
						|
  unsigned result = 0;
 | 
						|
  if (n < 150) {
 | 
						|
    unsigned prod = 1;
 | 
						|
    for (int i = n; i > diff; i--) {
 | 
						|
      prod *= i;
 | 
						|
    }
 | 
						|
    result = prod / factorial (k);
 | 
						|
  } else {
 | 
						|
    double prod = 0.0;
 | 
						|
    for (int i = n; i > diff; i--) {
 | 
						|
      prod += std::log (i);
 | 
						|
    }
 | 
						|
    prod -= logFactorial (k);
 | 
						|
    result = static_cast<unsigned> (std::exp (prod));
 | 
						|
  }
 | 
						|
  return result;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
size_t
 | 
						|
sizeExpected (const Ranges& ranges)
 | 
						|
{
 | 
						|
  return std::accumulate (ranges.begin(),
 | 
						|
      ranges.end(), 1, std::multiplies<unsigned>());
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
unsigned
 | 
						|
nrDigits (int num)
 | 
						|
{
 | 
						|
  unsigned count = 1;
 | 
						|
  while (num >= 10) {
 | 
						|
    num /= 10;
 | 
						|
    count ++;
 | 
						|
  }
 | 
						|
  return count;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
bool
 | 
						|
isInteger (const std::string& s)
 | 
						|
{
 | 
						|
  std::stringstream ss1 (s);
 | 
						|
  std::stringstream ss2;
 | 
						|
  int integer;
 | 
						|
  ss1 >> integer;
 | 
						|
  ss2 << integer;
 | 
						|
  return (ss1.str() == ss2.str());
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
std::string
 | 
						|
parametersToString (const Params& v, unsigned precision)
 | 
						|
{
 | 
						|
  std::stringstream ss;
 | 
						|
  ss.precision (precision);
 | 
						|
  ss << "[" ;
 | 
						|
  for (size_t i = 0; i < v.size(); i++) {
 | 
						|
    if (i != 0) ss << ", " ;
 | 
						|
    ss << v[i];
 | 
						|
  }
 | 
						|
  ss << "]" ;
 | 
						|
  return ss.str();
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
std::vector<std::string>
 | 
						|
getStateLines (const Vars& vars)
 | 
						|
{
 | 
						|
  Ranges ranges;
 | 
						|
  for (size_t i = 0; i < vars.size(); i++) {
 | 
						|
    ranges.push_back (vars[i]->range());
 | 
						|
  }
 | 
						|
  Indexer indexer (ranges);
 | 
						|
  std::vector<std::string> jointStrings;
 | 
						|
  while (indexer.valid()) {
 | 
						|
    std::stringstream ss;
 | 
						|
    for (size_t i = 0; i < vars.size(); i++) {
 | 
						|
      if (i != 0) ss << ", " ;
 | 
						|
      ss << vars[i]->label() << "=" ;
 | 
						|
      ss << vars[i]->states()[(indexer[i])];
 | 
						|
    }
 | 
						|
    jointStrings.push_back (ss.str());
 | 
						|
    ++ indexer;
 | 
						|
  }
 | 
						|
  return jointStrings;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
bool invalidValue (std::string option, std::string value)
 | 
						|
{
 | 
						|
  std::cerr << "Warning: invalid value `" << value << "' " ;
 | 
						|
  std::cerr << "for `" << option << "'." ;
 | 
						|
  std::cerr << std::endl;
 | 
						|
  return false;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
bool
 | 
						|
setHorusFlag (std::string option, std::string value)
 | 
						|
{
 | 
						|
  bool returnVal = true;
 | 
						|
  if (option == "lifted_solver") {
 | 
						|
    if      (value == "lve")
 | 
						|
      Globals::liftedSolver = LiftedSolverType::lveSolver;
 | 
						|
    else if (value == "lbp")
 | 
						|
      Globals::liftedSolver = LiftedSolverType::lbpSolver;
 | 
						|
    else if (value == "lkc")
 | 
						|
      Globals::liftedSolver = LiftedSolverType::lkcSolver;
 | 
						|
    else
 | 
						|
      returnVal = invalidValue (option, value);
 | 
						|
 | 
						|
  } else if (option == "ground_solver" || option == "solver") {
 | 
						|
    if      (value == "hve")
 | 
						|
      Globals::groundSolver = GroundSolverType::veSolver;
 | 
						|
    else if (value == "bp")
 | 
						|
      Globals::groundSolver = GroundSolverType::bpSolver;
 | 
						|
    else if (value == "cbp")
 | 
						|
      Globals::groundSolver = GroundSolverType::CbpSolver;
 | 
						|
    else
 | 
						|
      returnVal = invalidValue (option, value);
 | 
						|
 | 
						|
  } else if (option == "verbosity") {
 | 
						|
    std::stringstream ss;
 | 
						|
    ss << value;
 | 
						|
    ss >> Globals::verbosity;
 | 
						|
 | 
						|
  } else if (option == "use_logarithms") {
 | 
						|
    if      (value == "true")  Globals::logDomain = true;
 | 
						|
    else if (value == "false") Globals::logDomain = false;
 | 
						|
    else                       returnVal = invalidValue (option, value);
 | 
						|
 | 
						|
  } else if (option == "hve_elim_heuristic") {
 | 
						|
    typedef ElimGraph::ElimHeuristic ElimHeuristic;
 | 
						|
    if      (value == "sequential")
 | 
						|
      ElimGraph::setElimHeuristic (ElimHeuristic::sequentialEh);
 | 
						|
    else if (value == "min_neighbors")
 | 
						|
      ElimGraph::setElimHeuristic (ElimHeuristic::minNeighborsEh);
 | 
						|
    else if (value == "min_weight")
 | 
						|
      ElimGraph::setElimHeuristic (ElimHeuristic::minWeightEh);
 | 
						|
    else if (value == "min_fill")
 | 
						|
      ElimGraph::setElimHeuristic (ElimHeuristic::minFillEh);
 | 
						|
    else if (value == "weighted_min_fill")
 | 
						|
      ElimGraph::setElimHeuristic (ElimHeuristic::weightedMinFillEh);
 | 
						|
    else
 | 
						|
      returnVal = invalidValue (option, value);
 | 
						|
 | 
						|
  } else if (option == "bp_msg_schedule") {
 | 
						|
    typedef BeliefProp::MsgSchedule MsgSchedule;
 | 
						|
    if      (value == "seq_fixed")
 | 
						|
      BeliefProp::setMsgSchedule (MsgSchedule::seqFixedSch);
 | 
						|
    else if (value == "seq_random")
 | 
						|
      BeliefProp::setMsgSchedule (MsgSchedule::seqRandomSch);
 | 
						|
    else if (value == "parallel")
 | 
						|
      BeliefProp::setMsgSchedule (MsgSchedule::parallelSch);
 | 
						|
    else if (value == "max_residual")
 | 
						|
      BeliefProp::setMsgSchedule (MsgSchedule::maxResidualSch);
 | 
						|
    else
 | 
						|
      returnVal = invalidValue (option, value);
 | 
						|
 | 
						|
  } else if (option == "bp_accuracy") {
 | 
						|
    std::stringstream ss;
 | 
						|
    double acc;
 | 
						|
    ss << value;
 | 
						|
    ss >> acc;
 | 
						|
    BeliefProp::setAccuracy (acc);
 | 
						|
 | 
						|
  } else if (option == "bp_max_iter") {
 | 
						|
    std::stringstream ss;
 | 
						|
    unsigned mi;
 | 
						|
    ss << value;
 | 
						|
    ss >> mi;
 | 
						|
    BeliefProp::setMaxIterations (mi);
 | 
						|
 | 
						|
  } else if (option == "export_libdai") {
 | 
						|
    if      (value == "true")  FactorGraph::enableExportToLibDai();
 | 
						|
    else if (value == "false") FactorGraph::disableExportToLibDai();
 | 
						|
    else                       returnVal = invalidValue (option, value);
 | 
						|
 | 
						|
  } else if (option == "export_uai") {
 | 
						|
    if      (value == "true")  FactorGraph::enableExportToUai();
 | 
						|
    else if (value == "false") FactorGraph::disableExportToUai();
 | 
						|
    else                       returnVal = invalidValue (option, value);
 | 
						|
 | 
						|
  } else if (option == "export_graphviz") {
 | 
						|
    if      (value == "true")  FactorGraph::enableExportToGraphViz();
 | 
						|
    else if (value == "false") FactorGraph::disableExportToGraphViz();
 | 
						|
    else                       returnVal = invalidValue (option, value);
 | 
						|
 | 
						|
  } else if (option == "print_fg") {
 | 
						|
    if      (value == "true")  FactorGraph::enablePrintFactorGraph();
 | 
						|
    else if (value == "false") FactorGraph::disablePrintFactorGraph();
 | 
						|
    else                       returnVal = invalidValue (option, value);
 | 
						|
 | 
						|
  } else {
 | 
						|
    std::cerr << "Warning: invalid option `" << option << "'" << std::endl;
 | 
						|
    returnVal = false;
 | 
						|
  }
 | 
						|
  return returnVal;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
printHeader (std::string header, std::ostream& os)
 | 
						|
{
 | 
						|
  printAsteriskLine (os);
 | 
						|
  os << header << std::endl;
 | 
						|
  printAsteriskLine (os);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
printSubHeader (std::string header, std::ostream& os)
 | 
						|
{
 | 
						|
  printDashedLine (os);
 | 
						|
  os << header << std::endl;
 | 
						|
  printDashedLine (os);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
printAsteriskLine (std::ostream& os)
 | 
						|
{
 | 
						|
  os << "********************************" ;
 | 
						|
  os << "********************************" ;
 | 
						|
  os << std::endl;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
printDashedLine (std::ostream& os)
 | 
						|
{
 | 
						|
  os << "--------------------------------" ;
 | 
						|
  os << "--------------------------------" ;
 | 
						|
  os << std::endl;
 | 
						|
}
 | 
						|
 | 
						|
}  // namespace Util
 | 
						|
 | 
						|
 | 
						|
 | 
						|
namespace LogAware {
 | 
						|
 | 
						|
void
 | 
						|
normalize (Params& v)
 | 
						|
{
 | 
						|
  if (Globals::logDomain) {
 | 
						|
    double sum = std::accumulate (v.begin(), v.end(),
 | 
						|
        LogAware::addIdenty(), Util::logSum);
 | 
						|
    assert (sum != -std::numeric_limits<double>::infinity());
 | 
						|
    v -= sum;
 | 
						|
  } else {
 | 
						|
    double sum = std::accumulate (v.begin(), v.end(), 0.0);
 | 
						|
    assert (sum != 0.0);
 | 
						|
    v /= sum;
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
double
 | 
						|
getL1Distance (const Params& v1, const Params& v2)
 | 
						|
{
 | 
						|
  assert (v1.size() == v2.size());
 | 
						|
  double dist = 0.0;
 | 
						|
  if (Globals::logDomain) {
 | 
						|
    dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
 | 
						|
        std::plus<double>(), FuncObj::abs_diff_exp<double>());
 | 
						|
  } else {
 | 
						|
    dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
 | 
						|
        std::plus<double>(), FuncObj::abs_diff<double>());
 | 
						|
  }
 | 
						|
  return dist;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
double
 | 
						|
getMaxNorm (const Params& v1, const Params& v2)
 | 
						|
{
 | 
						|
  assert (v1.size() == v2.size());
 | 
						|
  double max = 0.0;
 | 
						|
  if (Globals::logDomain) {
 | 
						|
    max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
 | 
						|
        FuncObj::max<double>(), FuncObj::abs_diff_exp<double>());
 | 
						|
  } else {
 | 
						|
    max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
 | 
						|
        FuncObj::max<double>(), FuncObj::abs_diff<double>());
 | 
						|
  }
 | 
						|
  return max;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
double
 | 
						|
pow (double base, unsigned iexp)
 | 
						|
{
 | 
						|
  return Globals::logDomain
 | 
						|
      ? base * iexp
 | 
						|
      : std::pow (base, iexp);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
double
 | 
						|
pow (double base, double exp)
 | 
						|
{
 | 
						|
  // `expoent' should not be in log domain
 | 
						|
  return Globals::logDomain
 | 
						|
      ? base * exp
 | 
						|
      : std::pow (base, exp);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
pow (Params& v, unsigned iexp)
 | 
						|
{
 | 
						|
  if (iexp == 1) {
 | 
						|
    return;
 | 
						|
  }
 | 
						|
  Globals::logDomain ? v *= iexp : v ^= (int)iexp;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
pow (Params& v, double exp)
 | 
						|
{
 | 
						|
  // `expoent' should not be in log domain
 | 
						|
  Globals::logDomain ? v *= exp : v ^= exp;
 | 
						|
}
 | 
						|
 | 
						|
}  // namespace LogAware
 | 
						|
 | 
						|
}  // namespace Horus
 | 
						|
 | 
						|
 |