diff --git a/packages/CLPBN/horus/BpSolver.cpp b/packages/CLPBN/horus/BpSolver.cpp index 5296e50e3..93d6d6e50 100644 --- a/packages/CLPBN/horus/BpSolver.cpp +++ b/packages/CLPBN/horus/BpSolver.cpp @@ -87,13 +87,13 @@ BpSolver::getPosterioriOf (VarId vid) const SpLinkSet& links = ninf(var)->getLinks(); if (Globals::logDomain) { for (unsigned i = 0; i < links.size(); i++) { - Util::add (probs, links[i]->getMessage()); + probs += links[i]->getMessage(); } LogAware::normalize (probs); Util::fromLog (probs); } else { for (unsigned i = 0; i < links.size(); i++) { - Util::multiply (probs, links[i]->getMessage()); + probs *= links[i]->getMessage(); } LogAware::normalize (probs); } @@ -362,16 +362,16 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const if (Globals::logDomain) { SpLinkSet::const_iterator it; for (it = links.begin(); it != links.end(); ++ it) { - Util::add (msg, (*it)->getMessage()); + msg += (*it)->getMessage(); if (Constants::SHOW_BP_CALCS) { cout << " x " << (*it)->getMessage(); } } - Util::subtract (msg, link->getMessage()); + msg -= link->getMessage(); } else { for (unsigned i = 0; i < links.size(); i++) { if (links[i]->getFactor() != dst) { - Util::multiply (msg, links[i]->getMessage()); + msg *= links[i]->getMessage(); if (Constants::SHOW_BP_CALCS) { cout << " x " << links[i]->getMessage(); } diff --git a/packages/CLPBN/horus/CbpSolver.cpp b/packages/CLPBN/horus/CbpSolver.cpp index 6b38021ec..a098af51b 100644 --- a/packages/CLPBN/horus/CbpSolver.cpp +++ b/packages/CLPBN/horus/CbpSolver.cpp @@ -82,14 +82,14 @@ CbpSolver::getPosterioriOf (VarId vid) if (Globals::logDomain) { for (unsigned i = 0; i < links.size(); i++) { CbpSolverLink* l = static_cast (links[i]); - Util::add (probs, l->poweredMessage()); + probs += l->poweredMessage(); } LogAware::normalize (probs); Util::fromLog (probs); } else { for (unsigned i = 0; i < links.size(); i++) { CbpSolverLink* l = static_cast (links[i]); - Util::multiply (probs, l->poweredMessage()); + probs *= l->poweredMessage(); } LogAware::normalize (probs); } @@ -330,14 +330,14 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const CbpSolverLink* cl = static_cast (links[i]); if ( ! (cl->getFactor() == dst && cl->index() == link->index())) { CbpSolverLink* cl = static_cast (links[i]); - Util::add (msg, cl->poweredMessage()); + msg += cl->poweredMessage(); } } } else { for (unsigned i = 0; i < links.size(); i++) { CbpSolverLink* cl = static_cast (links[i]); if ( ! (cl->getFactor() == dst && cl->index() == link->index())) { - Util::multiply (msg, cl->poweredMessage()); + msg *= cl->poweredMessage(); if (Constants::SHOW_BP_CALCS) { cout << " x " << cl->getNextMessage() << "^" << link->nrEdges(); } diff --git a/packages/CLPBN/horus/Factor.h b/packages/CLPBN/horus/Factor.h index aa04fe32f..df45810e2 100644 --- a/packages/CLPBN/horus/Factor.h +++ b/packages/CLPBN/horus/Factor.h @@ -79,11 +79,8 @@ class TFactor if (args_ == g_args) { // optimization: if the factors contain the same set of args, // we can do a 1 to 1 operation on the parameters - if (Globals::logDomain) { - Util::add (params_, g_params); - } else { - Util::multiply (params_, g_params); - } + Globals::logDomain ? params_ += g_params + : params_ *= g_params; } else { bool sharedArgs = false; vector gvarpos; diff --git a/packages/CLPBN/horus/HorusYap.cpp b/packages/CLPBN/horus/HorusYap.cpp index e373f0445..58a820ba9 100644 --- a/packages/CLPBN/horus/HorusYap.cpp +++ b/packages/CLPBN/horus/HorusYap.cpp @@ -572,7 +572,7 @@ extern "C" void init_predicates (void) { YAP_UserCPredicate ("cpp_create_lifted_network", createLiftedNetwork, 3); - YAP_UserCPredicate ("cpp_create_ground_network", createGroundNetwork, 4); + YAP_UserCPredicate ("cpp_create_ground_network", createGroundNetwork, 4); YAP_UserCPredicate ("cpp_run_lifted_solver", runLiftedSolver, 3); YAP_UserCPredicate ("cpp_run_ground_solver", runGroundSolver, 3); YAP_UserCPredicate ("cpp_set_parfactors_params", setParfactorsParams, 2); diff --git a/packages/CLPBN/horus/Util.cpp b/packages/CLPBN/horus/Util.cpp index 083103ff8..446cc75ca 100644 --- a/packages/CLPBN/horus/Util.cpp +++ b/packages/CLPBN/horus/Util.cpp @@ -79,9 +79,7 @@ stringToDouble (string str) void toLog (Params& v) { - for (unsigned i = 0; i < v.size(); i++) { - v[i] = log (v[i]); - } + transform (v.begin(), v.end(), v.begin(), ::log); } @@ -89,9 +87,7 @@ toLog (Params& v) void fromLog (Params& v) { - for (unsigned i = 0; i < v.size(); i++) { - v[i] = exp (v[i]); - } + transform (v.begin(), v.end(), v.begin(), ::exp); } @@ -152,11 +148,8 @@ nrCombinations (unsigned n, unsigned k) unsigned expectedSize (const Ranges& ranges) { - unsigned prod = 1; - for (unsigned i = 0; i < ranges.size(); i++) { - prod *= ranges[i]; - } - return prod; + return std::accumulate ( + ranges.begin(), ranges.end(), 1, multiplies()); } @@ -410,55 +403,40 @@ getMaxNorm (const Params& v1, const Params& v2) } + double -pow (double p, unsigned expoent) +pow (double base, unsigned iexp) { - return Globals::logDomain ? p * expoent : std::pow (p, expoent); + return Globals::logDomain ? base * iexp : std::pow (base, iexp); } double -pow (double p, double expoent) +pow (double base, double exp) { // assumes that `expoent' is never in log domain - return Globals::logDomain ? p * expoent : std::pow (p, expoent); + return Globals::logDomain ? base * exp : std::pow (base, exp); } void -pow (Params& v, unsigned expoent) +pow (Params& v, unsigned iexp) { - if (expoent == 1) { + if (iexp == 1) { return; } - if (Globals::logDomain) { - for (unsigned i = 0; i < v.size(); i++) { - v[i] *= expoent; - } - } else { - for (unsigned i = 0; i < v.size(); i++) { - v[i] = std::pow (v[i], expoent); - } - } + Globals::logDomain ? v *= iexp : v ^= (int)iexp; } void -pow (Params& v, double expoent) +pow (Params& v, double exp) { - // assumes that `expoent' is never in log domain - if (Globals::logDomain) { - for (unsigned i = 0; i < v.size(); i++) { - v[i] *= expoent; - } - } else { - for (unsigned i = 0; i < v.size(); i++) { - v[i] = std::pow (v[i], expoent); - } - } + // `expoent' should not be in log domain + Globals::logDomain ? v *= exp : v ^= exp; } } diff --git a/packages/CLPBN/horus/Util.h b/packages/CLPBN/horus/Util.h index 5b9304d2a..4f748d3e3 100644 --- a/packages/CLPBN/horus/Util.h +++ b/packages/CLPBN/horus/Util.h @@ -40,6 +40,14 @@ template std::string toString (const T&); template <> std::string toString (const bool&); +double logSum (double, double); + +void add (Params&, const Params&, unsigned); + +void multiply (Params&, const Params&, unsigned); + +unsigned maxUnsigned (void); + unsigned stringToUnsigned (string); double stringToDouble (string); @@ -48,20 +56,6 @@ void toLog (Params&); void fromLog (Params&); -double logSum (double, double); - -void multiply (Params&, const Params&); - -void multiply (Params&, const Params&, unsigned); - -void add (Params&, const Params&); - -void subtract (Params&, const Params&); - -void add (Params&, const Params&, unsigned); - -unsigned maxUnsigned (void); - double factorial (unsigned); double logFactorial (unsigned); @@ -147,10 +141,7 @@ Util::indexOf (const vector& v, const T& e) { int pos = std::distance (v.begin(), std::find (v.begin(), v.end(), e)); - if (pos == (int)v.size()) { - pos = -1; - } - return pos; + return pos != (int)v.size() ? pos : -1; } @@ -216,11 +207,15 @@ Util::logSum (double x, double y) inline void -Util::multiply (Params& v1, const Params& v2) +Util::add (Params& v1, const Params& v2, unsigned repetitions) { - assert (v1.size() == v2.size()); - for (unsigned i = 0; i < v1.size(); i++) { - v1[i] *= v2[i]; + for (unsigned count = 0; count < v1.size(); ) { + for (unsigned i = 0; i < v2.size(); i++) { + for (unsigned r = 0; r < repetitions; r++) { + v1[count] += v2[i]; + count ++; + } + } } } @@ -241,41 +236,6 @@ Util::multiply (Params& v1, const Params& v2, unsigned repetitions) -inline void -Util::add (Params& v1, const Params& v2) -{ - assert (v1.size() == v2.size()); - std::transform (v1.begin(), v1.end(), v2.begin(), - v1.begin(), plus()); -} - - - -inline void -Util::subtract (Params& v1, const Params& v2) -{ - assert (v1.size() == v2.size()); - std::transform (v1.begin(), v1.end(), v2.begin(), - v1.begin(), minus()); -} - - - -inline void -Util::add (Params& v1, const Params& v2, unsigned repetitions) -{ - for (unsigned count = 0; count < v1.size(); ) { - for (unsigned i = 0; i < v2.size(); i++) { - for (unsigned r = 0; r < repetitions; r++) { - v1[count] += v2[i]; - count ++; - } - } - } -} - - - inline unsigned Util::maxUnsigned (void) { @@ -284,6 +244,100 @@ Util::maxUnsigned (void) +template +void operator+=(std::vector& v, double val) +{ + std::transform (v.begin(), v.end(), v.begin(), + std::bind1st (plus(), val)); +} + + + +template +void operator-=(std::vector& v, double val) +{ + std::transform (v.begin(), v.end(), v.begin(), + std::bind1st (minus(), val)); +} + + + +template +void operator*=(std::vector& v, double val) +{ + std::transform (v.begin(), v.end(), v.begin(), + std::bind1st (multiplies(), val)); +} + + + +template +void operator/=(std::vector& v, double val) +{ + std::transform (v.begin(), v.end(), v.begin(), + std::bind1st (divides(), val)); +} + + + +template +void operator+=(std::vector& a, const std::vector& b) +{ + assert (a.size() == b.size()); + std::transform (a.begin(), a.end(), b.begin(), a.begin(), + plus()); +} + + + +template +void operator-=(std::vector& a, const std::vector& b) +{ + assert (a.size() == b.size()); + std::transform (a.begin(), a.end(), b.begin(), a.begin(), + minus()); +} + + + +template +void operator*=(std::vector& a, const std::vector& b) +{ + assert (a.size() == b.size()); + std::transform (a.begin(), a.end(), b.begin(), a.begin(), + multiplies()); +} + + + +template +void operator/=(std::vector& a, const std::vector& b) +{ + assert (a.size() == b.size()); + std::transform (a.begin(), a.end(), b.begin(), a.begin(), + divides()); +} + + + +template +void operator^=(std::vector& v, double exp) +{ + std::transform (v.begin(), v.end(), v.begin(), + std::bind2nd (ptr_fun (std::pow), exp)); +} + + + +template +void operator^=(std::vector& v, int iexp) +{ + std::transform (v.begin(), v.end(), v.begin(), + std::bind2nd (ptr_fun (std::pow), iexp)); +} + + + namespace LogAware { inline double