This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/bp/Factor.h

306 lines
7.7 KiB
C
Raw Normal View History

2011-12-12 15:29:51 +00:00
#ifndef HORUS_FACTOR_H
#define HORUS_FACTOR_H
#include <vector>
2012-04-05 23:00:48 +01:00
#include "Var.h"
2012-03-31 23:27:37 +01:00
#include "Indexer.h"
#include "Util.h"
2011-12-12 15:29:51 +00:00
using namespace std;
2012-03-31 23:27:37 +01:00
template <typename T>
class TFactor
{
public:
const vector<T>& arguments (void) const { return args_; }
vector<T>& arguments (void) { return args_; }
const Ranges& ranges (void) const { return ranges_; }
const Params& params (void) const { return params_; }
Params& params (void) { return params_; }
unsigned nrArguments (void) const { return args_.size(); }
unsigned size (void) const { return params_.size(); }
unsigned distId (void) const { return distId_; }
void setDistId (unsigned id) { distId_ = id; }
2012-04-10 12:53:52 +01:00
void normalize (void) { LogAware::normalize (params_); }
2012-03-31 23:27:37 +01:00
void setParams (const Params& newParams)
{
params_ = newParams;
assert (params_.size() == Util::expectedSize (ranges_));
}
int indexOf (const T& t) const
{
int idx = -1;
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i] == t) {
idx = i;
break;
}
}
return idx;
}
const T& argument (unsigned idx) const
{
assert (idx < args_.size());
return args_[idx];
2012-04-03 11:58:21 +01:00
}
T& argument (unsigned idx)
{
assert (idx < args_.size());
return args_[idx];
2012-03-31 23:27:37 +01:00
}
unsigned range (unsigned idx) const
{
assert (idx < ranges_.size());
return ranges_[idx];
}
void multiply (TFactor<T>& g)
{
const vector<T>& g_args = g.arguments();
const Ranges& g_ranges = g.ranges();
const Params& g_params = g.params();
if (args_ == g_args) {
// optimization: if the factors contain the same set of args,
// we can do a 1 to 1 operation on the parameters
if (Globals::logDomain) {
Util::add (params_, g_params);
} else {
Util::multiply (params_, g_params);
}
} else {
bool sharedArgs = false;
vector<unsigned> gvarpos;
for (unsigned i = 0; i < g_args.size(); i++) {
int idx = indexOf (g_args[i]);
if (idx == -1) {
ullong newSize = params_.size() * g_ranges[i];
if (newSize > params_.max_size()) {
2012-04-30 14:12:00 +01:00
cerr << "error: an overflow occurred on factor multiplication" ;
cerr << endl;
abort();
}
2012-03-31 23:27:37 +01:00
insertArgument (g_args[i], g_ranges[i]);
gvarpos.push_back (args_.size() - 1);
} else {
sharedArgs = true;
gvarpos.push_back (idx);
}
}
if (sharedArgs == false) {
// optimization: if the original factors doesn't have common args,
// we don't need to marry the states of the common args
unsigned count = 0;
for (unsigned i = 0; i < params_.size(); i++) {
if (Globals::logDomain) {
params_[i] += g_params[count];
} else {
params_[i] *= g_params[count];
}
count ++;
if (count >= g_params.size()) {
count = 0;
}
}
} else {
StatesIndexer indexer (ranges_, false);
while (indexer.valid()) {
unsigned g_li = 0;
unsigned prod = 1;
for (int j = gvarpos.size() - 1; j >= 0; j--) {
g_li += indexer[gvarpos[j]] * prod;
prod *= g_ranges[j];
}
if (Globals::logDomain) {
params_[indexer] += g_params[g_li];
} else {
params_[indexer] *= g_params[g_li];
}
++ indexer;
}
}
}
}
void absorveEvidence (const T& arg, unsigned evidence)
{
int idx = indexOf (arg);
assert (idx != -1);
assert (evidence < ranges_[idx]);
Params copy = params_;
params_.clear();
params_.reserve (copy.size() / ranges_[idx]);
StatesIndexer indexer (ranges_);
for (unsigned i = 0; i < evidence; i++) {
indexer.increment (idx);
}
while (indexer.valid()) {
params_.push_back (copy[indexer]);
indexer.incrementExcluding (idx);
}
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
void reorderArguments (const vector<T> newArgs)
{
assert (newArgs.size() == args_.size());
if (newArgs == args_) {
return; // already in the wanted order
}
Ranges newRanges;
vector<unsigned> positions;
for (unsigned i = 0; i < newArgs.size(); i++) {
unsigned idx = indexOf (newArgs[i]);
newRanges.push_back (ranges_[idx]);
positions.push_back (idx);
}
unsigned N = ranges_.size();
Params newParams (params_.size());
for (unsigned i = 0; i < params_.size(); i++) {
unsigned li = i;
// calculate vector index corresponding to linear index
vector<unsigned> vi (N);
for (int k = N-1; k >= 0; k--) {
vi[k] = li % ranges_[k];
li /= ranges_[k];
}
// convert permuted vector index to corresponding linear index
unsigned prod = 1;
unsigned new_li = 0;
for (int k = N - 1; k >= 0; k--) {
new_li += vi[positions[k]] * prod;
prod *= ranges_[positions[k]];
}
newParams[new_li] = params_[i];
}
args_ = newArgs;
ranges_ = newRanges;
params_ = newParams;
}
bool contains (const T& arg) const
{
return Util::contains (args_, arg);
}
bool contains (const vector<T>& args) const
{
for (unsigned i = 0; i < args_.size(); i++) {
if (contains (args[i]) == false) {
return false;
}
}
return true;
}
double& operator[] (psize_t idx)
{
assert (idx < params_.size());
return params_[idx];
}
2012-03-31 23:27:37 +01:00
protected:
vector<T> args_;
Ranges ranges_;
Params params_;
unsigned distId_;
private:
void insertArgument (const T& arg, unsigned range)
{
assert (indexOf (arg) == -1);
Params copy = params_;
params_.clear();
params_.reserve (copy.size() * range);
for (unsigned i = 0; i < copy.size(); i++) {
for (unsigned reps = 0; reps < range; reps++) {
params_.push_back (copy[i]);
}
}
args_.push_back (arg);
ranges_.push_back (range);
}
void insertArguments (const vector<T>& args, const Ranges& ranges)
{
Params copy = params_;
unsigned nrStates = 1;
for (unsigned i = 0; i < args.size(); i++) {
assert (indexOf (args[i]) == -1);
args_.push_back (args[i]);
ranges_.push_back (ranges[i]);
nrStates *= ranges[i];
}
params_.clear();
params_.reserve (copy.size() * nrStates);
for (unsigned i = 0; i < copy.size(); i++) {
for (unsigned reps = 0; reps < nrStates; reps++) {
params_.push_back (copy[i]);
}
}
}
};
class Factor : public TFactor<VarId>
{
public:
Factor (void) { }
2012-03-31 23:27:37 +01:00
Factor (const Factor&);
2012-03-31 23:27:37 +01:00
2012-04-10 12:53:52 +01:00
Factor (const VarIds&, const Ranges&, const Params&,
2012-03-31 23:27:37 +01:00
unsigned = Util::maxUnsigned());
2012-04-10 12:53:52 +01:00
Factor (const Vars&, const Params&,
unsigned = Util::maxUnsigned());
2012-03-31 23:27:37 +01:00
void sumOut (VarId);
2012-03-31 23:27:37 +01:00
void sumOutAllExcept (VarId);
void sumOutAllExcept (const VarIds&);
void sumOutIndex (unsigned idx);
void sumOutAllExceptIndex (unsigned idx);
2012-03-31 23:27:37 +01:00
void sumOutFirstVariable (void);
void sumOutLastVariable (void);
void multiply (Factor&);
void reorderAccordingVarIds (void);
string getLabel (void) const;
void print (void) const;
private:
2012-03-31 23:27:37 +01:00
void copyFromFactor (const Factor& f);
2011-12-12 15:29:51 +00:00
};
2011-12-12 15:29:51 +00:00
#endif // HORUS_FACTOR_H