replace Util::add and Util::multiply by Util::apply_n_times

This commit is contained in:
Tiago Gomes 2012-05-28 21:09:56 +01:00
parent f489a59194
commit 54ae29ae02
3 changed files with 32 additions and 41 deletions

View File

@ -275,7 +275,8 @@ BpSolver::calcFactorToVarMsg (SpLink* link)
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
}
Util::add (msgProduct, getVarToFactorMsg (links[i]), reps);
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::plus<double>());
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
@ -289,7 +290,8 @@ BpSolver::calcFactorToVarMsg (SpLink* link)
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
}
Util::multiply (msgProduct, getVarToFactorMsg (links[i]), reps);
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::multiplies<double>());
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}

View File

@ -219,7 +219,8 @@ CbpSolver::calcFactorToVarMsg (SpLink* _link)
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
}
Util::add (msgProduct, getVarToFactorMsg (links[i]), reps);
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::plus<double>());
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
@ -234,7 +235,8 @@ CbpSolver::calcFactorToVarMsg (SpLink* _link)
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
}
Util::multiply (msgProduct, getVarToFactorMsg (links[i]), reps);
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::multiplies<double>());
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}

View File

@ -41,6 +41,9 @@ template <typename K, typename V> bool contains (
template <typename T> size_t indexOf (const vector<T>&, const T&);
template <class Operation>
void apply_n_times (Params& v1, const Params& v2, unsigned repetitions, Operation);
template <typename T> void log (vector<T>&);
template <typename T> void exp (vector<T>&);
@ -54,10 +57,6 @@ template <> std::string toString (const bool&);
double logSum (double, double);
void add (Params&, const Params&, unsigned);
void multiply (Params&, const Params&, unsigned);
unsigned maxUnsigned (void);
unsigned stringToUnsigned (string);
@ -153,10 +152,29 @@ Util::indexOf (const vector<T>& v, const T& e)
template <class Operation> void
Util::apply_n_times (Params& v1, const Params& v2, unsigned repetitions,
Operation unary_op)
{
Params::iterator first = v1.begin();
Params::const_iterator last = v1.end();
Params::const_iterator first2 = v2.begin();
Params::const_iterator last2 = v2.end();
while (first != last) {
for (first2 = v2.begin(); first2 != last2; ++first2) {
std::transform (first, first + repetitions, first,
std::bind1st (unary_op, *first2));
first += repetitions;
}
}
}
template <typename T> void
Util::log (vector<T>& v)
{
transform (v.begin(), v.end(), v.begin(), ::log);
std::transform (v.begin(), v.end(), v.begin(), ::log);
}
@ -164,7 +182,7 @@ Util::log (vector<T>& v)
template <typename T> void
Util::exp (vector<T>& v)
{
transform (v.begin(), v.end(), v.begin(), ::exp);
std::transform (v.begin(), v.end(), v.begin(), ::exp);
}
@ -224,36 +242,6 @@ Util::logSum (double x, double y)
inline void
Util::add (Params& v1, const Params& v2, unsigned repetitions)
{
for (size_t count = 0; count < v1.size(); ) {
for (size_t i = 0; i < v2.size(); i++) {
for (unsigned r = 0; r < repetitions; r++) {
v1[count] += v2[i];
count ++;
}
}
}
}
inline void
Util::multiply (Params& v1, const Params& v2, unsigned repetitions)
{
for (size_t count = 0; count < v1.size(); ) {
for (size_t i = 0; i < v2.size(); i++) {
for (unsigned r = 0; r < repetitions; r++) {
v1[count] *= v2[i];
count ++;
}
}
}
}
inline unsigned
Util::maxUnsigned (void)
{
@ -273,7 +261,6 @@ inline double noEvidence() { return Globals::logDomain ? NEG_INF : 0.0; }
inline double log (double v) { return Globals::logDomain ? ::log (v) : v; }
inline double exp (double v) { return Globals::logDomain ? ::exp (v) : v; }
void normalize (Params&);
double getL1Distance (const Params&, const Params&);