This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/Factor.cpp
Tiago Gomes a2e54a235a Factor: improve factor multiplication
Pass the argument as reference-to-const and also allow chaining of multiplications.
2013-03-21 21:49:12 +00:00

228 lines
4.4 KiB
C++

#include <cassert>
#include <algorithm>
#include <iostream>
#include <sstream>
#include "Factor.h"
#include "Indexer.h"
#include "Var.h"
namespace Horus {
Factor::Factor (
const VarIds& vids,
const Ranges& ranges,
const Params& params,
unsigned distId)
{
args_ = vids;
ranges_ = ranges;
params_ = params;
distId_ = distId;
assert (params_.size() == Util::sizeExpected (ranges_));
}
Factor::Factor (
const Vars& vars,
const Params& params,
unsigned distId)
{
for (size_t i = 0; i < vars.size(); i++) {
args_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->range());
}
params_ = params;
distId_ = distId;
assert (params_.size() == Util::sizeExpected (ranges_));
}
void
Factor::sumOut (VarId vid)
{
if (vid == args_.front() && ranges_.front() == 2) {
// optimization
sumOutFirstVariable();
} else if (vid == args_.back() && ranges_.back() == 2) {
// optimization
sumOutLastVariable();
} else {
assert (indexOf (vid) != args_.size());
sumOutIndex (indexOf (vid));
}
}
void
Factor::sumOutAllExcept (VarId vid)
{
assert (indexOf (vid) != args_.size());
sumOutAllExceptIndex (indexOf (vid));
}
void
Factor::sumOutAllExcept (const VarIds& vids)
{
std::vector<bool> mask (args_.size(), false);
for (unsigned i = 0; i < vids.size(); i++) {
assert (indexOf (vids[i]) != args_.size());
mask[indexOf (vids[i])] = true;
}
sumOutArgs (mask);
}
void
Factor::sumOutAllExceptIndex (size_t idx)
{
assert (idx < args_.size());
std::vector<bool> mask (args_.size(), false);
mask[idx] = true;
sumOutArgs (mask);
}
Factor&
Factor::multiply (const Factor& g)
{
if (args_.empty()) {
operator= (g);
} else {
GenericFactor<VarId>::multiply (g);
}
return *this;
}
std::string
Factor::getLabel() const
{
std::stringstream ss;
ss << "f(" ;
for (size_t i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ;
ss << Var (args_[i], ranges_[i]).label();
}
ss << ")" ;
return ss.str();
}
void
Factor::print() const
{
Vars vars;
for (size_t i = 0; i < args_.size(); i++) {
vars.push_back (new Var (args_[i], ranges_[i]));
}
std::vector<std::string> jointStrings = Util::getStateLines (vars);
for (size_t i = 0; i < params_.size(); i++) {
// cout << "[" << distId_ << "] " ;
std::cout << "f(" << jointStrings[i] << ")" ;
std::cout << " = " << params_[i] << std::endl;
}
std::cout << std::endl;
for (size_t i = 0; i < vars.size(); i++) {
delete vars[i];
}
}
void
Factor::sumOutFirstVariable()
{
assert (ranges_.front() == 2);
size_t sep = params_.size() / 2;
if (Globals::logDomain) {
std::transform (
params_.begin(), params_.begin() + sep,
params_.begin() + sep, params_.begin(),
Util::logSum);
} else {
std::transform (
params_.begin(), params_.begin() + sep,
params_.begin() + sep, params_.begin(),
std::plus<double>());
}
params_.resize (sep);
args_.erase (args_.begin());
ranges_.erase (ranges_.begin());
}
void
Factor::sumOutLastVariable()
{
assert (ranges_.back() == 2);
Params::iterator first1 = params_.begin();
Params::iterator first2 = params_.begin();
Params::iterator last = params_.end();
if (Globals::logDomain) {
while (first2 != last) {
double tmp = *first2++;
*first1++ = Util::logSum (tmp, *first2++);
}
} else {
while (first2 != last) {
*first1 = *first2++;
*first1++ += *first2++;
}
}
params_.resize (params_.size() / 2);
args_.pop_back();
ranges_.pop_back();
}
void
Factor::sumOutArgs (const std::vector<bool>& mask)
{
assert (mask.size() == args_.size());
size_t new_size = 1;
Ranges oldRanges = ranges_;
args_.clear();
ranges_.clear();
for (unsigned i = 0; i < mask.size(); i++) {
if (mask[i]) {
new_size *= ranges_[i];
args_.push_back (args_[i]);
ranges_.push_back (ranges_[i]);
}
}
Params newps (new_size, LogAware::addIdenty());
Params::const_iterator first = params_.begin();
Params::const_iterator last = params_.end();
MapIndexer indexer (oldRanges, mask);
if (Globals::logDomain) {
while (first != last) {
newps[indexer] = Util::logSum (newps[indexer], *first++);
++ indexer;
}
} else {
while (first != last) {
newps[indexer] += *first++;
++ indexer;
}
}
params_ = newps;
}
} // namespace Horus