some refactorings
This commit is contained in:
parent
444eaacc63
commit
acc5ab056a
@ -90,7 +90,7 @@ BpSolver::getPosterioriOf (VarId vid)
|
||||
probs += links[i]->getMessage();
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
Util::exp (probs);
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
probs *= links[i]->getMessage();
|
||||
@ -134,7 +134,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
res.normalize();
|
||||
Params jointDist = res.params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (jointDist);
|
||||
Util::exp (jointDist);
|
||||
}
|
||||
return jointDist;
|
||||
}
|
||||
|
@ -20,8 +20,8 @@ class SpLink
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
msgSended_ = false;
|
||||
|
@ -85,7 +85,7 @@ CbpSolver::getPosterioriOf (VarId vid)
|
||||
probs += l->poweredMessage();
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
Util::exp (probs);
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
|
@ -27,7 +27,7 @@ Factor::Factor (
|
||||
ranges_ = ranges;
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||
}
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ Factor::Factor (
|
||||
}
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
@ -38,7 +38,7 @@ class TFactor
|
||||
void setParams (const Params& newParams)
|
||||
{
|
||||
params_ = newParams;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||
}
|
||||
|
||||
int indexOf (const T& t) const
|
||||
|
@ -86,9 +86,9 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrParams;
|
||||
if (nrParams != Util::expectedSize (factorRanges[i])) {
|
||||
if (nrParams != Util::sizeExpected (factorRanges[i])) {
|
||||
cerr << "error: invalid number of parameters for factor nº " << i ;
|
||||
cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
|
||||
cerr << ", expected: " << Util::sizeExpected (factorRanges[i]);
|
||||
cerr << ", given: " << nrParams << endl;
|
||||
abort();
|
||||
}
|
||||
@ -97,7 +97,7 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
is >> params[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
Util::log (params);
|
||||
}
|
||||
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
|
||||
}
|
||||
@ -144,7 +144,7 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
ignoreLines (is);
|
||||
unsigned nNonzeros;
|
||||
is >> nNonzeros;
|
||||
Params params (Util::expectedSize (ranges), 0);
|
||||
Params params (Util::sizeExpected (ranges), 0);
|
||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||
ignoreLines (is);
|
||||
unsigned index;
|
||||
@ -155,7 +155,7 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
params[index] = val;
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
Util::log (params);
|
||||
}
|
||||
reverse (vids.begin(), vids.end());
|
||||
Factor f (vids, ranges, params);
|
||||
@ -338,7 +338,7 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
Params params = facNodes_[i]->factor().params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
Util::exp (params);
|
||||
}
|
||||
out << endl << params.size() << endl << " " ;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
@ -373,7 +373,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
out << endl;
|
||||
Params params = facNodes_[i]->factor().params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
Util::exp (params);
|
||||
}
|
||||
out << params.size() << endl;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
|
@ -644,7 +644,7 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
|
||||
(*pfList_.begin())->normalize();
|
||||
Params params = (*pfList_.begin())->params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
Util::exp (params);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
@ -268,7 +268,7 @@ readParameters (YAP_Term paramL)
|
||||
paramL = YAP_TailOfTerm (paramL);
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
Util::log (params);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ Parfactor::Parfactor (
|
||||
}
|
||||
}
|
||||
constr_ = new ConstraintTree (logVars, tuples);
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||
}
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
|
||||
ranges_ = g->ranges();
|
||||
distId_ = g->distId();
|
||||
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||
}
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
|
||||
ranges_ = g->ranges();
|
||||
distId_ = g->distId();
|
||||
constr_ = constr;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||
}
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ Parfactor::Parfactor (const Parfactor& g)
|
||||
ranges_ = g.ranges();
|
||||
distId_ = g.distId();
|
||||
constr_ = new ConstraintTree (*g.constr());
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
@ -76,21 +76,6 @@ stringToDouble (string str)
|
||||
|
||||
|
||||
|
||||
void
|
||||
toLog (Params& v)
|
||||
{
|
||||
transform (v.begin(), v.end(), v.begin(), ::log);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
fromLog (Params& v)
|
||||
{
|
||||
transform (v.begin(), v.end(), v.begin(), ::exp);
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
factorial (unsigned num)
|
||||
@ -146,7 +131,7 @@ nrCombinations (unsigned n, unsigned k)
|
||||
|
||||
|
||||
unsigned
|
||||
expectedSize (const Ranges& ranges)
|
||||
sizeExpected (const Ranges& ranges)
|
||||
{
|
||||
return std::accumulate (
|
||||
ranges.begin(), ranges.end(), 1, multiplies<unsigned>());
|
||||
@ -155,7 +140,7 @@ expectedSize (const Ranges& ranges)
|
||||
|
||||
|
||||
unsigned
|
||||
getNumberOfDigits (int num)
|
||||
nrDigits (int num)
|
||||
{
|
||||
unsigned count = 1;
|
||||
while (num >= 10) {
|
||||
|
@ -19,6 +19,106 @@
|
||||
using namespace std;
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator+=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind1st (plus<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator-=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind1st (minus<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator*=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind1st (multiplies<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator/=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind1st (divides<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator+=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
plus<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator-=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
minus<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator*=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
multiplies<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator/=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
divides<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator^=(std::vector<T>& v, double exp)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator^=(std::vector<T>& v, int iexp)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
|
||||
}
|
||||
|
||||
|
||||
|
||||
namespace {
|
||||
const double NEG_INF = -numeric_limits<double>::infinity();
|
||||
};
|
||||
|
||||
|
||||
|
||||
namespace Util {
|
||||
|
||||
template <typename T> void addToVector (vector<T>&, const vector<T>&);
|
||||
@ -40,6 +140,10 @@ template <typename T> std::string toString (const T&);
|
||||
|
||||
template <> std::string toString (const bool&);
|
||||
|
||||
template <typename T> void log (vector<T>&);
|
||||
|
||||
template <typename T> void exp (vector<T>&);
|
||||
|
||||
double logSum (double, double);
|
||||
|
||||
void add (Params&, const Params&, unsigned);
|
||||
@ -52,19 +156,15 @@ unsigned stringToUnsigned (string);
|
||||
|
||||
double stringToDouble (string);
|
||||
|
||||
void toLog (Params&);
|
||||
|
||||
void fromLog (Params&);
|
||||
|
||||
double factorial (unsigned);
|
||||
|
||||
double logFactorial (unsigned);
|
||||
|
||||
unsigned nrCombinations (unsigned, unsigned);
|
||||
|
||||
unsigned expectedSize (const Ranges&);
|
||||
unsigned sizeExpected (const Ranges&);
|
||||
|
||||
unsigned getNumberOfDigits (int);
|
||||
unsigned nrDigits (int);
|
||||
|
||||
bool isInteger (const string&);
|
||||
|
||||
@ -168,9 +268,20 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
|
||||
}
|
||||
|
||||
|
||||
namespace {
|
||||
const double NEG_INF = -numeric_limits<double>::infinity();
|
||||
};
|
||||
template <typename T> void
|
||||
Util::log (vector<T>& v)
|
||||
{
|
||||
transform (v.begin(), v.end(), v.begin(), ::log);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::exp (vector<T>& v)
|
||||
{
|
||||
transform (v.begin(), v.end(), v.begin(), ::exp);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline double
|
||||
@ -244,154 +355,16 @@ Util::maxUnsigned (void)
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator+=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind1st (plus<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator-=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind1st (minus<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator*=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind1st (multiplies<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator/=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind1st (divides<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator+=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
plus<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator-=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
minus<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator*=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
multiplies<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator/=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
divides<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator^=(std::vector<T>& v, double exp)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator^=(std::vector<T>& v, int iexp)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
|
||||
}
|
||||
|
||||
|
||||
|
||||
namespace LogAware {
|
||||
|
||||
inline double
|
||||
one()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
zero() {
|
||||
return Globals::logDomain ? NEG_INF : 0.0 ;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
addIdenty()
|
||||
{
|
||||
return Globals::logDomain ? NEG_INF : 0.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
multIdenty()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
withEvidence()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
noEvidence() {
|
||||
return Globals::logDomain ? NEG_INF : 0.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
tl (double v)
|
||||
{
|
||||
return Globals::logDomain ? log (v) : v;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
fl (double v)
|
||||
{
|
||||
return Globals::logDomain ? exp (v) : v;
|
||||
}
|
||||
inline double one() { return Globals::logDomain ? 0.0 : 1.0; }
|
||||
inline double zero() { return Globals::logDomain ? NEG_INF : 0.0; }
|
||||
inline double addIdenty() { return Globals::logDomain ? NEG_INF : 0.0; }
|
||||
inline double multIdenty() { return Globals::logDomain ? 0.0 : 1.0; }
|
||||
inline double withEvidence() { return Globals::logDomain ? 0.0 : 1.0; }
|
||||
inline double noEvidence() { return Globals::logDomain ? NEG_INF : 0.0; }
|
||||
inline double log (double v) { return Globals::logDomain ? ::log (v) : v; }
|
||||
inline double exp (double v) { return Globals::logDomain ? ::exp (v) : v; }
|
||||
|
||||
|
||||
void normalize (Params&);
|
||||
|
@ -33,7 +33,7 @@ VarElimSolver::solveQuery (VarIds queryVids)
|
||||
processFactorList (queryVids);
|
||||
Params params = factorList_.back()->params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
Util::exp (params);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
Reference in New Issue
Block a user