Add a way to summout a dimension given a index instead of a variable id.

This is required for counting belief propagation.
This commit is contained in:
Tiago Gomes 2012-04-25 22:59:01 +01:00
parent fc362fe123
commit ad24a360ce
2 changed files with 49 additions and 16 deletions

View File

@ -49,6 +49,24 @@ Factor::Factor (
void
Factor::sumOut (VarId vid)
{
int idx = indexOf (vid);
assert (idx != -1);
if (vid == args_.back()) {
sumOutLastVariable(); // optimization
return;
}
if (vid == args_.front()) {
sumOutFirstVariable(); // optimization
return;
}
sumOutIndex (idx);
}
void void
Factor::sumOutAllExcept (VarId vid) Factor::sumOutAllExcept (VarId vid)
{ {
@ -77,20 +95,8 @@ Factor::sumOutAllExcept (const VarIds& vids)
void void
Factor::sumOut (VarId vid) Factor::sumOutIndex (unsigned idx)
{ {
int idx = indexOf (vid);
assert (idx != -1);
if (vid == args_.back()) {
sumOutLastVariable(); // optimization
return;
}
if (vid == args_.front()) {
sumOutFirstVariable(); // optimization
return;
}
// number of parameters separating a different state of `var', // number of parameters separating a different state of `var',
// with the states of the remaining variables fixed // with the states of the remaining variables fixed
unsigned varOffset = 1; unsigned varOffset = 1;
@ -125,7 +131,7 @@ Factor::sumOut (VarId vid)
} }
newps.push_back (sum); newps.push_back (sum);
count1 ++; count1 ++;
if (idx == (int)args_.size() - 1) { if (idx == args_.size() - 1) {
offset = count1 * ranges_[idx]; offset = count1 * ranges_[idx];
} else { } else {
if (((offset - varOffset + 1) % leftVarOffset) == 0) { if (((offset - varOffset + 1) % leftVarOffset) == 0) {
@ -142,6 +148,21 @@ Factor::sumOut (VarId vid)
void
Factor::sumOutAllExceptIndex (unsigned idx)
{
int i = idx;
while (args_.size() > i + 1) {
sumOutLastVariable();
}
while (i > 0) {
sumOutFirstVariable();
i -- ;
}
}
void void
Factor::sumOutFirstVariable (void) Factor::sumOutFirstVariable (void)
{ {
@ -243,7 +264,8 @@ Factor::print (void) const
} }
vector<string> jointStrings = Util::getStateLines (vars); vector<string> jointStrings = Util::getStateLines (vars);
for (unsigned i = 0; i < params_.size(); i++) { for (unsigned i = 0; i < params_.size(); i++) {
cout << "[" << distId_ << "] f(" << jointStrings[i] << ")" ; // cout << "[" << distId_ << "] " ;
cout << "f(" << jointStrings[i] << ")" ;
cout << " = " << params_[i] << endl; cout << " = " << params_[i] << endl;
} }
cout << endl; cout << endl;

View File

@ -204,6 +204,13 @@ class TFactor
return true; return true;
} }
double& operator[] (unsigned idx)
{
assert (idx < params_.size());
return params_[idx];
}
protected: protected:
vector<T> args_; vector<T> args_;
Ranges ranges_; Ranges ranges_;
@ -261,11 +268,15 @@ class Factor : public TFactor<VarId>
Factor (const Vars&, const Params&, Factor (const Vars&, const Params&,
unsigned = Util::maxUnsigned()); unsigned = Util::maxUnsigned());
void sumOut (VarId);
void sumOutAllExcept (VarId); void sumOutAllExcept (VarId);
void sumOutAllExcept (const VarIds&); void sumOutAllExcept (const VarIds&);
void sumOut (VarId); void sumOutIndex (unsigned idx);
void sumOutAllExceptIndex (unsigned idx);
void sumOutFirstVariable (void); void sumOutFirstVariable (void);