add operators to manipulate vectors of parameters

This commit is contained in:
Tiago Gomes 2012-05-24 14:55:30 +01:00
parent 6cb718942a
commit 444eaacc63
6 changed files with 138 additions and 109 deletions

View File

@ -87,13 +87,13 @@ BpSolver::getPosterioriOf (VarId vid)
const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
Util::add (probs, links[i]->getMessage());
probs += links[i]->getMessage();
}
LogAware::normalize (probs);
Util::fromLog (probs);
} else {
for (unsigned i = 0; i < links.size(); i++) {
Util::multiply (probs, links[i]->getMessage());
probs *= links[i]->getMessage();
}
LogAware::normalize (probs);
}
@ -362,16 +362,16 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
if (Globals::logDomain) {
SpLinkSet::const_iterator it;
for (it = links.begin(); it != links.end(); ++ it) {
Util::add (msg, (*it)->getMessage());
msg += (*it)->getMessage();
if (Constants::SHOW_BP_CALCS) {
cout << " x " << (*it)->getMessage();
}
}
Util::subtract (msg, link->getMessage());
msg -= link->getMessage();
} else {
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
Util::multiply (msg, links[i]->getMessage());
msg *= links[i]->getMessage();
if (Constants::SHOW_BP_CALCS) {
cout << " x " << links[i]->getMessage();
}

View File

@ -82,14 +82,14 @@ CbpSolver::getPosterioriOf (VarId vid)
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (probs, l->poweredMessage());
probs += l->poweredMessage();
}
LogAware::normalize (probs);
Util::fromLog (probs);
} else {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (probs, l->poweredMessage());
probs *= l->poweredMessage();
}
LogAware::normalize (probs);
}
@ -330,14 +330,14 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
Util::add (msg, cl->poweredMessage());
msg += cl->poweredMessage();
}
}
} else {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
Util::multiply (msg, cl->poweredMessage());
msg *= cl->poweredMessage();
if (Constants::SHOW_BP_CALCS) {
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
}

View File

@ -79,11 +79,8 @@ class TFactor
if (args_ == g_args) {
// optimization: if the factors contain the same set of args,
// we can do a 1 to 1 operation on the parameters
if (Globals::logDomain) {
Util::add (params_, g_params);
} else {
Util::multiply (params_, g_params);
}
Globals::logDomain ? params_ += g_params
: params_ *= g_params;
} else {
bool sharedArgs = false;
vector<unsigned> gvarpos;

View File

@ -572,7 +572,7 @@ extern "C" void
init_predicates (void)
{
YAP_UserCPredicate ("cpp_create_lifted_network", createLiftedNetwork, 3);
YAP_UserCPredicate ("cpp_create_ground_network", createGroundNetwork, 4);
YAP_UserCPredicate ("cpp_create_ground_network", createGroundNetwork, 4);
YAP_UserCPredicate ("cpp_run_lifted_solver", runLiftedSolver, 3);
YAP_UserCPredicate ("cpp_run_ground_solver", runGroundSolver, 3);
YAP_UserCPredicate ("cpp_set_parfactors_params", setParfactorsParams, 2);

View File

@ -79,9 +79,7 @@ stringToDouble (string str)
void
toLog (Params& v)
{
for (unsigned i = 0; i < v.size(); i++) {
v[i] = log (v[i]);
}
transform (v.begin(), v.end(), v.begin(), ::log);
}
@ -89,9 +87,7 @@ toLog (Params& v)
void
fromLog (Params& v)
{
for (unsigned i = 0; i < v.size(); i++) {
v[i] = exp (v[i]);
}
transform (v.begin(), v.end(), v.begin(), ::exp);
}
@ -152,11 +148,8 @@ nrCombinations (unsigned n, unsigned k)
unsigned
expectedSize (const Ranges& ranges)
{
unsigned prod = 1;
for (unsigned i = 0; i < ranges.size(); i++) {
prod *= ranges[i];
}
return prod;
return std::accumulate (
ranges.begin(), ranges.end(), 1, multiplies<unsigned>());
}
@ -410,55 +403,40 @@ getMaxNorm (const Params& v1, const Params& v2)
}
double
pow (double p, unsigned expoent)
pow (double base, unsigned iexp)
{
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
return Globals::logDomain ? base * iexp : std::pow (base, iexp);
}
double
pow (double p, double expoent)
pow (double base, double exp)
{
// assumes that `expoent' is never in log domain
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
return Globals::logDomain ? base * exp : std::pow (base, exp);
}
void
pow (Params& v, unsigned expoent)
pow (Params& v, unsigned iexp)
{
if (expoent == 1) {
if (iexp == 1) {
return;
}
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
Globals::logDomain ? v *= iexp : v ^= (int)iexp;
}
void
pow (Params& v, double expoent)
pow (Params& v, double exp)
{
// assumes that `expoent' is never in log domain
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
// `expoent' should not be in log domain
Globals::logDomain ? v *= exp : v ^= exp;
}
}

View File

@ -40,6 +40,14 @@ template <typename T> std::string toString (const T&);
template <> std::string toString (const bool&);
double logSum (double, double);
void add (Params&, const Params&, unsigned);
void multiply (Params&, const Params&, unsigned);
unsigned maxUnsigned (void);
unsigned stringToUnsigned (string);
double stringToDouble (string);
@ -48,20 +56,6 @@ void toLog (Params&);
void fromLog (Params&);
double logSum (double, double);
void multiply (Params&, const Params&);
void multiply (Params&, const Params&, unsigned);
void add (Params&, const Params&);
void subtract (Params&, const Params&);
void add (Params&, const Params&, unsigned);
unsigned maxUnsigned (void);
double factorial (unsigned);
double logFactorial (unsigned);
@ -147,10 +141,7 @@ Util::indexOf (const vector<T>& v, const T& e)
{
int pos = std::distance (v.begin(),
std::find (v.begin(), v.end(), e));
if (pos == (int)v.size()) {
pos = -1;
}
return pos;
return pos != (int)v.size() ? pos : -1;
}
@ -216,11 +207,15 @@ Util::logSum (double x, double y)
inline void
Util::multiply (Params& v1, const Params& v2)
Util::add (Params& v1, const Params& v2, unsigned repetitions)
{
assert (v1.size() == v2.size());
for (unsigned i = 0; i < v1.size(); i++) {
v1[i] *= v2[i];
for (unsigned count = 0; count < v1.size(); ) {
for (unsigned i = 0; i < v2.size(); i++) {
for (unsigned r = 0; r < repetitions; r++) {
v1[count] += v2[i];
count ++;
}
}
}
}
@ -241,41 +236,6 @@ Util::multiply (Params& v1, const Params& v2, unsigned repetitions)
inline void
Util::add (Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
std::transform (v1.begin(), v1.end(), v2.begin(),
v1.begin(), plus<double>());
}
inline void
Util::subtract (Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
std::transform (v1.begin(), v1.end(), v2.begin(),
v1.begin(), minus<double>());
}
inline void
Util::add (Params& v1, const Params& v2, unsigned repetitions)
{
for (unsigned count = 0; count < v1.size(); ) {
for (unsigned i = 0; i < v2.size(); i++) {
for (unsigned r = 0; r < repetitions; r++) {
v1[count] += v2[i];
count ++;
}
}
}
}
inline unsigned
Util::maxUnsigned (void)
{
@ -284,6 +244,100 @@ Util::maxUnsigned (void)
template <typename T>
void operator+=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (plus<double>(), val));
}
template <typename T>
void operator-=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (minus<double>(), val));
}
template <typename T>
void operator*=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (multiplies<double>(), val));
}
template <typename T>
void operator/=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (divides<double>(), val));
}
template <typename T>
void operator+=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
plus<double>());
}
template <typename T>
void operator-=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
minus<double>());
}
template <typename T>
void operator*=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
multiplies<double>());
}
template <typename T>
void operator/=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
divides<double>());
}
template <typename T>
void operator^=(std::vector<T>& v, double exp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
}
template <typename T>
void operator^=(std::vector<T>& v, int iexp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
}
namespace LogAware {
inline double