add operators to manipulate vectors of parameters
This commit is contained in:
parent
6cb718942a
commit
444eaacc63
@ -87,13 +87,13 @@ BpSolver::getPosterioriOf (VarId vid)
|
|||||||
const SpLinkSet& links = ninf(var)->getLinks();
|
const SpLinkSet& links = ninf(var)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
Util::add (probs, links[i]->getMessage());
|
probs += links[i]->getMessage();
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
Util::fromLog (probs);
|
Util::fromLog (probs);
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
Util::multiply (probs, links[i]->getMessage());
|
probs *= links[i]->getMessage();
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
}
|
}
|
||||||
@ -362,16 +362,16 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
SpLinkSet::const_iterator it;
|
SpLinkSet::const_iterator it;
|
||||||
for (it = links.begin(); it != links.end(); ++ it) {
|
for (it = links.begin(); it != links.end(); ++ it) {
|
||||||
Util::add (msg, (*it)->getMessage());
|
msg += (*it)->getMessage();
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " x " << (*it)->getMessage();
|
cout << " x " << (*it)->getMessage();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Util::subtract (msg, link->getMessage());
|
msg -= link->getMessage();
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getFactor() != dst) {
|
if (links[i]->getFactor() != dst) {
|
||||||
Util::multiply (msg, links[i]->getMessage());
|
msg *= links[i]->getMessage();
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " x " << links[i]->getMessage();
|
cout << " x " << links[i]->getMessage();
|
||||||
}
|
}
|
||||||
|
@ -82,14 +82,14 @@ CbpSolver::getPosterioriOf (VarId vid)
|
|||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::add (probs, l->poweredMessage());
|
probs += l->poweredMessage();
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
Util::fromLog (probs);
|
Util::fromLog (probs);
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::multiply (probs, l->poweredMessage());
|
probs *= l->poweredMessage();
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
}
|
}
|
||||||
@ -330,14 +330,14 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const
|
|||||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||||
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
||||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::add (msg, cl->poweredMessage());
|
msg += cl->poweredMessage();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||||
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
||||||
Util::multiply (msg, cl->poweredMessage());
|
msg *= cl->poweredMessage();
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
|
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
|
||||||
}
|
}
|
||||||
|
@ -79,11 +79,8 @@ class TFactor
|
|||||||
if (args_ == g_args) {
|
if (args_ == g_args) {
|
||||||
// optimization: if the factors contain the same set of args,
|
// optimization: if the factors contain the same set of args,
|
||||||
// we can do a 1 to 1 operation on the parameters
|
// we can do a 1 to 1 operation on the parameters
|
||||||
if (Globals::logDomain) {
|
Globals::logDomain ? params_ += g_params
|
||||||
Util::add (params_, g_params);
|
: params_ *= g_params;
|
||||||
} else {
|
|
||||||
Util::multiply (params_, g_params);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
bool sharedArgs = false;
|
bool sharedArgs = false;
|
||||||
vector<unsigned> gvarpos;
|
vector<unsigned> gvarpos;
|
||||||
|
@ -79,9 +79,7 @@ stringToDouble (string str)
|
|||||||
void
|
void
|
||||||
toLog (Params& v)
|
toLog (Params& v)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
transform (v.begin(), v.end(), v.begin(), ::log);
|
||||||
v[i] = log (v[i]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -89,9 +87,7 @@ toLog (Params& v)
|
|||||||
void
|
void
|
||||||
fromLog (Params& v)
|
fromLog (Params& v)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
transform (v.begin(), v.end(), v.begin(), ::exp);
|
||||||
v[i] = exp (v[i]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -152,11 +148,8 @@ nrCombinations (unsigned n, unsigned k)
|
|||||||
unsigned
|
unsigned
|
||||||
expectedSize (const Ranges& ranges)
|
expectedSize (const Ranges& ranges)
|
||||||
{
|
{
|
||||||
unsigned prod = 1;
|
return std::accumulate (
|
||||||
for (unsigned i = 0; i < ranges.size(); i++) {
|
ranges.begin(), ranges.end(), 1, multiplies<unsigned>());
|
||||||
prod *= ranges[i];
|
|
||||||
}
|
|
||||||
return prod;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -410,55 +403,40 @@ getMaxNorm (const Params& v1, const Params& v2)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
double
|
||||||
pow (double p, unsigned expoent)
|
pow (double base, unsigned iexp)
|
||||||
{
|
{
|
||||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
return Globals::logDomain ? base * iexp : std::pow (base, iexp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
double
|
||||||
pow (double p, double expoent)
|
pow (double base, double exp)
|
||||||
{
|
{
|
||||||
// assumes that `expoent' is never in log domain
|
// assumes that `expoent' is never in log domain
|
||||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
return Globals::logDomain ? base * exp : std::pow (base, exp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
pow (Params& v, unsigned expoent)
|
pow (Params& v, unsigned iexp)
|
||||||
{
|
{
|
||||||
if (expoent == 1) {
|
if (iexp == 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (Globals::logDomain) {
|
Globals::logDomain ? v *= iexp : v ^= (int)iexp;
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] *= expoent;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] = std::pow (v[i], expoent);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
pow (Params& v, double expoent)
|
pow (Params& v, double exp)
|
||||||
{
|
{
|
||||||
// assumes that `expoent' is never in log domain
|
// `expoent' should not be in log domain
|
||||||
if (Globals::logDomain) {
|
Globals::logDomain ? v *= exp : v ^= exp;
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] *= expoent;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] = std::pow (v[i], expoent);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -40,6 +40,14 @@ template <typename T> std::string toString (const T&);
|
|||||||
|
|
||||||
template <> std::string toString (const bool&);
|
template <> std::string toString (const bool&);
|
||||||
|
|
||||||
|
double logSum (double, double);
|
||||||
|
|
||||||
|
void add (Params&, const Params&, unsigned);
|
||||||
|
|
||||||
|
void multiply (Params&, const Params&, unsigned);
|
||||||
|
|
||||||
|
unsigned maxUnsigned (void);
|
||||||
|
|
||||||
unsigned stringToUnsigned (string);
|
unsigned stringToUnsigned (string);
|
||||||
|
|
||||||
double stringToDouble (string);
|
double stringToDouble (string);
|
||||||
@ -48,20 +56,6 @@ void toLog (Params&);
|
|||||||
|
|
||||||
void fromLog (Params&);
|
void fromLog (Params&);
|
||||||
|
|
||||||
double logSum (double, double);
|
|
||||||
|
|
||||||
void multiply (Params&, const Params&);
|
|
||||||
|
|
||||||
void multiply (Params&, const Params&, unsigned);
|
|
||||||
|
|
||||||
void add (Params&, const Params&);
|
|
||||||
|
|
||||||
void subtract (Params&, const Params&);
|
|
||||||
|
|
||||||
void add (Params&, const Params&, unsigned);
|
|
||||||
|
|
||||||
unsigned maxUnsigned (void);
|
|
||||||
|
|
||||||
double factorial (unsigned);
|
double factorial (unsigned);
|
||||||
|
|
||||||
double logFactorial (unsigned);
|
double logFactorial (unsigned);
|
||||||
@ -147,10 +141,7 @@ Util::indexOf (const vector<T>& v, const T& e)
|
|||||||
{
|
{
|
||||||
int pos = std::distance (v.begin(),
|
int pos = std::distance (v.begin(),
|
||||||
std::find (v.begin(), v.end(), e));
|
std::find (v.begin(), v.end(), e));
|
||||||
if (pos == (int)v.size()) {
|
return pos != (int)v.size() ? pos : -1;
|
||||||
pos = -1;
|
|
||||||
}
|
|
||||||
return pos;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -216,11 +207,15 @@ Util::logSum (double x, double y)
|
|||||||
|
|
||||||
|
|
||||||
inline void
|
inline void
|
||||||
Util::multiply (Params& v1, const Params& v2)
|
Util::add (Params& v1, const Params& v2, unsigned repetitions)
|
||||||
{
|
{
|
||||||
assert (v1.size() == v2.size());
|
for (unsigned count = 0; count < v1.size(); ) {
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
for (unsigned i = 0; i < v2.size(); i++) {
|
||||||
v1[i] *= v2[i];
|
for (unsigned r = 0; r < repetitions; r++) {
|
||||||
|
v1[count] += v2[i];
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,41 +236,6 @@ Util::multiply (Params& v1, const Params& v2, unsigned repetitions)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
Util::add (Params& v1, const Params& v2)
|
|
||||||
{
|
|
||||||
assert (v1.size() == v2.size());
|
|
||||||
std::transform (v1.begin(), v1.end(), v2.begin(),
|
|
||||||
v1.begin(), plus<double>());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
Util::subtract (Params& v1, const Params& v2)
|
|
||||||
{
|
|
||||||
assert (v1.size() == v2.size());
|
|
||||||
std::transform (v1.begin(), v1.end(), v2.begin(),
|
|
||||||
v1.begin(), minus<double>());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
Util::add (Params& v1, const Params& v2, unsigned repetitions)
|
|
||||||
{
|
|
||||||
for (unsigned count = 0; count < v1.size(); ) {
|
|
||||||
for (unsigned i = 0; i < v2.size(); i++) {
|
|
||||||
for (unsigned r = 0; r < repetitions; r++) {
|
|
||||||
v1[count] += v2[i];
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline unsigned
|
inline unsigned
|
||||||
Util::maxUnsigned (void)
|
Util::maxUnsigned (void)
|
||||||
{
|
{
|
||||||
@ -284,6 +244,100 @@ Util::maxUnsigned (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator+=(std::vector<T>& v, double val)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
|
std::bind1st (plus<double>(), val));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator-=(std::vector<T>& v, double val)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
|
std::bind1st (minus<double>(), val));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator*=(std::vector<T>& v, double val)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
|
std::bind1st (multiplies<double>(), val));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator/=(std::vector<T>& v, double val)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
|
std::bind1st (divides<double>(), val));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator+=(std::vector<T>& a, const std::vector<T>& b)
|
||||||
|
{
|
||||||
|
assert (a.size() == b.size());
|
||||||
|
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||||
|
plus<double>());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator-=(std::vector<T>& a, const std::vector<T>& b)
|
||||||
|
{
|
||||||
|
assert (a.size() == b.size());
|
||||||
|
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||||
|
minus<double>());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator*=(std::vector<T>& a, const std::vector<T>& b)
|
||||||
|
{
|
||||||
|
assert (a.size() == b.size());
|
||||||
|
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||||
|
multiplies<double>());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator/=(std::vector<T>& a, const std::vector<T>& b)
|
||||||
|
{
|
||||||
|
assert (a.size() == b.size());
|
||||||
|
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||||
|
divides<double>());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator^=(std::vector<T>& v, double exp)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
|
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator^=(std::vector<T>& v, int iexp)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
|
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace LogAware {
|
namespace LogAware {
|
||||||
|
|
||||||
inline double
|
inline double
|
||||||
|
Reference in New Issue
Block a user