refactor functions for summing out
This commit is contained in:
@@ -124,7 +124,7 @@ Parfactor::exclusiveLogVars (size_t fIdx) const
|
||||
|
||||
|
||||
void
|
||||
Parfactor::sumOut (size_t fIdx)
|
||||
Parfactor::sumOutIndex (size_t fIdx)
|
||||
{
|
||||
assert (fIdx < args_.size());
|
||||
assert (args_[fIdx].contains (elimLogVars()));
|
||||
@@ -134,46 +134,29 @@ Parfactor::sumOut (size_t fIdx)
|
||||
args_[fIdx].countedLogVar());
|
||||
unsigned R = args_[fIdx].range();
|
||||
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||
StatesIndexer sindexer (ranges_, fIdx);
|
||||
while (sindexer.valid()) {
|
||||
unsigned h = sindexer[fIdx];
|
||||
StatesIndexer indexer (ranges_, fIdx);
|
||||
while (indexer.valid()) {
|
||||
if (Globals::logDomain) {
|
||||
params_[sindexer] += numAssigns[h];
|
||||
params_[indexer] += numAssigns[ indexer[fIdx] ];
|
||||
} else {
|
||||
params_[sindexer] *= numAssigns[h];
|
||||
params_[indexer] *= numAssigns[ indexer[fIdx] ];
|
||||
}
|
||||
++ sindexer;
|
||||
}
|
||||
}
|
||||
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.resize (copy.size() / ranges_[fIdx], LogAware::addIdenty());
|
||||
MapIndexer indexer (ranges_, fIdx);
|
||||
if (Globals::logDomain) {
|
||||
for (size_t i = 0; i < copy.size(); i++) {
|
||||
params_[indexer] = Util::logSum (params_[indexer], copy[i]);
|
||||
++ indexer;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < copy.size(); i++) {
|
||||
params_[indexer] += copy[i];
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
|
||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||
unsigned exp;
|
||||
if (args_[fIdx].isCounting()) {
|
||||
// counting log vars were already raised on counting conversion
|
||||
LogAware::pow (params_, constr_->getConditionalCount (
|
||||
excl - args_[fIdx].countedLogVar()));
|
||||
exp = constr_->getConditionalCount (excl - args_[fIdx].countedLogVar());
|
||||
} else {
|
||||
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||
exp = constr_->getConditionalCount (excl);
|
||||
}
|
||||
constr_->remove (excl);
|
||||
|
||||
args_.erase (args_.begin() + fIdx);
|
||||
ranges_.erase (ranges_.begin() + fIdx);
|
||||
TFactor<ProbFormula>::sumOutIndex (fIdx);
|
||||
LogAware::pow (params_, exp);
|
||||
}
|
||||
|
||||
|
||||
@@ -245,10 +228,10 @@ Parfactor::countConvert (LogVar X)
|
||||
params_.reserve (sumout.size() * H);
|
||||
|
||||
ranges_[fIdx] = H;
|
||||
MapIndexer mapIndexer (ranges_, fIdx);
|
||||
MappingIndexer mapIndexer (ranges_, fIdx);
|
||||
while (mapIndexer.valid()) {
|
||||
double prod = LogAware::multIdenty();
|
||||
size_t i = mapIndexer.mappedIndex();
|
||||
size_t i = mapIndexer;
|
||||
unsigned h = mapIndexer[fIdx];
|
||||
for (unsigned r = 0; r < R; r++) {
|
||||
if (Globals::logDomain) {
|
||||
|
Reference in New Issue
Block a user