some refactorings

This commit is contained in:
Tiago Gomes 2012-05-24 16:14:13 +01:00
parent 444eaacc63
commit acc5ab056a
12 changed files with 152 additions and 194 deletions

View File

@ -90,7 +90,7 @@ BpSolver::getPosterioriOf (VarId vid)
probs += links[i]->getMessage();
}
LogAware::normalize (probs);
Util::fromLog (probs);
Util::exp (probs);
} else {
for (unsigned i = 0; i < links.size(); i++) {
probs *= links[i]->getMessage();
@ -134,7 +134,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
res.normalize();
Params jointDist = res.params();
if (Globals::logDomain) {
Util::fromLog (jointDist);
Util::exp (jointDist);
}
return jointDist;
}

View File

@ -20,8 +20,8 @@ class SpLink
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
msgSended_ = false;

View File

@ -85,7 +85,7 @@ CbpSolver::getPosterioriOf (VarId vid)
probs += l->poweredMessage();
}
LogAware::normalize (probs);
Util::fromLog (probs);
Util::exp (probs);
} else {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);

View File

@ -27,7 +27,7 @@ Factor::Factor (
ranges_ = ranges;
params_ = params;
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
assert (params_.size() == Util::sizeExpected (ranges_));
}
@ -43,7 +43,7 @@ Factor::Factor (
}
params_ = params;
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
assert (params_.size() == Util::sizeExpected (ranges_));
}

View File

@ -38,7 +38,7 @@ class TFactor
void setParams (const Params& newParams)
{
params_ = newParams;
assert (params_.size() == Util::expectedSize (ranges_));
assert (params_.size() == Util::sizeExpected (ranges_));
}
int indexOf (const T& t) const

View File

@ -86,9 +86,9 @@ FactorGraph::readFromUaiFormat (const char* fileName)
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrParams;
if (nrParams != Util::expectedSize (factorRanges[i])) {
if (nrParams != Util::sizeExpected (factorRanges[i])) {
cerr << "error: invalid number of parameters for factor nº " << i ;
cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
cerr << ", expected: " << Util::sizeExpected (factorRanges[i]);
cerr << ", given: " << nrParams << endl;
abort();
}
@ -97,7 +97,7 @@ FactorGraph::readFromUaiFormat (const char* fileName)
is >> params[j];
}
if (Globals::logDomain) {
Util::toLog (params);
Util::log (params);
}
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
}
@ -144,7 +144,7 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
ignoreLines (is);
unsigned nNonzeros;
is >> nNonzeros;
Params params (Util::expectedSize (ranges), 0);
Params params (Util::sizeExpected (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
unsigned index;
@ -155,7 +155,7 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
params[index] = val;
}
if (Globals::logDomain) {
Util::toLog (params);
Util::log (params);
}
reverse (vids.begin(), vids.end());
Factor f (vids, ranges, params);
@ -338,7 +338,7 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
for (unsigned i = 0; i < facNodes_.size(); i++) {
Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) {
Util::fromLog (params);
Util::exp (params);
}
out << endl << params.size() << endl << " " ;
for (unsigned j = 0; j < params.size(); j++) {
@ -373,7 +373,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
out << endl;
Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) {
Util::fromLog (params);
Util::exp (params);
}
out << params.size() << endl;
for (unsigned j = 0; j < params.size(); j++) {

View File

@ -644,7 +644,7 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
(*pfList_.begin())->normalize();
Params params = (*pfList_.begin())->params();
if (Globals::logDomain) {
Util::fromLog (params);
Util::exp (params);
}
return params;
}

View File

@ -268,7 +268,7 @@ readParameters (YAP_Term paramL)
paramL = YAP_TailOfTerm (paramL);
}
if (Globals::logDomain) {
Util::toLog (params);
Util::log (params);
}
return params;
}

View File

@ -27,7 +27,7 @@ Parfactor::Parfactor (
}
}
constr_ = new ConstraintTree (logVars, tuples);
assert (params_.size() == Util::expectedSize (ranges_));
assert (params_.size() == Util::sizeExpected (ranges_));
}
@ -39,7 +39,7 @@ Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
ranges_ = g->ranges();
distId_ = g->distId();
constr_ = new ConstraintTree (g->logVars(), {tuple});
assert (params_.size() == Util::expectedSize (ranges_));
assert (params_.size() == Util::sizeExpected (ranges_));
}
@ -51,7 +51,7 @@ Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
ranges_ = g->ranges();
distId_ = g->distId();
constr_ = constr;
assert (params_.size() == Util::expectedSize (ranges_));
assert (params_.size() == Util::sizeExpected (ranges_));
}
@ -63,7 +63,7 @@ Parfactor::Parfactor (const Parfactor& g)
ranges_ = g.ranges();
distId_ = g.distId();
constr_ = new ConstraintTree (*g.constr());
assert (params_.size() == Util::expectedSize (ranges_));
assert (params_.size() == Util::sizeExpected (ranges_));
}

View File

