refactor functions for summing out

This commit is contained in:
Tiago Gomes
2012-05-25 20:15:05 +01:00
parent df8a3c5fdc
commit 5ff161b10f
7 changed files with 176 additions and 246 deletions

View File

@@ -124,7 +124,7 @@ Parfactor::exclusiveLogVars (size_t fIdx) const
void
Parfactor::sumOut (size_t fIdx)
Parfactor::sumOutIndex (size_t fIdx)
{
assert (fIdx < args_.size());
assert (args_[fIdx].contains (elimLogVars()));
@@ -134,46 +134,29 @@ Parfactor::sumOut (size_t fIdx)
args_[fIdx].countedLogVar());
unsigned R = args_[fIdx].range();
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
StatesIndexer sindexer (ranges_, fIdx);
while (sindexer.valid()) {
unsigned h = sindexer[fIdx];
StatesIndexer indexer (ranges_, fIdx);
while (indexer.valid()) {
if (Globals::logDomain) {
params_[sindexer] += numAssigns[h];
params_[indexer] += numAssigns[ indexer[fIdx] ];
} else {
params_[sindexer] *= numAssigns[h];
params_[indexer] *= numAssigns[ indexer[fIdx] ];
}
++ sindexer;
}
}
Params copy = params_;
params_.clear();
params_.resize (copy.size() / ranges_[fIdx], LogAware::addIdenty());
MapIndexer indexer (ranges_, fIdx);
if (Globals::logDomain) {
for (size_t i = 0; i < copy.size(); i++) {
params_[indexer] = Util::logSum (params_[indexer], copy[i]);
++ indexer;
}
} else {
for (size_t i = 0; i < copy.size(); i++) {
params_[indexer] += copy[i];
++ indexer;
}
}
LogVarSet excl = exclusiveLogVars (fIdx);
unsigned exp;
if (args_[fIdx].isCounting()) {
// counting log vars were already raised on counting conversion
LogAware::pow (params_, constr_->getConditionalCount (
excl - args_[fIdx].countedLogVar()));
exp = constr_->getConditionalCount (excl - args_[fIdx].countedLogVar());
} else {
LogAware::pow (params_, constr_->getConditionalCount (excl));
exp = constr_->getConditionalCount (excl);
}
constr_->remove (excl);
args_.erase (args_.begin() + fIdx);
ranges_.erase (ranges_.begin() + fIdx);
TFactor<ProbFormula>::sumOutIndex (fIdx);
LogAware::pow (params_, exp);
}
@@ -245,10 +228,10 @@ Parfactor::countConvert (LogVar X)
params_.reserve (sumout.size() * H);
ranges_[fIdx] = H;
MapIndexer mapIndexer (ranges_, fIdx);
MappingIndexer mapIndexer (ranges_, fIdx);
while (mapIndexer.valid()) {
double prod = LogAware::multIdenty();
size_t i = mapIndexer.mappedIndex();
size_t i = mapIndexer;
unsigned h = mapIndexer[fIdx];
for (unsigned r = 0; r < R; r++) {
if (Globals::logDomain) {