This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/bp/Factor.cpp

510 lines
12 KiB
C++
Raw Normal View History

#include <cstdlib>
#include <cassert>
2011-12-12 15:29:51 +00:00
#include <algorithm>
#include <iostream>
#include <sstream>
#include "Factor.h"
2011-12-12 15:29:51 +00:00
#include "StatesIndexer.h"
Factor::Factor (const Factor& g)
{
2011-12-12 15:29:51 +00:00
copyFromFactor (g);
}
2011-12-12 15:29:51 +00:00
Factor::Factor (VarId vid, unsigned nStates)
{
2011-12-12 15:29:51 +00:00
varids_.push_back (vid);
ranges_.push_back (nStates);
dist_ = new Distribution (ParamSet (nStates, 1.0));
}
2011-12-12 15:29:51 +00:00
Factor::Factor (const VarNodes& vars)
{
int nParams = 1;
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < vars.size(); i++) {
varids_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
nParams *= vars[i]->nrStates();
}
// create a uniform distribution
double val = 1.0 / nParams;
dist_ = new Distribution (ParamSet (nParams, val));
}
2011-12-12 15:29:51 +00:00
Factor::Factor (VarId vid, unsigned nStates, const ParamSet& params)
{
2011-12-12 15:29:51 +00:00
varids_.push_back (vid);
ranges_.push_back (nStates);
dist_ = new Distribution (params);
}
2011-12-12 15:29:51 +00:00
Factor::Factor (VarNodes& vars, Distribution* dist)
{
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < vars.size(); i++) {
varids_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
}
dist_ = dist;
}
2011-12-12 15:29:51 +00:00
Factor::Factor (const VarNodes& vars, const ParamSet& params)
{
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < vars.size(); i++) {
varids_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
}
dist_ = new Distribution (params);
}
2011-12-12 15:29:51 +00:00
Factor::Factor (const VarIdSet& vids,
const Ranges& ranges,
const ParamSet& params)
{
varids_ = vids;
ranges_ = ranges;
dist_ = new Distribution (params);
}
void
Factor::setParameters (const ParamSet& params)
{
assert (dist_->params.size() == params.size());
dist_->updateParameters (params);
}
void
2011-12-12 15:29:51 +00:00
Factor::copyFromFactor (const Factor& g)
{
2011-12-12 15:29:51 +00:00
varids_ = g.getVarIds();
ranges_ = g.getRanges();
dist_ = new Distribution (g.getDistribution()->params);
}
void
Factor::multiplyByFactor (const Factor& g, const vector<CptEntry>* entries)
{
2011-12-12 15:29:51 +00:00
if (varids_.size() == 0) {
copyFromFactor (g);
return;
}
2011-12-12 15:29:51 +00:00
const VarIdSet& gvarids = g.getVarIds();
const Ranges& granges = g.getRanges();
const ParamSet& gparams = g.getParameters();
2011-12-12 15:29:51 +00:00
if (varids_ == gvarids) {
// optimization: if the factors contain the same set of variables,
2011-12-12 15:29:51 +00:00
// we can do a 1 to 1 operation on the parameters
switch (NSPACE) {
case NumberSpace::NORMAL:
Util::multiply (dist_->params, gparams);
break;
case NumberSpace::LOGARITHM:
Util::add (dist_->params, gparams);
}
} else {
bool hasCommonVars = false;
2011-12-12 15:29:51 +00:00
vector<unsigned> gvarpos;
for (unsigned i = 0; i < gvarids.size(); i++) {
int pos = getPositionOf (gvarids[i]);
if (pos == -1) {
insertVariable (gvarids[i], granges[i]);
gvarpos.push_back (varids_.size() - 1);
} else {
hasCommonVars = true;
2011-12-12 15:29:51 +00:00
gvarpos.push_back (pos);
}
}
if (hasCommonVars) {
2011-12-12 15:29:51 +00:00
vector<unsigned> gvaroffsets (gvarids.size());
gvaroffsets[gvarids.size() - 1] = 1;
for (int i = gvarids.size() - 2; i >= 0; i--) {
gvaroffsets[i] = gvaroffsets[i + 1] * granges[i + 1];
}
if (entries == 0) {
entries = &getCptEntries();
}
for (unsigned i = 0; i < entries->size(); i++) {
unsigned idx = 0;
const DConf& conf = (*entries)[i].getDomainConfiguration();
2011-12-12 15:29:51 +00:00
for (unsigned j = 0; j < gvarpos.size(); j++) {
idx += gvaroffsets[j] * conf[ gvarpos[j] ];
}
switch (NSPACE) {
case NumberSpace::NORMAL:
dist_->params[i] *= gparams[idx];
break;
case NumberSpace::LOGARITHM:
dist_->params[i] += gparams[idx];
}
}
} else {
// optimization: if the original factors doesn't have common variables,
// we don't need to marry the states of the common variables
unsigned count = 0;
for (unsigned i = 0; i < dist_->params.size(); i++) {
2011-12-12 15:29:51 +00:00
switch (NSPACE) {
case NumberSpace::NORMAL:
dist_->params[i] *= gparams[count];
break;
case NumberSpace::LOGARITHM:
dist_->params[i] += gparams[count];
}
count ++;
2011-12-12 15:29:51 +00:00
if (count >= gparams.size()) {
count = 0;
}
}
}
}
2011-12-12 15:29:51 +00:00
dist_->entries.clear();
}
void
2011-12-12 15:29:51 +00:00
Factor::insertVariable (VarId vid, unsigned nStates)
{
2011-12-12 15:29:51 +00:00
assert (getPositionOf (vid) == -1);
ParamSet newPs;
2011-12-12 15:29:51 +00:00
newPs.reserve (dist_->params.size() * nStates);
for (unsigned i = 0; i < dist_->params.size(); i++) {
2011-12-12 15:29:51 +00:00
for (unsigned j = 0; j < nStates; j++) {
newPs.push_back (dist_->params[i]);
}
}
2011-12-12 15:29:51 +00:00
varids_.push_back (vid);
ranges_.push_back (nStates);
dist_->updateParameters (newPs);
2011-12-12 15:29:51 +00:00
dist_->entries.clear();
}
void
2011-12-12 15:29:51 +00:00
Factor::removeAllVariablesExcept (VarId vid)
{
2011-12-12 15:29:51 +00:00
assert (getPositionOf (vid) != -1);
while (varids_.back() != vid) {
removeLastVariable();
}
while (varids_.front() != vid) {
removeFirstVariable();
}
}
void
Factor::removeVariable (VarId vid)
{
int pos = getPositionOf (vid);
assert (pos != -1);
if (vid == varids_.back()) {
removeLastVariable(); // optimization
return;
}
if (vid == varids_.front()) {
removeFirstVariable(); // optimization
return;
}
// number of parameters separating a different state of `var',
// with the states of the remaining variables fixed
unsigned varOffset = 1;
// number of parameters separating a different state of the variable
// on the left of `var', with the states of the remaining vars fixed
unsigned leftVarOffset = 1;
2011-12-12 15:29:51 +00:00
for (int i = varids_.size() - 1; i > pos; i--) {
varOffset *= ranges_[i];
leftVarOffset *= ranges_[i];
}
2011-12-12 15:29:51 +00:00
leftVarOffset *= ranges_[pos];
unsigned offset = 0;
unsigned count1 = 0;
unsigned count2 = 0;
2011-12-12 15:29:51 +00:00
unsigned newPsSize = dist_->params.size() / ranges_[pos];
ParamSet newPs;
newPs.reserve (newPsSize);
while (newPs.size() < newPsSize) {
2011-12-12 15:29:51 +00:00
double sum = Util::addIdenty();
for (unsigned i = 0; i < ranges_[pos]; i++) {
switch (NSPACE) {
case NumberSpace::NORMAL:
sum += dist_->params[offset];
break;
case NumberSpace::LOGARITHM:
Util::logSum (sum, dist_->params[offset]);
}
offset += varOffset;
}
newPs.push_back (sum);
count1 ++;
2011-12-12 15:29:51 +00:00
if (pos == (int)varids_.size() - 1) {
offset = count1 * ranges_[pos];
} else {
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
count1 = 0;
count2 ++;
}
offset = (leftVarOffset * count2) + count1;
}
}
2011-12-12 15:29:51 +00:00
varids_.erase (varids_.begin() + pos);
ranges_.erase (ranges_.begin() + pos);
dist_->updateParameters (newPs);
2011-12-12 15:29:51 +00:00
dist_->entries.clear();
}
2011-12-12 15:29:51 +00:00
void
Factor::removeFirstVariable (void)
{
2011-12-12 15:29:51 +00:00
ParamSet& params = dist_->params;
unsigned nStates = ranges_.front();
unsigned sep = params.size() / nStates;
switch (NSPACE) {
case NumberSpace::NORMAL:
for (unsigned i = sep; i < params.size(); i++) {
params[i % sep] += params[i];
}
break;
case NumberSpace::LOGARITHM:
for (unsigned i = sep; i < params.size(); i++) {
Util::logSum (params[i % sep], params[i]);
}
}
params.resize (sep);
varids_.erase (varids_.begin());
ranges_.erase (ranges_.begin());
dist_->entries.clear();
}
2011-12-12 15:29:51 +00:00
void
Factor::removeLastVariable (void)
{
ParamSet& params = dist_->params;
unsigned nStates = ranges_.back();
unsigned idx1 = 0;
unsigned idx2 = 0;
switch (NSPACE) {
case NumberSpace::NORMAL:
while (idx1 < params.size()) {
params[idx2] = params[idx1];
idx1 ++;
for (unsigned j = 1; j < nStates; j++) {
params[idx2] += params[idx1];
idx1 ++;
}
2011-12-12 15:29:51 +00:00
idx2 ++;
}
2011-12-12 15:29:51 +00:00
break;
case NumberSpace::LOGARITHM:
while (idx1 < params.size()) {
params[idx2] = params[idx1];
idx1 ++;
for (unsigned j = 1; j < nStates; j++) {
Util::logSum (params[idx2], params[idx1]);
idx1 ++;
}
idx2 ++;
}
}
params.resize (idx2);
varids_.pop_back();
ranges_.pop_back();
dist_->entries.clear();
}
void
Factor::orderVariables (void)
{
VarIdSet sortedVarIds = varids_;
sort (sortedVarIds.begin(), sortedVarIds.end());
orderVariables (sortedVarIds);
}
void
Factor::orderVariables (const VarIdSet& newVarIdOrder)
{
assert (newVarIdOrder.size() == varids_.size());
if (newVarIdOrder == varids_) {
return;
}
Ranges newRangeOrder;
for (unsigned i = 0; i < newVarIdOrder.size(); i++) {
unsigned pos = getPositionOf (newVarIdOrder[i]);
newRangeOrder.push_back (ranges_[pos]);
}
vector<unsigned> positions;
for (unsigned i = 0; i < newVarIdOrder.size(); i++) {
positions.push_back (getPositionOf (newVarIdOrder[i]));
}
unsigned N = ranges_.size();
ParamSet newPs (dist_->params.size());
for (unsigned i = 0; i < dist_->params.size(); i++) {
unsigned li = i;
// calculate vector index corresponding to linear index
vector<unsigned> vi (N);
for (int k = N-1; k >= 0; k--) {
vi[k] = li % ranges_[k];
li /= ranges_[k];
}
2011-12-12 15:29:51 +00:00
// convert permuted vector index to corresponding linear index
unsigned prod = 1;
unsigned new_li = 0;
for (int k = N-1; k >= 0; k--) {
new_li += vi[positions[k]] * prod;
prod *= ranges_[positions[k]];
}
2011-12-12 15:29:51 +00:00
newPs[new_li] = dist_->params[i];
}
2011-12-12 15:29:51 +00:00
varids_ = newVarIdOrder;
ranges_ = newRangeOrder;
dist_->params = newPs;
dist_->entries.clear();
}
void
Factor::removeInconsistentEntries (VarId vid, unsigned evidence)
{
int pos = getPositionOf (vid);
assert (pos != -1);
ParamSet newPs;
newPs.reserve (dist_->params.size() / ranges_[pos]);
StatesIndexer idx (ranges_);
for (unsigned i = 0; i < evidence; i++) {
idx.incrementState (pos);
}
while (idx.valid()) {
newPs.push_back (dist_->params[idx.getLinearIndex()]);
idx.nextSameState (pos);
}
varids_.erase (varids_.begin() + pos);
ranges_.erase (ranges_.begin() + pos);
dist_->updateParameters (newPs);
dist_->entries.clear();
}
string
Factor::getLabel (void) const
{
stringstream ss;
2011-12-12 15:29:51 +00:00
ss << "f(" ;
for (unsigned i = 0; i < varids_.size(); i++) {
if (i != 0) ss << "," ;
2011-12-12 15:29:51 +00:00
ss << VarNode (varids_[i], ranges_[i]).label();
}
ss << ")" ;
return ss.str();
}
void
2011-12-12 15:29:51 +00:00
Factor::printFactor (void) const
{
2011-12-12 15:29:51 +00:00
VarNodes vars;
for (unsigned i = 0; i < varids_.size(); i++) {
vars.push_back (new VarNode (varids_[i], ranges_[i]));
}
vector<string> jointStrings = Util::getJointStateStrings (vars);
for (unsigned i = 0; i < dist_->params.size(); i++) {
cout << "f(" << jointStrings[i] << ")" ;
cout << " = " << dist_->params[i] << endl;
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < vars.size(); i++) {
delete vars[i];
}
}
int
2011-12-12 15:29:51 +00:00
Factor::getPositionOf (VarId vid) const
{
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < varids_.size(); i++) {
if (varids_[i] == vid) {
return i;
}
}
return -1;
}
2011-12-12 15:29:51 +00:00
const vector<CptEntry>&
Factor::getCptEntries (void) const
{
if (dist_->entries.size() == 0) {
vector<DConf> confs (dist_->params.size());
for (unsigned i = 0; i < dist_->params.size(); i++) {
confs[i].resize (varids_.size());
}
unsigned nReps = 1;
for (int i = varids_.size() - 1; i >= 0; i--) {
unsigned index = 0;
while (index < dist_->params.size()) {
for (unsigned j = 0; j < ranges_[i]; j++) {
for (unsigned r = 0; r < nReps; r++) {
confs[index][i] = j;
index++;
}
}
}
nReps *= ranges_[i];
}
dist_->entries.clear();
dist_->entries.reserve (dist_->params.size());
for (unsigned i = 0; i < dist_->params.size(); i++) {
dist_->entries.push_back (CptEntry (i, confs[i]));
}
}
return dist_->entries;
}