This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/bp/Util.h

397 lines
6.9 KiB
C
Raw Normal View History

2012-03-22 11:33:24 +00:00
#ifndef HORUS_UTIL_H
#define HORUS_UTIL_H
2012-03-31 23:27:37 +01:00
#include <cmath>
#include <cassert>
#include <limits>
2012-04-05 18:38:56 +01:00
#include <algorithm>
2012-03-22 11:33:24 +00:00
#include <vector>
2012-03-31 23:27:37 +01:00
#include <set>
#include <queue>
#include <unordered_map>
#include <sstream>
#include <iostream>
2012-03-22 11:33:24 +00:00
#include "Horus.h"
using namespace std;
2012-03-31 23:27:37 +01:00
2012-03-22 11:33:24 +00:00
namespace Util {
2012-03-31 23:27:37 +01:00
template <typename T> void addToVector (vector<T>&, const vector<T>&);
2012-04-05 18:38:56 +01:00
template <typename T> void addToSet (set<T>&, const vector<T>&);
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
2012-03-31 23:27:37 +01:00
template <typename T> bool contains (const vector<T>&, const T&);
template <typename T> bool contains (const set<T>&, const T&);
template <typename K, typename V> bool contains (
const unordered_map<K, V>&, const K&);
template <typename T> std::string toString (const T&);
2012-04-13 15:21:11 +01:00
template <> std::string toString (const bool&);
unsigned stringToUnsigned (string);
double stringToDouble (string);
2012-03-31 23:27:37 +01:00
void toLog (Params&);
void fromLog (Params&);
double logSum (double, double);
void multiply (Params&, const Params&);
void multiply (Params&, const Params&, unsigned);
void add (Params&, const Params&);
void add (Params&, const Params&, unsigned);
unsigned maxUnsigned (void);
double factorial (unsigned);
2012-03-31 23:27:37 +01:00
double logFactorial (unsigned);
unsigned nrCombinations (unsigned, unsigned);
2012-03-31 23:27:37 +01:00
unsigned expectedSize (const Ranges&);
unsigned getNumberOfDigits (int);
bool isInteger (const string&);
string parametersToString (const Params&, unsigned = Constants::PRECISION);
2012-04-10 15:00:18 +01:00
vector<string> getStateLines (const Vars&);
2012-03-31 23:27:37 +01:00
bool setHorusFlag (string key, string value);
2012-03-31 23:27:37 +01:00
void printHeader (string, std::ostream& os = std::cout);
void printSubHeader (string, std::ostream& os = std::cout);
void printAsteriskLine (std::ostream& os = std::cout);
void printDashedLine (std::ostream& os = std::cout);
};
template <typename T> void
Util::addToVector (vector<T>& v, const vector<T>& elements)
{
v.insert (v.end(), elements.begin(), elements.end());
}
2012-04-05 18:38:56 +01:00
template <typename T> void
Util::addToSet (set<T>& s, const vector<T>& elements)
{
s.insert (elements.begin(), elements.end());
}
2012-03-31 23:27:37 +01:00
template <typename T> void
Util::addToQueue (queue<T>& q, const vector<T>& elements)
{
for (unsigned i = 0; i < elements.size(); i++) {
q.push (elements[i]);
}
}
template <typename T> bool
Util::contains (const vector<T>& v, const T& e)
{
return std::find (v.begin(), v.end(), e) != v.end();
}
template <typename T> bool
Util::contains (const set<T>& s, const T& e)
{
return s.find (e) != s.end();
}
template <typename K, typename V> bool
2012-04-05 18:38:56 +01:00
Util::contains (const unordered_map<K, V>& m, const K& k)
2012-03-31 23:27:37 +01:00
{
return m.find (k) != m.end();
}
template <typename T> std::string
Util::toString (const T& t)
2012-03-22 11:33:24 +00:00
{
std::stringstream ss;
ss << t;
return ss.str();
}
template <typename T>
std::ostream& operator << (std::ostream& os, const vector<T>& v)
{
os << "[" ;
for (unsigned i = 0; i < v.size(); i++) {
os << ((i != 0) ? ", " : "") << v[i];
}
os << "]" ;
return os;
}
2012-03-31 23:27:37 +01:00
namespace {
const double NEG_INF = -numeric_limits<double>::infinity();
2012-03-31 23:27:37 +01:00
};
2012-03-22 11:33:24 +00:00
2012-03-31 23:27:37 +01:00
inline double
Util::logSum (double x, double y)
2012-03-22 11:33:24 +00:00
{
// std::log (std::exp (x) + std::exp (y)) can overflow!
assert (std::isnan (x) == false);
assert (std::isnan (y) == false);
if (x == NEG_INF) {
2012-03-31 23:27:37 +01:00
return y;
2012-03-22 11:33:24 +00:00
}
if (y == NEG_INF) {
2012-03-31 23:27:37 +01:00
return x;
2012-03-22 11:33:24 +00:00
}
// if one value is much smaller than the other,
// keep the larger value
const double tol = 460.517; // log (1e200)
if (x < y - tol) {
return y;
}
if (y < x - tol) {
return x;
}
assert (std::isnan (x - y) == false);
const double exp_diff = std::exp (x - y);
if (std::isfinite (exp_diff) == false) {
2012-03-31 23:27:37 +01:00
// difference is too large
return x > y ? x : y;
2012-03-22 11:33:24 +00:00
}
// otherwise return the sum
return y + std::log (static_cast<double>(1.0) + exp_diff);
2012-03-22 11:33:24 +00:00
}
inline void
Util::multiply (Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
for (unsigned i = 0; i < v1.size(); i++) {
v1[i] *= v2[i];
}
}
inline void
Util::multiply (Params& v1, const Params& v2, unsigned repetitions)
{
for (unsigned count = 0; count < v1.size(); ) {
for (unsigned i = 0; i < v2.size(); i++) {
for (unsigned r = 0; r < repetitions; r++) {
v1[count] *= v2[i];
count ++;
}
}
}
}
inline void
Util::add (Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
for (unsigned i = 0; i < v1.size(); i++) {
v1[i] += v2[i];
}
}
inline void
Util::add (Params& v1, const Params& v2, unsigned repetitions)
{
for (unsigned count = 0; count < v1.size(); ) {
for (unsigned i = 0; i < v2.size(); i++) {
for (unsigned r = 0; r < repetitions; r++) {
v1[count] += v2[i];
count ++;
}
}
}
}
2012-03-31 23:27:37 +01:00
inline unsigned
Util::maxUnsigned (void)
2012-03-22 11:33:24 +00:00
{
2012-03-31 23:27:37 +01:00
return numeric_limits<unsigned>::max();
2012-03-22 11:33:24 +00:00
}
2012-03-31 23:27:37 +01:00
namespace LogAware {
2012-03-22 11:33:24 +00:00
inline double
2012-03-31 23:27:37 +01:00
one()
2012-03-22 11:33:24 +00:00
{
2012-03-31 23:27:37 +01:00
return Globals::logDomain ? 0.0 : 1.0;
2012-03-22 11:33:24 +00:00
}
2012-03-31 23:27:37 +01:00
2012-03-22 11:33:24 +00:00
inline double
2012-03-31 23:27:37 +01:00
zero() {
return Globals::logDomain ? NEG_INF : 0.0 ;
2012-03-22 11:33:24 +00:00
}
2012-03-31 23:27:37 +01:00
2012-03-22 11:33:24 +00:00
inline double
2012-03-31 23:27:37 +01:00
addIdenty()
2012-03-22 11:33:24 +00:00
{
return Globals::logDomain ? NEG_INF : 0.0;
2012-03-22 11:33:24 +00:00
}
2012-03-31 23:27:37 +01:00
inline double
multIdenty()
{
return Globals::logDomain ? 0.0 : 1.0;
}
2012-03-22 11:33:24 +00:00
inline double
2012-03-31 23:27:37 +01:00
withEvidence()
2012-03-22 11:33:24 +00:00
{
return Globals::logDomain ? 0.0 : 1.0;
}
2012-03-31 23:27:37 +01:00
2012-03-22 11:33:24 +00:00
inline double
2012-03-31 23:27:37 +01:00
noEvidence() {
return Globals::logDomain ? NEG_INF : 0.0;
2012-03-22 11:33:24 +00:00
}
2012-03-31 23:27:37 +01:00
2012-03-22 11:33:24 +00:00
inline double
2012-03-31 23:27:37 +01:00
tl (double v)
2012-03-22 11:33:24 +00:00
{
2012-03-31 23:27:37 +01:00
return Globals::logDomain ? log (v) : v;
2012-03-22 11:33:24 +00:00
}
2012-03-31 23:27:37 +01:00
2012-03-22 11:33:24 +00:00
inline double
2012-03-31 23:27:37 +01:00
fl (double v)
{
return Globals::logDomain ? exp (v) : v;
2012-03-22 11:33:24 +00:00
}
2012-03-31 23:27:37 +01:00
void normalize (Params&);
double getL1Distance (const Params&, const Params&);
double getMaxNorm (const Params&, const Params&);
double pow (double, unsigned);
double pow (double, double);
void pow (Params&, unsigned);
void pow (Params&, double);
};
2012-04-13 15:21:11 +01:00
2012-03-22 11:33:24 +00:00
struct NetInfo
{
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
{
this->size = size;
this->loopy = loopy;
this->nIters = nIters;
this->time = time;
}
unsigned size;
bool loopy;
unsigned nIters;
double time;
};
struct CompressInfo
{
CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e)
{
nrGroundVars = a;
nrGroundFactors = b;
nrClusterVars = c;
nrClusterFactors = d;
nrNeighborless = e;
2012-03-22 11:33:24 +00:00
}
unsigned nrGroundVars;
unsigned nrGroundFactors;
unsigned nrClusterVars;
unsigned nrClusterFactors;
unsigned nrNeighborless;
2012-03-22 11:33:24 +00:00
};
class Statistics
{
public:
static unsigned getSolvedNetworksCounting (void);
2012-03-31 23:27:37 +01:00
2012-03-22 11:33:24 +00:00
static void incrementPrimaryNetworksCounting (void);
2012-03-31 23:27:37 +01:00
2012-03-22 11:33:24 +00:00
static unsigned getPrimaryNetworksCounting (void);
2012-03-31 23:27:37 +01:00
2012-03-22 11:33:24 +00:00
static void updateStatistics (unsigned, bool, unsigned, double);
2012-03-31 23:27:37 +01:00
2012-03-22 11:33:24 +00:00
static void printStatistics (void);
2012-03-31 23:27:37 +01:00
static void writeStatistics (const char*);
2012-03-31 23:27:37 +01:00
2012-03-22 11:33:24 +00:00
static void updateCompressingStatistics (
unsigned, unsigned, unsigned, unsigned, unsigned);
private:
static string getStatisticString (void);
static vector<NetInfo> netInfo_;
static vector<CompressInfo> compressInfo_;
static unsigned primaryNetCount_;
};
#endif // HORUS_UTIL_H