yay, my first commit

This commit is contained in:
Tiago Gomes
2012-03-31 23:27:37 +01:00
parent 88411f4b40
commit 313f9a036a
68 changed files with 81842 additions and 2916 deletions

View File

@@ -1,53 +1,131 @@
#ifndef HORUS_UTIL_H
#define HORUS_UTIL_H
#include <cmath>
#include <cassert>
#include <limits>
#include <vector>
#include <set>
#include <queue>
#include <unordered_map>
#include <sstream>
#include <iostream>
#include "Horus.h"
using namespace std;
namespace Util {
void toLog (Params&);
void fromLog (Params&);
void normalize (Params&);
void logSum (double&, double);
void multiply (Params&, const Params&);
void multiply (Params&, const Params&, unsigned);
void add (Params&, const Params&);
void add (Params&, const Params&, unsigned);
void pow (Params&, double);
void pow (Params&, unsigned);
double pow (double, unsigned);
double factorial (double);
unsigned nrCombinations (unsigned, unsigned);
double getL1Distance (const Params&, const Params&);
double getMaxNorm (const Params&, const Params&);
unsigned getNumberOfDigits (int);
bool isInteger (const string&);
string parametersToString (const Params&, unsigned = PRECISION);
vector<string> getJointStateStrings (const VarNodes&);
double tl (double);
double fl (double);
double multIdenty();
double addIdenty();
double withEvidence();
double noEvidence();
double one();
double zero();
template <typename T> void addToVector (vector<T>&, const vector<T>&);
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
template <typename T> bool contains (const vector<T>&, const T&);
template <typename T> bool contains (const set<T>&, const T&);
template <typename K, typename V> bool contains (
const unordered_map<K, V>&, const K&);
template <typename T> std::string toString (const T&);
void toLog (Params&);
void fromLog (Params&);
double logSum (double, double);
void multiply (Params&, const Params&);
void multiply (Params&, const Params&, unsigned);
void add (Params&, const Params&);
void add (Params&, const Params&, unsigned);
double factorial (double);
unsigned nrCombinations (unsigned, unsigned);
unsigned expectedSize (const Ranges&);
unsigned getNumberOfDigits (int);
bool isInteger (const string&);
string parametersToString (const Params&, unsigned = Constants::PRECISION);
vector<string> getJointStateStrings (const VarNodes&);
void printHeader (string, std::ostream& os = std::cout);
void printSubHeader (string, std::ostream& os = std::cout);
void printAsteriskLine (std::ostream& os = std::cout);
void printDashedLine (std::ostream& os = std::cout);
unsigned maxUnsigned (void);
};
template <class T>
std::string toString (const T& t)
template <typename T> void
Util::addToVector (vector<T>& v, const vector<T>& elements)
{
v.insert (v.end(), elements.begin(), elements.end());
}
template <typename T> void
Util::addToQueue (queue<T>& q, const vector<T>& elements)
{
for (unsigned i = 0; i < elements.size(); i++) {
q.push (elements[i]);
}
}
template <typename T> bool
Util::contains (const vector<T>& v, const T& e)
{
return std::find (v.begin(), v.end(), e) != v.end();
}
template <typename T> bool
Util::contains (const set<T>& s, const T& e)
{
return s.find (e) != s.end();
}
template <typename K, typename V> bool
Util::contains (
const unordered_map<K, V>& m, const K& k)
{
return m.find (k) != m.end();
}
template <typename T> std::string
Util::toString (const T& t)
{
std::stringstream ss;
ss << t;
return ss.str();
}
};
template <typename T>
@@ -62,28 +140,31 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
}
namespace {
const double INF = -numeric_limits<double>::infinity();
};
inline void
Util::logSum (double& x, double y)
inline double
Util::logSum (double x, double y)
{
x = log (exp (x) + exp (y)); return;
return log (exp (x) + exp (y));
assert (isfinite (x) && isfinite (y));
// If one value is much smaller than the other, keep the larger value.
if (x < (y - log (1e200))) {
x = y;
return;
return y;
}
if (y < (x - log (1e200))) {
return;
return x;
}
double diff = x - y;
assert (isfinite (diff) && isfinite (x) && isfinite (y));
if (!isfinite (exp (diff))) { // difference is too large
x = x > y ? x : y;
} else { // otherwise return the sum.
x = y + log (static_cast<double>(1.0) + exp (diff));
if (!isfinite (exp (diff))) {
// difference is too large
return x > y ? x : y;
}
// otherwise return the sum.
return y + log (static_cast<double>(1.0) + exp (diff));
}
@@ -140,52 +221,87 @@ Util::add (Params& v1, const Params& v2, unsigned repetitions)
inline double
Util::tl (double v)
inline unsigned
Util::maxUnsigned (void)
{
return Globals::logDomain ? log(v) : v;
return numeric_limits<unsigned>::max();
}
inline double
Util::fl (double v)
{
return Globals::logDomain ? exp(v) : v;
}
namespace LogAware {
inline double
Util::multIdenty() {
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
Util::addIdenty()
{
return Globals::logDomain ? INF : 0.0;
}
inline double
Util::withEvidence()
one()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
Util::noEvidence() {
return Globals::logDomain ? INF : 0.0;
}
inline double
Util::one()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
Util::zero() {
zero() {
return Globals::logDomain ? INF : 0.0 ;
}
inline double
addIdenty()
{
return Globals::logDomain ? INF : 0.0;
}
inline double
multIdenty()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
withEvidence()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
noEvidence() {
return Globals::logDomain ? INF : 0.0;
}
inline double
tl (double v)
{
return Globals::logDomain ? log (v) : v;
}
inline double
fl (double v)
{
return Globals::logDomain ? exp (v) : v;
}
void normalize (Params&);
double getL1Distance (const Params&, const Params&);
double getMaxNorm (const Params&, const Params&);
double pow (double, unsigned);
double pow (double, double);
void pow (Params&, unsigned);
void pow (Params&, double);
};
struct NetInfo
{
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
@@ -224,11 +340,17 @@ class Statistics
{
public:
static unsigned getSolvedNetworksCounting (void);
static void incrementPrimaryNetworksCounting (void);
static unsigned getPrimaryNetworksCounting (void);
static void updateStatistics (unsigned, bool, unsigned, double);
static void printStatistics (void);
static void writeStatisticsToFile (const char*);
static void updateCompressingStatistics (
unsigned, unsigned, unsigned, unsigned, unsigned);