refactor functions for summing out
This commit is contained in:
parent
df8a3c5fdc
commit
5ff161b10f
@ -51,16 +51,16 @@ Factor::Factor (
|
|||||||
void
|
void
|
||||||
Factor::sumOut (VarId vid)
|
Factor::sumOut (VarId vid)
|
||||||
{
|
{
|
||||||
if (vid == args_.back()) {
|
if (vid == args_.front() && ranges_.front() == 2) {
|
||||||
sumOutLastVariable(); // optimization
|
// optimization
|
||||||
return;
|
sumOutFirstVariable();
|
||||||
}
|
} else if (vid == args_.back() && ranges_.back() == 2) {
|
||||||
if (vid == args_.front()) {
|
// optimization
|
||||||
sumOutFirstVariable(); // optimization
|
sumOutLastVariable();
|
||||||
return;
|
} else {
|
||||||
}
|
assert (indexOf (vid) != args_.size());
|
||||||
assert (indexOf (vid) != args_.size());
|
sumOutIndex (indexOf (vid));
|
||||||
sumOutIndex (indexOf (vid));
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -69,12 +69,7 @@ void
|
|||||||
Factor::sumOutAllExcept (VarId vid)
|
Factor::sumOutAllExcept (VarId vid)
|
||||||
{
|
{
|
||||||
assert (indexOf (vid) != args_.size());
|
assert (indexOf (vid) != args_.size());
|
||||||
while (args_.back() != vid) {
|
sumOutAllExceptIndex (indexOf (vid));
|
||||||
sumOutLastVariable();
|
|
||||||
}
|
|
||||||
while (args_.front() != vid) {
|
|
||||||
sumOutFirstVariable();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -82,67 +77,12 @@ Factor::sumOutAllExcept (VarId vid)
|
|||||||
void
|
void
|
||||||
Factor::sumOutAllExcept (const VarIds& vids)
|
Factor::sumOutAllExcept (const VarIds& vids)
|
||||||
{
|
{
|
||||||
for (int i = 0; i < (int)args_.size(); i++) {
|
vector<bool> mask (args_.size(), false);
|
||||||
if (Util::contains (vids, args_[i]) == false) {
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
sumOut (args_[i]);
|
assert (indexOf (vids[i]) != args_.size());
|
||||||
i --;
|
mask[indexOf (vids[i])] = true;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
sumOutArgs (mask);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::sumOutIndex (size_t idx)
|
|
||||||
{
|
|
||||||
assert (idx < args_.size());
|
|
||||||
// number of parameters separating a different state of `var',
|
|
||||||
// with the states of the remaining variables fixed
|
|
||||||
unsigned varOffset = 1;
|
|
||||||
|
|
||||||
// number of parameters separating a different state of the variable
|
|
||||||
// on the left of `var', with the states of the remaining vars fixed
|
|
||||||
unsigned leftVarOffset = 1;
|
|
||||||
|
|
||||||
for (int i = args_.size() - 1; i > (int)idx; i--) {
|
|
||||||
varOffset *= ranges_[i];
|
|
||||||
leftVarOffset *= ranges_[i];
|
|
||||||
}
|
|
||||||
leftVarOffset *= ranges_[idx];
|
|
||||||
|
|
||||||
unsigned offset = 0;
|
|
||||||
unsigned count1 = 0;
|
|
||||||
unsigned count2 = 0;
|
|
||||||
unsigned newpsSize = params_.size() / ranges_[idx];
|
|
||||||
|
|
||||||
Params newps;
|
|
||||||
newps.reserve (newpsSize);
|
|
||||||
|
|
||||||
while (newps.size() < newpsSize) {
|
|
||||||
double sum = LogAware::addIdenty();
|
|
||||||
for (unsigned i = 0; i < ranges_[idx]; i++) {
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
sum = Util::logSum (sum, params_[offset]);
|
|
||||||
} else {
|
|
||||||
sum += params_[offset];
|
|
||||||
}
|
|
||||||
offset += varOffset;
|
|
||||||
}
|
|
||||||
newps.push_back (sum);
|
|
||||||
count1 ++;
|
|
||||||
if (idx == args_.size() - 1) {
|
|
||||||
offset = count1 * ranges_[idx];
|
|
||||||
} else {
|
|
||||||
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
|
||||||
count1 = 0;
|
|
||||||
count2 ++;
|
|
||||||
}
|
|
||||||
offset = (leftVarOffset * count2) + count1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
args_.erase (args_.begin() + idx);
|
|
||||||
ranges_.erase (ranges_.begin() + idx);
|
|
||||||
params_ = newps;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -151,73 +91,12 @@ void
|
|||||||
Factor::sumOutAllExceptIndex (size_t idx)
|
Factor::sumOutAllExceptIndex (size_t idx)
|
||||||
{
|
{
|
||||||
assert (idx < args_.size());
|
assert (idx < args_.size());
|
||||||
while (args_.size() > idx + 1) {
|
vector<bool> mask (args_.size(), false);
|
||||||
sumOutLastVariable();
|
mask[idx] = true;
|
||||||
}
|
sumOutArgs (mask);
|
||||||
for (size_t i = 0; i < idx; i++) {
|
|
||||||
sumOutFirstVariable();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::sumOutFirstVariable (void)
|
|
||||||
{
|
|
||||||
assert (args_.size() > 1);
|
|
||||||
unsigned range = ranges_.front();
|
|
||||||
unsigned sep = params_.size() / range;
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (size_t i = sep; i < params_.size(); i++) {
|
|
||||||
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t i = sep; i < params_.size(); i++) {
|
|
||||||
params_[i % sep] += params_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
params_.resize (sep);
|
|
||||||
args_.erase (args_.begin());
|
|
||||||
ranges_.erase (ranges_.begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::sumOutLastVariable (void)
|
|
||||||
{
|
|
||||||
assert (args_.size() > 1);
|
|
||||||
size_t idx1 = 0;
|
|
||||||
size_t idx2 = 0;
|
|
||||||
unsigned range = ranges_.back();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
while (idx1 < params_.size()) {
|
|
||||||
params_[idx2] = params_[idx1];
|
|
||||||
idx1 ++;
|
|
||||||
for (unsigned j = 1; j < range; j++) {
|
|
||||||
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
|
|
||||||
idx1 ++;
|
|
||||||
}
|
|
||||||
idx2 ++;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
while (idx1 < params_.size()) {
|
|
||||||
params_[idx2] = params_[idx1];
|
|
||||||
idx1 ++;
|
|
||||||
for (unsigned j = 1; j < range; j++) {
|
|
||||||
params_[idx2] += params_[idx1];
|
|
||||||
idx1 ++;
|
|
||||||
}
|
|
||||||
idx2 ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
params_.resize (idx2);
|
|
||||||
args_.pop_back();
|
|
||||||
ranges_.pop_back();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Factor::multiply (Factor& g)
|
Factor::multiply (Factor& g)
|
||||||
{
|
{
|
||||||
@ -276,6 +155,87 @@ Factor::print (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::sumOutFirstVariable (void)
|
||||||
|
{
|
||||||
|
size_t sep = params_.size() / 2;
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
std::transform (
|
||||||
|
params_.begin(), params_.begin() + sep,
|
||||||
|
params_.begin() + sep, params_.begin(),
|
||||||
|
Util::logSum);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
std::transform (
|
||||||
|
params_.begin(), params_.begin() + sep,
|
||||||
|
params_.begin() + sep, params_.begin(),
|
||||||
|
std::plus<double>());
|
||||||
|
}
|
||||||
|
params_.resize (sep);
|
||||||
|
args_.erase (args_.begin());
|
||||||
|
ranges_.erase (ranges_.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::sumOutLastVariable (void)
|
||||||
|
{
|
||||||
|
Params::iterator first1 = params_.begin();
|
||||||
|
Params::iterator first2 = params_.begin();
|
||||||
|
Params::iterator last = params_.end();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
while (first2 != last) {
|
||||||
|
// the arguments can be swaped, but that is ok
|
||||||
|
*first1++ = Util::logSum (*first2++, *first2++);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
while (first2 != last) {
|
||||||
|
*first1++ = (*first2++) + (*first2++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
params_.resize (params_.size() / 2);
|
||||||
|
args_.pop_back();
|
||||||
|
ranges_.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::sumOutArgs (const vector<bool>& mask)
|
||||||
|
{
|
||||||
|
assert (mask.size() == args_.size());
|
||||||
|
size_t new_size = 1;
|
||||||
|
Ranges oldRanges = ranges_;
|
||||||
|
args_.clear();
|
||||||
|
ranges_.clear();
|
||||||
|
for (unsigned i = 0; i < mask.size(); i++) {
|
||||||
|
if (mask[i]) {
|
||||||
|
new_size *= ranges_[i];
|
||||||
|
args_.push_back (args_[i]);
|
||||||
|
ranges_.push_back (ranges_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Params newps (new_size, LogAware::addIdenty());
|
||||||
|
Params::const_iterator first = params_.begin();
|
||||||
|
Params::const_iterator last = params_.end();
|
||||||
|
MappingIndexer indexer (oldRanges, mask);
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
while (first != last) {
|
||||||
|
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
while (first != last) {
|
||||||
|
newps[indexer] += *first++;
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
params_ = newps;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Factor::copyFromFactor (const Factor& g)
|
Factor::copyFromFactor (const Factor& g)
|
||||||
{
|
{
|
||||||
|
@ -128,6 +128,31 @@ class TFactor
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void sumOutIndex (size_t idx)
|
||||||
|
{
|
||||||
|
assert (idx < args_.size());
|
||||||
|
assert (args_.size() > 1);
|
||||||
|
size_t new_size = params_.size() / ranges_[idx];
|
||||||
|
Params newps (new_size, LogAware::addIdenty());
|
||||||
|
Params::const_iterator first = params_.begin();
|
||||||
|
Params::const_iterator last = params_.end();
|
||||||
|
MappingIndexer indexer (ranges_, idx);
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
while (first != last) {
|
||||||
|
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
while (first != last) {
|
||||||
|
newps[indexer] += *first++;
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
params_ = newps;
|
||||||
|
args_.erase (args_.begin() + idx);
|
||||||
|
ranges_.erase (ranges_.begin() + idx);
|
||||||
|
}
|
||||||
|
|
||||||
void absorveEvidence (const T& arg, unsigned evidence)
|
void absorveEvidence (const T& arg, unsigned evidence)
|
||||||
{
|
{
|
||||||
size_t idx = indexOf (arg);
|
size_t idx = indexOf (arg);
|
||||||
@ -137,7 +162,7 @@ class TFactor
|
|||||||
params_.clear();
|
params_.clear();
|
||||||
params_.reserve (copy.size() / ranges_[idx]);
|
params_.reserve (copy.size() / ranges_[idx]);
|
||||||
StatesIndexer indexer (ranges_);
|
StatesIndexer indexer (ranges_);
|
||||||
for (size_t i = 0; i < evidence; i++) {
|
for (unsigned i = 0; i < evidence; i++) {
|
||||||
indexer.increment (idx);
|
indexer.increment (idx);
|
||||||
}
|
}
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
@ -246,7 +271,7 @@ class TFactor
|
|||||||
params_.push_back (copy[i]);
|
params_.push_back (copy[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -270,14 +295,8 @@ class Factor : public TFactor<VarId>
|
|||||||
|
|
||||||
void sumOutAllExcept (const VarIds&);
|
void sumOutAllExcept (const VarIds&);
|
||||||
|
|
||||||
void sumOutIndex (size_t idx);
|
|
||||||
|
|
||||||
void sumOutAllExceptIndex (size_t idx);
|
void sumOutAllExceptIndex (size_t idx);
|
||||||
|
|
||||||
void sumOutFirstVariable (void);
|
|
||||||
|
|
||||||
void sumOutLastVariable (void);
|
|
||||||
|
|
||||||
void multiply (Factor&);
|
void multiply (Factor&);
|
||||||
|
|
||||||
void reorderAccordingVarIds (void);
|
void reorderAccordingVarIds (void);
|
||||||
@ -287,6 +306,12 @@ class Factor : public TFactor<VarId>
|
|||||||
void print (void) const;
|
void print (void) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void sumOutFirstVariable (void);
|
||||||
|
|
||||||
|
void sumOutLastVariable (void);
|
||||||
|
|
||||||
|
void sumOutArgs (const vector<bool>& mask);
|
||||||
|
|
||||||
void copyFromFactor (const Factor& f);
|
void copyFromFactor (const Factor& f);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
@ -219,12 +219,12 @@ SumOutOperator::apply (void)
|
|||||||
size_t fIdx = product->indexOfGroup (group_);
|
size_t fIdx = product->indexOfGroup (group_);
|
||||||
LogVarSet excl = product->exclusiveLogVars (fIdx);
|
LogVarSet excl = product->exclusiveLogVars (fIdx);
|
||||||
if (product->constr()->isCountNormalized (excl)) {
|
if (product->constr()->isCountNormalized (excl)) {
|
||||||
product->sumOut (fIdx);
|
product->sumOutIndex (fIdx);
|
||||||
pfList_.addShattered (product);
|
pfList_.addShattered (product);
|
||||||
} else {
|
} else {
|
||||||
Parfactors pfs = FoveSolver::countNormalize (product, excl);
|
Parfactors pfs = FoveSolver::countNormalize (product, excl);
|
||||||
for (size_t i = 0; i < pfs.size(); i++) {
|
for (size_t i = 0; i < pfs.size(); i++) {
|
||||||
pfs[i]->sumOut (fIdx);
|
pfs[i]->sumOutIndex (fIdx);
|
||||||
pfList_.add (pfs[i]);
|
pfList_.add (pfs[i]);
|
||||||
}
|
}
|
||||||
delete product;
|
delete product;
|
||||||
|
@ -146,75 +146,39 @@ class StatesIndexer
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MapIndexer
|
class MappingIndexer
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
MapIndexer (const Ranges& ranges, const vector<bool>& mapDims)
|
MappingIndexer (const Ranges& ranges, const vector<bool>& mask)
|
||||||
|
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||||
|
valid_(true)
|
||||||
{
|
{
|
||||||
assert (ranges.size() == mapDims.size());
|
|
||||||
size_t prod = 1;
|
size_t prod = 1;
|
||||||
offsets_.resize (ranges.size());
|
offsets_.resize (ranges.size());
|
||||||
for (size_t i = ranges.size(); i-- > 0; ) {
|
for (size_t i = ranges.size(); i-- > 0; ) {
|
||||||
if (mapDims[i]) {
|
if (mask[i]) {
|
||||||
offsets_[i] = prod;
|
offsets_[i] = prod;
|
||||||
prod *= ranges[i];
|
prod *= ranges[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
indices_.resize (ranges.size(), 0);
|
assert (ranges.size() == mask.size());
|
||||||
ranges_ = ranges;
|
|
||||||
index_ = 0;
|
|
||||||
valid_ = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MapIndexer (const Ranges& ranges, size_t ignoreDim)
|
MappingIndexer (const Ranges& ranges, size_t dim)
|
||||||
|
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||||
|
valid_(true)
|
||||||
{
|
{
|
||||||
size_t prod = 1;
|
size_t prod = 1;
|
||||||
offsets_.resize (ranges.size());
|
offsets_.resize (ranges.size());
|
||||||
for (size_t i = ranges.size(); i-- > 0; ) {
|
for (size_t i = ranges.size(); i-- > 0; ) {
|
||||||
if (i != ignoreDim) {
|
if (i != dim) {
|
||||||
offsets_[i] = prod;
|
offsets_[i] = prod;
|
||||||
prod *= ranges[i];
|
prod *= ranges[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
indices_.resize (ranges.size(), 0);
|
|
||||||
ranges_ = ranges;
|
|
||||||
index_ = 0;
|
|
||||||
valid_ = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
MappingIndexer& operator++ (void)
|
||||||
MapIndexer (
|
|
||||||
const VarIds& loopVids,
|
|
||||||
const Ranges& loopRanges,
|
|
||||||
const VarIds& mapVids,
|
|
||||||
const Ranges& mapRanges)
|
|
||||||
{
|
|
||||||
unsigned prod = 1;
|
|
||||||
vector<unsigned> offsets (mapRanges.size());
|
|
||||||
for (size_t i = mapRanges.size(); i-- > 0; ) {
|
|
||||||
offsets[i] = prod;
|
|
||||||
prod *= mapRanges[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
offsets_.reserve (loopVids.size());
|
|
||||||
for (size_t i = 0; i < loopVids.size(); i++) {
|
|
||||||
VarIds::const_iterator it =
|
|
||||||
std::find (mapVids.begin(), mapVids.end(), loopVids[i]);
|
|
||||||
if (it != mapVids.end()) {
|
|
||||||
offsets_.push_back (offsets[it - mapVids.begin()]);
|
|
||||||
} else {
|
|
||||||
offsets_.push_back (0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
indices_.resize (loopVids.size(), 0);
|
|
||||||
ranges_ = loopRanges;
|
|
||||||
index_ = 0;
|
|
||||||
size_ = prod;
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
MapIndexer& operator ++ (void)
|
|
||||||
{
|
{
|
||||||
assert (valid_);
|
assert (valid_);
|
||||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||||
@ -231,11 +195,6 @@ class MapIndexer
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t mappedIndex (void) const
|
|
||||||
{
|
|
||||||
return index_;
|
|
||||||
}
|
|
||||||
|
|
||||||
operator size_t (void) const
|
operator size_t (void) const
|
||||||
{
|
{
|
||||||
return index_;
|
return index_;
|
||||||
@ -259,21 +218,27 @@ class MapIndexer
|
|||||||
index_ = 0;
|
index_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &os, const MapIndexer& idx)
|
friend std::ostream& operator<< (std::ostream&, const MappingIndexer&);
|
||||||
{
|
|
||||||
os << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
|
|
||||||
os << idx.indices_;
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t index_;
|
size_t index_;
|
||||||
bool valid_;
|
Ranges indices_;
|
||||||
vector<unsigned> ranges_;
|
const Ranges& ranges_;
|
||||||
vector<unsigned> indices_;
|
bool valid_;
|
||||||
vector<size_t> offsets_;
|
vector<size_t> offsets_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
inline std::ostream& operator<< (ostream &os, const MappingIndexer& mi)
|
||||||
|
{
|
||||||
|
os << "(" ;
|
||||||
|
os << std::setw (2) << std::setfill('0') << mi.index_;
|
||||||
|
os << ") " ;
|
||||||
|
os << mi.indices_;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif // HORUS_STATESINDEXER_H
|
#endif // HORUS_STATESINDEXER_H
|
||||||
|
|
||||||
|
@ -124,7 +124,7 @@ Parfactor::exclusiveLogVars (size_t fIdx) const
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Parfactor::sumOut (size_t fIdx)
|
Parfactor::sumOutIndex (size_t fIdx)
|
||||||
{
|
{
|
||||||
assert (fIdx < args_.size());
|
assert (fIdx < args_.size());
|
||||||
assert (args_[fIdx].contains (elimLogVars()));
|
assert (args_[fIdx].contains (elimLogVars()));
|
||||||
@ -134,46 +134,29 @@ Parfactor::sumOut (size_t fIdx)
|
|||||||
args_[fIdx].countedLogVar());
|
args_[fIdx].countedLogVar());
|
||||||
unsigned R = args_[fIdx].range();
|
unsigned R = args_[fIdx].range();
|
||||||
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||||
StatesIndexer sindexer (ranges_, fIdx);
|
StatesIndexer indexer (ranges_, fIdx);
|
||||||
while (sindexer.valid()) {
|
while (indexer.valid()) {
|
||||||
unsigned h = sindexer[fIdx];
|
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
params_[sindexer] += numAssigns[h];
|
params_[indexer] += numAssigns[ indexer[fIdx] ];
|
||||||
} else {
|
} else {
|
||||||
params_[sindexer] *= numAssigns[h];
|
params_[indexer] *= numAssigns[ indexer[fIdx] ];
|
||||||
}
|
}
|
||||||
++ sindexer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Params copy = params_;
|
|
||||||
params_.clear();
|
|
||||||
params_.resize (copy.size() / ranges_[fIdx], LogAware::addIdenty());
|
|
||||||
MapIndexer indexer (ranges_, fIdx);
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (size_t i = 0; i < copy.size(); i++) {
|
|
||||||
params_[indexer] = Util::logSum (params_[indexer], copy[i]);
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t i = 0; i < copy.size(); i++) {
|
|
||||||
params_[indexer] += copy[i];
|
|
||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||||
|
unsigned exp;
|
||||||
if (args_[fIdx].isCounting()) {
|
if (args_[fIdx].isCounting()) {
|
||||||
// counting log vars were already raised on counting conversion
|
// counting log vars were already raised on counting conversion
|
||||||
LogAware::pow (params_, constr_->getConditionalCount (
|
exp = constr_->getConditionalCount (excl - args_[fIdx].countedLogVar());
|
||||||
excl - args_[fIdx].countedLogVar()));
|
|
||||||
} else {
|
} else {
|
||||||
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
exp = constr_->getConditionalCount (excl);
|
||||||
}
|
}
|
||||||
constr_->remove (excl);
|
constr_->remove (excl);
|
||||||
|
|
||||||
args_.erase (args_.begin() + fIdx);
|
TFactor<ProbFormula>::sumOutIndex (fIdx);
|
||||||
ranges_.erase (ranges_.begin() + fIdx);
|
LogAware::pow (params_, exp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -245,10 +228,10 @@ Parfactor::countConvert (LogVar X)
|
|||||||
params_.reserve (sumout.size() * H);
|
params_.reserve (sumout.size() * H);
|
||||||
|
|
||||||
ranges_[fIdx] = H;
|
ranges_[fIdx] = H;
|
||||||
MapIndexer mapIndexer (ranges_, fIdx);
|
MappingIndexer mapIndexer (ranges_, fIdx);
|
||||||
while (mapIndexer.valid()) {
|
while (mapIndexer.valid()) {
|
||||||
double prod = LogAware::multIdenty();
|
double prod = LogAware::multIdenty();
|
||||||
size_t i = mapIndexer.mappedIndex();
|
size_t i = mapIndexer;
|
||||||
unsigned h = mapIndexer[fIdx];
|
unsigned h = mapIndexer[fIdx];
|
||||||
for (unsigned r = 0; r < R; r++) {
|
for (unsigned r = 0; r < R; r++) {
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
|
@ -44,7 +44,7 @@ class Parfactor : public TFactor<ProbFormula>
|
|||||||
|
|
||||||
LogVarSet exclusiveLogVars (size_t fIdx) const;
|
LogVarSet exclusiveLogVars (size_t fIdx) const;
|
||||||
|
|
||||||
void sumOut (size_t fIdx);
|
void sumOutIndex (size_t fIdx);
|
||||||
|
|
||||||
void multiply (Parfactor&);
|
void multiply (Parfactor&);
|
||||||
|
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
- Refactor sum out in factor
|
|
||||||
- Add a way to sum out several vars at the same time
|
|
||||||
- Receive ranges as a constant reference in Indexer
|
- Receive ranges as a constant reference in Indexer
|
||||||
- Check if evidence remains in the compressed factor graph
|
- Check if evidence remains in the compressed factor graph
|
||||||
- Consider using hashs instead of vectors of colors to calculate the groups in
|
- Consider using hashs instead of vectors of colors to calculate the groups in
|
||||||
counting bp
|
counting bp
|
||||||
- use more psize_t instead of unsigned for looping through params
|
|
||||||
- Find a way to decrease the time required to find an
|
- Find a way to decrease the time required to find an
|
||||||
elimination order for variable elimination
|
elimination order for variable elimination
|
||||||
- Add a sequential elimination heuristic
|
- Add a sequential elimination heuristic
|
||||||
|
Reference in New Issue
Block a user