some refactorings

This commit is contained in:
Tiago Gomes 2012-05-24 16:14:13 +01:00
parent 444eaacc63
commit acc5ab056a
12 changed files with 152 additions and 194 deletions

View File

@ -90,7 +90,7 @@ BpSolver::getPosterioriOf (VarId vid)
probs += links[i]->getMessage(); probs += links[i]->getMessage();
} }
LogAware::normalize (probs); LogAware::normalize (probs);
Util::fromLog (probs); Util::exp (probs);
} else { } else {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
probs *= links[i]->getMessage(); probs *= links[i]->getMessage();
@ -134,7 +134,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
res.normalize(); res.normalize();
Params jointDist = res.params(); Params jointDist = res.params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::fromLog (jointDist); Util::exp (jointDist);
} }
return jointDist; return jointDist;
} }

View File

@ -20,8 +20,8 @@ class SpLink
{ {
fac_ = fn; fac_ = fn;
var_ = vn; var_ = vn;
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range())); v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range())); v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
currMsg_ = &v1_; currMsg_ = &v1_;
nextMsg_ = &v2_; nextMsg_ = &v2_;
msgSended_ = false; msgSended_ = false;

View File

@ -85,7 +85,7 @@ CbpSolver::getPosterioriOf (VarId vid)
probs += l->poweredMessage(); probs += l->poweredMessage();
} }
LogAware::normalize (probs); LogAware::normalize (probs);
Util::fromLog (probs); Util::exp (probs);
} else { } else {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);

View File

@ -27,7 +27,7 @@ Factor::Factor (
ranges_ = ranges; ranges_ = ranges;
params_ = params; params_ = params;
distId_ = distId; distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }
@ -43,7 +43,7 @@ Factor::Factor (
} }
params_ = params; params_ = params;
distId_ = distId; distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }

View File

@ -38,7 +38,7 @@ class TFactor
void setParams (const Params& newParams) void setParams (const Params& newParams)
{ {
params_ = newParams; params_ = newParams;
assert (params_.size() == Util::expectedSize (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }
int indexOf (const T& t) const int indexOf (const T& t) const

View File

@ -86,9 +86,9 @@ FactorGraph::readFromUaiFormat (const char* fileName)
for (unsigned i = 0; i < nrFactors; i++) { for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is); ignoreLines (is);
is >> nrParams; is >> nrParams;
if (nrParams != Util::expectedSize (factorRanges[i])) { if (nrParams != Util::sizeExpected (factorRanges[i])) {
cerr << "error: invalid number of parameters for factor nº " << i ; cerr << "error: invalid number of parameters for factor nº " << i ;
cerr << ", expected: " << Util::expectedSize (factorRanges[i]); cerr << ", expected: " << Util::sizeExpected (factorRanges[i]);
cerr << ", given: " << nrParams << endl; cerr << ", given: " << nrParams << endl;
abort(); abort();
} }
@ -97,7 +97,7 @@ FactorGraph::readFromUaiFormat (const char* fileName)
is >> params[j]; is >> params[j];
} }
if (Globals::logDomain) { if (Globals::logDomain) {
Util::toLog (params); Util::log (params);
} }
addFactor (Factor (factorVarIds[i], factorRanges[i], params)); addFactor (Factor (factorVarIds[i], factorRanges[i], params));
} }
@ -144,7 +144,7 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
ignoreLines (is); ignoreLines (is);
unsigned nNonzeros; unsigned nNonzeros;
is >> nNonzeros; is >> nNonzeros;
Params params (Util::expectedSize (ranges), 0); Params params (Util::sizeExpected (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) { for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is); ignoreLines (is);
unsigned index; unsigned index;
@ -155,7 +155,7 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
params[index] = val; params[index] = val;
} }
if (Globals::logDomain) { if (Globals::logDomain) {
Util::toLog (params); Util::log (params);
} }
reverse (vids.begin(), vids.end()); reverse (vids.begin(), vids.end());
Factor f (vids, ranges, params); Factor f (vids, ranges, params);
@ -338,7 +338,7 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
for (unsigned i = 0; i < facNodes_.size(); i++) { for (unsigned i = 0; i < facNodes_.size(); i++) {
Params params = facNodes_[i]->factor().params(); Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::fromLog (params); Util::exp (params);
} }
out << endl << params.size() << endl << " " ; out << endl << params.size() << endl << " " ;
for (unsigned j = 0; j < params.size(); j++) { for (unsigned j = 0; j < params.size(); j++) {
@ -373,7 +373,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
out << endl; out << endl;
Params params = facNodes_[i]->factor().params(); Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::fromLog (params); Util::exp (params);
} }
out << params.size() << endl; out << params.size() << endl;
for (unsigned j = 0; j < params.size(); j++) { for (unsigned j = 0; j < params.size(); j++) {

View File

@ -644,7 +644,7 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
(*pfList_.begin())->normalize(); (*pfList_.begin())->normalize();
Params params = (*pfList_.begin())->params(); Params params = (*pfList_.begin())->params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::fromLog (params); Util::exp (params);
} }
return params; return params;
} }

