refactor functions for summing out
This commit is contained in:
@@ -51,16 +51,16 @@ Factor::Factor (
|
||||
void
|
||||
Factor::sumOut (VarId vid)
|
||||
{
|
||||
if (vid == args_.back()) {
|
||||
sumOutLastVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
if (vid == args_.front()) {
|
||||
sumOutFirstVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
assert (indexOf (vid) != args_.size());
|
||||
sumOutIndex (indexOf (vid));
|
||||
if (vid == args_.front() && ranges_.front() == 2) {
|
||||
// optimization
|
||||
sumOutFirstVariable();
|
||||
} else if (vid == args_.back() && ranges_.back() == 2) {
|
||||
// optimization
|
||||
sumOutLastVariable();
|
||||
} else {
|
||||
assert (indexOf (vid) != args_.size());
|
||||
sumOutIndex (indexOf (vid));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -69,12 +69,7 @@ void
|
||||
Factor::sumOutAllExcept (VarId vid)
|
||||
{
|
||||
assert (indexOf (vid) != args_.size());
|
||||
while (args_.back() != vid) {
|
||||
sumOutLastVariable();
|
||||
}
|
||||
while (args_.front() != vid) {
|
||||
sumOutFirstVariable();
|
||||
}
|
||||
sumOutAllExceptIndex (indexOf (vid));
|
||||
}
|
||||
|
||||
|
||||
@@ -82,67 +77,12 @@ Factor::sumOutAllExcept (VarId vid)
|
||||
void
|
||||
Factor::sumOutAllExcept (const VarIds& vids)
|
||||
{
|
||||
for (int i = 0; i < (int)args_.size(); i++) {
|
||||
if (Util::contains (vids, args_[i]) == false) {
|
||||
sumOut (args_[i]);
|
||||
i --;
|
||||
}
|
||||
vector<bool> mask (args_.size(), false);
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
assert (indexOf (vids[i]) != args_.size());
|
||||
mask[indexOf (vids[i])] = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutIndex (size_t idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
// number of parameters separating a different state of `var',
|
||||
// with the states of the remaining variables fixed
|
||||
unsigned varOffset = 1;
|
||||
|
||||
// number of parameters separating a different state of the variable
|
||||
// on the left of `var', with the states of the remaining vars fixed
|
||||
unsigned leftVarOffset = 1;
|
||||
|
||||
for (int i = args_.size() - 1; i > (int)idx; i--) {
|
||||
varOffset *= ranges_[i];
|
||||
leftVarOffset *= ranges_[i];
|
||||
}
|
||||
leftVarOffset *= ranges_[idx];
|
||||
|
||||
unsigned offset = 0;
|
||||
unsigned count1 = 0;
|
||||
unsigned count2 = 0;
|
||||
unsigned newpsSize = params_.size() / ranges_[idx];
|
||||
|
||||
Params newps;
|
||||
newps.reserve (newpsSize);
|
||||
|
||||
while (newps.size() < newpsSize) {
|
||||
double sum = LogAware::addIdenty();
|
||||
for (unsigned i = 0; i < ranges_[idx]; i++) {
|
||||
if (Globals::logDomain) {
|
||||
sum = Util::logSum (sum, params_[offset]);
|
||||
} else {
|
||||
sum += params_[offset];
|
||||
}
|
||||
offset += varOffset;
|
||||
}
|
||||
newps.push_back (sum);
|
||||
count1 ++;
|
||||
if (idx == args_.size() - 1) {
|
||||
offset = count1 * ranges_[idx];
|
||||
} else {
|
||||
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
||||
count1 = 0;
|
||||
count2 ++;
|
||||
}
|
||||
offset = (leftVarOffset * count2) + count1;
|
||||
}
|
||||
}
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
params_ = newps;
|
||||
sumOutArgs (mask);
|
||||
}
|
||||
|
||||
|
||||
@@ -151,73 +91,12 @@ void
|
||||
Factor::sumOutAllExceptIndex (size_t idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
while (args_.size() > idx + 1) {
|
||||
sumOutLastVariable();
|
||||
}
|
||||
for (size_t i = 0; i < idx; i++) {
|
||||
sumOutFirstVariable();
|
||||
}
|
||||
vector<bool> mask (args_.size(), false);
|
||||
mask[idx] = true;
|
||||
sumOutArgs (mask);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutFirstVariable (void)
|
||||
{
|
||||
assert (args_.size() > 1);
|
||||
unsigned range = ranges_.front();
|
||||
unsigned sep = params_.size() / range;
|
||||
if (Globals::logDomain) {
|
||||
for (size_t i = sep; i < params_.size(); i++) {
|
||||
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = sep; i < params_.size(); i++) {
|
||||
params_[i % sep] += params_[i];
|
||||
}
|
||||
}
|
||||
params_.resize (sep);
|
||||
args_.erase (args_.begin());
|
||||
ranges_.erase (ranges_.begin());
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutLastVariable (void)
|
||||
{
|
||||
assert (args_.size() > 1);
|
||||
size_t idx1 = 0;
|
||||
size_t idx2 = 0;
|
||||
unsigned range = ranges_.back();
|
||||
if (Globals::logDomain) {
|
||||
while (idx1 < params_.size()) {
|
||||
params_[idx2] = params_[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < range; j++) {
|
||||
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
|
||||
idx1 ++;
|
||||
}
|
||||
idx2 ++;
|
||||
}
|
||||
} else {
|
||||
while (idx1 < params_.size()) {
|
||||
params_[idx2] = params_[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < range; j++) {
|
||||
params_[idx2] += params_[idx1];
|
||||
idx1 ++;
|
||||
}
|
||||
idx2 ++;
|
||||
}
|
||||
}
|
||||
params_.resize (idx2);
|
||||
args_.pop_back();
|
||||
ranges_.pop_back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::multiply (Factor& g)
|
||||
{
|
||||
@@ -276,6 +155,87 @@ Factor::print (void) const
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutFirstVariable (void)
|
||||
{
|
||||
size_t sep = params_.size() / 2;
|
||||
if (Globals::logDomain) {
|
||||
std::transform (
|
||||
params_.begin(), params_.begin() + sep,
|
||||
params_.begin() + sep, params_.begin(),
|
||||
Util::logSum);
|
||||
|
||||
} else {
|
||||
std::transform (
|
||||
params_.begin(), params_.begin() + sep,
|
||||
params_.begin() + sep, params_.begin(),
|
||||
std::plus<double>());
|
||||
}
|
||||
params_.resize (sep);
|
||||
args_.erase (args_.begin());
|
||||
ranges_.erase (ranges_.begin());
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutLastVariable (void)
|
||||
{
|
||||
Params::iterator first1 = params_.begin();
|
||||
Params::iterator first2 = params_.begin();
|
||||
Params::iterator last = params_.end();
|
||||
if (Globals::logDomain) {
|
||||
while (first2 != last) {
|
||||
// the arguments can be swaped, but that is ok
|
||||
*first1++ = Util::logSum (*first2++, *first2++);
|
||||
}
|
||||
} else {
|
||||
while (first2 != last) {
|
||||
*first1++ = (*first2++) + (*first2++);
|
||||
}
|
||||
}
|
||||
params_.resize (params_.size() / 2);
|
||||
args_.pop_back();
|
||||
ranges_.pop_back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutArgs (const vector<bool>& mask)
|
||||
{
|
||||
assert (mask.size() == args_.size());
|
||||
size_t new_size = 1;
|
||||
Ranges oldRanges = ranges_;
|
||||
args_.clear();
|
||||
ranges_.clear();
|
||||
for (unsigned i = 0; i < mask.size(); i++) {
|
||||
if (mask[i]) {
|
||||
new_size *= ranges_[i];
|
||||
args_.push_back (args_[i]);
|
||||
ranges_.push_back (ranges_[i]);
|
||||
}
|
||||
}
|
||||
Params newps (new_size, LogAware::addIdenty());
|
||||
Params::const_iterator first = params_.begin();
|
||||
Params::const_iterator last = params_.end();
|
||||
MappingIndexer indexer (oldRanges, mask);
|
||||
if (Globals::logDomain) {
|
||||
while (first != last) {
|
||||
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
||||
++ indexer;
|
||||
}
|
||||
} else {
|
||||
while (first != last) {
|
||||
newps[indexer] += *first++;
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
params_ = newps;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::copyFromFactor (const Factor& g)
|
||||
{
|
||||
|
Reference in New Issue
Block a user