This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/bp/Factor.cpp

266 lines
5.0 KiB
C++
Raw Normal View History

#include <cstdlib>
#include <cassert>
2011-12-12 15:29:51 +00:00
#include <algorithm>
#include <iostream>
#include <sstream>
#include "Factor.h"
2012-03-22 11:33:24 +00:00
#include "Indexer.h"
2012-03-31 23:27:37 +01:00
Factor::Factor (const Factor& g)
{
2011-12-12 15:29:51 +00:00
copyFromFactor (g);
}
2012-03-31 23:27:37 +01:00
Factor::Factor (
2012-04-10 12:53:52 +01:00
const VarIds& vids,
const Ranges& ranges,
const Params& params,
unsigned distId)
{
2012-04-10 12:53:52 +01:00
args_ = vids;
ranges_ = ranges;
2012-03-31 23:27:37 +01:00
params_ = params;
2012-04-10 12:53:52 +01:00
distId_ = distId;
2012-03-31 23:27:37 +01:00
assert (params_.size() == Util::expectedSize (ranges_));
}
2012-03-31 23:27:37 +01:00
Factor::Factor (
2012-04-05 23:00:48 +01:00
const Vars& vars,
2012-03-31 23:27:37 +01:00
const Params& params,
unsigned distId)
{
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < vars.size(); i++) {
2012-03-31 23:27:37 +01:00
args_.push_back (vars[i]->varId());
2012-04-05 18:38:56 +01:00
ranges_.push_back (vars[i]->range());
2011-12-12 15:29:51 +00:00
}
2012-03-31 23:27:37 +01:00
params_ = params;
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
}
void
2012-03-22 11:33:24 +00:00
Factor::sumOutAllExcept (VarId vid)
{
2012-03-22 11:33:24 +00:00
assert (indexOf (vid) != -1);
2012-03-31 23:27:37 +01:00
while (args_.back() != vid) {
2012-03-22 11:33:24 +00:00
sumOutLastVariable();
2011-12-12 15:29:51 +00:00
}
2012-03-31 23:27:37 +01:00
while (args_.front() != vid) {
2012-03-22 11:33:24 +00:00
sumOutFirstVariable();
2011-12-12 15:29:51 +00:00
}
}
void
2012-03-22 11:33:24 +00:00
Factor::sumOutAllExcept (const VarIds& vids)
{
2012-03-31 23:27:37 +01:00
for (int i = 0; i < (int)args_.size(); i++) {
if (Util::contains (vids, args_[i]) == false) {
sumOut (args_[i]);
i --;
2012-03-22 11:33:24 +00:00
}
}
}
void
Factor::sumOut (VarId vid)
2011-12-12 15:29:51 +00:00
{
2012-03-22 11:33:24 +00:00
int idx = indexOf (vid);
assert (idx != -1);
2011-12-12 15:29:51 +00:00
2012-03-31 23:27:37 +01:00
if (vid == args_.back()) {
2012-03-22 11:33:24 +00:00
sumOutLastVariable(); // optimization
2011-12-12 15:29:51 +00:00
return;
}
2012-03-31 23:27:37 +01:00
if (vid == args_.front()) {
2012-03-22 11:33:24 +00:00
sumOutFirstVariable(); // optimization
2011-12-12 15:29:51 +00:00
return;
}
// number of parameters separating a different state of `var',
// with the states of the remaining variables fixed
unsigned varOffset = 1;
// number of parameters separating a different state of the variable
// on the left of `var', with the states of the remaining vars fixed
unsigned leftVarOffset = 1;
2012-03-31 23:27:37 +01:00
for (int i = args_.size() - 1; i > idx; i--) {
2011-12-12 15:29:51 +00:00
varOffset *= ranges_[i];
leftVarOffset *= ranges_[i];
}
2012-03-22 11:33:24 +00:00
leftVarOffset *= ranges_[idx];
unsigned offset = 0;
unsigned count1 = 0;
unsigned count2 = 0;
2012-03-31 23:27:37 +01:00
unsigned newpsSize = params_.size() / ranges_[idx];
2012-03-22 11:33:24 +00:00
Params newps;
newps.reserve (newpsSize);
2012-03-22 11:33:24 +00:00
while (newps.size() < newpsSize) {
2012-03-31 23:27:37 +01:00
double sum = LogAware::addIdenty();
2012-03-22 11:33:24 +00:00
for (unsigned i = 0; i < ranges_[idx]; i++) {
if (Globals::logDomain) {
2012-03-31 23:27:37 +01:00
sum = Util::logSum (sum, params_[offset]);
2012-03-22 11:33:24 +00:00
} else {
2012-03-31 23:27:37 +01:00
sum += params_[offset];
2011-12-12 15:29:51 +00:00
}
offset += varOffset;
}
2012-03-22 11:33:24 +00:00
newps.push_back (sum);
count1 ++;
2012-03-31 23:27:37 +01:00
if (idx == (int)args_.size() - 1) {
2012-03-22 11:33:24 +00:00
offset = count1 * ranges_[idx];
} else {
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
count1 = 0;
count2 ++;
}
offset = (leftVarOffset * count2) + count1;
}
}
2012-03-31 23:27:37 +01:00
args_.erase (args_.begin() + idx);
2012-03-22 11:33:24 +00:00
ranges_.erase (ranges_.begin() + idx);
2012-03-31 23:27:37 +01:00
params_ = newps;
}
2011-12-12 15:29:51 +00:00
void
2012-03-22 11:33:24 +00:00
Factor::sumOutFirstVariable (void)
{
2012-04-05 18:38:56 +01:00
unsigned range = ranges_.front();
unsigned sep = params_.size() / range;
2012-03-22 11:33:24 +00:00
if (Globals::logDomain) {
2012-03-31 23:27:37 +01:00
for (unsigned i = sep; i < params_.size(); i++) {
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
2012-03-22 11:33:24 +00:00
}
} else {
2012-03-31 23:27:37 +01:00
for (unsigned i = sep; i < params_.size(); i++) {
params_[i % sep] += params_[i];
2012-03-22 11:33:24 +00:00
}
2011-12-12 15:29:51 +00:00
}
2012-03-31 23:27:37 +01:00
params_.resize (sep);
args_.erase (args_.begin());
2011-12-12 15:29:51 +00:00
ranges_.erase (ranges_.begin());
}
2011-12-12 15:29:51 +00:00
void
2012-03-22 11:33:24 +00:00
Factor::sumOutLastVariable (void)
2011-12-12 15:29:51 +00:00
{
2012-04-05 18:38:56 +01:00
unsigned range = ranges_.back();
2011-12-12 15:29:51 +00:00
unsigned idx1 = 0;
unsigned idx2 = 0;
2012-03-22 11:33:24 +00:00
if (Globals::logDomain) {
2012-03-31 23:27:37 +01:00
while (idx1 < params_.size()) {
params_[idx2] = params_[idx1];
2012-03-22 11:33:24 +00:00
idx1 ++;
2012-04-05 18:38:56 +01:00
for (unsigned j = 1; j < range; j++) {
2012-03-31 23:27:37 +01:00
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
2011-12-12 15:29:51 +00:00
idx1 ++;
}
2012-03-22 11:33:24 +00:00
idx2 ++;
}
} else {
2012-03-31 23:27:37 +01:00
while (idx1 < params_.size()) {
params_[idx2] = params_[idx1];
2012-03-22 11:33:24 +00:00
idx1 ++;
2012-04-05 18:38:56 +01:00
for (unsigned j = 1; j < range; j++) {
2012-03-31 23:27:37 +01:00
params_[idx2] += params_[idx1];
2011-12-12 15:29:51 +00:00
idx1 ++;
}
2012-03-22 11:33:24 +00:00
idx2 ++;
}
2011-12-12 15:29:51 +00:00
}
2012-03-31 23:27:37 +01:00
params_.resize (idx2);
args_.pop_back();
2011-12-12 15:29:51 +00:00
ranges_.pop_back();
}
void
2012-03-31 23:27:37 +01:00
Factor::multiply (Factor& g)
2011-12-12 15:29:51 +00:00
{
2012-03-31 23:27:37 +01:00
if (args_.size() == 0) {
copyFromFactor (g);
2011-12-12 15:29:51 +00:00
return;
}
2012-03-31 23:27:37 +01:00
TFactor<VarId>::multiply (g);
2011-12-12 15:29:51 +00:00
}
void
2012-03-31 23:27:37 +01:00
Factor::reorderAccordingVarIds (void)
2011-12-12 15:29:51 +00:00
{
2012-03-31 23:27:37 +01:00
VarIds sortedVarIds = args_;
sort (sortedVarIds.begin(), sortedVarIds.end());
reorderArguments (sortedVarIds);
}
string
Factor::getLabel (void) const
{
stringstream ss;
2011-12-12 15:29:51 +00:00
ss << "f(" ;
2012-03-31 23:27:37 +01:00
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ;
2012-04-05 23:00:48 +01:00
ss << Var (args_[i], ranges_[i]).label();
}
ss << ")" ;
return ss.str();
}
void
2012-03-22 11:33:24 +00:00
Factor::print (void) const
{
2012-04-05 23:00:48 +01:00
Vars vars;
2012-03-31 23:27:37 +01:00
for (unsigned i = 0; i < args_.size(); i++) {
2012-04-05 23:00:48 +01:00
vars.push_back (new Var (args_[i], ranges_[i]));
2011-12-12 15:29:51 +00:00
}
vector<string> jointStrings = Util::getJointStateStrings (vars);
2012-03-31 23:27:37 +01:00
for (unsigned i = 0; i < params_.size(); i++) {
2011-12-12 15:29:51 +00:00
cout << "f(" << jointStrings[i] << ")" ;
2012-03-31 23:27:37 +01:00
cout << " = " << params_[i] << endl;
}
2012-03-22 11:33:24 +00:00
cout << endl;
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < vars.size(); i++) {
delete vars[i];
}
}
2012-03-31 23:27:37 +01:00
void
Factor::copyFromFactor (const Factor& g)
{
args_ = g.arguments();
ranges_ = g.ranges();
params_ = g.params();
distId_ = g.distId();
}