View File

@ -268,7 +268,7 @@ readParameters (YAP_Term paramL)
paramL = YAP_TailOfTerm (paramL); paramL = YAP_TailOfTerm (paramL);
} }
if (Globals::logDomain) { if (Globals::logDomain) {
Util::toLog (params); Util::log (params);
} }
return params; return params;
} }

View File

@ -27,7 +27,7 @@ Parfactor::Parfactor (
} }
} }
constr_ = new ConstraintTree (logVars, tuples); constr_ = new ConstraintTree (logVars, tuples);
assert (params_.size() == Util::expectedSize (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }
@ -39,7 +39,7 @@ Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
ranges_ = g->ranges(); ranges_ = g->ranges();
distId_ = g->distId(); distId_ = g->distId();
constr_ = new ConstraintTree (g->logVars(), {tuple}); constr_ = new ConstraintTree (g->logVars(), {tuple});
assert (params_.size() == Util::expectedSize (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }
@ -51,7 +51,7 @@ Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
ranges_ = g->ranges(); ranges_ = g->ranges();
distId_ = g->distId(); distId_ = g->distId();
constr_ = constr; constr_ = constr;
assert (params_.size() == Util::expectedSize (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }
@ -63,7 +63,7 @@ Parfactor::Parfactor (const Parfactor& g)
ranges_ = g.ranges(); ranges_ = g.ranges();
distId_ = g.distId(); distId_ = g.distId();
constr_ = new ConstraintTree (*g.constr()); constr_ = new ConstraintTree (*g.constr());
assert (params_.size() == Util::expectedSize (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }

View File

@ -76,21 +76,6 @@ stringToDouble (string str)
void
toLog (Params& v)
{
transform (v.begin(), v.end(), v.begin(), ::log);
}
void
fromLog (Params& v)
{
transform (v.begin(), v.end(), v.begin(), ::exp);
}
double double
factorial (unsigned num) factorial (unsigned num)
@ -146,7 +131,7 @@ nrCombinations (unsigned n, unsigned k)
unsigned unsigned
expectedSize (const Ranges& ranges) sizeExpected (const Ranges& ranges)
{ {
return std::accumulate ( return std::accumulate (
ranges.begin(), ranges.end(), 1, multiplies<unsigned>()); ranges.begin(), ranges.end(), 1, multiplies<unsigned>());
@ -155,7 +140,7 @@ expectedSize (const Ranges& ranges)
unsigned unsigned
getNumberOfDigits (int num) nrDigits (int num)
{ {
unsigned count = 1; unsigned count = 1;
while (num >= 10) { while (num >= 10) {

View File

@ -19,6 +19,106 @@
using namespace std; using namespace std;
template <typename T>
void operator+=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (plus<double>(), val));
}
template <typename T>
void operator-=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (minus<double>(), val));
}
template <typename T>
void operator*=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (multiplies<double>(), val));
}
template <typename T>
void operator/=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (divides<double>(), val));
}
template <typename T>
void operator+=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
plus<double>());
}
template <typename T>
void operator-=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
minus<double>());
}
template <typename T>
void operator*=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
multiplies<double>());
}
template <typename T>
void operator/=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
divides<double>());
}
template <typename T>
void operator^=(std::vector<T>& v, double exp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
}
template <typename T>
void operator^=(std::vector<T>& v, int iexp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
}
namespace {
const double NEG_INF = -numeric_limits<double>::infinity();
};
namespace Util { namespace Util {
template <typename T> void addToVector (vector<T>&, const vector<T>&); template <typename T> void addToVector (vector<T>&, const vector<T>&);
@ -40,6 +140,10 @@ template <typename T> std::string toString (const T&);
template <> std::string toString (const bool&); template <> std::string toString (const bool&);
template <typename T> void log (vector<T>&);
template <typename T> void exp (vector<T>&);
double logSum (double, double); double logSum (double, double);
void add (Params&, const Params&, unsigned); void add (Params&, const Params&, unsigned);
@ -52,19 +156,15 @@ unsigned stringToUnsigned (string);
double stringToDouble (string); double stringToDouble (string);
void toLog (Params&);
void fromLog (Params&);
double factorial (unsigned); double factorial (unsigned);
double logFactorial (unsigned); double logFactorial (unsigned);
unsigned nrCombinations (unsigned, unsigned); unsigned nrCombinations (unsigned, unsigned);
unsigned expectedSize (const Ranges&); unsigned sizeExpected (const Ranges&);
unsigned getNumberOfDigits (int); unsigned nrDigits (int);
bool isInteger (const string&); bool isInteger (const string&);
@ -168,9 +268,20 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
} }
namespace { template <typename T> void
const double NEG_INF = -numeric_limits<double>::infinity(); Util::log (vector<T>& v)
}; {
transform (v.begin(), v.end(), v.begin(), ::log);
}
template <typename T> void
Util::exp (vector<T>& v)
{
transform (v.begin(), v.end(), v.begin(), ::exp);
}
inline double inline double
@ -244,154 +355,16 @@ Util::maxUnsigned (void)
template <typename T>
void operator+=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (plus<double>(), val));
}
template <typename T>
void operator-=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (minus<double>(), val));
}
template <typename T>
void operator*=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (multiplies<double>(), val));
}
template <typename T>
void operator/=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (divides<double>(), val));
}
template <typename T>
void operator+=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
plus<double>());
}
template <typename T>
void operator-=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
minus<double>());
}
template <typename T>
void operator*=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
multiplies<double>());
}
template <typename T>
void operator/=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
divides<double>());
}
template <typename T>
void operator^=(std::vector<T>& v, double exp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
}
template <typename T>
void operator^=(std::vector<T>& v, int iexp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
}
namespace LogAware { namespace LogAware {
inline double inline double one() { return Globals::logDomain ? 0.0 : 1.0; }
one() inline double zero() { return Globals::logDomain ? NEG_INF : 0.0; }
{ inline double addIdenty() { return Globals::logDomain ? NEG_INF : 0.0; }
return Globals::logDomain ? 0.0 : 1.0; inline double multIdenty() { return Globals::logDomain ? 0.0 : 1.0; }
} inline double withEvidence() { return Globals::logDomain ? 0.0 : 1.0; }
inline double noEvidence() { return Globals::logDomain ? NEG_INF : 0.0; }
inline double log (double v) { return Globals::logDomain ? ::log (v) : v; }
inline double inline double exp (double v) { return Globals::logDomain ? ::exp (v) : v; }
zero() {
return Globals::logDomain ? NEG_INF : 0.0 ;
}
inline double
addIdenty()
{
return Globals::logDomain ? NEG_INF : 0.0;
}
inline double
multIdenty()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
withEvidence()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
noEvidence() {
return Globals::logDomain ? NEG_INF : 0.0;
}
inline double
tl (double v)
{
return Globals::logDomain ? log (v) : v;
}
inline double
fl (double v)
{
return Globals::logDomain ? exp (v) : v;
}
void normalize (Params&); void normalize (Params&);

View File

@ -33,7 +33,7 @@ VarElimSolver::solveQuery (VarIds queryVids)
processFactorList (queryVids); processFactorList (queryVids);
Params params = factorList_.back()->params(); Params params = factorList_.back()->params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::fromLog (params); Util::exp (params);
} }
return params; return params;
} }