fix align of logical variables
This commit is contained in:
		@@ -347,6 +347,10 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new)
 | 
			
		||||
void
 | 
			
		||||
ConstraintTree::applySubstitution (const Substitution& theta)
 | 
			
		||||
{
 | 
			
		||||
  LogVars discardedLvs = theta.getDiscardedLogVars();
 | 
			
		||||
  for (unsigned i = 0; i < discardedLvs.size(); i++) {
 | 
			
		||||
    remove(discardedLvs[i]);
 | 
			
		||||
  }
 | 
			
		||||
  for (unsigned i = 0; i < logVars_.size(); i++) {
 | 
			
		||||
    logVars_[i] = theta.newNameFor (logVars_[i]);
 | 
			
		||||
  }
 | 
			
		||||
 
 | 
			
		||||
@@ -245,7 +245,6 @@ Factor::multiply (Factor& g)
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  TFactor<VarId>::multiply (g);
 | 
			
		||||
  cout << "Factor mult called" << endl;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -62,6 +62,12 @@ class TFactor
 | 
			
		||||
      return args_[idx];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    T& argument (unsigned idx)
 | 
			
		||||
    {
 | 
			
		||||
      assert (idx < args_.size());
 | 
			
		||||
      return args_[idx];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    unsigned range (unsigned idx) const
 | 
			
		||||
    {
 | 
			
		||||
      assert (idx < ranges_.size());
 | 
			
		||||
 
 | 
			
		||||
@@ -554,7 +554,7 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
 | 
			
		||||
  for (unsigned i = 0; i < query.size(); i++) {
 | 
			
		||||
    ParfactorList::iterator it = pfList_.begin();
 | 
			
		||||
    while (it != pfList_.end()) {
 | 
			
		||||
      int group = (*it)->groupWithGround (query[i]);
 | 
			
		||||
      int group = (*it)->findGroup (query[i]);
 | 
			
		||||
      if (group != -1) {
 | 
			
		||||
        todo.push (group);
 | 
			
		||||
        done.insert (group);
 | 
			
		||||
 
 | 
			
		||||
@@ -452,7 +452,6 @@ setBayesNetParams (void)
 | 
			
		||||
int
 | 
			
		||||
setExtraVarsInfo (void)
 | 
			
		||||
{
 | 
			
		||||
  // BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
 | 
			
		||||
  GraphicalModel::clearVariablesInformation();
 | 
			
		||||
  YAP_Term varsInfoL =  YAP_ARG2;
 | 
			
		||||
  while (varsInfoL != YAP_TermNil()) {
 | 
			
		||||
 
 | 
			
		||||
@@ -95,6 +95,25 @@ ostream& operator<< (ostream &os, const Ground& gr)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LogVars
 | 
			
		||||
Substitution::getDiscardedLogVars (void) const
 | 
			
		||||
{
 | 
			
		||||
  LogVars discardedLvs;
 | 
			
		||||
  set<LogVar> doneLvs;
 | 
			
		||||
  unordered_map<LogVar, LogVar>::const_iterator it;
 | 
			
		||||
  it = subs_.begin();
 | 
			
		||||
  while (it != subs_.end()) {
 | 
			
		||||
    if (Util::contains (doneLvs, it->second)) {
 | 
			
		||||
      discardedLvs.push_back (it->first);
 | 
			
		||||
    } else {
 | 
			
		||||
      doneLvs.insert (it->second);
 | 
			
		||||
    }
 | 
			
		||||
    it ++;
 | 
			
		||||
  }
 | 
			
		||||
  return discardedLvs;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ostream& operator<< (ostream &os, const Substitution& theta)
 | 
			
		||||
{
 | 
			
		||||
 
 | 
			
		||||
@@ -141,10 +141,13 @@ class Substitution
 | 
			
		||||
      return subs_.find (X)->second;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    LogVars getDiscardedLogVars (void) const;
 | 
			
		||||
 | 
			
		||||
    friend ostream& operator<< (ostream &os, const Substitution& theta);
 | 
			
		||||
 | 
			
		||||
  private:
 | 
			
		||||
    unordered_map<LogVar, LogVar> subs_;
 | 
			
		||||
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -191,7 +191,6 @@ Parfactor::multiply (Parfactor& g)
 | 
			
		||||
{
 | 
			
		||||
  alignAndExponentiate (this, &g);
 | 
			
		||||
  TFactor<ProbFormula>::multiply (g);
 | 
			
		||||
  cout << "calling lifted mult" << endl;
 | 
			
		||||
  constr_->join (g.constr(), true);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -377,15 +376,6 @@ Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
void
 | 
			
		||||
Parfactor::setFormulaGroup (const ProbFormula& f, int group)
 | 
			
		||||
{
 | 
			
		||||
  assert (indexOf (f) != -1);
 | 
			
		||||
  args_[indexOf (f)].setGroup (group);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
void
 | 
			
		||||
Parfactor::setNewGroups (void)
 | 
			
		||||
{
 | 
			
		||||
@@ -415,7 +405,7 @@ Parfactor::applySubstitution (const Substitution& theta)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
int
 | 
			
		||||
Parfactor::groupWithGround (const Ground& ground) const
 | 
			
		||||
Parfactor::findGroup (const Ground& ground) const
 | 
			
		||||
{
 | 
			
		||||
  int group = -1;
 | 
			
		||||
  for (unsigned i = 0; i < args_.size(); i++) {
 | 
			
		||||
@@ -436,7 +426,7 @@ Parfactor::groupWithGround (const Ground& ground) const
 | 
			
		||||
bool
 | 
			
		||||
Parfactor::containsGround (const Ground& ground) const
 | 
			
		||||
{
 | 
			
		||||
  return groupWithGround (ground) != -1;
 | 
			
		||||
  return findGroup (ground) != -1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -670,7 +660,6 @@ Parfactor::align (
 | 
			
		||||
  LogVar freeLogVar = 0;
 | 
			
		||||
  Substitution theta1;
 | 
			
		||||
  Substitution theta2;
 | 
			
		||||
 | 
			
		||||
  const LogVarSet& allLvs1 = g1->logVarSet();
 | 
			
		||||
  for (unsigned i = 0; i < allLvs1.size(); i++) {
 | 
			
		||||
    theta1.add (allLvs1[i], freeLogVar);
 | 
			
		||||
 
 | 
			
		||||
@@ -60,13 +60,11 @@ class Parfactor : public TFactor<ProbFormula>
 | 
			
		||||
 | 
			
		||||
    void absorveEvidence (const ProbFormula&, unsigned);
 | 
			
		||||
 | 
			
		||||
    void setFormulaGroup (const ProbFormula&, int);
 | 
			
		||||
 | 
			
		||||
    void setNewGroups (void);
 | 
			
		||||
 | 
			
		||||
    void applySubstitution (const Substitution&);
 | 
			
		||||
 | 
			
		||||
    int groupWithGround (const Ground&) const;
 | 
			
		||||
    int findGroup (const Ground&) const;
 | 
			
		||||
 | 
			
		||||
    bool containsGround (const Ground&) const;
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -197,8 +197,8 @@ ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
 | 
			
		||||
  for (unsigned i = 0; i < formulas1.size(); i++) {
 | 
			
		||||
    for (unsigned j = 0; j < formulas2.size(); j++) {
 | 
			
		||||
      if (formulas1[i].sameSkeletonAs (formulas2[j])) {
 | 
			
		||||
        std::pair<Parfactors, Parfactors> res 
 | 
			
		||||
            = shatter (formulas1[i], g1, formulas2[j], g2);
 | 
			
		||||
        std::pair<Parfactors, Parfactors> res;
 | 
			
		||||
        res = shatter (i, g1, j, g2);
 | 
			
		||||
        if (res.first.empty()  == false || 
 | 
			
		||||
            res.second.empty() == false) {
 | 
			
		||||
          return res;
 | 
			
		||||
@@ -213,9 +213,11 @@ ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
 | 
			
		||||
 | 
			
		||||
std::pair<Parfactors, Parfactors>
 | 
			
		||||
ParfactorList::shatter (
 | 
			
		||||
    ProbFormula& f1, Parfactor* g1,
 | 
			
		||||
    ProbFormula& f2, Parfactor* g2)
 | 
			
		||||
    unsigned fIdx1, Parfactor* g1,
 | 
			
		||||
    unsigned fIdx2, Parfactor* g2)
 | 
			
		||||
{
 | 
			
		||||
  ProbFormula& f1 = g1->argument (fIdx1);
 | 
			
		||||
  ProbFormula& f2 = g2->argument (fIdx2);
 | 
			
		||||
  // cout << endl;
 | 
			
		||||
  // Util::printDashLine();
 | 
			
		||||
  // cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
 | 
			
		||||
@@ -299,8 +301,8 @@ ParfactorList::shatter (
 | 
			
		||||
  } else {
 | 
			
		||||
    group = ProbFormula::getNewGroup();
 | 
			
		||||
  }
 | 
			
		||||
  Parfactors res1 = shatter (g1, f1, commCt1, exclCt1, group);
 | 
			
		||||
  Parfactors res2 = shatter (g2, f2, commCt2, exclCt2, group);
 | 
			
		||||
  Parfactors res1 = shatter (g1, fIdx1, commCt1, exclCt1, group);
 | 
			
		||||
  Parfactors res2 = shatter (g2, fIdx2, commCt2, exclCt2, group);
 | 
			
		||||
  return make_pair (res1, res2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -309,15 +311,16 @@ ParfactorList::shatter (
 | 
			
		||||
Parfactors
 | 
			
		||||
ParfactorList::shatter (
 | 
			
		||||
    Parfactor* g,
 | 
			
		||||
    const ProbFormula& f,
 | 
			
		||||
    unsigned fIdx,
 | 
			
		||||
    ConstraintTree* commCt,
 | 
			
		||||
    ConstraintTree* exclCt,
 | 
			
		||||
    unsigned commGroup)
 | 
			
		||||
{
 | 
			
		||||
  ProbFormula& f = g->argument (fIdx);
 | 
			
		||||
  if (exclCt->empty()) {
 | 
			
		||||
    delete commCt;
 | 
			
		||||
    delete exclCt;
 | 
			
		||||
    g->setFormulaGroup (f, commGroup);
 | 
			
		||||
    f.setGroup (commGroup);
 | 
			
		||||
    return { };
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -346,7 +349,7 @@ ParfactorList::shatter (
 | 
			
		||||
  } else {
 | 
			
		||||
    Parfactor* newPf = new Parfactor (g, commCt);
 | 
			
		||||
    newPf->setNewGroups();
 | 
			
		||||
    newPf->setFormulaGroup (f, commGroup);
 | 
			
		||||
    newPf->argument (fIdx).setGroup (commGroup);
 | 
			
		||||
    result.push_back (newPf);
 | 
			
		||||
    newPf = new Parfactor (g, exclCt);
 | 
			
		||||
    newPf->setNewGroups();
 | 
			
		||||
 
 | 
			
		||||
@@ -67,11 +67,11 @@ class ParfactorList
 | 
			
		||||
        Parfactor*, Parfactor*);
 | 
			
		||||
 | 
			
		||||
    std::pair<Parfactors, Parfactors> shatter (
 | 
			
		||||
        ProbFormula&, Parfactor*, ProbFormula&, Parfactor*);
 | 
			
		||||
        unsigned, Parfactor*, unsigned, Parfactor*);
 | 
			
		||||
 | 
			
		||||
    Parfactors shatter (
 | 
			
		||||
        Parfactor*,
 | 
			
		||||
        const ProbFormula&,
 | 
			
		||||
        unsigned,
 | 
			
		||||
        ConstraintTree*,
 | 
			
		||||
        ConstraintTree*,
 | 
			
		||||
        unsigned);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user