add operators to manipulate vectors of parameters

This commit is contained in:
Tiago Gomes 2012-05-24 14:55:30 +01:00
parent 6cb718942a
commit 444eaacc63
6 changed files with 138 additions and 109 deletions

View File

@ -87,13 +87,13 @@ BpSolver::getPosterioriOf (VarId vid)
const SpLinkSet& links = ninf(var)->getLinks(); const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) { if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
Util::add (probs, links[i]->getMessage()); probs += links[i]->getMessage();
} }
LogAware::normalize (probs); LogAware::normalize (probs);
Util::fromLog (probs); Util::fromLog (probs);
} else { } else {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
Util::multiply (probs, links[i]->getMessage()); probs *= links[i]->getMessage();
} }
LogAware::normalize (probs); LogAware::normalize (probs);
} }
@ -362,16 +362,16 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
if (Globals::logDomain) { if (Globals::logDomain) {
SpLinkSet::const_iterator it; SpLinkSet::const_iterator it;
for (it = links.begin(); it != links.end(); ++ it) { for (it = links.begin(); it != links.end(); ++ it) {
Util::add (msg, (*it)->getMessage()); msg += (*it)->getMessage();
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " x " << (*it)->getMessage(); cout << " x " << (*it)->getMessage();
} }
} }
Util::subtract (msg, link->getMessage()); msg -= link->getMessage();
} else { } else {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) { if (links[i]->getFactor() != dst) {
Util::multiply (msg, links[i]->getMessage()); msg *= links[i]->getMessage();
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " x " << links[i]->getMessage(); cout << " x " << links[i]->getMessage();
} }

View File

@ -82,14 +82,14 @@ CbpSolver::getPosterioriOf (VarId vid)
if (Globals::logDomain) { if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (probs, l->poweredMessage()); probs += l->poweredMessage();
} }
LogAware::normalize (probs); LogAware::normalize (probs);
Util::fromLog (probs); Util::fromLog (probs);
} else { } else {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (probs, l->poweredMessage()); probs *= l->poweredMessage();
} }
LogAware::normalize (probs); LogAware::normalize (probs);
} }
@ -330,14 +330,14 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) { if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
Util::add (msg, cl->poweredMessage()); msg += cl->poweredMessage();
} }
} }
} else { } else {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) { if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
Util::multiply (msg, cl->poweredMessage()); msg *= cl->poweredMessage();
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges(); cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
} }

View File

@ -79,11 +79,8 @@ class TFactor
if (args_ == g_args) { if (args_ == g_args) {
// optimization: if the factors contain the same set of args, // optimization: if the factors contain the same set of args,
// we can do a 1 to 1 operation on the parameters // we can do a 1 to 1 operation on the parameters
if (Globals::logDomain) { Globals::logDomain ? params_ += g_params
Util::add (params_, g_params); : params_ *= g_params;
} else {
Util::multiply (params_, g_params);
}
} else { } else {
bool sharedArgs = false; bool sharedArgs = false;
vector<unsigned> gvarpos; vector<unsigned> gvarpos;

View File

@ -79,9 +79,7 @@ stringToDouble (string str)
void void
toLog (Params& v) toLog (Params& v)
{ {
for (unsigned i = 0; i < v.size(); i++) { transform (v.begin(), v.end(), v.begin(), ::log);
v[i] = log (v[i]);
}
} }
@ -89,9 +87,7 @@ toLog (Params& v)
void void
fromLog (Params& v) fromLog (Params& v)
{ {
for (unsigned i = 0; i < v.size(); i++) { transform (v.begin(), v.end(), v.begin(), ::exp);
v[i] = exp (v[i]);
}
} }
@ -152,11 +148,8 @@ nrCombinations (unsigned n, unsigned k)
unsigned unsigned
expectedSize (const Ranges& ranges) expectedSize (const Ranges& ranges)
{ {
unsigned prod = 1; return std::accumulate (
for (unsigned i = 0; i < ranges.size(); i++) { ranges.begin(), ranges.end(), 1, multiplies<unsigned>());
prod *= ranges[i];
}
return prod;
} }
@ -410,55 +403,40 @@ getMaxNorm (const Params& v1, const Params& v2)
} }
double double
pow (double p, unsigned expoent) pow (double base, unsigned iexp)
{ {
return Globals::logDomain ? p * expoent : std::pow (p, expoent); return Globals::logDomain ? base * iexp : std::pow (base, iexp);
} }
double double
pow (double p, double expoent) pow (double base, double exp)
{ {
// assumes that `expoent' is never in log domain // assumes that `expoent' is never in log domain
return Globals::logDomain ? p * expoent : std::pow (p, expoent); return Globals::logDomain ? base * exp : std::pow (base, exp);
} }
void void
pow (Params& v, unsigned expoent) pow (Params& v, unsigned iexp)
{ {
if (expoent == 1) { if (iexp == 1) {
return; return;
} }
if (Globals::logDomain) { Globals::logDomain ? v *= iexp : v ^= (int)iexp;
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
} }
void void
pow (Params& v, double expoent) pow (Params& v, double exp)
{ {
// assumes that `expoent' is never in log domain // `expoent' should not be in log domain
if (Globals::logDomain) { Globals::logDomain ? v *= exp : v ^= exp;
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
} }
} }

