yay, my first commit
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
|
||||
#include "Factor.h"
|
||||
#include "Indexer.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const Factor& g)
|
||||
@@ -18,206 +18,73 @@ Factor::Factor (const Factor& g)
|
||||
|
||||
|
||||
|
||||
Factor::Factor (VarId vid, unsigned nStates)
|
||||
Factor::Factor (VarId vid, unsigned nrStates)
|
||||
{
|
||||
varids_.push_back (vid);
|
||||
ranges_.push_back (nStates);
|
||||
dist_ = new Distribution (Params (nStates, 1.0));
|
||||
args_.push_back (vid);
|
||||
ranges_.push_back (nrStates);
|
||||
params_.resize (nrStates, 1.0);
|
||||
distId_ = Util::maxUnsigned();
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const VarNodes& vars)
|
||||
{
|
||||
int nParams = 1;
|
||||
int nrParams = 1;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varids_.push_back (vars[i]->varId());
|
||||
args_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
nParams *= vars[i]->nrStates();
|
||||
nrParams *= vars[i]->nrStates();
|
||||
}
|
||||
// create a uniform distribution
|
||||
double val = 1.0 / nParams;
|
||||
dist_ = new Distribution (Params (nParams, val));
|
||||
double val = 1.0 / nrParams;
|
||||
params_.resize (nrParams, val);
|
||||
distId_ = Util::maxUnsigned();
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (VarId vid, unsigned nStates, const Params& params)
|
||||
Factor::Factor (
|
||||
VarId vid,
|
||||
unsigned nrStates,
|
||||
const Params& params)
|
||||
{
|
||||
varids_.push_back (vid);
|
||||
ranges_.push_back (nStates);
|
||||
dist_ = new Distribution (params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (VarNodes& vars, Distribution* dist)
|
||||
{
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varids_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
}
|
||||
dist_ = dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const VarNodes& vars, const Params& params)
|
||||
{
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varids_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
}
|
||||
dist_ = new Distribution (params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const VarIds& vids,
|
||||
const Ranges& ranges,
|
||||
const Params& params)
|
||||
{
|
||||
varids_ = vids;
|
||||
ranges_ = ranges;
|
||||
dist_ = new Distribution (params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::~Factor (void)
|
||||
{
|
||||
if (dist_->shared() == false) {
|
||||
delete dist_;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::setParameters (const Params& params)
|
||||
{
|
||||
assert (dist_->params.size() == params.size());
|
||||
dist_->params = params;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::copyFromFactor (const Factor& g)
|
||||
{
|
||||
varids_ = g.getVarIds();
|
||||
ranges_ = g.getRanges();
|
||||
dist_ = new Distribution (g.getParameters());
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::multiply (const Factor& g)
|
||||
{
|
||||
if (varids_.size() == 0) {
|
||||
copyFromFactor (g);
|
||||
return;
|
||||
}
|
||||
|
||||
const VarIds& g_varids = g.getVarIds();
|
||||
const Ranges& g_ranges = g.getRanges();
|
||||
const Params& g_params = g.getParameters();
|
||||
|
||||
if (varids_ == g_varids) {
|
||||
// optimization: if the factors contain the same set of variables,
|
||||
// we can do a 1 to 1 operation on the parameters
|
||||
if (Globals::logDomain) {
|
||||
Util::add (dist_->params, g_params);
|
||||
} else {
|
||||
Util::multiply (dist_->params, g_params);
|
||||
}
|
||||
} else {
|
||||
bool sharedVars = false;
|
||||
vector<unsigned> gvarpos;
|
||||
for (unsigned i = 0; i < g_varids.size(); i++) {
|
||||
int idx = indexOf (g_varids[i]);
|
||||
if (idx == -1) {
|
||||
insertVariable (g_varids[i], g_ranges[i]);
|
||||
gvarpos.push_back (varids_.size() - 1);
|
||||
} else {
|
||||
sharedVars = true;
|
||||
gvarpos.push_back (idx);
|
||||
}
|
||||
}
|
||||
if (sharedVars == false) {
|
||||
// optimization: if the original factors doesn't have common variables,
|
||||
// we don't need to marry the states of the common variables
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
if (Globals::logDomain) {
|
||||
dist_->params[i] += g_params[count];
|
||||
} else {
|
||||
dist_->params[i] *= g_params[count];
|
||||
}
|
||||
count ++;
|
||||
if (count >= g_params.size()) {
|
||||
count = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
StatesIndexer indexer (ranges_, false);
|
||||
while (indexer.valid()) {
|
||||
unsigned g_li = 0;
|
||||
unsigned prod = 1;
|
||||
for (int j = gvarpos.size() - 1; j >= 0; j--) {
|
||||
g_li += indexer[gvarpos[j]] * prod;
|
||||
prod *= g_ranges[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
dist_->params[indexer] += g_params[g_li];
|
||||
} else {
|
||||
dist_->params[indexer] *= g_params[g_li];
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::insertVariable (VarId varId, unsigned nrStates)
|
||||
{
|
||||
assert (indexOf (varId) == -1);
|
||||
Params oldParams = dist_->params;
|
||||
dist_->params.clear();
|
||||
dist_->params.reserve (oldParams.size() * nrStates);
|
||||
for (unsigned i = 0; i < oldParams.size(); i++) {
|
||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
||||
dist_->params.push_back (oldParams[i]);
|
||||
}
|
||||
}
|
||||
varids_.push_back (varId);
|
||||
args_.push_back (vid);
|
||||
ranges_.push_back (nrStates);
|
||||
params_ = params;
|
||||
distId_ = Util::maxUnsigned();
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::insertVariables (const VarIds& varIds, const Ranges& ranges)
|
||||
Factor::Factor (
|
||||
const VarNodes& vars,
|
||||
const Params& params,
|
||||
unsigned distId)
|
||||
{
|
||||
Params oldParams = dist_->params;
|
||||
unsigned nrStates = 1;
|
||||
for (unsigned i = 0; i < varIds.size(); i++) {
|
||||
assert (indexOf (varIds[i]) == -1);
|
||||
varids_.push_back (varIds[i]);
|
||||
ranges_.push_back (ranges[i]);
|
||||
nrStates *= ranges[i];
|
||||
}
|
||||
dist_->params.clear();
|
||||
dist_->params.reserve (oldParams.size() * nrStates);
|
||||
for (unsigned i = 0; i < oldParams.size(); i++) {
|
||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
||||
dist_->params.push_back (oldParams[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
args_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
}
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (
|
||||
const VarIds& vids,
|
||||
const Ranges& ranges,
|
||||
const Params& params)
|
||||
{
|
||||
args_ = vids;
|
||||
ranges_ = ranges;
|
||||
params_ = params;
|
||||
distId_ = Util::maxUnsigned();
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
@@ -226,10 +93,10 @@ void
|
||||
Factor::sumOutAllExcept (VarId vid)
|
||||
{
|
||||
assert (indexOf (vid) != -1);
|
||||
while (varids_.back() != vid) {
|
||||
while (args_.back() != vid) {
|
||||
sumOutLastVariable();
|
||||
}
|
||||
while (varids_.front() != vid) {
|
||||
while (args_.front() != vid) {
|
||||
sumOutFirstVariable();
|
||||
}
|
||||
}
|
||||
@@ -239,9 +106,10 @@ Factor::sumOutAllExcept (VarId vid)
|
||||
void
|
||||
Factor::sumOutAllExcept (const VarIds& vids)
|
||||
{
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
if (std::find (vids.begin(), vids.end(), varids_[i]) == vids.end()) {
|
||||
sumOut (varids_[i]);
|
||||
for (int i = 0; i < (int)args_.size(); i++) {
|
||||
if (Util::contains (vids, args_[i]) == false) {
|
||||
sumOut (args_[i]);
|
||||
i --;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -254,11 +122,11 @@ Factor::sumOut (VarId vid)
|
||||
int idx = indexOf (vid);
|
||||
assert (idx != -1);
|
||||
|
||||
if (vid == varids_.back()) {
|
||||
if (vid == args_.back()) {
|
||||
sumOutLastVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
if (vid == varids_.front()) {
|
||||
if (vid == args_.front()) {
|
||||
sumOutFirstVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
@@ -271,7 +139,7 @@ Factor::sumOut (VarId vid)
|
||||
// on the left of `var', with the states of the remaining vars fixed
|
||||
unsigned leftVarOffset = 1;
|
||||
|
||||
for (int i = varids_.size() - 1; i > idx; i--) {
|
||||
for (int i = args_.size() - 1; i > idx; i--) {
|
||||
varOffset *= ranges_[i];
|
||||
leftVarOffset *= ranges_[i];
|
||||
}
|
||||
@@ -280,25 +148,24 @@ Factor::sumOut (VarId vid)
|
||||
unsigned offset = 0;
|
||||
unsigned count1 = 0;
|
||||
unsigned count2 = 0;
|
||||
unsigned newpsSize = dist_->params.size() / ranges_[idx];
|
||||
unsigned newpsSize = params_.size() / ranges_[idx];
|
||||
|
||||
Params newps;
|
||||
newps.reserve (newpsSize);
|
||||
Params& params = dist_->params;
|
||||
|
||||
while (newps.size() < newpsSize) {
|
||||
double sum = Util::addIdenty();
|
||||
double sum = LogAware::addIdenty();
|
||||
for (unsigned i = 0; i < ranges_[idx]; i++) {
|
||||
if (Globals::logDomain) {
|
||||
Util::logSum (sum, params[offset]);
|
||||
sum = Util::logSum (sum, params_[offset]);
|
||||
} else {
|
||||
sum += params[offset];
|
||||
sum += params_[offset];
|
||||
}
|
||||
offset += varOffset;
|
||||
}
|
||||
newps.push_back (sum);
|
||||
count1 ++;
|
||||
if (idx == (int)varids_.size() - 1) {
|
||||
if (idx == (int)args_.size() - 1) {
|
||||
offset = count1 * ranges_[idx];
|
||||
} else {
|
||||
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
||||
@@ -308,9 +175,9 @@ Factor::sumOut (VarId vid)
|
||||
offset = (leftVarOffset * count2) + count1;
|
||||
}
|
||||
}
|
||||
varids_.erase (varids_.begin() + idx);
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
dist_->params = newps;
|
||||
params_ = newps;
|
||||
}
|
||||
|
||||
|
||||
@@ -318,20 +185,19 @@ Factor::sumOut (VarId vid)
|
||||
void
|
||||
Factor::sumOutFirstVariable (void)
|
||||
{
|
||||
Params& params = dist_->params;
|
||||
unsigned nStates = ranges_.front();
|
||||
unsigned sep = params.size() / nStates;
|
||||
unsigned sep = params_.size() / nStates;
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = sep; i < params.size(); i++) {
|
||||
Util::logSum (params[i % sep], params[i]);
|
||||
for (unsigned i = sep; i < params_.size(); i++) {
|
||||
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = sep; i < params.size(); i++) {
|
||||
params[i % sep] += params[i];
|
||||
for (unsigned i = sep; i < params_.size(); i++) {
|
||||
params_[i % sep] += params_[i];
|
||||
}
|
||||
}
|
||||
params.resize (sep);
|
||||
varids_.erase (varids_.begin());
|
||||
params_.resize (sep);
|
||||
args_.erase (args_.begin());
|
||||
ranges_.erase (ranges_.begin());
|
||||
}
|
||||
|
||||
@@ -340,143 +206,56 @@ Factor::sumOutFirstVariable (void)
|
||||
void
|
||||
Factor::sumOutLastVariable (void)
|
||||
{
|
||||
Params& params = dist_->params;
|
||||
unsigned nStates = ranges_.back();
|
||||
unsigned idx1 = 0;
|
||||
unsigned idx2 = 0;
|
||||
if (Globals::logDomain) {
|
||||
while (idx1 < params.size()) {
|
||||
params[idx2] = params[idx1];
|
||||
while (idx1 < params_.size()) {
|
||||
params_[idx2] = params_[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < nStates; j++) {
|
||||
Util::logSum (params[idx2], params[idx1]);
|
||||
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
|
||||
idx1 ++;
|
||||
}
|
||||
idx2 ++;
|
||||
}
|
||||
} else {
|
||||
while (idx1 < params.size()) {
|
||||
params[idx2] = params[idx1];
|
||||
while (idx1 < params_.size()) {
|
||||
params_[idx2] = params_[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < nStates; j++) {
|
||||
params[idx2] += params[idx1];
|
||||
params_[idx2] += params_[idx1];
|
||||
idx1 ++;
|
||||
}
|
||||
idx2 ++;
|
||||
}
|
||||
}
|
||||
params.resize (idx2);
|
||||
varids_.pop_back();
|
||||
params_.resize (idx2);
|
||||
args_.pop_back();
|
||||
ranges_.pop_back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::orderVariables (void)
|
||||
Factor::multiply (Factor& g)
|
||||
{
|
||||
VarIds sortedVarIds = varids_;
|
||||
sort (sortedVarIds.begin(), sortedVarIds.end());
|
||||
reorderVariables (sortedVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::reorderVariables (const VarIds& newVarIds)
|
||||
{
|
||||
assert (newVarIds.size() == varids_.size());
|
||||
if (newVarIds == varids_) {
|
||||
if (args_.size() == 0) {
|
||||
copyFromFactor (g);
|
||||
return;
|
||||
}
|
||||
|
||||
Ranges newRanges;
|
||||
vector<unsigned> positions;
|
||||
for (unsigned i = 0; i < newVarIds.size(); i++) {
|
||||
unsigned idx = indexOf (newVarIds[i]);
|
||||
newRanges.push_back (ranges_[idx]);
|
||||
positions.push_back (idx);
|
||||
}
|
||||
|
||||
unsigned N = ranges_.size();
|
||||
Params newParams (dist_->params.size());
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
unsigned li = i;
|
||||
// calculate vector index corresponding to linear index
|
||||
vector<unsigned> vi (N);
|
||||
for (int k = N-1; k >= 0; k--) {
|
||||
vi[k] = li % ranges_[k];
|
||||
li /= ranges_[k];
|
||||
}
|
||||
// convert permuted vector index to corresponding linear index
|
||||
unsigned prod = 1;
|
||||
unsigned new_li = 0;
|
||||
for (int k = N-1; k >= 0; k--) {
|
||||
new_li += vi[positions[k]] * prod;
|
||||
prod *= ranges_[positions[k]];
|
||||
}
|
||||
newParams[new_li] = dist_->params[i];
|
||||
}
|
||||
varids_ = newVarIds;
|
||||
ranges_ = newRanges;
|
||||
dist_->params = newParams;
|
||||
TFactor<VarId>::multiply (g);
|
||||
cout << "Factor mult called" << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::absorveEvidence (VarId vid, unsigned evidence)
|
||||
Factor::reorderAccordingVarIds (void)
|
||||
{
|
||||
int idx = indexOf (vid);
|
||||
assert (idx != -1);
|
||||
|
||||
Params oldParams = dist_->params;
|
||||
dist_->params.clear();
|
||||
dist_->params.reserve (oldParams.size() / ranges_[idx]);
|
||||
StatesIndexer indexer (ranges_);
|
||||
for (unsigned i = 0; i < evidence; i++) {
|
||||
indexer.increment (idx);
|
||||
}
|
||||
while (indexer.valid()) {
|
||||
dist_->params.push_back (oldParams[indexer]);
|
||||
indexer.incrementExcluding (idx);
|
||||
}
|
||||
varids_.erase (varids_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::normalize (void)
|
||||
{
|
||||
Util::normalize (dist_->params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Factor::contains (const VarIds& vars) const
|
||||
{
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
if (indexOf (vars[i]) == -1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
Factor::indexOf (VarId vid) const
|
||||
{
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
if (varids_[i] == vid) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
VarIds sortedVarIds = args_;
|
||||
sort (sortedVarIds.begin(), sortedVarIds.end());
|
||||
reorderArguments (sortedVarIds);
|
||||
}
|
||||
|
||||
|
||||
@@ -486,9 +265,9 @@ Factor::getLabel (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "f(" ;
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
ss << VarNode (varids_[i], ranges_[i]).label();
|
||||
ss << VarNode (args_[i], ranges_[i]).label();
|
||||
}
|
||||
ss << ")" ;
|
||||
return ss.str();
|
||||
@@ -500,13 +279,13 @@ void
|
||||
Factor::print (void) const
|
||||
{
|
||||
VarNodes vars;
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
vars.push_back (new VarNode (varids_[i], ranges_[i]));
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
vars.push_back (new VarNode (args_[i], ranges_[i]));
|
||||
}
|
||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
cout << "f(" << jointStrings[i] << ")" ;
|
||||
cout << " = " << dist_->params[i] << endl;
|
||||
cout << " = " << params_[i] << endl;
|
||||
}
|
||||
cout << endl;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
@@ -515,3 +294,13 @@ Factor::print (void) const
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::copyFromFactor (const Factor& g)
|
||||
{
|
||||
args_ = g.arguments();
|
||||
ranges_ = g.ranges();
|
||||
params_ = g.params();
|
||||
distId_ = g.distId();
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user