refactor indexer classes and receive te ranges as a constant reference

This commit is contained in:
Tiago Gomes 2012-05-25 21:16:08 +01:00
parent 5ff161b10f
commit 2efca0c85a
7 changed files with 71 additions and 80 deletions

View File

@ -399,18 +399,18 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
observedVars.push_back (fg->getVarNode (observedVids[j])); observedVars.push_back (fg->getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range()); observedRanges.push_back (observedVars.back()->range());
} }
StatesIndexer idx (observedRanges, false); Indexer indexer (observedRanges, false);
while (idx.valid()) { while (indexer.valid()) {
for (size_t j = 0; j < observedVars.size(); j++) { for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (idx[j]); observedVars[j]->setEvidence (indexer[j]);
} }
++ idx;
BpSolver solver (*fg); BpSolver solver (*fg);
solver.runSolver(); solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]); Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (size_t k = 0; k < beliefs.size(); k++) { for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]); newBeliefs.push_back (beliefs[k]);
} }
++ indexer;
} }
int count = -1; int count = -1;

View File

@ -219,7 +219,7 @@ Factor::sumOutArgs (const vector<bool>& mask)
Params newps (new_size, LogAware::addIdenty()); Params newps (new_size, LogAware::addIdenty());
Params::const_iterator first = params_.begin(); Params::const_iterator first = params_.begin();
Params::const_iterator last = params_.end(); Params::const_iterator last = params_.end();
MappingIndexer indexer (oldRanges, mask); CutIndexer indexer (oldRanges, mask);
if (Globals::logDomain) { if (Globals::logDomain) {
while (first != last) { while (first != last) {
newps[indexer] = Util::logSum (newps[indexer], *first++); newps[indexer] = Util::logSum (newps[indexer], *first++);

View File

@ -109,7 +109,7 @@ class TFactor
} }
} }
} else { } else {
StatesIndexer indexer (ranges_, false); Indexer indexer (ranges_, false);
while (indexer.valid()) { while (indexer.valid()) {
size_t g_li = 0; size_t g_li = 0;
size_t prod = 1; size_t prod = 1;
@ -136,7 +136,7 @@ class TFactor
Params newps (new_size, LogAware::addIdenty()); Params newps (new_size, LogAware::addIdenty());
Params::const_iterator first = params_.begin(); Params::const_iterator first = params_.begin();
Params::const_iterator last = params_.end(); Params::const_iterator last = params_.end();
MappingIndexer indexer (ranges_, idx); CutIndexer indexer (ranges_, idx);
if (Globals::logDomain) { if (Globals::logDomain) {
while (first != last) { while (first != last) {
newps[indexer] = Util::logSum (newps[indexer], *first++); newps[indexer] = Util::logSum (newps[indexer], *first++);
@ -161,13 +161,13 @@ class TFactor
Params copy = params_; Params copy = params_;
params_.clear(); params_.clear();
params_.reserve (copy.size() / ranges_[idx]); params_.reserve (copy.size() / ranges_[idx]);
StatesIndexer indexer (ranges_); Indexer indexer (ranges_);
for (unsigned i = 0; i < evidence; i++) { for (unsigned i = 0; i < evidence; i++) {
indexer.increment (idx); indexer.incrementDimension (idx);
} }
while (indexer.valid()) { while (indexer.valid()) {
params_.push_back (copy[indexer]); params_.push_back (copy[indexer]);
indexer.incrementExcluding (idx); indexer.incrementExceptDimension (idx);
} }
args_.erase (args_.begin() + idx); args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx); ranges_.erase (ranges_.begin() + idx);

View File

@ -1,29 +1,23 @@
#ifndef HORUS_STATESINDEXER_H #ifndef HORUS_INDEXER_H
#define HORUS_STATESINDEXER_H #define HORUS_INDEXER_H
#include <algorithm> #include <algorithm>
#include <numeric>
#include <functional> #include <functional>
#include <sstream> #include <sstream>
#include <iomanip> #include <iomanip>
#include "Var.h"
#include "Util.h" #include "Util.h"
class Indexer
class StatesIndexer
{ {
public: public:
Indexer (const Ranges& ranges, bool calcOffsets = true)
StatesIndexer (const Ranges& ranges, bool calcOffsets = true) : index_(0), indices_(ranges.size(), 0), ranges_(ranges)
{ {
li_ = 0;
size_ = std::accumulate (ranges.begin(), ranges.end(), 1, size_ = std::accumulate (ranges.begin(), ranges.end(), 1,
std::multiplies<unsigned>()); std::multiplies<unsigned>());
indices_.resize (ranges.size(), 0);
ranges_ = ranges;
if (calcOffsets) { if (calcOffsets) {
calculateOffsets(); calculateOffsets();
} }
@ -39,47 +33,37 @@ class StatesIndexer
indices_[i] = 0; indices_[i] = 0;
} }
} }
li_ ++; index_ ++;
} }
void increment (size_t dim) void incrementDimension (size_t dim)
{ {
assert (dim < ranges_.size()); assert (dim < ranges_.size());
assert (ranges_.size() == offsets_.size()); assert (ranges_.size() == offsets_.size());
assert (indices_[dim] < ranges_[dim]); assert (indices_[dim] < ranges_[dim]);
indices_[dim] ++; indices_[dim] ++;
li_ += offsets_[dim]; index_ += offsets_[dim];
} }
void incrementExcluding (size_t skipDim) void incrementExceptDimension (size_t dim)
{ {
assert (ranges_.size() == offsets_.size()); assert (ranges_.size() == offsets_.size());
for (size_t i = ranges_.size(); i-- > 0; ) { for (size_t i = ranges_.size(); i-- > 0; ) {
if (i != (int)skipDim) { if (i != dim) {
indices_[i] ++; indices_[i] ++;
li_ += offsets_[i]; index_ += offsets_[i];
if (indices_[i] != ranges_[i]) { if (indices_[i] != ranges_[i]) {
return; return;
} else { } else {
indices_[i] = 0; indices_[i] = 0;
li_ -= offsets_[i] * ranges_[i]; index_ -= offsets_[i] * ranges_[i];
} }
} }
} }
li_ = size_; index_ = size_;
} }
size_t linearIndex (void) const Indexer& operator++ (void)
{
return li_;
}
const vector<unsigned>& indices (void) const
{
return indices_;
}
StatesIndexer& operator ++ (void)
{ {
increment(); increment();
return *this; return *this;
@ -87,7 +71,7 @@ class StatesIndexer
operator size_t (void) const operator size_t (void) const
{ {
return li_; return index_;
} }
unsigned operator[] (size_t dim) const unsigned operator[] (size_t dim) const
@ -99,19 +83,19 @@ class StatesIndexer
bool valid (void) const bool valid (void) const
{ {
return li_ < size_; return index_ < size_;
} }
void reset (void) void reset (void)
{ {
std::fill (indices_.begin(), indices_.end(), 0); std::fill (indices_.begin(), indices_.end(), 0);
li_ = 0; index_ = 0;
} }
void reset (size_t dim) void resetDimension (size_t dim)
{ {
indices_[dim] = 0; indices_[dim] = 0;
li_ -= offsets_[dim] * ranges_[dim]; index_ -= offsets_[dim] * ranges_[dim];
} }
size_t size (void) const size_t size (void) const
@ -119,12 +103,7 @@ class StatesIndexer
return size_ ; return size_ ;
} }
friend ostream& operator<< (ostream &os, const StatesIndexer& idx) friend std::ostream& operator<< (std::ostream&, const Indexer&);
{
os << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
os << idx.indices_;
return os;
}
private: private:
void calculateOffsets (void) void calculateOffsets (void)
@ -137,19 +116,31 @@ class StatesIndexer
} }
} }
size_t li_; size_t index_;
size_t size_; Ranges indices_;
vector<unsigned> indices_; const Ranges& ranges_;
vector<unsigned> ranges_; size_t size_;
vector<size_t> offsets_; vector<size_t> offsets_;
}; };
class MappingIndexer inline std::ostream&
operator<< (std::ostream& os, const Indexer& indexer)
{
os << "(" ;
os << std::setw (2) << std::setfill('0') << indexer.index_;
os << ") " ;
os << indexer.indices_;
return os;
}
class CutIndexer
{ {
public: public:
MappingIndexer (const Ranges& ranges, const vector<bool>& mask) CutIndexer (const Ranges& ranges, const vector<bool>& mask)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges), : index_(0), indices_(ranges.size(), 0), ranges_(ranges),
valid_(true) valid_(true)
{ {
@ -164,7 +155,7 @@ class MappingIndexer
assert (ranges.size() == mask.size()); assert (ranges.size() == mask.size());
} }
MappingIndexer (const Ranges& ranges, size_t dim) CutIndexer (const Ranges& ranges, size_t dim)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges), : index_(0), indices_(ranges.size(), 0), ranges_(ranges),
valid_(true) valid_(true)
{ {
@ -178,7 +169,7 @@ class MappingIndexer
} }
} }
MappingIndexer& operator++ (void) CutIndexer& operator++ (void)
{ {
assert (valid_); assert (valid_);
for (size_t i = ranges_.size(); i-- > 0; ) { for (size_t i = ranges_.size(); i-- > 0; ) {
@ -218,7 +209,7 @@ class MappingIndexer
index_ = 0; index_ = 0;
} }
friend std::ostream& operator<< (std::ostream&, const MappingIndexer&); friend std::ostream& operator<< (std::ostream&, const CutIndexer&);
private: private:
size_t index_; size_t index_;
@ -230,15 +221,16 @@ class MappingIndexer
inline std::ostream& operator<< (ostream &os, const MappingIndexer& mi) inline std::ostream&
operator<< (std::ostream &os, const CutIndexer& indexer)
{ {
os << "(" ; os << "(" ;
os << std::setw (2) << std::setfill('0') << mi.index_; os << std::setw (2) << std::setfill('0') << indexer.index_;
os << ") " ; os << ") " ;
os << mi.indices_; os << indexer.indices_;
return os; return os;
} }
#endif // HORUS_STATESINDEXER_H #endif // HORUS_INDEXER_H

View File

@ -134,7 +134,7 @@ Parfactor::sumOutIndex (size_t fIdx)
args_[fIdx].countedLogVar()); args_[fIdx].countedLogVar());
unsigned R = args_[fIdx].range(); unsigned R = args_[fIdx].range();
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R); vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
StatesIndexer indexer (ranges_, fIdx); Indexer indexer (ranges_, fIdx);
while (indexer.valid()) { while (indexer.valid()) {
if (Globals::logDomain) { if (Globals::logDomain) {
params_[indexer] += numAssigns[ indexer[fIdx] ]; params_[indexer] += numAssigns[ indexer[fIdx] ];
@ -210,29 +210,29 @@ Parfactor::countConvert (LogVar X)
unsigned H = HistogramSet::nrHistograms (N, R); unsigned H = HistogramSet::nrHistograms (N, R);
vector<Histogram> histograms = HistogramSet::getHistograms (N, R); vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
StatesIndexer indexer (ranges_); Indexer indexer (ranges_);
vector<Params> sumout (params_.size() / R); vector<Params> sumout (params_.size() / R);
unsigned count = 0; unsigned count = 0;
while (indexer.valid()) { while (indexer.valid()) {
sumout[count].reserve (R); sumout[count].reserve (R);
for (unsigned r = 0; r < R; r++) { for (unsigned r = 0; r < R; r++) {
sumout[count].push_back (params_[indexer]); sumout[count].push_back (params_[indexer]);
indexer.increment (fIdx); indexer.incrementDimension (fIdx);
} }
count ++; count ++;
indexer.reset (fIdx); indexer.resetDimension (fIdx);
indexer.incrementExcluding (fIdx); indexer.incrementExceptDimension (fIdx);
} }
params_.clear(); params_.clear();
params_.reserve (sumout.size() * H); params_.reserve (sumout.size() * H);
ranges_[fIdx] = H; ranges_[fIdx] = H;
MappingIndexer mapIndexer (ranges_, fIdx); CutIndexer cutIndexer (ranges_, fIdx);
while (mapIndexer.valid()) { while (cutIndexer.valid()) {
double prod = LogAware::multIdenty(); double prod = LogAware::multIdenty();
size_t i = mapIndexer; size_t i = cutIndexer;
unsigned h = mapIndexer[fIdx]; unsigned h = cutIndexer[fIdx];
for (unsigned r = 0; r < R; r++) { for (unsigned r = 0; r < R; r++) {
if (Globals::logDomain) { if (Globals::logDomain) {
prod += LogAware::pow (sumout[i][r], histograms[h][r]); prod += LogAware::pow (sumout[i][r], histograms[h][r]);
@ -241,7 +241,7 @@ Parfactor::countConvert (LogVar X)
} }
} }
params_.push_back (prod); params_.push_back (prod);
++ mapIndexer; ++ cutIndexer;
} }
args_[fIdx].setCountedLogVar (X); args_[fIdx].setCountedLogVar (X);
simplifyCountingFormulas (fIdx); simplifyCountingFormulas (fIdx);
@ -310,7 +310,7 @@ Parfactor::fullExpand (LogVar X)
sumIndexes.reserve (N * R); sumIndexes.reserve (N * R);
Ranges expandRanges (N, R); Ranges expandRanges (N, R);
StatesIndexer indexer (expandRanges); Indexer indexer (expandRanges);
while (indexer.valid()) { while (indexer.valid()) {
vector<unsigned> hist (R, 0); vector<unsigned> hist (R, 0);
for (unsigned n = 0; n < N; n++) { for (unsigned n = 0; n < N; n++) {
@ -572,7 +572,7 @@ void
Parfactor::printParameters (void) const Parfactor::printParameters (void) const
{ {
vector<string> jointStrings; vector<string> jointStrings;
StatesIndexer indexer (ranges_); Indexer indexer (ranges_);
while (indexer.valid()) { while (indexer.valid()) {
stringstream ss; stringstream ss;
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
@ -740,7 +740,7 @@ Parfactor::simplifyParfactor (size_t fIdx1, size_t fIdx2)
{ {
Params copy = params_; Params copy = params_;
params_.clear(); params_.clear();
StatesIndexer indexer (ranges_); Indexer indexer (ranges_);
while (indexer.valid()) { while (indexer.valid()) {
if (indexer[fIdx1] == indexer[fIdx2]) { if (indexer[fIdx1] == indexer[fIdx2]) {
params_.push_back (copy[indexer]); params_.push_back (copy[indexer]);

View File

@ -1,4 +1,3 @@
- Receive ranges as a constant reference in Indexer
- Check if evidence remains in the compressed factor graph - Check if evidence remains in the compressed factor graph
- Consider using hashs instead of vectors of colors to calculate the groups in - Consider using hashs instead of vectors of colors to calculate the groups in
counting bp counting bp

View File

@ -183,7 +183,7 @@ getStateLines (const Vars& vars)
for (size_t i = 0; i < vars.size(); i++) { for (size_t i = 0; i < vars.size(); i++) {
ranges.push_back (vars[i]->range()); ranges.push_back (vars[i]->range());
} }
StatesIndexer indexer (ranges); Indexer indexer (ranges);
vector<string> jointStrings; vector<string> jointStrings;
while (indexer.valid()) { while (indexer.valid()) {
stringstream ss; stringstream ss;