This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/Factor.cpp

240 lines
4.5 KiB
C++
Raw Normal View History

2012-05-23 14:56:01 +01:00
#include <cassert>
#include <algorithm>
#include <iostream>
#include <sstream>
#include "Factor.h"
2012-12-27 12:54:58 +00:00
#include "Var.h"
2012-05-23 14:56:01 +01:00
2013-02-07 23:53:13 +00:00
namespace horus {
2012-05-23 14:56:01 +01:00
Factor::Factor (const Factor& g)
{
2012-05-25 21:22:48 +01:00
clone (g);
2012-05-23 14:56:01 +01:00
}
Factor::Factor (
const VarIds& vids,
const Ranges& ranges,
const Params& params,
unsigned distId)
{
args_ = vids;
ranges_ = ranges;
params_ = params;
distId_ = distId;
2013-02-08 00:15:41 +00:00
assert (params_.size() == util::sizeExpected (ranges_));
2012-05-23 14:56:01 +01:00
}
Factor::Factor (
const Vars& vars,
const Params& params,
unsigned distId)
{
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < vars.size(); i++) {
2012-05-23 14:56:01 +01:00
args_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->range());
}
params_ = params;
distId_ = distId;
2013-02-08 00:15:41 +00:00
assert (params_.size() == util::sizeExpected (ranges_));
2012-05-23 14:56:01 +01:00
}
void
Factor::sumOut (VarId vid)
{
2012-05-25 20:15:05 +01:00
if (vid == args_.front() && ranges_.front() == 2) {
// optimization
sumOutFirstVariable();
} else if (vid == args_.back() && ranges_.back() == 2) {
// optimization
sumOutLastVariable();
} else {
assert (indexOf (vid) != args_.size());
sumOutIndex (indexOf (vid));
}
2012-05-23 14:56:01 +01:00
}
void
Factor::sumOutAllExcept (VarId vid)
{
2012-05-24 22:55:20 +01:00
assert (indexOf (vid) != args_.size());
2012-05-25 20:15:05 +01:00
sumOutAllExceptIndex (indexOf (vid));
2012-05-23 14:56:01 +01:00
}
void
Factor::sumOutAllExcept (const VarIds& vids)
{
2013-02-07 13:37:15 +00:00
std::vector<bool> mask (args_.size(), false);
2012-05-25 20:15:05 +01:00
for (unsigned i = 0; i < vids.size(); i++) {
assert (indexOf (vids[i]) != args_.size());
mask[indexOf (vids[i])] = true;
2012-05-23 14:56:01 +01:00
}
2012-05-25 20:15:05 +01:00
sumOutArgs (mask);
2012-05-23 14:56:01 +01:00
}
void
2012-05-24 22:55:20 +01:00
Factor::sumOutAllExceptIndex (size_t idx)
2012-05-23 14:56:01 +01:00
{
assert (idx < args_.size());
2013-02-07 13:37:15 +00:00
std::vector<bool> mask (args_.size(), false);
2012-05-25 20:15:05 +01:00
mask[idx] = true;
sumOutArgs (mask);
2012-05-23 14:56:01 +01:00
}
void
Factor::multiply (Factor& g)
{
2012-12-27 12:54:58 +00:00
if (args_.empty()) {
2012-05-25 21:22:48 +01:00
clone (g);
2012-12-27 12:54:58 +00:00
} else {
TFactor<VarId>::multiply (g);
2012-05-23 14:56:01 +01:00
}
}
2013-02-07 13:37:15 +00:00
std::string
2012-05-23 14:56:01 +01:00
Factor::getLabel (void) const
{
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2012-05-23 14:56:01 +01:00
ss << "f(" ;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < args_.size(); i++) {
2012-05-23 14:56:01 +01:00
if (i != 0) ss << "," ;
ss << Var (args_[i], ranges_[i]).label();
}
ss << ")" ;
return ss.str();
}
void
Factor::print (void) const
{
Vars vars;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < args_.size(); i++) {
2012-05-23 14:56:01 +01:00
vars.push_back (new Var (args_[i], ranges_[i]));
}
2013-02-08 00:15:41 +00:00
std::vector<std::string> jointStrings = util::getStateLines (vars);
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < params_.size(); i++) {
2012-05-23 14:56:01 +01:00
// cout << "[" << distId_ << "] " ;
2013-02-07 13:37:15 +00:00
std::cout << "f(" << jointStrings[i] << ")" ;
std::cout << " = " << params_[i] << std::endl;
2012-05-23 14:56:01 +01:00
}
2013-02-07 13:37:15 +00:00
std::cout << std::endl;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < vars.size(); i++) {
2012-05-23 14:56:01 +01:00
delete vars[i];
}
}
2012-05-25 20:15:05 +01:00
void
Factor::sumOutFirstVariable (void)
{
size_t sep = params_.size() / 2;
2013-02-08 00:15:41 +00:00
if (globals::logDomain) {
2012-05-25 20:15:05 +01:00
std::transform (
params_.begin(), params_.begin() + sep,
params_.begin() + sep, params_.begin(),
2013-02-08 00:15:41 +00:00
util::logSum);
2012-05-25 20:15:05 +01:00
} else {
std::transform (
params_.begin(), params_.begin() + sep,
params_.begin() + sep, params_.begin(),
std::plus<double>());
}
params_.resize (sep);
args_.erase (args_.begin());
ranges_.erase (ranges_.begin());
}
void
Factor::sumOutLastVariable (void)
{
Params::iterator first1 = params_.begin();
Params::iterator first2 = params_.begin();
Params::iterator last = params_.end();
2013-02-08 00:15:41 +00:00
if (globals::logDomain) {
2012-05-25 20:15:05 +01:00
while (first2 != last) {
// the arguments can be swaped, but that is ok
2013-02-08 00:15:41 +00:00
*first1++ = util::logSum (*first2++, *first2++);
2012-05-25 20:15:05 +01:00
}
} else {
while (first2 != last) {
*first1++ = (*first2++) + (*first2++);
}
}
params_.resize (params_.size() / 2);
args_.pop_back();
ranges_.pop_back();
}
void
2013-02-07 13:37:15 +00:00
Factor::sumOutArgs (const std::vector<bool>& mask)
2012-05-25 20:15:05 +01:00
{
assert (mask.size() == args_.size());
size_t new_size = 1;
Ranges oldRanges = ranges_;
args_.clear();
ranges_.clear();
for (unsigned i = 0; i < mask.size(); i++) {
if (mask[i]) {
new_size *= ranges_[i];
args_.push_back (args_[i]);
ranges_.push_back (ranges_[i]);
}
}
2013-02-08 00:15:41 +00:00
Params newps (new_size, log_aware::addIdenty());
2012-05-25 20:15:05 +01:00
Params::const_iterator first = params_.begin();
Params::const_iterator last = params_.end();
2012-05-28 12:32:15 +01:00
MapIndexer indexer (oldRanges, mask);
2013-02-08 00:15:41 +00:00
if (globals::logDomain) {
2012-05-25 20:15:05 +01:00
while (first != last) {
2013-02-08 00:15:41 +00:00
newps[indexer] = util::logSum (newps[indexer], *first++);
2012-05-25 20:15:05 +01:00
++ indexer;
}
} else {
while (first != last) {
newps[indexer] += *first++;
++ indexer;
}
}
params_ = newps;
}
2012-05-23 14:56:01 +01:00
void
2012-05-25 21:22:48 +01:00
Factor::clone (const Factor& g)
2012-05-23 14:56:01 +01:00
{
args_ = g.arguments();
ranges_ = g.ranges();
params_ = g.params();
distId_ = g.distId();
}
2013-02-07 23:53:13 +00:00
} // namespace horus