replace Util::add and Util::multiply by Util::apply_n_times
This commit is contained in:
parent
f489a59194
commit
54ae29ae02
@ -275,7 +275,8 @@ BpSolver::calcFactorToVarMsg (SpLink* link)
|
|||||||
cout << " message from " << links[i]->varNode()->label();
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
cout << ": " ;
|
cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::add (msgProduct, getVarToFactorMsg (links[i]), reps);
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
|
reps, std::plus<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
@ -289,7 +290,8 @@ BpSolver::calcFactorToVarMsg (SpLink* link)
|
|||||||
cout << " message from " << links[i]->varNode()->label();
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
cout << ": " ;
|
cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::multiply (msgProduct, getVarToFactorMsg (links[i]), reps);
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
|
reps, std::multiplies<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
@ -219,7 +219,8 @@ CbpSolver::calcFactorToVarMsg (SpLink* _link)
|
|||||||
cout << " message from " << links[i]->varNode()->label();
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
cout << ": " ;
|
cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::add (msgProduct, getVarToFactorMsg (links[i]), reps);
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
|
reps, std::plus<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
@ -234,7 +235,8 @@ CbpSolver::calcFactorToVarMsg (SpLink* _link)
|
|||||||
cout << " message from " << links[i]->varNode()->label();
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
cout << ": " ;
|
cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::multiply (msgProduct, getVarToFactorMsg (links[i]), reps);
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
|
reps, std::multiplies<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
@ -41,6 +41,9 @@ template <typename K, typename V> bool contains (
|
|||||||
|
|
||||||
template <typename T> size_t indexOf (const vector<T>&, const T&);
|
template <typename T> size_t indexOf (const vector<T>&, const T&);
|
||||||
|
|
||||||
|
template <class Operation>
|
||||||
|
void apply_n_times (Params& v1, const Params& v2, unsigned repetitions, Operation);
|
||||||
|
|
||||||
template <typename T> void log (vector<T>&);
|
template <typename T> void log (vector<T>&);
|
||||||
|
|
||||||
template <typename T> void exp (vector<T>&);
|
template <typename T> void exp (vector<T>&);
|
||||||
@ -54,10 +57,6 @@ template <> std::string toString (const bool&);
|
|||||||
|
|
||||||
double logSum (double, double);
|
double logSum (double, double);
|
||||||
|
|
||||||
void add (Params&, const Params&, unsigned);
|
|
||||||
|
|
||||||
void multiply (Params&, const Params&, unsigned);
|
|
||||||
|
|
||||||
unsigned maxUnsigned (void);
|
unsigned maxUnsigned (void);
|
||||||
|
|
||||||
unsigned stringToUnsigned (string);
|
unsigned stringToUnsigned (string);
|
||||||
@ -153,10 +152,29 @@ Util::indexOf (const vector<T>& v, const T& e)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <class Operation> void
|
||||||
|
Util::apply_n_times (Params& v1, const Params& v2, unsigned repetitions,
|
||||||
|
Operation unary_op)
|
||||||
|
{
|
||||||
|
Params::iterator first = v1.begin();
|
||||||
|
Params::const_iterator last = v1.end();
|
||||||
|
Params::const_iterator first2 = v2.begin();
|
||||||
|
Params::const_iterator last2 = v2.end();
|
||||||
|
while (first != last) {
|
||||||
|
for (first2 = v2.begin(); first2 != last2; ++first2) {
|
||||||
|
std::transform (first, first + repetitions, first,
|
||||||
|
std::bind1st (unary_op, *first2));
|
||||||
|
first += repetitions;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T> void
|
template <typename T> void
|
||||||
Util::log (vector<T>& v)
|
Util::log (vector<T>& v)
|
||||||
{
|
{
|
||||||
transform (v.begin(), v.end(), v.begin(), ::log);
|
std::transform (v.begin(), v.end(), v.begin(), ::log);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -164,7 +182,7 @@ Util::log (vector<T>& v)
|
|||||||
template <typename T> void
|
template <typename T> void
|
||||||
Util::exp (vector<T>& v)
|
Util::exp (vector<T>& v)
|
||||||
{
|
{
|
||||||
transform (v.begin(), v.end(), v.begin(), ::exp);
|
std::transform (v.begin(), v.end(), v.begin(), ::exp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -224,36 +242,6 @@ Util::logSum (double x, double y)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
Util::add (Params& v1, const Params& v2, unsigned repetitions)
|
|
||||||
{
|
|
||||||
for (size_t count = 0; count < v1.size(); ) {
|
|
||||||
for (size_t i = 0; i < v2.size(); i++) {
|
|
||||||
for (unsigned r = 0; r < repetitions; r++) {
|
|
||||||
v1[count] += v2[i];
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
Util::multiply (Params& v1, const Params& v2, unsigned repetitions)
|
|
||||||
{
|
|
||||||
for (size_t count = 0; count < v1.size(); ) {
|
|
||||||
for (size_t i = 0; i < v2.size(); i++) {
|
|
||||||
for (unsigned r = 0; r < repetitions; r++) {
|
|
||||||
v1[count] *= v2[i];
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline unsigned
|
inline unsigned
|
||||||
Util::maxUnsigned (void)
|
Util::maxUnsigned (void)
|
||||||
{
|
{
|
||||||
@ -273,7 +261,6 @@ inline double noEvidence() { return Globals::logDomain ? NEG_INF : 0.0; }
|
|||||||
inline double log (double v) { return Globals::logDomain ? ::log (v) : v; }
|
inline double log (double v) { return Globals::logDomain ? ::log (v) : v; }
|
||||||
inline double exp (double v) { return Globals::logDomain ? ::exp (v) : v; }
|
inline double exp (double v) { return Globals::logDomain ? ::exp (v) : v; }
|
||||||
|
|
||||||
|
|
||||||
void normalize (Params&);
|
void normalize (Params&);
|
||||||
|
|
||||||
double getL1Distance (const Params&, const Params&);
|
double getL1Distance (const Params&, const Params&);
|
||||||
|
Reference in New Issue
Block a user