2011-05-17 12:00:33 +01:00
|
|
|
#include <cstdlib>
|
|
|
|
#include <cassert>
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
#include <iostream>
|
|
|
|
#include <sstream>
|
|
|
|
|
2011-05-17 12:00:33 +01:00
|
|
|
#include "Factor.h"
|
|
|
|
#include "FgVarNode.h"
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
Factor::Factor (const Factor& g)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
copyFactor (g);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
Factor::Factor (FgVarNode* var)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
Factor (FgVarSet() = {var});
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
Factor::Factor (const FgVarSet& vars)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
vars_ = vars;
|
|
|
|
int nParams = 1;
|
|
|
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
|
|
|
nParams *= vars_[i]->getDomainSize();
|
|
|
|
}
|
|
|
|
// create a uniform distribution
|
|
|
|
double val = 1.0 / nParams;
|
|
|
|
dist_ = new Distribution (ParamSet (nParams, val));
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
Factor::Factor (FgVarNode* var,
|
|
|
|
const ParamSet& params)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
vars_.push_back (var);
|
|
|
|
dist_ = new Distribution (params);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
Factor::Factor (FgVarSet& vars,
|
|
|
|
Distribution* dist)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
vars_ = vars;
|
|
|
|
dist_ = dist;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
Factor::Factor (const FgVarSet& vars,
|
|
|
|
const ParamSet& params)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
vars_ = vars;
|
|
|
|
dist_ = new Distribution (params);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
Factor::setParameters (const ParamSet& params)
|
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
assert (dist_->params.size() == params.size());
|
|
|
|
dist_->updateParameters (params);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
void
|
|
|
|
Factor::copyFactor (const Factor& g)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
vars_ = g.getFgVarNodes();
|
|
|
|
dist_ = new Distribution (g.getDistribution()->params);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
void
|
|
|
|
Factor::multiplyByFactor (const Factor& g, const vector<CptEntry>* entries)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
if (vars_.size() == 0) {
|
|
|
|
copyFactor (g);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
const FgVarSet& gVs = g.getFgVarNodes();
|
2011-05-17 12:00:33 +01:00
|
|
|
const ParamSet& gPs = g.getParameters();
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
bool factorsAreEqual = true;
|
|
|
|
if (gVs.size() == vars_.size()) {
|
|
|
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
|
|
|
if (gVs[i] != vars_[i]) {
|
|
|
|
factorsAreEqual = false;
|
|
|
|
break;
|
|
|
|
}
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
} else {
|
|
|
|
factorsAreEqual = false;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
|
|
|
|
if (factorsAreEqual) {
|
|
|
|
// optimization: if the factors contain the same set of variables,
|
|
|
|
// we can do 1 to 1 operations on the parameteres
|
|
|
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
|
|
|
dist_->params[i] *= gPs[i];
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
} else {
|
|
|
|
bool hasCommonVars = false;
|
|
|
|
vector<unsigned> gVsIndexes;
|
|
|
|
for (unsigned i = 0; i < gVs.size(); i++) {
|
|
|
|
int idx = getIndexOf (gVs[i]);
|
|
|
|
if (idx == -1) {
|
|
|
|
insertVariable (gVs[i]);
|
|
|
|
gVsIndexes.push_back (vars_.size() - 1);
|
|
|
|
} else {
|
|
|
|
hasCommonVars = true;
|
|
|
|
gVsIndexes.push_back (idx);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
if (hasCommonVars) {
|
|
|
|
vector<unsigned> gVsOffsets (gVs.size());
|
|
|
|
gVsOffsets[gVs.size() - 1] = 1;
|
|
|
|
for (int i = gVs.size() - 2; i >= 0; i--) {
|
|
|
|
gVsOffsets[i] = gVsOffsets[i + 1] * gVs[i + 1]->getDomainSize();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (entries == 0) {
|
|
|
|
entries = &getCptEntries();
|
|
|
|
}
|
|
|
|
|
|
|
|
for (unsigned i = 0; i < entries->size(); i++) {
|
|
|
|
unsigned idx = 0;
|
|
|
|
const DConf& conf = (*entries)[i].getDomainConfiguration();
|
|
|
|
for (unsigned j = 0; j < gVsIndexes.size(); j++) {
|
|
|
|
idx += gVsOffsets[j] * conf[ gVsIndexes[j] ];
|
|
|
|
}
|
|
|
|
dist_->params[i] = dist_->params[i] * gPs[idx];
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// optimization: if the original factors doesn't have common variables,
|
|
|
|
// we don't need to marry the states of the common variables
|
|
|
|
unsigned count = 0;
|
|
|
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
|
|
|
dist_->params[i] *= gPs[count];
|
|
|
|
count ++;
|
|
|
|
if (count >= gPs.size()) {
|
|
|
|
count = 0;
|
|
|
|
}
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
Factor::insertVariable (FgVarNode* var)
|
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
assert (getIndexOf (var) == -1);
|
|
|
|
ParamSet newPs;
|
|
|
|
newPs.reserve (dist_->params.size() * var->getDomainSize());
|
|
|
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
|
|
|
for (unsigned j = 0; j < var->getDomainSize(); j++) {
|
|
|
|
newPs.push_back (dist_->params[i]);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
vars_.push_back (var);
|
|
|
|
dist_->updateParameters (newPs);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
2011-07-22 21:33:30 +01:00
|
|
|
Factor::removeVariable (const FgVarNode* var)
|
|
|
|
{
|
2011-05-17 12:00:33 +01:00
|
|
|
int varIndex = getIndexOf (var);
|
2011-07-22 21:33:30 +01:00
|
|
|
assert (varIndex >= 0 && varIndex < (int)vars_.size());
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
// number of parameters separating a different state of `var',
|
|
|
|
// with the states of the remaining variables fixed
|
|
|
|
unsigned varOffset = 1;
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
// number of parameters separating a different state of the variable
|
|
|
|
// on the left of `var', with the states of the remaining vars fixed
|
|
|
|
unsigned leftVarOffset = 1;
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
for (int i = vars_.size() - 1; i > varIndex; i--) {
|
|
|
|
varOffset *= vars_[i]->getDomainSize();
|
|
|
|
leftVarOffset *= vars_[i]->getDomainSize();
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
leftVarOffset *= vars_[varIndex]->getDomainSize();
|
|
|
|
|
|
|
|
unsigned offset = 0;
|
|
|
|
unsigned count1 = 0;
|
|
|
|
unsigned count2 = 0;
|
|
|
|
unsigned newPsSize = dist_->params.size() / vars_[varIndex]->getDomainSize();
|
|
|
|
|
2011-05-17 12:00:33 +01:00
|
|
|
ParamSet newPs;
|
|
|
|
newPs.reserve (newPsSize);
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
// stringstream ss;
|
|
|
|
// ss << "marginalizing " << vars_[varIndex]->getLabel();
|
|
|
|
// ss << " from factor " << getLabel() << endl;
|
2011-05-17 12:00:33 +01:00
|
|
|
while (newPs.size() < newPsSize) {
|
2011-07-22 21:33:30 +01:00
|
|
|
// ss << " sum = ";
|
2011-05-17 12:00:33 +01:00
|
|
|
double sum = 0.0;
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned i = 0; i < vars_[varIndex]->getDomainSize(); i++) {
|
|
|
|
// if (i != 0) ss << " + ";
|
|
|
|
// ss << dist_->params[offset];
|
|
|
|
sum += dist_->params[offset];
|
|
|
|
offset += varOffset;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
newPs.push_back (sum);
|
2011-07-22 21:33:30 +01:00
|
|
|
count1 ++;
|
|
|
|
if (varIndex == (int)vars_.size() - 1) {
|
|
|
|
offset = count1 * vars_[varIndex]->getDomainSize();
|
2011-05-17 12:00:33 +01:00
|
|
|
} else {
|
2011-07-22 21:33:30 +01:00
|
|
|
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
|
|
|
count1 = 0;
|
|
|
|
count2 ++;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
offset = (leftVarOffset * count2) + count1;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
// ss << " = " << sum << endl;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
// cout << ss.str() << endl;
|
|
|
|
vars_.erase (vars_.begin() + varIndex);
|
|
|
|
dist_->updateParameters (newPs);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
const vector<CptEntry>&
|
|
|
|
Factor::getCptEntries (void) const
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
if (dist_->entries.size() == 0) {
|
|
|
|
vector<DConf> confs (dist_->params.size());
|
|
|
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
|
|
|
confs[i].resize (vars_.size());
|
|
|
|
}
|
|
|
|
|
|
|
|
unsigned nReps = 1;
|
|
|
|
for (int i = vars_.size() - 1; i >= 0; i--) {
|
|
|
|
unsigned index = 0;
|
|
|
|
while (index < dist_->params.size()) {
|
|
|
|
for (unsigned j = 0; j < vars_[i]->getDomainSize(); j++) {
|
|
|
|
for (unsigned r = 0; r < nReps; r++) {
|
|
|
|
confs[index][i] = j;
|
|
|
|
index++;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
nReps *= vars_[i]->getDomainSize();
|
|
|
|
}
|
|
|
|
dist_->entries.clear();
|
|
|
|
dist_->entries.reserve (dist_->params.size());
|
|
|
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
|
|
|
dist_->entries.push_back (CptEntry (i, confs[i]));
|
|
|
|
}
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
return dist_->entries;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
string
|
2011-07-22 21:33:30 +01:00
|
|
|
Factor::getLabel (void) const
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
|
|
|
stringstream ss;
|
2011-07-22 21:33:30 +01:00
|
|
|
ss << "Φ(" ;
|
|
|
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
|
|
|
if (i != 0) ss << "," ;
|
|
|
|
ss << vars_[i]->getLabel();
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
ss << ")" ;
|
2011-05-17 12:00:33 +01:00
|
|
|
return ss.str();
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
void
|
|
|
|
Factor::printFactor (void)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
stringstream ss;
|
|
|
|
ss << getLabel() << endl;
|
|
|
|
ss << "--------------------" << endl;
|
|
|
|
VarSet vs;
|
|
|
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
|
|
|
vs.push_back (vars_[i]);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
vector<string> domainConfs = Util::getInstantiations (vs);
|
|
|
|
const vector<CptEntry>& entries = getCptEntries();
|
|
|
|
for (unsigned i = 0; i < entries.size(); i++) {
|
|
|
|
ss << "Φ(" << domainConfs[i] << ")" ;
|
|
|
|
unsigned idx = entries[i].getParameterIndex();
|
|
|
|
ss << " = " << dist_->params[idx] << endl;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
cout << ss.str();
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int
|
|
|
|
Factor::getIndexOf (const FgVarNode* var) const
|
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
|
|
|
if (vars_[i] == var) {
|
2011-05-17 12:00:33 +01:00
|
|
|
return i;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return -1;
|
|
|
|
}
|
|
|
|
|