add operators to manipulate vectors of parameters
This commit is contained in:
parent
6cb718942a
commit
444eaacc63
@ -87,13 +87,13 @@ BpSolver::getPosterioriOf (VarId vid)
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::add (probs, links[i]->getMessage());
|
||||
probs += links[i]->getMessage();
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::multiply (probs, links[i]->getMessage());
|
||||
probs *= links[i]->getMessage();
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
}
|
||||
@ -362,16 +362,16 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
if (Globals::logDomain) {
|
||||
SpLinkSet::const_iterator it;
|
||||
for (it = links.begin(); it != links.end(); ++ it) {
|
||||
Util::add (msg, (*it)->getMessage());
|
||||
msg += (*it)->getMessage();
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " x " << (*it)->getMessage();
|
||||
}
|
||||
}
|
||||
Util::subtract (msg, link->getMessage());
|
||||
msg -= link->getMessage();
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
Util::multiply (msg, links[i]->getMessage());
|
||||
msg *= links[i]->getMessage();
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " x " << links[i]->getMessage();
|
||||
}
|
||||
|
@ -82,14 +82,14 @@ CbpSolver::getPosterioriOf (VarId vid)
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (probs, l->poweredMessage());
|
||||
probs += l->poweredMessage();
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::multiply (probs, l->poweredMessage());
|
||||
probs *= l->poweredMessage();
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
}
|
||||
@ -330,14 +330,14 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const
|
||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (msg, cl->poweredMessage());
|
||||
msg += cl->poweredMessage();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
||||
Util::multiply (msg, cl->poweredMessage());
|
||||
msg *= cl->poweredMessage();
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
|
||||
}
|
||||
|
@ -79,11 +79,8 @@ class TFactor
|
||||
if (args_ == g_args) {
|
||||
// optimization: if the factors contain the same set of args,
|
||||
// we can do a 1 to 1 operation on the parameters
|
||||
if (Globals::logDomain) {
|
||||
Util::add (params_, g_params);
|
||||
} else {
|
||||
Util::multiply (params_, g_params);
|
||||
}
|
||||
Globals::logDomain ? params_ += g_params
|
||||
: params_ *= g_params;
|
||||
} else {
|
||||
bool sharedArgs = false;
|
||||
vector<unsigned> gvarpos;
|
||||
|
@ -572,7 +572,7 @@ extern "C" void
|
||||
init_predicates (void)
|
||||
{
|
||||
YAP_UserCPredicate ("cpp_create_lifted_network", createLiftedNetwork, 3);
|
||||
YAP_UserCPredicate ("cpp_create_ground_network", createGroundNetwork, 4);
|
||||
YAP_UserCPredicate ("cpp_create_ground_network", createGroundNetwork, 4);
|
||||
YAP_UserCPredicate ("cpp_run_lifted_solver", runLiftedSolver, 3);
|
||||
YAP_UserCPredicate ("cpp_run_ground_solver", runGroundSolver, 3);
|
||||
YAP_UserCPredicate ("cpp_set_parfactors_params", setParfactorsParams, 2);
|
||||
|
@ -79,9 +79,7 @@ stringToDouble (string str)
|
||||
void
|
||||
toLog (Params& v)
|
||||
{
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = log (v[i]);
|
||||
}
|
||||
transform (v.begin(), v.end(), v.begin(), ::log);
|
||||
}
|
||||
|
||||
|
||||
@ -89,9 +87,7 @@ toLog (Params& v)
|
||||
void
|
||||
fromLog (Params& v)
|
||||
{
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = exp (v[i]);
|
||||
}
|
||||
transform (v.begin(), v.end(), v.begin(), ::exp);
|
||||
}
|
||||
|
||||
|
||||
@ -152,11 +148,8 @@ nrCombinations (unsigned n, unsigned k)
|
||||
unsigned
|
||||
expectedSize (const Ranges& ranges)
|
||||
{
|
||||
unsigned prod = 1;
|
||||
for (unsigned i = 0; i < ranges.size(); i++) {
|
||||
prod *= ranges[i];
|
||||
}
|
||||
return prod;
|
||||
return std::accumulate (
|
||||
ranges.begin(), ranges.end(), 1, multiplies<unsigned>());
|
||||
}
|
||||
|
||||
|
||||
@ -410,55 +403,40 @@ getMaxNorm (const Params& v1, const Params& v2)
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
pow (double p, unsigned expoent)
|
||||
pow (double base, unsigned iexp)
|
||||
{
|
||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||
return Globals::logDomain ? base * iexp : std::pow (base, iexp);
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
pow (double p, double expoent)
|
||||
pow (double base, double exp)
|
||||
{
|
||||
// assumes that `expoent' is never in log domain
|
||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||
return Globals::logDomain ? base * exp : std::pow (base, exp);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
pow (Params& v, unsigned expoent)
|
||||
pow (Params& v, unsigned iexp)
|
||||
{
|
||||
if (expoent == 1) {
|
||||
if (iexp == 1) {
|
||||
return;
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] *= expoent;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = std::pow (v[i], expoent);
|
||||
}
|
||||
}
|
||||
Globals::logDomain ? v *= iexp : v ^= (int)iexp;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
pow (Params& v, double expoent)
|
||||
pow (Params& v, double exp)
|
||||
{
|
||||
// assumes that `expoent' is never in log domain
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] *= expoent;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = std::pow (v[i], expoent);
|
||||
}
|
||||
}
|
||||
// `expoent' should not be in log domain
|
||||
Globals::logDomain ? v *= exp : v ^= exp;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -40,6 +40,14 @@ template <typename T> std::string toString (const T&);
|
||||
|
||||
template <> std::string toString (const bool&);
|
||||
|
||||
double logSum (double, double);
|
||||
|
||||
void add (Params&, const Params&, unsigned);
|
||||
|
||||
void multiply (Params&, const Params&, unsigned);
|
||||
|
||||
unsigned maxUnsigned (void);
|
||||
|
||||
unsigned stringToUnsigned (string);
|
||||
|
||||
double stringToDouble (string);
|
||||
@ -48,20 +56,6 @@ void toLog (Params&);
|
||||
|
||||
void fromLog (Params&);
|
||||
|
||||
double logSum (double, double);
|
||||
|
||||
void multiply (Params&, const Params&);
|
||||
|
||||
void multiply (Params&, const Params&, unsigned);
|
||||
|
||||
void add (Params&, const Params&);
|
||||
|
||||
void subtract (Params&, const Params&);
|
||||
|
||||
void add (Params&, const Params&, unsigned);
|
||||
|
||||
unsigned maxUnsigned (void);
|
||||
|
||||
double factorial (unsigned);
|
||||
|
||||
double logFactorial (unsigned);
|
||||
@ -147,10 +141,7 @@ Util::indexOf (const vector<T>& v, const T& e)
|
||||
{
|
||||
int pos = std::distance (v.begin(),
|
||||
std::find (v.begin(), v.end(), e));
|
||||
if (pos == (int)v.size()) {
|
||||
pos = -1;
|
||||
}
|
||||
return pos;
|
||||
return pos != (int)v.size() ? pos : -1;
|
||||
}
|
||||
|
||||
|
||||
@ -216,11 +207,15 @@ Util::logSum (double x, double y)
|
||||
|
||||
|
||||
inline void
|
||||
Util::multiply (Params& v1, const Params& v2)
|
||||
Util::add (Params& v1, const Params& v2, unsigned repetitions)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
v1[i] *= v2[i];
|
||||
for (unsigned count = 0; count < v1.size(); ) {
|
||||
for (unsigned i = 0; i < v2.size(); i++) {
|
||||
for (unsigned r = 0; r < repetitions; r++) {
|
||||
v1[count] += v2[i];
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -241,41 +236,6 @@ Util::multiply (Params& v1, const Params& v2, unsigned repetitions)
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Util::add (Params& v1, const Params& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
std::transform (v1.begin(), v1.end(), v2.begin(),
|
||||
v1.begin(), plus<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Util::subtract (Params& v1, const Params& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
std::transform (v1.begin(), v1.end(), v2.begin(),
|
||||
v1.begin(), minus<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Util::add (Params& v1, const Params& v2, unsigned repetitions)
|
||||
{
|
||||
for (unsigned count = 0; count < v1.size(); ) {
|
||||
for (unsigned i = 0; i < v2.size(); i++) {
|
||||
for (unsigned r = 0; r < repetitions; r++) {
|
||||
v1[count] += v2[i];
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
Util::maxUnsigned (void)
|
||||
{
|
||||
@ -284,6 +244,100 @@ Util::maxUnsigned (void)
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator+=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind1st (plus<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator-=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind1st (minus<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator*=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind1st (multiplies<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator/=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind1st (divides<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator+=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
plus<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator-=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
minus<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator*=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
multiplies<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator/=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
divides<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator^=(std::vector<T>& v, double exp)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator^=(std::vector<T>& v, int iexp)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
|
||||
}
|
||||
|
||||
|
||||
|
||||
namespace LogAware {
|
||||
|
||||
inline double
|
||||
|
Reference in New Issue
Block a user