@ -76,21 +76,6 @@ stringToDouble (string str)
void
toLog (Params& v)
{
transform (v.begin(), v.end(), v.begin(), ::log);
}
void
fromLog (Params& v)
{
transform (v.begin(), v.end(), v.begin(), ::exp);
}
double
factorial (unsigned num)
@ -146,7 +131,7 @@ nrCombinations (unsigned n, unsigned k)
unsigned
expectedSize (const Ranges& ranges)
sizeExpected (const Ranges& ranges)
{
return std::accumulate (
ranges.begin(), ranges.end(), 1, multiplies<unsigned>());
@ -155,7 +140,7 @@ expectedSize (const Ranges& ranges)
unsigned
getNumberOfDigits (int num)
nrDigits (int num)
{
unsigned count = 1;
while (num >= 10) {

View File

@ -19,6 +19,106 @@
using namespace std;
template <typename T>
void operator+=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (plus<double>(), val));
}
template <typename T>
void operator-=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (minus<double>(), val));
}
template <typename T>
void operator*=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (multiplies<double>(), val));
}
template <typename T>
void operator/=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (divides<double>(), val));
}
template <typename T>
void operator+=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
plus<double>());
}
template <typename T>
void operator-=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
minus<double>());
}
template <typename T>
void operator*=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
multiplies<double>());
}
template <typename T>
void operator/=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
divides<double>());
}
template <typename T>
void operator^=(std::vector<T>& v, double exp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
}
template <typename T>
void operator^=(std::vector<T>& v, int iexp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
}
namespace {
const double NEG_INF = -numeric_limits<double>::infinity();
};
namespace Util {
template <typename T> void addToVector (vector<T>&, const vector<T>&);
@ -40,6 +140,10 @@ template <typename T> std::string toString (const T&);
template <> std::string toString (const bool&);
template <typename T> void log (vector<T>&);
template <typename T> void exp (vector<T>&);
double logSum (double, double);
void add (Params&, const Params&, unsigned);
@ -52,19 +156,15 @@ unsigned stringToUnsigned (string);
double stringToDouble (string);
void toLog (Params&);
void fromLog (Params&);
double factorial (unsigned);
double logFactorial (unsigned);
unsigned nrCombinations (unsigned, unsigned);
unsigned expectedSize (const Ranges&);
unsigned sizeExpected (const Ranges&);
unsigned getNumberOfDigits (int);
unsigned nrDigits (int);
bool isInteger (const string&);
@ -168,9 +268,20 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
}
namespace {
const double NEG_INF = -numeric_limits<double>::infinity();
};
template <typename T> void
Util::log (vector<T>& v)
{
transform (v.begin(), v.end(), v.begin(), ::log);
}
template <typename T> void
Util::exp (vector<T>& v)
{
transform (v.begin(), v.end(), v.begin(), ::exp);
}
inline double
@ -244,154 +355,16 @@ Util::maxUnsigned (void)
template <typename T>
void operator+=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (plus<double>(), val));
}
template <typename T>
void operator-=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (minus<double>(), val));
}
template <typename T>
void operator*=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (multiplies<double>(), val));
}
template <typename T>
void operator/=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (divides<double>(), val));
}
template <typename T>
void operator+=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
plus<double>());
}
template <typename T>
void operator-=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
minus<double>());
}
template <typename T>
void operator*=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
multiplies<double>());
}
template <typename T>
void operator/=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
divides<double>());
}
template <typename T>
void operator^=(std::vector<T>& v, double exp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
}
template <typename T>
void operator^=(std::vector<T>& v, int iexp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
}
namespace LogAware {
inline double
one()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
zero() {
return Globals::logDomain ? NEG_INF : 0.0 ;
}
inline double
addIdenty()
{
return Globals::logDomain ? NEG_INF : 0.0;
}
inline double
multIdenty()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
withEvidence()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
noEvidence() {
return Globals::logDomain ? NEG_INF : 0.0;
}
inline double
tl (double v)
{
return Globals::logDomain ? log (v) : v;
}
inline double
fl (double v)
{
return Globals::logDomain ? exp (v) : v;
}
inline double one() { return Globals::logDomain ? 0.0 : 1.0; }
inline double zero() { return Globals::logDomain ? NEG_INF : 0.0; }
inline double addIdenty() { return Globals::logDomain ? NEG_INF : 0.0; }
inline double multIdenty() { return Globals::logDomain ? 0.0 : 1.0; }
inline double withEvidence() { return Globals::logDomain ? 0.0 : 1.0; }
inline double noEvidence() { return Globals::logDomain ? NEG_INF : 0.0; }
inline double log (double v) { return Globals::logDomain ? ::log (v) : v; }
inline double exp (double v) { return Globals::logDomain ? ::exp (v) : v; }
void normalize (Params&);

View File

@ -33,7 +33,7 @@ VarElimSolver::solveQuery (VarIds queryVids)
processFactorList (queryVids);
Params params = factorList_.back()->params();
if (Globals::logDomain) {
Util::fromLog (params);
Util::exp (params);
}
return params;
}