improvements in the way we align logical variables

This commit is contained in:
Tiago Gomes 2012-04-26 18:00:06 +01:00
parent cc09e77707
commit 995a11be83

View File

@ -660,8 +660,8 @@ Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2); unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
LogAware::pow (g1->params(), 1.0 / condCount2); LogAware::pow (g1->params(), 1.0 / condCount2);
LogAware::pow (g2->params(), 1.0 / condCount1); LogAware::pow (g2->params(), 1.0 / condCount1);
// this must be done in the end or else X_1 and X_2 // the alignment should be done in the end or else X_1 and X_2
// will refer the old log var names in the code above // will refer to the old log var names on the code above
align (g1, X_1, g2, X_2); align (g1, X_1, g2, X_2);
} }
@ -673,24 +673,39 @@ Parfactor::align (
Parfactor* g2, const LogVars& alignLvs2) Parfactor* g2, const LogVars& alignLvs2)
{ {
LogVar freeLogVar = 0; LogVar freeLogVar = 0;
Substitution theta1; Substitution theta1, theta2;
Substitution theta2; for (unsigned i = 0; i < alignLvs1.size(); i++) {
bool b1 = theta1.containsReplacementFor (alignLvs1[i]);
bool b2 = theta2.containsReplacementFor (alignLvs2[i]);
// handle this type of situation:
// g1 = p(X), q(X) ; X in {x1}
// g2 = p(X), q(Y) ; X in {x1}, Y in {x1}
if (b1 == false && b2 == false) {
theta1.add (alignLvs1[i], freeLogVar);
theta2.add (alignLvs2[i], freeLogVar);
++ freeLogVar;
} else if (b1 == false && b2) {
theta1.add (alignLvs1[i], theta2.newNameFor (alignLvs2[i]));
} else if (b1 && b2 == false) {
theta2.add (alignLvs2[i], theta1.newNameFor (alignLvs1[i]));
}
}
const LogVarSet& allLvs1 = g1->logVarSet(); const LogVarSet& allLvs1 = g1->logVarSet();
for (unsigned i = 0; i < allLvs1.size(); i++) { for (unsigned i = 0; i < allLvs1.size(); i++) {
theta1.add (allLvs1[i], freeLogVar); if (theta1.containsReplacementFor (allLvs1[i]) == false) {
++ freeLogVar; theta1.add (allLvs1[i], freeLogVar);
++ freeLogVar;
}
} }
const LogVarSet& allLvs2 = g2->logVarSet(); const LogVarSet& allLvs2 = g2->logVarSet();
for (unsigned i = 0; i < allLvs2.size(); i++) { for (unsigned i = 0; i < allLvs2.size(); i++) {
theta2.add (allLvs2[i], freeLogVar); if (theta2.containsReplacementFor (allLvs2[i]) == false) {
++ freeLogVar; theta2.add (allLvs2[i], freeLogVar);
} ++ freeLogVar;
}
assert (alignLvs1.size() == alignLvs2.size());
for (unsigned i = 0; i < alignLvs1.size(); i++) {
theta1.rename (alignLvs1[i], theta2.newNameFor (alignLvs2[i]));
} }
// cout << "theta1: " << theta1 << endl;
// cout << "theta2: " << theta2 << endl;
g1->applySubstitution (theta1); g1->applySubstitution (theta1);
g2->applySubstitution (theta2); g2->applySubstitution (theta2);
} }