Add a way to summout a dimension given a index instead of a variable id.
This is required for counting belief propagation.
This commit is contained in:
parent
fc362fe123
commit
ad24a360ce
@ -49,6 +49,24 @@ Factor::Factor (
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOut (VarId vid)
|
||||
{
|
||||
int idx = indexOf (vid);
|
||||
assert (idx != -1);
|
||||
if (vid == args_.back()) {
|
||||
sumOutLastVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
if (vid == args_.front()) {
|
||||
sumOutFirstVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
sumOutIndex (idx);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutAllExcept (VarId vid)
|
||||
{
|
||||
@ -77,20 +95,8 @@ Factor::sumOutAllExcept (const VarIds& vids)
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOut (VarId vid)
|
||||
Factor::sumOutIndex (unsigned idx)
|
||||
{
|
||||
int idx = indexOf (vid);
|
||||
assert (idx != -1);
|
||||
|
||||
if (vid == args_.back()) {
|
||||
sumOutLastVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
if (vid == args_.front()) {
|
||||
sumOutFirstVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
|
||||
// number of parameters separating a different state of `var',
|
||||
// with the states of the remaining variables fixed
|
||||
unsigned varOffset = 1;
|
||||
@ -125,7 +131,7 @@ Factor::sumOut (VarId vid)
|
||||
}
|
||||
newps.push_back (sum);
|
||||
count1 ++;
|
||||
if (idx == (int)args_.size() - 1) {
|
||||
if (idx == args_.size() - 1) {
|
||||
offset = count1 * ranges_[idx];
|
||||
} else {
|
||||
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
||||
@ -142,6 +148,21 @@ Factor::sumOut (VarId vid)
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutAllExceptIndex (unsigned idx)
|
||||
{
|
||||
int i = idx;
|
||||
while (args_.size() > i + 1) {
|
||||
sumOutLastVariable();
|
||||
}
|
||||
while (i > 0) {
|
||||
sumOutFirstVariable();
|
||||
i -- ;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutFirstVariable (void)
|
||||
{
|
||||
@ -243,7 +264,8 @@ Factor::print (void) const
|
||||
}
|
||||
vector<string> jointStrings = Util::getStateLines (vars);
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
cout << "[" << distId_ << "] f(" << jointStrings[i] << ")" ;
|
||||
// cout << "[" << distId_ << "] " ;
|
||||
cout << "f(" << jointStrings[i] << ")" ;
|
||||
cout << " = " << params_[i] << endl;
|
||||
}
|
||||
cout << endl;
|
||||
|
@ -204,6 +204,13 @@ class TFactor
|
||||
return true;
|
||||
}
|
||||
|
||||
double& operator[] (unsigned idx)
|
||||
{
|
||||
assert (idx < params_.size());
|
||||
return params_[idx];
|
||||
}
|
||||
|
||||
|
||||
protected:
|
||||
vector<T> args_;
|
||||
Ranges ranges_;
|
||||
@ -261,11 +268,15 @@ class Factor : public TFactor<VarId>
|
||||
Factor (const Vars&, const Params&,
|
||||
unsigned = Util::maxUnsigned());
|
||||
|
||||
void sumOut (VarId);
|
||||
|
||||
void sumOutAllExcept (VarId);
|
||||
|
||||
void sumOutAllExcept (const VarIds&);
|
||||
|
||||
void sumOut (VarId);
|
||||
void sumOutIndex (unsigned idx);
|
||||
|
||||
void sumOutAllExceptIndex (unsigned idx);
|
||||
|
||||
void sumOutFirstVariable (void);
|
||||
|
||||
|
Reference in New Issue
Block a user