510 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			510 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#include <cstdlib>
 | 
						|
#include <cassert>
 | 
						|
 | 
						|
#include <algorithm>
 | 
						|
 | 
						|
#include <iostream>
 | 
						|
#include <sstream>
 | 
						|
 | 
						|
#include "Factor.h"
 | 
						|
#include "StatesIndexer.h"
 | 
						|
 | 
						|
 | 
						|
Factor::Factor (const Factor& g)
 | 
						|
{
 | 
						|
  copyFromFactor (g);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
Factor::Factor (VarId vid, unsigned nStates)
 | 
						|
{
 | 
						|
  varids_.push_back (vid);
 | 
						|
  ranges_.push_back (nStates);
 | 
						|
  dist_ = new Distribution (ParamSet (nStates, 1.0));
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
Factor::Factor (const VarNodes& vars)
 | 
						|
{
 | 
						|
  int nParams = 1;
 | 
						|
  for (unsigned i = 0; i < vars.size(); i++) {
 | 
						|
    varids_.push_back (vars[i]->varId());
 | 
						|
    ranges_.push_back (vars[i]->nrStates());
 | 
						|
    nParams *= vars[i]->nrStates();
 | 
						|
  }
 | 
						|
  // create a uniform distribution
 | 
						|
  double val = 1.0 / nParams;
 | 
						|
  dist_ = new Distribution (ParamSet (nParams, val));
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
Factor::Factor (VarId vid, unsigned nStates, const ParamSet& params)
 | 
						|
{
 | 
						|
  varids_.push_back (vid);
 | 
						|
  ranges_.push_back (nStates);
 | 
						|
  dist_ = new Distribution (params);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
Factor::Factor (VarNodes& vars, Distribution* dist)
 | 
						|
{
 | 
						|
  for (unsigned i = 0; i < vars.size(); i++) {
 | 
						|
    varids_.push_back (vars[i]->varId());
 | 
						|
    ranges_.push_back (vars[i]->nrStates());
 | 
						|
  }
 | 
						|
  dist_ = dist;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
Factor::Factor (const VarNodes& vars, const ParamSet& params)
 | 
						|
{
 | 
						|
  for (unsigned i = 0; i < vars.size(); i++) {
 | 
						|
    varids_.push_back (vars[i]->varId());
 | 
						|
    ranges_.push_back (vars[i]->nrStates());
 | 
						|
  }
 | 
						|
  dist_ = new Distribution (params);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
Factor::Factor (const VarIdSet& vids,
 | 
						|
                const Ranges& ranges,
 | 
						|
                const ParamSet& params)
 | 
						|
{
 | 
						|
  varids_ = vids;
 | 
						|
  ranges_ = ranges;
 | 
						|
  dist_   = new Distribution (params);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
Factor::setParameters (const ParamSet& params)
 | 
						|
{
 | 
						|
  assert (dist_->params.size() == params.size());
 | 
						|
  dist_->updateParameters (params);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
Factor::copyFromFactor (const Factor& g)
 | 
						|
{
 | 
						|
  varids_ = g.getVarIds();
 | 
						|
  ranges_ = g.getRanges();
 | 
						|
  dist_ = new Distribution (g.getDistribution()->params);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
Factor::multiplyByFactor (const Factor& g, const vector<CptEntry>* entries)
 | 
						|
{
 | 
						|
  if (varids_.size() == 0) {
 | 
						|
    copyFromFactor (g);
 | 
						|
    return;
 | 
						|
  }
 | 
						|
 | 
						|
  const VarIdSet&  gvarids = g.getVarIds();
 | 
						|
  const Ranges&    granges = g.getRanges();
 | 
						|
  const ParamSet&  gparams = g.getParameters();
 | 
						|
 | 
						|
  if (varids_ == gvarids) {
 | 
						|
    // optimization: if the factors contain the same set of variables,
 | 
						|
    // we can do a 1 to 1 operation on the parameters
 | 
						|
    switch (NSPACE) {
 | 
						|
      case NumberSpace::NORMAL:
 | 
						|
        Util::multiply (dist_->params, gparams);
 | 
						|
        break;
 | 
						|
      case NumberSpace::LOGARITHM:
 | 
						|
        Util::add (dist_->params, gparams);
 | 
						|
    }
 | 
						|
  } else {
 | 
						|
    bool hasCommonVars = false;
 | 
						|
    vector<unsigned> gvarpos;
 | 
						|
    for (unsigned i = 0; i < gvarids.size(); i++) {
 | 
						|
      int pos = getPositionOf (gvarids[i]);
 | 
						|
      if (pos == -1) {
 | 
						|
        insertVariable (gvarids[i], granges[i]);
 | 
						|
        gvarpos.push_back (varids_.size() - 1);
 | 
						|
      } else {
 | 
						|
        hasCommonVars = true;
 | 
						|
        gvarpos.push_back (pos);
 | 
						|
      }
 | 
						|
    }
 | 
						|
    if (hasCommonVars) {
 | 
						|
      vector<unsigned> gvaroffsets (gvarids.size());
 | 
						|
      gvaroffsets[gvarids.size() - 1] = 1;
 | 
						|
      for (int i = gvarids.size() - 2; i >= 0; i--) {
 | 
						|
        gvaroffsets[i] = gvaroffsets[i + 1] * granges[i + 1];
 | 
						|
      }
 | 
						|
 | 
						|
      if (entries == 0) {
 | 
						|
        entries = &getCptEntries();
 | 
						|
      }
 | 
						|
 | 
						|
      for (unsigned i = 0; i < entries->size(); i++) {
 | 
						|
        unsigned idx = 0;
 | 
						|
        const DConf& conf = (*entries)[i].getDomainConfiguration();
 | 
						|
        for (unsigned j = 0; j < gvarpos.size(); j++) {
 | 
						|
          idx += gvaroffsets[j] * conf[ gvarpos[j] ];
 | 
						|
        }
 | 
						|
        switch (NSPACE) {
 | 
						|
          case NumberSpace::NORMAL:
 | 
						|
            dist_->params[i] *= gparams[idx];
 | 
						|
            break;
 | 
						|
          case NumberSpace::LOGARITHM:
 | 
						|
            dist_->params[i] += gparams[idx];
 | 
						|
        }
 | 
						|
      }
 | 
						|
    } else {
 | 
						|
      // optimization: if the original factors doesn't have common variables,
 | 
						|
      // we don't need to marry the states of the common variables
 | 
						|
      unsigned count = 0;
 | 
						|
      for (unsigned i = 0; i < dist_->params.size(); i++) {
 | 
						|
        switch (NSPACE) {
 | 
						|
          case NumberSpace::NORMAL:
 | 
						|
            dist_->params[i] *= gparams[count];
 | 
						|
            break;
 | 
						|
          case NumberSpace::LOGARITHM:
 | 
						|
            dist_->params[i] += gparams[count];
 | 
						|
        }
 | 
						|
        count ++;
 | 
						|
        if (count >= gparams.size()) {
 | 
						|
          count = 0;
 | 
						|
        }
 | 
						|
      }
 | 
						|
    }
 | 
						|
  }
 | 
						|
  dist_->entries.clear();
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
Factor::insertVariable (VarId vid, unsigned nStates)
 | 
						|
{
 | 
						|
  assert (getPositionOf (vid) == -1);
 | 
						|
  ParamSet newPs;
 | 
						|
  newPs.reserve (dist_->params.size() * nStates);
 | 
						|
  for (unsigned i = 0; i < dist_->params.size(); i++) {
 | 
						|
    for (unsigned j = 0; j < nStates; j++) {
 | 
						|
      newPs.push_back (dist_->params[i]);
 | 
						|
    }
 | 
						|
  }
 | 
						|
  varids_.push_back (vid);
 | 
						|
  ranges_.push_back (nStates);
 | 
						|
  dist_->updateParameters (newPs);
 | 
						|
  dist_->entries.clear();
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
Factor::removeAllVariablesExcept (VarId vid)
 | 
						|
{
 | 
						|
  assert (getPositionOf (vid) != -1);
 | 
						|
  while (varids_.back() != vid) {
 | 
						|
    removeLastVariable();
 | 
						|
  }
 | 
						|
  while (varids_.front() != vid) {
 | 
						|
    removeFirstVariable();
 | 
						|
  } 
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
Factor::removeVariable (VarId vid)
 | 
						|
{
 | 
						|
  int pos = getPositionOf (vid);
 | 
						|
  assert (pos != -1);
 | 
						|
 | 
						|
  if (vid == varids_.back()) {
 | 
						|
    removeLastVariable();  // optimization
 | 
						|
    return;
 | 
						|
  } 
 | 
						|
  if (vid == varids_.front()) {
 | 
						|
    removeFirstVariable(); // optimization
 | 
						|
    return;
 | 
						|
  } 
 | 
						|
 | 
						|
  // number of parameters separating a different state of `var', 
 | 
						|
  // with the states of the remaining variables fixed
 | 
						|
  unsigned varOffset = 1;
 | 
						|
 | 
						|
  // number of parameters separating a different state of the variable
 | 
						|
  // on the left of `var', with the states of the remaining vars fixed
 | 
						|
  unsigned leftVarOffset = 1;
 | 
						|
 | 
						|
  for (int i = varids_.size() - 1; i > pos; i--) {
 | 
						|
    varOffset     *= ranges_[i];
 | 
						|
    leftVarOffset *= ranges_[i];
 | 
						|
  }
 | 
						|
  leftVarOffset *= ranges_[pos];
 | 
						|
 | 
						|
  unsigned offset    = 0;
 | 
						|
  unsigned count1    = 0;
 | 
						|
  unsigned count2    = 0;
 | 
						|
  unsigned newPsSize = dist_->params.size() / ranges_[pos];
 | 
						|
 | 
						|
  ParamSet newPs;
 | 
						|
  newPs.reserve (newPsSize);
 | 
						|
 | 
						|
  while (newPs.size() < newPsSize) {
 | 
						|
    double sum = Util::addIdenty();
 | 
						|
    for (unsigned i = 0; i < ranges_[pos]; i++) {
 | 
						|
      switch (NSPACE) {
 | 
						|
        case NumberSpace::NORMAL:
 | 
						|
          sum += dist_->params[offset];
 | 
						|
          break;
 | 
						|
        case NumberSpace::LOGARITHM:
 | 
						|
          Util::logSum (sum, dist_->params[offset]);
 | 
						|
      }
 | 
						|
      offset += varOffset;
 | 
						|
    }
 | 
						|
    newPs.push_back (sum);
 | 
						|
    count1 ++;
 | 
						|
    if (pos == (int)varids_.size() - 1) {
 | 
						|
      offset = count1 * ranges_[pos];
 | 
						|
    } else {
 | 
						|
      if (((offset - varOffset + 1) % leftVarOffset) == 0) {
 | 
						|
        count1 = 0;
 | 
						|
        count2 ++;
 | 
						|
      }
 | 
						|
      offset = (leftVarOffset * count2) + count1;
 | 
						|
    }
 | 
						|
  }
 | 
						|
  varids_.erase (varids_.begin() + pos);
 | 
						|
  ranges_.erase (ranges_.begin() + pos);
 | 
						|
  dist_->updateParameters (newPs);
 | 
						|
  dist_->entries.clear();
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
Factor::removeFirstVariable (void)
 | 
						|
{
 | 
						|
  ParamSet& params = dist_->params;
 | 
						|
  unsigned nStates = ranges_.front();
 | 
						|
  unsigned sep = params.size() / nStates;
 | 
						|
  switch (NSPACE) {
 | 
						|
    case NumberSpace::NORMAL:
 | 
						|
      for (unsigned i = sep; i < params.size(); i++) {
 | 
						|
        params[i % sep] += params[i];
 | 
						|
      }
 | 
						|
      break;
 | 
						|
    case NumberSpace::LOGARITHM:
 | 
						|
      for (unsigned i = sep; i < params.size(); i++) {
 | 
						|
        Util::logSum (params[i % sep], params[i]);
 | 
						|
      }
 | 
						|
  }
 | 
						|
  params.resize (sep);
 | 
						|
  varids_.erase (varids_.begin());
 | 
						|
  ranges_.erase (ranges_.begin());
 | 
						|
  dist_->entries.clear();
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
Factor::removeLastVariable (void)
 | 
						|
{
 | 
						|
  ParamSet& params = dist_->params;
 | 
						|
  unsigned nStates = ranges_.back();
 | 
						|
  unsigned idx1 = 0;
 | 
						|
  unsigned idx2 = 0;
 | 
						|
  switch (NSPACE) {
 | 
						|
    case NumberSpace::NORMAL:
 | 
						|
      while (idx1 < params.size()) {
 | 
						|
        params[idx2] = params[idx1];
 | 
						|
        idx1 ++;
 | 
						|
        for (unsigned j = 1; j < nStates; j++) {
 | 
						|
          params[idx2] += params[idx1];
 | 
						|
          idx1 ++;
 | 
						|
        }
 | 
						|
        idx2 ++;
 | 
						|
      }
 | 
						|
      break;
 | 
						|
    case NumberSpace::LOGARITHM:
 | 
						|
      while (idx1 < params.size()) {
 | 
						|
        params[idx2] = params[idx1];
 | 
						|
        idx1 ++;
 | 
						|
        for (unsigned j = 1; j < nStates; j++) {
 | 
						|
          Util::logSum (params[idx2], params[idx1]);
 | 
						|
          idx1 ++;
 | 
						|
        }
 | 
						|
        idx2 ++;
 | 
						|
      }
 | 
						|
  }
 | 
						|
  params.resize (idx2);
 | 
						|
  varids_.pop_back();
 | 
						|
  ranges_.pop_back();
 | 
						|
  dist_->entries.clear();
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
Factor::orderVariables (void)
 | 
						|
{
 | 
						|
  VarIdSet sortedVarIds = varids_;
 | 
						|
  sort (sortedVarIds.begin(), sortedVarIds.end());
 | 
						|
  orderVariables (sortedVarIds);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
Factor::orderVariables (const VarIdSet& newVarIdOrder)
 | 
						|
{
 | 
						|
  assert (newVarIdOrder.size() == varids_.size());
 | 
						|
  if (newVarIdOrder == varids_) {
 | 
						|
    return;
 | 
						|
  }
 | 
						|
 | 
						|
  Ranges newRangeOrder;
 | 
						|
  for (unsigned i = 0; i < newVarIdOrder.size(); i++) {
 | 
						|
    unsigned pos = getPositionOf (newVarIdOrder[i]);
 | 
						|
    newRangeOrder.push_back (ranges_[pos]);
 | 
						|
  }
 | 
						|
 | 
						|
  vector<unsigned> positions;
 | 
						|
  for (unsigned i = 0; i < newVarIdOrder.size(); i++) {
 | 
						|
    positions.push_back (getPositionOf (newVarIdOrder[i]));
 | 
						|
  }
 | 
						|
      
 | 
						|
  unsigned N = ranges_.size();
 | 
						|
  ParamSet newPs (dist_->params.size());
 | 
						|
  for (unsigned i = 0; i < dist_->params.size(); i++) {
 | 
						|
    unsigned li = i;
 | 
						|
    // calculate vector index corresponding to linear index
 | 
						|
    vector<unsigned> vi (N);
 | 
						|
    for (int k = N-1; k >= 0; k--) {
 | 
						|
      vi[k] = li % ranges_[k];
 | 
						|
      li /= ranges_[k];
 | 
						|
    }
 | 
						|
    // convert permuted vector index to corresponding linear index
 | 
						|
    unsigned prod = 1;
 | 
						|
    unsigned new_li = 0;
 | 
						|
    for (int k = N-1; k >= 0; k--) {
 | 
						|
      new_li += vi[positions[k]] * prod;
 | 
						|
      prod *= ranges_[positions[k]];
 | 
						|
    }
 | 
						|
    newPs[new_li] = dist_->params[i];
 | 
						|
  }
 | 
						|
  varids_ = newVarIdOrder;
 | 
						|
  ranges_ = newRangeOrder;
 | 
						|
  dist_->params = newPs;
 | 
						|
  dist_->entries.clear(); 
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
Factor::removeInconsistentEntries (VarId vid, unsigned evidence)
 | 
						|
{
 | 
						|
  int pos = getPositionOf (vid);
 | 
						|
  assert (pos != -1);
 | 
						|
  ParamSet newPs;
 | 
						|
  newPs.reserve (dist_->params.size() / ranges_[pos]);
 | 
						|
  StatesIndexer idx (ranges_);
 | 
						|
  for (unsigned i = 0; i < evidence; i++) {
 | 
						|
    idx.incrementState (pos);
 | 
						|
  }
 | 
						|
  while (idx.valid()) {
 | 
						|
    newPs.push_back (dist_->params[idx.getLinearIndex()]);
 | 
						|
    idx.nextSameState (pos);
 | 
						|
  }
 | 
						|
  varids_.erase (varids_.begin() + pos);
 | 
						|
  ranges_.erase (ranges_.begin() + pos);
 | 
						|
  dist_->updateParameters (newPs);
 | 
						|
  dist_->entries.clear();
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
string
 | 
						|
Factor::getLabel (void) const
 | 
						|
{
 | 
						|
  stringstream ss;
 | 
						|
  ss << "f(" ;
 | 
						|
  for (unsigned i = 0; i < varids_.size(); i++) {
 | 
						|
    if (i != 0) ss << "," ;
 | 
						|
    ss << VarNode (varids_[i], ranges_[i]).label();
 | 
						|
  }
 | 
						|
  ss << ")" ;
 | 
						|
  return ss.str();
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
Factor::printFactor (void) const
 | 
						|
{
 | 
						|
  VarNodes vars;
 | 
						|
  for (unsigned i = 0; i < varids_.size(); i++) {
 | 
						|
    vars.push_back (new VarNode (varids_[i], ranges_[i]));
 | 
						|
  }
 | 
						|
  vector<string> jointStrings = Util::getJointStateStrings (vars);
 | 
						|
  for (unsigned i = 0; i < dist_->params.size(); i++) {
 | 
						|
    cout << "f(" << jointStrings[i] << ")" ;
 | 
						|
    cout << " = " << dist_->params[i] << endl;
 | 
						|
  }
 | 
						|
  for (unsigned i = 0; i < vars.size(); i++) {
 | 
						|
    delete vars[i];
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
int
 | 
						|
Factor::getPositionOf (VarId vid) const
 | 
						|
{
 | 
						|
  for (unsigned i = 0; i < varids_.size(); i++) {
 | 
						|
    if (varids_[i] == vid) {
 | 
						|
      return i;
 | 
						|
    }
 | 
						|
  }
 | 
						|
  return -1;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
const vector<CptEntry>&
 | 
						|
Factor::getCptEntries (void) const
 | 
						|
{
 | 
						|
  if (dist_->entries.size() == 0) {
 | 
						|
    vector<DConf> confs (dist_->params.size());
 | 
						|
    for (unsigned i = 0; i < dist_->params.size(); i++) {
 | 
						|
      confs[i].resize (varids_.size());
 | 
						|
    }
 | 
						|
    unsigned nReps = 1;
 | 
						|
    for (int i = varids_.size() - 1; i >= 0; i--) {
 | 
						|
      unsigned index = 0;
 | 
						|
      while (index < dist_->params.size()) {
 | 
						|
        for (unsigned j = 0; j < ranges_[i]; j++) {
 | 
						|
          for (unsigned r = 0; r < nReps; r++) {
 | 
						|
            confs[index][i] = j;
 | 
						|
            index++;
 | 
						|
          }
 | 
						|
        }
 | 
						|
      }
 | 
						|
      nReps *= ranges_[i];
 | 
						|
    }
 | 
						|
    dist_->entries.clear();
 | 
						|
    dist_->entries.reserve (dist_->params.size());
 | 
						|
    for (unsigned i = 0; i < dist_->params.size(); i++) {
 | 
						|
      dist_->entries.push_back (CptEntry (i, confs[i]));
 | 
						|
    }
 | 
						|
  }
 | 
						|
  return dist_->entries;
 | 
						|
}
 | 
						|
 |