diff --git a/packages/CLPBN/clpbn/bp/Parfactor.cpp b/packages/CLPBN/clpbn/bp/Parfactor.cpp index 60e7c1648..8d49136b7 100644 --- a/packages/CLPBN/clpbn/bp/Parfactor.cpp +++ b/packages/CLPBN/clpbn/bp/Parfactor.cpp @@ -660,8 +660,8 @@ Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2) unsigned condCount2 = g2->constr()->getConditionalCount (Y_2); LogAware::pow (g1->params(), 1.0 / condCount2); LogAware::pow (g2->params(), 1.0 / condCount1); - // this must be done in the end or else X_1 and X_2 - // will refer the old log var names in the code above + // the alignment should be done in the end or else X_1 and X_2 + // will refer to the old log var names on the code above align (g1, X_1, g2, X_2); } @@ -673,24 +673,39 @@ Parfactor::align ( Parfactor* g2, const LogVars& alignLvs2) { LogVar freeLogVar = 0; - Substitution theta1; - Substitution theta2; + Substitution theta1, theta2; + for (unsigned i = 0; i < alignLvs1.size(); i++) { + bool b1 = theta1.containsReplacementFor (alignLvs1[i]); + bool b2 = theta2.containsReplacementFor (alignLvs2[i]); + // handle this type of situation: + // g1 = p(X), q(X) ; X in {x1} + // g2 = p(X), q(Y) ; X in {x1}, Y in {x1} + if (b1 == false && b2 == false) { + theta1.add (alignLvs1[i], freeLogVar); + theta2.add (alignLvs2[i], freeLogVar); + ++ freeLogVar; + } else if (b1 == false && b2) { + theta1.add (alignLvs1[i], theta2.newNameFor (alignLvs2[i])); + } else if (b1 && b2 == false) { + theta2.add (alignLvs2[i], theta1.newNameFor (alignLvs1[i])); + } + } const LogVarSet& allLvs1 = g1->logVarSet(); for (unsigned i = 0; i < allLvs1.size(); i++) { - theta1.add (allLvs1[i], freeLogVar); - ++ freeLogVar; + if (theta1.containsReplacementFor (allLvs1[i]) == false) { + theta1.add (allLvs1[i], freeLogVar); + ++ freeLogVar; + } } - const LogVarSet& allLvs2 = g2->logVarSet(); for (unsigned i = 0; i < allLvs2.size(); i++) { - theta2.add (allLvs2[i], freeLogVar); - ++ freeLogVar; - } - - assert (alignLvs1.size() == alignLvs2.size()); - for (unsigned i = 0; i < alignLvs1.size(); i++) { - theta1.rename (alignLvs1[i], theta2.newNameFor (alignLvs2[i])); + if (theta2.containsReplacementFor (allLvs2[i]) == false) { + theta2.add (allLvs2[i], freeLogVar); + ++ freeLogVar; + } } + // cout << "theta1: " << theta1 << endl; + // cout << "theta2: " << theta2 << endl; g1->applySubstitution (theta1); g2->applySubstitution (theta2); }