Add a way to summout a dimension given a index instead of a variable id.
This is required for counting belief propagation.
This commit is contained in:
parent
fc362fe123
commit
ad24a360ce
@ -49,6 +49,24 @@ Factor::Factor (
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::sumOut (VarId vid)
|
||||||
|
{
|
||||||
|
int idx = indexOf (vid);
|
||||||
|
assert (idx != -1);
|
||||||
|
if (vid == args_.back()) {
|
||||||
|
sumOutLastVariable(); // optimization
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (vid == args_.front()) {
|
||||||
|
sumOutFirstVariable(); // optimization
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
sumOutIndex (idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Factor::sumOutAllExcept (VarId vid)
|
Factor::sumOutAllExcept (VarId vid)
|
||||||
{
|
{
|
||||||
@ -77,20 +95,8 @@ Factor::sumOutAllExcept (const VarIds& vids)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Factor::sumOut (VarId vid)
|
Factor::sumOutIndex (unsigned idx)
|
||||||
{
|
{
|
||||||
int idx = indexOf (vid);
|
|
||||||
assert (idx != -1);
|
|
||||||
|
|
||||||
if (vid == args_.back()) {
|
|
||||||
sumOutLastVariable(); // optimization
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (vid == args_.front()) {
|
|
||||||
sumOutFirstVariable(); // optimization
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// number of parameters separating a different state of `var',
|
// number of parameters separating a different state of `var',
|
||||||
// with the states of the remaining variables fixed
|
// with the states of the remaining variables fixed
|
||||||
unsigned varOffset = 1;
|
unsigned varOffset = 1;
|
||||||
@ -125,7 +131,7 @@ Factor::sumOut (VarId vid)
|
|||||||
}
|
}
|
||||||
newps.push_back (sum);
|
newps.push_back (sum);
|
||||||
count1 ++;
|
count1 ++;
|
||||||
if (idx == (int)args_.size() - 1) {
|
if (idx == args_.size() - 1) {
|
||||||
offset = count1 * ranges_[idx];
|
offset = count1 * ranges_[idx];
|
||||||
} else {
|
} else {
|
||||||
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
||||||
@ -142,6 +148,21 @@ Factor::sumOut (VarId vid)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::sumOutAllExceptIndex (unsigned idx)
|
||||||
|
{
|
||||||
|
int i = idx;
|
||||||
|
while (args_.size() > i + 1) {
|
||||||
|
sumOutLastVariable();
|
||||||
|
}
|
||||||
|
while (i > 0) {
|
||||||
|
sumOutFirstVariable();
|
||||||
|
i -- ;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Factor::sumOutFirstVariable (void)
|
Factor::sumOutFirstVariable (void)
|
||||||
{
|
{
|
||||||
@ -243,7 +264,8 @@ Factor::print (void) const
|
|||||||
}
|
}
|
||||||
vector<string> jointStrings = Util::getStateLines (vars);
|
vector<string> jointStrings = Util::getStateLines (vars);
|
||||||
for (unsigned i = 0; i < params_.size(); i++) {
|
for (unsigned i = 0; i < params_.size(); i++) {
|
||||||
cout << "[" << distId_ << "] f(" << jointStrings[i] << ")" ;
|
// cout << "[" << distId_ << "] " ;
|
||||||
|
cout << "f(" << jointStrings[i] << ")" ;
|
||||||
cout << " = " << params_[i] << endl;
|
cout << " = " << params_[i] << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
|
@ -204,6 +204,13 @@ class TFactor
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double& operator[] (unsigned idx)
|
||||||
|
{
|
||||||
|
assert (idx < params_.size());
|
||||||
|
return params_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
vector<T> args_;
|
vector<T> args_;
|
||||||
Ranges ranges_;
|
Ranges ranges_;
|
||||||
@ -261,11 +268,15 @@ class Factor : public TFactor<VarId>
|
|||||||
Factor (const Vars&, const Params&,
|
Factor (const Vars&, const Params&,
|
||||||
unsigned = Util::maxUnsigned());
|
unsigned = Util::maxUnsigned());
|
||||||
|
|
||||||
|
void sumOut (VarId);
|
||||||
|
|
||||||
void sumOutAllExcept (VarId);
|
void sumOutAllExcept (VarId);
|
||||||
|
|
||||||
void sumOutAllExcept (const VarIds&);
|
void sumOutAllExcept (const VarIds&);
|
||||||
|
|
||||||
void sumOut (VarId);
|
void sumOutIndex (unsigned idx);
|
||||||
|
|
||||||
|
void sumOutAllExceptIndex (unsigned idx);
|
||||||
|
|
||||||
void sumOutFirstVariable (void);
|
void sumOutFirstVariable (void);
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user