refactor the way we calculate the product of two factors
This commit is contained in:
parent
bc2da47804
commit
1239832c21
@ -73,63 +73,42 @@ class TFactor
|
|||||||
|
|
||||||
void multiply (TFactor<T>& g)
|
void multiply (TFactor<T>& g)
|
||||||
{
|
{
|
||||||
|
if (args_ == g.arguments()) {
|
||||||
|
// optimization
|
||||||
|
Globals::logDomain
|
||||||
|
? params_ += g.params()
|
||||||
|
: params_ *= g.params();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
unsigned range_prod = 1;
|
||||||
|
bool share_arguments = false;
|
||||||
const vector<T>& g_args = g.arguments();
|
const vector<T>& g_args = g.arguments();
|
||||||
const Ranges& g_ranges = g.ranges();
|
const Ranges& g_ranges = g.ranges();
|
||||||
const Params& g_params = g.params();
|
const Params& g_params = g.params();
|
||||||
if (args_ == g_args) {
|
|
||||||
// optimization: if the factors contain the same set of args,
|
|
||||||
// we can do a 1 to 1 operation on the parameters
|
|
||||||
Globals::logDomain ? params_ += g_params
|
|
||||||
: params_ *= g_params;
|
|
||||||
} else {
|
|
||||||
bool sharedArgs = false;
|
|
||||||
vector<size_t> gvarpos;
|
|
||||||
for (size_t i = 0; i < g_args.size(); i++) {
|
for (size_t i = 0; i < g_args.size(); i++) {
|
||||||
size_t idx = indexOf (g_args[i]);
|
size_t idx = indexOf (g_args[i]);
|
||||||
if (idx == g_args.size()) {
|
if (idx == args_.size()) {
|
||||||
ullong newSize = params_.size() * g_ranges[i];
|
range_prod *= g_ranges[i];
|
||||||
if (newSize > params_.max_size()) {
|
args_.push_back (g_args[i]);
|
||||||
cerr << "error: an overflow occurred on factor multiplication" ;
|
ranges_.push_back (g_ranges[i]);
|
||||||
cerr << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
insertArgument (g_args[i], g_ranges[i]);
|
|
||||||
gvarpos.push_back (args_.size() - 1);
|
|
||||||
} else {
|
} else {
|
||||||
sharedArgs = true;
|
share_arguments = true;
|
||||||
gvarpos.push_back (idx);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (sharedArgs == false) {
|
if (share_arguments == false) {
|
||||||
// optimization: if the original factors doesn't have common args,
|
// optimization
|
||||||
// we don't need to marry the states of the common args
|
cartesianProduct (g_params.begin(), g_params.end());
|
||||||
size_t count = 0;
|
} else {
|
||||||
for (size_t i = 0; i < params_.size(); i++) {
|
extend (range_prod);
|
||||||
|
Params::iterator it = params_.begin();
|
||||||
|
CutIndexer indexer (args_, ranges_, g_args, g_ranges);
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
params_[i] += g_params[count];
|
for (; indexer.valid(); ++indexer) {
|
||||||
} else {
|
*it++ += g_params[indexer];
|
||||||
params_[i] *= g_params[count];
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
if (count >= g_params.size()) {
|
|
||||||
count = 0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Indexer indexer (ranges_, false);
|
for (; indexer.valid(); ++indexer) {
|
||||||
while (indexer.valid()) {
|
*it++ *= g_params[indexer];
|
||||||
size_t g_li = 0;
|
|
||||||
size_t prod = 1;
|
|
||||||
for (size_t j = gvarpos.size(); j-- > 0; ) {
|
|
||||||
g_li += indexer[gvarpos[j]] * prod;
|
|
||||||
prod *= g_ranges[j];
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
params_[indexer] += g_params[g_li];
|
|
||||||
} else {
|
|
||||||
params_[indexer] *= g_params[g_li];
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -145,14 +124,12 @@ class TFactor
|
|||||||
Params::const_iterator last = params_.end();
|
Params::const_iterator last = params_.end();
|
||||||
CutIndexer indexer (ranges_, idx);
|
CutIndexer indexer (ranges_, idx);
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
while (first != last) {
|
for (; first != last; ++indexer) {
|
||||||
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
||||||
++ indexer;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
while (first != last) {
|
for (; first != last; ++indexer) {
|
||||||
newps[indexer] += *first++;
|
newps[indexer] += *first++;
|
||||||
++ indexer;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
params_ = newps;
|
params_ = newps;
|
||||||
@ -160,15 +137,15 @@ class TFactor
|
|||||||
ranges_.erase (ranges_.begin() + idx);
|
ranges_.erase (ranges_.begin() + idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
void absorveEvidence (const T& arg, unsigned evidence)
|
void absorveEvidence (const T& arg, unsigned obsIdx)
|
||||||
{
|
{
|
||||||
size_t idx = indexOf (arg);
|
size_t idx = indexOf (arg);
|
||||||
assert (idx != args_.size());
|
assert (idx != args_.size());
|
||||||
assert (evidence < ranges_[idx]);
|
assert (obsIdx < ranges_[idx]);
|
||||||
Params newps;
|
Params newps;
|
||||||
newps.reserve (params_.size() / ranges_[idx]);
|
newps.reserve (params_.size() / ranges_[idx]);
|
||||||
Indexer indexer (ranges_);
|
Indexer indexer (ranges_);
|
||||||
for (unsigned i = 0; i < evidence; i++) {
|
for (unsigned i = 0; i < obsIdx; ++i) {
|
||||||
indexer.incrementDimension (idx);
|
indexer.incrementDimension (idx);
|
||||||
}
|
}
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
@ -199,7 +176,7 @@ class TFactor
|
|||||||
size_t li = i;
|
size_t li = i;
|
||||||
// calculate vector index corresponding to linear index
|
// calculate vector index corresponding to linear index
|
||||||
vector<unsigned> vi (N);
|
vector<unsigned> vi (N);
|
||||||
for (int k = N-1; k >= 0; k--) {
|
for (unsigned k = N; k-- > 0; ) {
|
||||||
vi[k] = li % ranges_[k];
|
vi[k] = li % ranges_[k];
|
||||||
li /= ranges_[k];
|
li /= ranges_[k];
|
||||||
}
|
}
|
||||||
@ -246,39 +223,45 @@ class TFactor
|
|||||||
unsigned distId_;
|
unsigned distId_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void insertArgument (const T& arg, unsigned range)
|
void extend (unsigned range_prod)
|
||||||
{
|
{
|
||||||
assert (indexOf (arg) == args_.size());
|
Params backup = params_;
|
||||||
Params copy = params_;
|
|
||||||
params_.clear();
|
params_.clear();
|
||||||
params_.reserve (copy.size() * range);
|
params_.reserve (backup.size() * range_prod);
|
||||||
for (size_t i = 0; i < copy.size(); i++) {
|
Params::const_iterator first = backup.begin();
|
||||||
for (unsigned reps = 0; reps < range; reps++) {
|
Params::const_iterator last = backup.end();
|
||||||
params_.push_back (copy[i]);
|
for (; first != last; ++first) {
|
||||||
|
for (unsigned reps = 0; reps < range_prod; ++reps) {
|
||||||
|
params_.push_back (*first);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
args_.push_back (arg);
|
|
||||||
ranges_.push_back (range);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void insertArguments (const vector<T>& args, const Ranges& ranges)
|
void cartesianProduct (
|
||||||
|
Params::const_iterator first2,
|
||||||
|
Params::const_iterator last2)
|
||||||
{
|
{
|
||||||
Params copy = params_;
|
Params backup = params_;
|
||||||
unsigned nrStates = 1;
|
|
||||||
for (size_t i = 0; i < args.size(); i++) {
|
|
||||||
assert (indexOf (args[i]) == args_.size());
|
|
||||||
args_.push_back (args[i]);
|
|
||||||
ranges_.push_back (ranges[i]);
|
|
||||||
nrStates *= ranges[i];
|
|
||||||
}
|
|
||||||
params_.clear();
|
params_.clear();
|
||||||
params_.reserve (copy.size() * nrStates);
|
params_.reserve (params_.size() * (last2 - first2));
|
||||||
for (size_t i = 0; i < copy.size(); i++) {
|
Params::const_iterator first1 = backup.begin();
|
||||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
Params::const_iterator last1 = backup.end();
|
||||||
params_.push_back (copy[i]);
|
Params::const_iterator tmp;
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (; first1 != last1; ++first1) {
|
||||||
|
for (tmp = first2; tmp != last2; ++tmp) {
|
||||||
|
params_.push_back ((*first1) + (*tmp));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; first1 != last1; ++first1) {
|
||||||
|
for (tmp = first2; tmp != last2; ++tmp) {
|
||||||
|
params_.push_back ((*first1) * (*tmp));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user