use STL to calculate L1 distance and max norm

This commit is contained in:
Tiago Gomes 2012-05-29 13:48:08 +01:00
parent 3ac854b2ff
commit 6feb746412
2 changed files with 49 additions and 31 deletions

View File

@ -321,23 +321,15 @@ namespace LogAware {
void
normalize (Params& v)
{
double sum = LogAware::addIdenty();
if (Globals::logDomain) {
for (size_t i = 0; i < v.size(); i++) {
sum = Util::logSum (sum, v[i]);
}
double sum = std::accumulate (v.begin(), v.end(),
LogAware::addIdenty(), Util::logSum);
assert (sum != -numeric_limits<double>::infinity());
for (size_t i = 0; i < v.size(); i++) {
v[i] -= sum;
}
v -= sum;
} else {
for (size_t i = 0; i < v.size(); i++) {
sum += v[i];
}
double sum = std::accumulate (v.begin(), v.end(), 0);
assert (sum != 0.0);
for (size_t i = 0; i < v.size(); i++) {
v[i] /= sum;
}
v /= sum;
}
}
@ -349,13 +341,11 @@ getL1Distance (const Params& v1, const Params& v2)
assert (v1.size() == v2.size());
double dist = 0.0;
if (Globals::logDomain) {
for (size_t i = 0; i < v1.size(); i++) {
dist += abs (exp(v1[i]) - exp(v2[i]));
}
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
std::plus<double>(), FuncObject::abs_diff_exp<double>());
} else {
for (size_t i = 0; i < v1.size(); i++) {
dist += abs (v1[i] - v2[i]);
}
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
std::plus<double>(), FuncObject::abs_diff<double>());
}
return dist;
}
@ -368,19 +358,11 @@ getMaxNorm (const Params& v1, const Params& v2)
assert (v1.size() == v2.size());
double max = 0.0;
if (Globals::logDomain) {
for (size_t i = 0; i < v1.size(); i++) {
double diff = abs (exp(v1[i]) - exp(v2[i]));
if (diff > max) {
max = diff;
}
}
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
FuncObject::max<double>(), FuncObject::abs_diff_exp<double>());
} else {
for (size_t i = 0; i < v1.size(); i++) {
double diff = abs (v1[i] - v2[i]);
if (diff > max) {
max = diff;
}
}
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
FuncObject::max<double>(), FuncObject::abs_diff<double>());
}
return max;
}

View File

@ -382,5 +382,41 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
return os;
}
namespace FuncObject {
template<typename T>
struct max : public std::binary_function<T, T, T>
{
T operator() (const T& x, const T& y) const
{
return x < y ? y : x;
}
};
template <typename T>
struct abs_diff : public std::binary_function<T, T, T>
{
T operator() (const T& x, const T& y) const
{
return std::abs (x - y);
}
};
template <typename T>
struct abs_diff_exp : public std::binary_function<T, T, T>
{
T operator() (const T& x, const T& y) const
{
return std::abs (std::exp (x) - std::exp (y));
}
};
}
#endif // HORUS_UTIL_H