2011-05-17 12:00:33 +01:00
|
|
|
#include <cstdlib>
|
|
|
|
#include <cassert>
|
|
|
|
|
2011-12-12 15:29:51 +00:00
|
|
|
#include <algorithm>
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
#include <iostream>
|
|
|
|
#include <sstream>
|
|
|
|
|
2011-05-17 12:00:33 +01:00
|
|
|
#include "Factor.h"
|
2012-03-22 11:33:24 +00:00
|
|
|
#include "Indexer.h"
|
2012-03-31 23:27:37 +01:00
|
|
|
|
2011-05-17 12:00:33 +01:00
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
Factor::Factor (const Factor& g)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-12-12 15:29:51 +00:00
|
|
|
copyFromFactor (g);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-03-31 23:27:37 +01:00
|
|
|
Factor::Factor (VarId vid, unsigned nrStates)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2012-03-31 23:27:37 +01:00
|
|
|
args_.push_back (vid);
|
|
|
|
ranges_.push_back (nrStates);
|
|
|
|
params_.resize (nrStates, 1.0);
|
|
|
|
distId_ = Util::maxUnsigned();
|
|
|
|
assert (params_.size() == Util::expectedSize (ranges_));
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-12-12 15:29:51 +00:00
|
|
|
Factor::Factor (const VarNodes& vars)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2012-03-31 23:27:37 +01:00
|
|
|
int nrParams = 1;
|
2011-12-12 15:29:51 +00:00
|
|
|
for (unsigned i = 0; i < vars.size(); i++) {
|
2012-03-31 23:27:37 +01:00
|
|
|
args_.push_back (vars[i]->varId());
|
2011-12-12 15:29:51 +00:00
|
|
|
ranges_.push_back (vars[i]->nrStates());
|
2012-03-31 23:27:37 +01:00
|
|
|
nrParams *= vars[i]->nrStates();
|
2011-07-22 21:33:30 +01:00
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
double val = 1.0 / nrParams;
|
|
|
|
params_.resize (nrParams, val);
|
|
|
|
distId_ = Util::maxUnsigned();
|
|
|
|
assert (params_.size() == Util::expectedSize (ranges_));
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-03-31 23:27:37 +01:00
|
|
|
Factor::Factor (
|
|
|
|
VarId vid,
|
|
|
|
unsigned nrStates,
|
|
|
|
const Params& params)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2012-03-31 23:27:37 +01:00
|
|
|
args_.push_back (vid);
|
|
|
|
ranges_.push_back (nrStates);
|
|
|
|
params_ = params;
|
|
|
|
distId_ = Util::maxUnsigned();
|
|
|
|
assert (params_.size() == Util::expectedSize (ranges_));
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-03-31 23:27:37 +01:00
|
|
|
Factor::Factor (
|
|
|
|
const VarNodes& vars,
|
|
|
|
const Params& params,
|
|
|
|
unsigned distId)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-12-12 15:29:51 +00:00
|
|
|
for (unsigned i = 0; i < vars.size(); i++) {
|
2012-03-31 23:27:37 +01:00
|
|
|
args_.push_back (vars[i]->varId());
|
2011-12-12 15:29:51 +00:00
|
|
|
ranges_.push_back (vars[i]->nrStates());
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
params_ = params;
|
|
|
|
distId_ = distId;
|
|
|
|
assert (params_.size() == Util::expectedSize (ranges_));
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-03-31 23:27:37 +01:00
|
|
|
Factor::Factor (
|
|
|
|
const VarIds& vids,
|
|
|
|
const Ranges& ranges,
|
2012-04-03 19:53:27 +01:00
|
|
|
const Params& params,
|
|
|
|
unsigned distId)
|
2011-12-12 15:29:51 +00:00
|
|
|
{
|
2012-04-03 19:53:27 +01:00
|
|
|
args_ = vids;
|
2011-12-12 15:29:51 +00:00
|
|
|
ranges_ = ranges;
|
2012-03-31 23:27:37 +01:00
|
|
|
params_ = params;
|
2012-04-03 19:53:27 +01:00
|
|
|
distId_ = distId;
|
2012-03-31 23:27:37 +01:00
|
|
|
assert (params_.size() == Util::expectedSize (ranges_));
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
2012-03-22 11:33:24 +00:00
|
|
|
Factor::sumOutAllExcept (VarId vid)
|
2011-07-22 21:33:30 +01:00
|
|
|
{
|
2012-03-22 11:33:24 +00:00
|
|
|
assert (indexOf (vid) != -1);
|
2012-03-31 23:27:37 +01:00
|
|
|
while (args_.back() != vid) {
|
2012-03-22 11:33:24 +00:00
|
|
|
sumOutLastVariable();
|
2011-12-12 15:29:51 +00:00
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
while (args_.front() != vid) {
|
2012-03-22 11:33:24 +00:00
|
|
|
sumOutFirstVariable();
|
2011-12-12 15:29:51 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
2012-03-22 11:33:24 +00:00
|
|
|
Factor::sumOutAllExcept (const VarIds& vids)
|
|
|
|
{
|
2012-03-31 23:27:37 +01:00
|
|
|
for (int i = 0; i < (int)args_.size(); i++) {
|
|
|
|
if (Util::contains (vids, args_[i]) == false) {
|
|
|
|
sumOut (args_[i]);
|
|
|
|
i --;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
Factor::sumOut (VarId vid)
|
2011-12-12 15:29:51 +00:00
|
|
|
{
|
2012-03-22 11:33:24 +00:00
|
|
|
int idx = indexOf (vid);
|
|
|
|
assert (idx != -1);
|
2011-12-12 15:29:51 +00:00
|
|
|
|
2012-03-31 23:27:37 +01:00
|
|
|
if (vid == args_.back()) {
|
2012-03-22 11:33:24 +00:00
|
|
|
sumOutLastVariable(); // optimization
|
2011-12-12 15:29:51 +00:00
|
|
|
return;
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
if (vid == args_.front()) {
|
2012-03-22 11:33:24 +00:00
|
|
|
sumOutFirstVariable(); // optimization
|
2011-12-12 15:29:51 +00:00
|
|
|
return;
|
|
|
|
}
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
// number of parameters separating a different state of `var',
|
|
|
|
// with the states of the remaining variables fixed
|
|
|
|
unsigned varOffset = 1;
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
// number of parameters separating a different state of the variable
|
|
|
|
// on the left of `var', with the states of the remaining vars fixed
|
|
|
|
unsigned leftVarOffset = 1;
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2012-03-31 23:27:37 +01:00
|
|
|
for (int i = args_.size() - 1; i > idx; i--) {
|
2011-12-12 15:29:51 +00:00
|
|
|
varOffset *= ranges_[i];
|
|
|
|
leftVarOffset *= ranges_[i];
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2012-03-22 11:33:24 +00:00
|
|
|
leftVarOffset *= ranges_[idx];
|
2011-07-22 21:33:30 +01:00
|
|
|
|
|
|
|
unsigned offset = 0;
|
|
|
|
unsigned count1 = 0;
|
|
|
|
unsigned count2 = 0;
|
2012-03-31 23:27:37 +01:00
|
|
|
unsigned newpsSize = params_.size() / ranges_[idx];
|
2011-07-22 21:33:30 +01:00
|
|
|
|
2012-03-22 11:33:24 +00:00
|
|
|
Params newps;
|
|
|
|
newps.reserve (newpsSize);
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2012-03-22 11:33:24 +00:00
|
|
|
while (newps.size() < newpsSize) {
|
2012-03-31 23:27:37 +01:00
|
|
|
double sum = LogAware::addIdenty();
|
2012-03-22 11:33:24 +00:00
|
|
|
for (unsigned i = 0; i < ranges_[idx]; i++) {
|
|
|
|
if (Globals::logDomain) {
|
2012-03-31 23:27:37 +01:00
|
|
|
sum = Util::logSum (sum, params_[offset]);
|
2012-03-22 11:33:24 +00:00
|
|
|
} else {
|
2012-03-31 23:27:37 +01:00
|
|
|
sum += params_[offset];
|
2011-12-12 15:29:51 +00:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
offset += varOffset;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2012-03-22 11:33:24 +00:00
|
|
|
newps.push_back (sum);
|
2011-07-22 21:33:30 +01:00
|
|
|
count1 ++;
|
2012-03-31 23:27:37 +01:00
|
|
|
if (idx == (int)args_.size() - 1) {
|
2012-03-22 11:33:24 +00:00
|
|
|
offset = count1 * ranges_[idx];
|
2011-05-17 12:00:33 +01:00
|
|
|
} else {
|
2011-07-22 21:33:30 +01:00
|
|
|
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
|
|
|
count1 = 0;
|
|
|
|
count2 ++;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
offset = (leftVarOffset * count2) + count1;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
args_.erase (args_.begin() + idx);
|
2012-03-22 11:33:24 +00:00
|
|
|
ranges_.erase (ranges_.begin() + idx);
|
2012-03-31 23:27:37 +01:00
|
|
|
params_ = newps;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-12-12 15:29:51 +00:00
|
|
|
void
|
2012-03-22 11:33:24 +00:00
|
|
|
Factor::sumOutFirstVariable (void)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-12-12 15:29:51 +00:00
|
|
|
unsigned nStates = ranges_.front();
|
2012-03-31 23:27:37 +01:00
|
|
|
unsigned sep = params_.size() / nStates;
|
2012-03-22 11:33:24 +00:00
|
|
|
if (Globals::logDomain) {
|
2012-03-31 23:27:37 +01:00
|
|
|
for (unsigned i = sep; i < params_.size(); i++) {
|
|
|
|
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
} else {
|
2012-03-31 23:27:37 +01:00
|
|
|
for (unsigned i = sep; i < params_.size(); i++) {
|
|
|
|
params_[i % sep] += params_[i];
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
2011-12-12 15:29:51 +00:00
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
params_.resize (sep);
|
|
|
|
args_.erase (args_.begin());
|
2011-12-12 15:29:51 +00:00
|
|
|
ranges_.erase (ranges_.begin());
|
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
|
2011-12-12 15:29:51 +00:00
|
|
|
|
|
|
|
|
|
|
|
void
|
2012-03-22 11:33:24 +00:00
|
|
|
Factor::sumOutLastVariable (void)
|
2011-12-12 15:29:51 +00:00
|
|
|
{
|
|
|
|
unsigned nStates = ranges_.back();
|
|
|
|
unsigned idx1 = 0;
|
|
|
|
unsigned idx2 = 0;
|
2012-03-22 11:33:24 +00:00
|
|
|
if (Globals::logDomain) {
|
2012-03-31 23:27:37 +01:00
|
|
|
while (idx1 < params_.size()) {
|
|
|
|
params_[idx2] = params_[idx1];
|
2012-03-22 11:33:24 +00:00
|
|
|
idx1 ++;
|
|
|
|
for (unsigned j = 1; j < nStates; j++) {
|
2012-03-31 23:27:37 +01:00
|
|
|
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
|
2011-12-12 15:29:51 +00:00
|
|
|
idx1 ++;
|
2011-07-22 21:33:30 +01:00
|
|
|
}
|
2012-03-22 11:33:24 +00:00
|
|
|
idx2 ++;
|
|
|
|
}
|
|
|
|
} else {
|
2012-03-31 23:27:37 +01:00
|
|
|
while (idx1 < params_.size()) {
|
|
|
|
params_[idx2] = params_[idx1];
|
2012-03-22 11:33:24 +00:00
|
|
|
idx1 ++;
|
|
|
|
for (unsigned j = 1; j < nStates; j++) {
|
2012-03-31 23:27:37 +01:00
|
|
|
params_[idx2] += params_[idx1];
|
2011-12-12 15:29:51 +00:00
|
|
|
idx1 ++;
|
|
|
|
}
|
2012-03-22 11:33:24 +00:00
|
|
|
idx2 ++;
|
|
|
|
}
|
2011-12-12 15:29:51 +00:00
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
params_.resize (idx2);
|
|
|
|
args_.pop_back();
|
2011-12-12 15:29:51 +00:00
|
|
|
ranges_.pop_back();
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
2012-03-31 23:27:37 +01:00
|
|
|
Factor::multiply (Factor& g)
|
2011-12-12 15:29:51 +00:00
|
|
|
{
|
2012-03-31 23:27:37 +01:00
|
|
|
if (args_.size() == 0) {
|
|
|
|
copyFromFactor (g);
|
2011-12-12 15:29:51 +00:00
|
|
|
return;
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
TFactor<VarId>::multiply (g);
|
2011-12-12 15:29:51 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
2012-03-31 23:27:37 +01:00
|
|
|
Factor::reorderAccordingVarIds (void)
|
2011-12-12 15:29:51 +00:00
|
|
|
{
|
2012-03-31 23:27:37 +01:00
|
|
|
VarIds sortedVarIds = args_;
|
|
|
|
sort (sortedVarIds.begin(), sortedVarIds.end());
|
|
|
|
reorderArguments (sortedVarIds);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
string
|
2011-07-22 21:33:30 +01:00
|
|
|
Factor::getLabel (void) const
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
|
|
|
stringstream ss;
|
2011-12-12 15:29:51 +00:00
|
|
|
ss << "f(" ;
|
2012-03-31 23:27:37 +01:00
|
|
|
for (unsigned i = 0; i < args_.size(); i++) {
|
2011-07-22 21:33:30 +01:00
|
|
|
if (i != 0) ss << "," ;
|
2012-03-31 23:27:37 +01:00
|
|
|
ss << VarNode (args_[i], ranges_[i]).label();
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
ss << ")" ;
|
2011-05-17 12:00:33 +01:00
|
|
|
return ss.str();
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
void
|
2012-03-22 11:33:24 +00:00
|
|
|
Factor::print (void) const
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-12-12 15:29:51 +00:00
|
|
|
VarNodes vars;
|
2012-03-31 23:27:37 +01:00
|
|
|
for (unsigned i = 0; i < args_.size(); i++) {
|
|
|
|
vars.push_back (new VarNode (args_[i], ranges_[i]));
|
2011-12-12 15:29:51 +00:00
|
|
|
}
|
|
|
|
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
2012-03-31 23:27:37 +01:00
|
|
|
for (unsigned i = 0; i < params_.size(); i++) {
|
2011-12-12 15:29:51 +00:00
|
|
|
cout << "f(" << jointStrings[i] << ")" ;
|
2012-03-31 23:27:37 +01:00
|
|
|
cout << " = " << params_[i] << endl;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2012-03-22 11:33:24 +00:00
|
|
|
cout << endl;
|
2011-12-12 15:29:51 +00:00
|
|
|
for (unsigned i = 0; i < vars.size(); i++) {
|
|
|
|
delete vars[i];
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2012-03-31 23:27:37 +01:00
|
|
|
|
|
|
|
void
|
|
|
|
Factor::copyFromFactor (const Factor& g)
|
|
|
|
{
|
|
|
|
args_ = g.arguments();
|
|
|
|
ranges_ = g.ranges();
|
|
|
|
params_ = g.params();
|
|
|
|
distId_ = g.distId();
|
|
|
|
}
|
|
|
|
|