View File

@ -40,6 +40,14 @@ template <typename T> std::string toString (const T&);
template <> std::string toString (const bool&); template <> std::string toString (const bool&);
double logSum (double, double);
void add (Params&, const Params&, unsigned);
void multiply (Params&, const Params&, unsigned);
unsigned maxUnsigned (void);
unsigned stringToUnsigned (string); unsigned stringToUnsigned (string);
double stringToDouble (string); double stringToDouble (string);
@ -48,20 +56,6 @@ void toLog (Params&);
void fromLog (Params&); void fromLog (Params&);
double logSum (double, double);
void multiply (Params&, const Params&);
void multiply (Params&, const Params&, unsigned);
void add (Params&, const Params&);
void subtract (Params&, const Params&);
void add (Params&, const Params&, unsigned);
unsigned maxUnsigned (void);
double factorial (unsigned); double factorial (unsigned);
double logFactorial (unsigned); double logFactorial (unsigned);
@ -147,10 +141,7 @@ Util::indexOf (const vector<T>& v, const T& e)
{ {
int pos = std::distance (v.begin(), int pos = std::distance (v.begin(),
std::find (v.begin(), v.end(), e)); std::find (v.begin(), v.end(), e));
if (pos == (int)v.size()) { return pos != (int)v.size() ? pos : -1;
pos = -1;
}
return pos;
} }
@ -216,11 +207,15 @@ Util::logSum (double x, double y)
inline void inline void
Util::multiply (Params& v1, const Params& v2) Util::add (Params& v1, const Params& v2, unsigned repetitions)
{ {
assert (v1.size() == v2.size()); for (unsigned count = 0; count < v1.size(); ) {
for (unsigned i = 0; i < v1.size(); i++) { for (unsigned i = 0; i < v2.size(); i++) {
v1[i] *= v2[i]; for (unsigned r = 0; r < repetitions; r++) {
v1[count] += v2[i];
count ++;
}
}
} }
} }
@ -241,41 +236,6 @@ Util::multiply (Params& v1, const Params& v2, unsigned repetitions)
inline void
Util::add (Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
std::transform (v1.begin(), v1.end(), v2.begin(),
v1.begin(), plus<double>());
}
inline void
Util::subtract (Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
std::transform (v1.begin(), v1.end(), v2.begin(),
v1.begin(), minus<double>());
}
inline void
Util::add (Params& v1, const Params& v2, unsigned repetitions)
{
for (unsigned count = 0; count < v1.size(); ) {
for (unsigned i = 0; i < v2.size(); i++) {
for (unsigned r = 0; r < repetitions; r++) {
v1[count] += v2[i];
count ++;
}
}
}
}
inline unsigned inline unsigned
Util::maxUnsigned (void) Util::maxUnsigned (void)
{ {
@ -284,6 +244,100 @@ Util::maxUnsigned (void)
template <typename T>
void operator+=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (plus<double>(), val));
}
template <typename T>
void operator-=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (minus<double>(), val));
}
template <typename T>
void operator*=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (multiplies<double>(), val));
}
template <typename T>
void operator/=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (divides<double>(), val));
}
template <typename T>
void operator+=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
plus<double>());
}
template <typename T>
void operator-=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
minus<double>());
}
template <typename T>
void operator*=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
multiplies<double>());
}
template <typename T>
void operator/=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
divides<double>());
}
template <typename T>
void operator^=(std::vector<T>& v, double exp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
}
template <typename T>
void operator^=(std::vector<T>& v, int iexp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
}
namespace LogAware { namespace LogAware {
inline double inline double