use STL to calculate L1 distance and max norm

This commit is contained in:
Tiago Gomes 2012-05-29 13:48:08 +01:00
parent 3ac854b2ff
commit 6feb746412
2 changed files with 49 additions and 31 deletions

View File

@ -321,23 +321,15 @@ namespace LogAware {
void void
normalize (Params& v) normalize (Params& v)
{ {
double sum = LogAware::addIdenty();
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < v.size(); i++) { double sum = std::accumulate (v.begin(), v.end(),
sum = Util::logSum (sum, v[i]); LogAware::addIdenty(), Util::logSum);
}
assert (sum != -numeric_limits<double>::infinity()); assert (sum != -numeric_limits<double>::infinity());
for (size_t i = 0; i < v.size(); i++) { v -= sum;
v[i] -= sum;
}
} else { } else {
for (size_t i = 0; i < v.size(); i++) { double sum = std::accumulate (v.begin(), v.end(), 0);
sum += v[i];
}
assert (sum != 0.0); assert (sum != 0.0);
for (size_t i = 0; i < v.size(); i++) { v /= sum;
v[i] /= sum;
}
} }
} }
@ -349,13 +341,11 @@ getL1Distance (const Params& v1, const Params& v2)
assert (v1.size() == v2.size()); assert (v1.size() == v2.size());
double dist = 0.0; double dist = 0.0;
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < v1.size(); i++) { dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
dist += abs (exp(v1[i]) - exp(v2[i])); std::plus<double>(), FuncObject::abs_diff_exp<double>());
}
} else { } else {
for (size_t i = 0; i < v1.size(); i++) { dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
dist += abs (v1[i] - v2[i]); std::plus<double>(), FuncObject::abs_diff<double>());
}
} }
return dist; return dist;
} }
@ -368,19 +358,11 @@ getMaxNorm (const Params& v1, const Params& v2)
assert (v1.size() == v2.size()); assert (v1.size() == v2.size());
double max = 0.0; double max = 0.0;
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < v1.size(); i++) { max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
double diff = abs (exp(v1[i]) - exp(v2[i])); FuncObject::max<double>(), FuncObject::abs_diff_exp<double>());
if (diff > max) {
max = diff;
}
}
} else { } else {
for (size_t i = 0; i < v1.size(); i++) { max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
double diff = abs (v1[i] - v2[i]); FuncObject::max<double>(), FuncObject::abs_diff<double>());
if (diff > max) {
max = diff;
}
}
} }
return max; return max;
} }

View File

@ -382,5 +382,41 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
return os; return os;
} }
namespace FuncObject {
template<typename T>
struct max : public std::binary_function<T, T, T>
{
T operator() (const T& x, const T& y) const
{
return x < y ? y : x;
}
};
template <typename T>
struct abs_diff : public std::binary_function<T, T, T>
{
T operator() (const T& x, const T& y) const
{
return std::abs (x - y);
}
};
template <typename T>
struct abs_diff_exp : public std::binary_function<T, T, T>
{
T operator() (const T& x, const T& y) const
{
return std::abs (std::exp (x) - std::exp (y));
}
};
}
#endif // HORUS_UTIL_H #endif // HORUS_UTIL_H