refactor functions for summing out

This commit is contained in:
Tiago Gomes
2012-05-25 20:15:05 +01:00
parent df8a3c5fdc
commit 5ff161b10f
7 changed files with 176 additions and 246 deletions

View File

@@ -128,6 +128,31 @@ class TFactor
}
}
void sumOutIndex (size_t idx)
{
assert (idx < args_.size());
assert (args_.size() > 1);
size_t new_size = params_.size() / ranges_[idx];
Params newps (new_size, LogAware::addIdenty());
Params::const_iterator first = params_.begin();
Params::const_iterator last = params_.end();
MappingIndexer indexer (ranges_, idx);
if (Globals::logDomain) {
while (first != last) {
newps[indexer] = Util::logSum (newps[indexer], *first++);
++ indexer;
}
} else {
while (first != last) {
newps[indexer] += *first++;
++ indexer;
}
}
params_ = newps;
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
void absorveEvidence (const T& arg, unsigned evidence)
{
size_t idx = indexOf (arg);
@@ -137,7 +162,7 @@ class TFactor
params_.clear();
params_.reserve (copy.size() / ranges_[idx]);
StatesIndexer indexer (ranges_);
for (size_t i = 0; i < evidence; i++) {
for (unsigned i = 0; i < evidence; i++) {
indexer.increment (idx);
}
while (indexer.valid()) {
@@ -246,7 +271,7 @@ class TFactor
params_.push_back (copy[i]);
}
}
}
}
};
@@ -270,14 +295,8 @@ class Factor : public TFactor<VarId>
void sumOutAllExcept (const VarIds&);
void sumOutIndex (size_t idx);
void sumOutAllExceptIndex (size_t idx);
void sumOutFirstVariable (void);
void sumOutLastVariable (void);
void multiply (Factor&);
void reorderAccordingVarIds (void);
@@ -287,6 +306,12 @@ class Factor : public TFactor<VarId>
void print (void) const;
private:
void sumOutFirstVariable (void);
void sumOutLastVariable (void);
void sumOutArgs (const vector<bool>& mask);
void copyFromFactor (const Factor& f);
};