refactor indexer classes and receive te ranges as a constant reference
This commit is contained in:
parent
5ff161b10f
commit
2efca0c85a
@ -399,18 +399,18 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
observedVars.push_back (fg->getVarNode (observedVids[j]));
|
observedVars.push_back (fg->getVarNode (observedVids[j]));
|
||||||
observedRanges.push_back (observedVars.back()->range());
|
observedRanges.push_back (observedVars.back()->range());
|
||||||
}
|
}
|
||||||
StatesIndexer idx (observedRanges, false);
|
Indexer indexer (observedRanges, false);
|
||||||
while (idx.valid()) {
|
while (indexer.valid()) {
|
||||||
for (size_t j = 0; j < observedVars.size(); j++) {
|
for (size_t j = 0; j < observedVars.size(); j++) {
|
||||||
observedVars[j]->setEvidence (idx[j]);
|
observedVars[j]->setEvidence (indexer[j]);
|
||||||
}
|
}
|
||||||
++ idx;
|
|
||||||
BpSolver solver (*fg);
|
BpSolver solver (*fg);
|
||||||
solver.runSolver();
|
solver.runSolver();
|
||||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||||
newBeliefs.push_back (beliefs[k]);
|
newBeliefs.push_back (beliefs[k]);
|
||||||
}
|
}
|
||||||
|
++ indexer;
|
||||||
}
|
}
|
||||||
|
|
||||||
int count = -1;
|
int count = -1;
|
||||||
|
@ -219,7 +219,7 @@ Factor::sumOutArgs (const vector<bool>& mask)
|
|||||||
Params newps (new_size, LogAware::addIdenty());
|
Params newps (new_size, LogAware::addIdenty());
|
||||||
Params::const_iterator first = params_.begin();
|
Params::const_iterator first = params_.begin();
|
||||||
Params::const_iterator last = params_.end();
|
Params::const_iterator last = params_.end();
|
||||||
MappingIndexer indexer (oldRanges, mask);
|
CutIndexer indexer (oldRanges, mask);
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
while (first != last) {
|
while (first != last) {
|
||||||
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
||||||
|
@ -109,7 +109,7 @@ class TFactor
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
StatesIndexer indexer (ranges_, false);
|
Indexer indexer (ranges_, false);
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
size_t g_li = 0;
|
size_t g_li = 0;
|
||||||
size_t prod = 1;
|
size_t prod = 1;
|
||||||
@ -136,7 +136,7 @@ class TFactor
|
|||||||
Params newps (new_size, LogAware::addIdenty());
|
Params newps (new_size, LogAware::addIdenty());
|
||||||
Params::const_iterator first = params_.begin();
|
Params::const_iterator first = params_.begin();
|
||||||
Params::const_iterator last = params_.end();
|
Params::const_iterator last = params_.end();
|
||||||
MappingIndexer indexer (ranges_, idx);
|
CutIndexer indexer (ranges_, idx);
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
while (first != last) {
|
while (first != last) {
|
||||||
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
||||||
@ -161,13 +161,13 @@ class TFactor
|
|||||||
Params copy = params_;
|
Params copy = params_;
|
||||||
params_.clear();
|
params_.clear();
|
||||||
params_.reserve (copy.size() / ranges_[idx]);
|
params_.reserve (copy.size() / ranges_[idx]);
|
||||||
StatesIndexer indexer (ranges_);
|
Indexer indexer (ranges_);
|
||||||
for (unsigned i = 0; i < evidence; i++) {
|
for (unsigned i = 0; i < evidence; i++) {
|
||||||
indexer.increment (idx);
|
indexer.incrementDimension (idx);
|
||||||
}
|
}
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
params_.push_back (copy[indexer]);
|
params_.push_back (copy[indexer]);
|
||||||
indexer.incrementExcluding (idx);
|
indexer.incrementExceptDimension (idx);
|
||||||
}
|
}
|
||||||
args_.erase (args_.begin() + idx);
|
args_.erase (args_.begin() + idx);
|
||||||
ranges_.erase (ranges_.begin() + idx);
|
ranges_.erase (ranges_.begin() + idx);
|
||||||
|
@ -1,29 +1,23 @@
|
|||||||
#ifndef HORUS_STATESINDEXER_H
|
#ifndef HORUS_INDEXER_H
|
||||||
#define HORUS_STATESINDEXER_H
|
#define HORUS_INDEXER_H
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <numeric>
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
||||||
#include "Var.h"
|
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
class Indexer
|
||||||
class StatesIndexer
|
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
Indexer (const Ranges& ranges, bool calcOffsets = true)
|
||||||
StatesIndexer (const Ranges& ranges, bool calcOffsets = true)
|
: index_(0), indices_(ranges.size(), 0), ranges_(ranges)
|
||||||
{
|
{
|
||||||
li_ = 0;
|
|
||||||
size_ = std::accumulate (ranges.begin(), ranges.end(), 1,
|
size_ = std::accumulate (ranges.begin(), ranges.end(), 1,
|
||||||
std::multiplies<unsigned>());
|
std::multiplies<unsigned>());
|
||||||
indices_.resize (ranges.size(), 0);
|
|
||||||
ranges_ = ranges;
|
|
||||||
if (calcOffsets) {
|
if (calcOffsets) {
|
||||||
calculateOffsets();
|
calculateOffsets();
|
||||||
}
|
}
|
||||||
@ -39,47 +33,37 @@ class StatesIndexer
|
|||||||
indices_[i] = 0;
|
indices_[i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
li_ ++;
|
index_ ++;
|
||||||
}
|
}
|
||||||
|
|
||||||
void increment (size_t dim)
|
void incrementDimension (size_t dim)
|
||||||
{
|
{
|
||||||
assert (dim < ranges_.size());
|
assert (dim < ranges_.size());
|
||||||
assert (ranges_.size() == offsets_.size());
|
assert (ranges_.size() == offsets_.size());
|
||||||
assert (indices_[dim] < ranges_[dim]);
|
assert (indices_[dim] < ranges_[dim]);
|
||||||
indices_[dim] ++;
|
indices_[dim] ++;
|
||||||
li_ += offsets_[dim];
|
index_ += offsets_[dim];
|
||||||
}
|
}
|
||||||
|
|
||||||
void incrementExcluding (size_t skipDim)
|
void incrementExceptDimension (size_t dim)
|
||||||
{
|
{
|
||||||
assert (ranges_.size() == offsets_.size());
|
assert (ranges_.size() == offsets_.size());
|
||||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||||
if (i != (int)skipDim) {
|
if (i != dim) {
|
||||||
indices_[i] ++;
|
indices_[i] ++;
|
||||||
li_ += offsets_[i];
|
index_ += offsets_[i];
|
||||||
if (indices_[i] != ranges_[i]) {
|
if (indices_[i] != ranges_[i]) {
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
indices_[i] = 0;
|
indices_[i] = 0;
|
||||||
li_ -= offsets_[i] * ranges_[i];
|
index_ -= offsets_[i] * ranges_[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
li_ = size_;
|
index_ = size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t linearIndex (void) const
|
Indexer& operator++ (void)
|
||||||
{
|
|
||||||
return li_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const vector<unsigned>& indices (void) const
|
|
||||||
{
|
|
||||||
return indices_;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatesIndexer& operator ++ (void)
|
|
||||||
{
|
{
|
||||||
increment();
|
increment();
|
||||||
return *this;
|
return *this;
|
||||||
@ -87,7 +71,7 @@ class StatesIndexer
|
|||||||
|
|
||||||
operator size_t (void) const
|
operator size_t (void) const
|
||||||
{
|
{
|
||||||
return li_;
|
return index_;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned operator[] (size_t dim) const
|
unsigned operator[] (size_t dim) const
|
||||||
@ -99,19 +83,19 @@ class StatesIndexer
|
|||||||
|
|
||||||
bool valid (void) const
|
bool valid (void) const
|
||||||
{
|
{
|
||||||
return li_ < size_;
|
return index_ < size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void reset (void)
|
void reset (void)
|
||||||
{
|
{
|
||||||
std::fill (indices_.begin(), indices_.end(), 0);
|
std::fill (indices_.begin(), indices_.end(), 0);
|
||||||
li_ = 0;
|
index_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void reset (size_t dim)
|
void resetDimension (size_t dim)
|
||||||
{
|
{
|
||||||
indices_[dim] = 0;
|
indices_[dim] = 0;
|
||||||
li_ -= offsets_[dim] * ranges_[dim];
|
index_ -= offsets_[dim] * ranges_[dim];
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t size (void) const
|
size_t size (void) const
|
||||||
@ -119,12 +103,7 @@ class StatesIndexer
|
|||||||
return size_ ;
|
return size_ ;
|
||||||
}
|
}
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &os, const StatesIndexer& idx)
|
friend std::ostream& operator<< (std::ostream&, const Indexer&);
|
||||||
{
|
|
||||||
os << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
|
|
||||||
os << idx.indices_;
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void calculateOffsets (void)
|
void calculateOffsets (void)
|
||||||
@ -137,19 +116,31 @@ class StatesIndexer
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t li_;
|
size_t index_;
|
||||||
|
Ranges indices_;
|
||||||
|
const Ranges& ranges_;
|
||||||
size_t size_;
|
size_t size_;
|
||||||
vector<unsigned> indices_;
|
|
||||||
vector<unsigned> ranges_;
|
|
||||||
vector<size_t> offsets_;
|
vector<size_t> offsets_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MappingIndexer
|
inline std::ostream&
|
||||||
|
operator<< (std::ostream& os, const Indexer& indexer)
|
||||||
|
{
|
||||||
|
os << "(" ;
|
||||||
|
os << std::setw (2) << std::setfill('0') << indexer.index_;
|
||||||
|
os << ") " ;
|
||||||
|
os << indexer.indices_;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CutIndexer
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
MappingIndexer (const Ranges& ranges, const vector<bool>& mask)
|
CutIndexer (const Ranges& ranges, const vector<bool>& mask)
|
||||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||||
valid_(true)
|
valid_(true)
|
||||||
{
|
{
|
||||||
@ -164,7 +155,7 @@ class MappingIndexer
|
|||||||
assert (ranges.size() == mask.size());
|
assert (ranges.size() == mask.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
MappingIndexer (const Ranges& ranges, size_t dim)
|
CutIndexer (const Ranges& ranges, size_t dim)
|
||||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||||
valid_(true)
|
valid_(true)
|
||||||
{
|
{
|
||||||
@ -178,7 +169,7 @@ class MappingIndexer
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MappingIndexer& operator++ (void)
|
CutIndexer& operator++ (void)
|
||||||
{
|
{
|
||||||
assert (valid_);
|
assert (valid_);
|
||||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||||
@ -218,7 +209,7 @@ class MappingIndexer
|
|||||||
index_ = 0;
|
index_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
friend std::ostream& operator<< (std::ostream&, const MappingIndexer&);
|
friend std::ostream& operator<< (std::ostream&, const CutIndexer&);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t index_;
|
size_t index_;
|
||||||
@ -230,15 +221,16 @@ class MappingIndexer
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline std::ostream& operator<< (ostream &os, const MappingIndexer& mi)
|
inline std::ostream&
|
||||||
|
operator<< (std::ostream &os, const CutIndexer& indexer)
|
||||||
{
|
{
|
||||||
os << "(" ;
|
os << "(" ;
|
||||||
os << std::setw (2) << std::setfill('0') << mi.index_;
|
os << std::setw (2) << std::setfill('0') << indexer.index_;
|
||||||
os << ") " ;
|
os << ") " ;
|
||||||
os << mi.indices_;
|
os << indexer.indices_;
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif // HORUS_STATESINDEXER_H
|
#endif // HORUS_INDEXER_H
|
||||||
|
|
||||||
|
@ -134,7 +134,7 @@ Parfactor::sumOutIndex (size_t fIdx)
|
|||||||
args_[fIdx].countedLogVar());
|
args_[fIdx].countedLogVar());
|
||||||
unsigned R = args_[fIdx].range();
|
unsigned R = args_[fIdx].range();
|
||||||
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||||
StatesIndexer indexer (ranges_, fIdx);
|
Indexer indexer (ranges_, fIdx);
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
params_[indexer] += numAssigns[ indexer[fIdx] ];
|
params_[indexer] += numAssigns[ indexer[fIdx] ];
|
||||||
@ -210,29 +210,29 @@ Parfactor::countConvert (LogVar X)
|
|||||||
unsigned H = HistogramSet::nrHistograms (N, R);
|
unsigned H = HistogramSet::nrHistograms (N, R);
|
||||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||||
|
|
||||||
StatesIndexer indexer (ranges_);
|
Indexer indexer (ranges_);
|
||||||
vector<Params> sumout (params_.size() / R);
|
vector<Params> sumout (params_.size() / R);
|
||||||
unsigned count = 0;
|
unsigned count = 0;
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
sumout[count].reserve (R);
|
sumout[count].reserve (R);
|
||||||
for (unsigned r = 0; r < R; r++) {
|
for (unsigned r = 0; r < R; r++) {
|
||||||
sumout[count].push_back (params_[indexer]);
|
sumout[count].push_back (params_[indexer]);
|
||||||
indexer.increment (fIdx);
|
indexer.incrementDimension (fIdx);
|
||||||
}
|
}
|
||||||
count ++;
|
count ++;
|
||||||
indexer.reset (fIdx);
|
indexer.resetDimension (fIdx);
|
||||||
indexer.incrementExcluding (fIdx);
|
indexer.incrementExceptDimension (fIdx);
|
||||||
}
|
}
|
||||||
|
|
||||||
params_.clear();
|
params_.clear();
|
||||||
params_.reserve (sumout.size() * H);
|
params_.reserve (sumout.size() * H);
|
||||||
|
|
||||||
ranges_[fIdx] = H;
|
ranges_[fIdx] = H;
|
||||||
MappingIndexer mapIndexer (ranges_, fIdx);
|
CutIndexer cutIndexer (ranges_, fIdx);
|
||||||
while (mapIndexer.valid()) {
|
while (cutIndexer.valid()) {
|
||||||
double prod = LogAware::multIdenty();
|
double prod = LogAware::multIdenty();
|
||||||
size_t i = mapIndexer;
|
size_t i = cutIndexer;
|
||||||
unsigned h = mapIndexer[fIdx];
|
unsigned h = cutIndexer[fIdx];
|
||||||
for (unsigned r = 0; r < R; r++) {
|
for (unsigned r = 0; r < R; r++) {
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
prod += LogAware::pow (sumout[i][r], histograms[h][r]);
|
prod += LogAware::pow (sumout[i][r], histograms[h][r]);
|
||||||
@ -241,7 +241,7 @@ Parfactor::countConvert (LogVar X)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
params_.push_back (prod);
|
params_.push_back (prod);
|
||||||
++ mapIndexer;
|
++ cutIndexer;
|
||||||
}
|
}
|
||||||
args_[fIdx].setCountedLogVar (X);
|
args_[fIdx].setCountedLogVar (X);
|
||||||
simplifyCountingFormulas (fIdx);
|
simplifyCountingFormulas (fIdx);
|
||||||
@ -310,7 +310,7 @@ Parfactor::fullExpand (LogVar X)
|
|||||||
sumIndexes.reserve (N * R);
|
sumIndexes.reserve (N * R);
|
||||||
|
|
||||||
Ranges expandRanges (N, R);
|
Ranges expandRanges (N, R);
|
||||||
StatesIndexer indexer (expandRanges);
|
Indexer indexer (expandRanges);
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
vector<unsigned> hist (R, 0);
|
vector<unsigned> hist (R, 0);
|
||||||
for (unsigned n = 0; n < N; n++) {
|
for (unsigned n = 0; n < N; n++) {
|
||||||
@ -572,7 +572,7 @@ void
|
|||||||
Parfactor::printParameters (void) const
|
Parfactor::printParameters (void) const
|
||||||
{
|
{
|
||||||
vector<string> jointStrings;
|
vector<string> jointStrings;
|
||||||
StatesIndexer indexer (ranges_);
|
Indexer indexer (ranges_);
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
for (size_t i = 0; i < args_.size(); i++) {
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
@ -740,7 +740,7 @@ Parfactor::simplifyParfactor (size_t fIdx1, size_t fIdx2)
|
|||||||
{
|
{
|
||||||
Params copy = params_;
|
Params copy = params_;
|
||||||
params_.clear();
|
params_.clear();
|
||||||
StatesIndexer indexer (ranges_);
|
Indexer indexer (ranges_);
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
if (indexer[fIdx1] == indexer[fIdx2]) {
|
if (indexer[fIdx1] == indexer[fIdx2]) {
|
||||||
params_.push_back (copy[indexer]);
|
params_.push_back (copy[indexer]);
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
- Receive ranges as a constant reference in Indexer
|
|
||||||
- Check if evidence remains in the compressed factor graph
|
- Check if evidence remains in the compressed factor graph
|
||||||
- Consider using hashs instead of vectors of colors to calculate the groups in
|
- Consider using hashs instead of vectors of colors to calculate the groups in
|
||||||
counting bp
|
counting bp
|
||||||
|
@ -183,7 +183,7 @@ getStateLines (const Vars& vars)
|
|||||||
for (size_t i = 0; i < vars.size(); i++) {
|
for (size_t i = 0; i < vars.size(); i++) {
|
||||||
ranges.push_back (vars[i]->range());
|
ranges.push_back (vars[i]->range());
|
||||||
}
|
}
|
||||||
StatesIndexer indexer (ranges);
|
Indexer indexer (ranges);
|
||||||
vector<string> jointStrings;
|
vector<string> jointStrings;
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
|
Reference in New Issue
Block a user