This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/Factor.cpp

238 lines
4.4 KiB
C++
Raw Normal View History

2012-05-23 14:56:01 +01:00
#include <cstdlib>
#include <cassert>
#include <algorithm>
#include <iostream>
#include <sstream>
#include "Factor.h"
2012-12-27 12:54:58 +00:00
#include "Var.h"
2012-05-23 14:56:01 +01:00
Factor::Factor (const Factor& g)
{
2012-05-25 21:22:48 +01:00
clone (g);
2012-05-23 14:56:01 +01:00
}
Factor::Factor (
const VarIds& vids,
const Ranges& ranges,
const Params& params,
unsigned distId)
{
args_ = vids;
ranges_ = ranges;
params_ = params;
distId_ = distId;
2012-05-24 16:14:13 +01:00
assert (params_.size() == Util::sizeExpected (ranges_));
2012-05-23 14:56:01 +01:00
}
Factor::Factor (
const Vars& vars,
const Params& params,
unsigned distId)
{
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < vars.size(); i++) {
2012-05-23 14:56:01 +01:00
args_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->range());
}
params_ = params;
distId_ = distId;
2012-05-24 16:14:13 +01:00
assert (params_.size() == Util::sizeExpected (ranges_));
2012-05-23 14:56:01 +01:00
}
void
Factor::sumOut (VarId vid)
{
2012-05-25 20:15:05 +01:00
if (vid == args_.front() && ranges_.front() == 2) {
// optimization
sumOutFirstVariable();
} else if (vid == args_.back() && ranges_.back() == 2) {
// optimization
sumOutLastVariable();
} else {
assert (indexOf (vid) != args_.size());
sumOutIndex (indexOf (vid));
}
2012-05-23 14:56:01 +01:00
}
void
Factor::sumOutAllExcept (VarId vid)
{
2012-05-24 22:55:20 +01:00
assert (indexOf (vid) != args_.size());
2012-05-25 20:15:05 +01:00
sumOutAllExceptIndex (indexOf (vid));
2012-05-23 14:56:01 +01:00
}
void
Factor::sumOutAllExcept (const VarIds& vids)
{
2012-05-25 20:15:05 +01:00
vector<bool> mask (args_.size(), false);
for (unsigned i = 0; i < vids.size(); i++) {
assert (indexOf (vids[i]) != args_.size());
mask[indexOf (vids[i])] = true;
2012-05-23 14:56:01 +01:00
}
2012-05-25 20:15:05 +01:00
sumOutArgs (mask);
2012-05-23 14:56:01 +01:00
}
void
2012-05-24 22:55:20 +01:00
Factor::sumOutAllExceptIndex (size_t idx)
2012-05-23 14:56:01 +01:00
{
assert (idx < args_.size());
2012-05-25 20:15:05 +01:00
vector<bool> mask (args_.size(), false);
mask[idx] = true;
sumOutArgs (mask);
2012-05-23 14:56:01 +01:00
}
void
Factor::multiply (Factor& g)
{
2012-12-27 12:54:58 +00:00
if (args_.empty()) {
2012-05-25 21:22:48 +01:00
clone (g);
2012-12-27 12:54:58 +00:00
} else {
TFactor<VarId>::multiply (g);
2012-05-23 14:56:01 +01:00
}
}
string
Factor::getLabel (void) const
{
stringstream ss;
ss << "f(" ;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < args_.size(); i++) {
2012-05-23 14:56:01 +01:00
if (i != 0) ss << "," ;
ss << Var (args_[i], ranges_[i]).label();
}
ss << ")" ;
return ss.str();
}
void
Factor::print (void) const
{
Vars vars;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < args_.size(); i++) {
2012-05-23 14:56:01 +01:00
vars.push_back (new Var (args_[i], ranges_[i]));
}
vector<string> jointStrings = Util::getStateLines (vars);
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < params_.size(); i++) {
2012-05-23 14:56:01 +01:00
// cout << "[" << distId_ << "] " ;
cout << "f(" << jointStrings[i] << ")" ;
cout << " = " << params_[i] << endl;
}
cout << endl;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < vars.size(); i++) {
2012-05-23 14:56:01 +01:00
delete vars[i];
}
}
2012-05-25 20:15:05 +01:00
void
Factor::sumOutFirstVariable (void)
{
size_t sep = params_.size() / 2;
if (Globals::logDomain) {
std::transform (
params_.begin(), params_.begin() + sep,
params_.begin() + sep, params_.begin(),
Util::logSum);
} else {
std::transform (
params_.begin(), params_.begin() + sep,
params_.begin() + sep, params_.begin(),
std::plus<double>());
}
params_.resize (sep);
args_.erase (args_.begin());
ranges_.erase (ranges_.begin());
}
void
Factor::sumOutLastVariable (void)
{
Params::iterator first1 = params_.begin();
Params::iterator first2 = params_.begin();
Params::iterator last = params_.end();
if (Globals::logDomain) {
while (first2 != last) {
// the arguments can be swaped, but that is ok
*first1++ = Util::logSum (*first2++, *first2++);
}
} else {
while (first2 != last) {
*first1++ = (*first2++) + (*first2++);
}
}
params_.resize (params_.size() / 2);
args_.pop_back();
ranges_.pop_back();
}
void
Factor::sumOutArgs (const vector<bool>& mask)
{
assert (mask.size() == args_.size());
size_t new_size = 1;
Ranges oldRanges = ranges_;
args_.clear();
ranges_.clear();
for (unsigned i = 0; i < mask.size(); i++) {
if (mask[i]) {
new_size *= ranges_[i];
args_.push_back (args_[i]);
ranges_.push_back (ranges_[i]);
}
}
Params newps (new_size, LogAware::addIdenty());
Params::const_iterator first = params_.begin();
Params::const_iterator last = params_.end();
2012-05-28 12:32:15 +01:00
MapIndexer indexer (oldRanges, mask);
2012-05-25 20:15:05 +01:00
if (Globals::logDomain) {
while (first != last) {
newps[indexer] = Util::logSum (newps[indexer], *first++);
++ indexer;
}
} else {
while (first != last) {
newps[indexer] += *first++;
++ indexer;
}
}
params_ = newps;
}
2012-05-23 14:56:01 +01:00
void
2012-05-25 21:22:48 +01:00
Factor::clone (const Factor& g)
2012-05-23 14:56:01 +01:00
{
args_ = g.arguments();
ranges_ = g.ranges();
params_ = g.params();
distId_ = g.distId();
}