284 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			284 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#include <vector>
 | 
						|
#include <queue>
 | 
						|
#include <iostream>
 | 
						|
 | 
						|
#include "LiftedOperations.h"
 | 
						|
 | 
						|
 | 
						|
namespace Horus {
 | 
						|
 | 
						|
namespace LiftedOperations {
 | 
						|
 | 
						|
namespace {
 | 
						|
 | 
						|
Parfactors absorve (ObservedFormula& obsFormula, Parfactor* g);
 | 
						|
 | 
						|
}
 | 
						|
 | 
						|
void
 | 
						|
shatterAgainstQuery (ParfactorList& pfList, const Grounds& query)
 | 
						|
{
 | 
						|
  for (size_t i = 0; i < query.size(); i++) {
 | 
						|
    if (query[i].isAtom()) {
 | 
						|
      continue;
 | 
						|
    }
 | 
						|
    bool found = false;
 | 
						|
    Parfactors newPfs;
 | 
						|
    ParfactorList::iterator it = pfList.begin();
 | 
						|
    while (it != pfList.end()) {
 | 
						|
      if ((*it)->containsGround (query[i])) {
 | 
						|
        found = true;
 | 
						|
        std::pair<ConstraintTree*, ConstraintTree*> split;
 | 
						|
        LogVars queryLvs (
 | 
						|
            (*it)->constr()->logVars().begin(),
 | 
						|
            (*it)->constr()->logVars().begin() + query[i].arity());
 | 
						|
        split = (*it)->constr()->split (query[i].args());
 | 
						|
        ConstraintTree* commCt = split.first;
 | 
						|
        ConstraintTree* exclCt = split.second;
 | 
						|
        newPfs.push_back (new Parfactor (*it, commCt));
 | 
						|
        if (exclCt->empty() == false) {
 | 
						|
          newPfs.push_back (new Parfactor (*it, exclCt));
 | 
						|
        } else {
 | 
						|
          delete exclCt;
 | 
						|
        }
 | 
						|
        it = pfList.removeAndDelete (it);
 | 
						|
      } else {
 | 
						|
        ++ it;
 | 
						|
      }
 | 
						|
    }
 | 
						|
    if (found == false) {
 | 
						|
      std::cerr << "Error: could not find a parfactor with ground " ;
 | 
						|
      std::cerr << "`" << query[i] << "'." << std::endl;
 | 
						|
      exit (EXIT_FAILURE);
 | 
						|
    }
 | 
						|
    pfList.add (newPfs);
 | 
						|
  }
 | 
						|
  if (Globals::verbosity > 2) {
 | 
						|
    Util::printAsteriskLine();
 | 
						|
    std::cout << "SHATTERED AGAINST THE QUERY" << std::endl;
 | 
						|
    for (size_t i = 0; i < query.size(); i++) {
 | 
						|
      std::cout << " -> " << query[i] << std::endl;
 | 
						|
    }
 | 
						|
    Util::printAsteriskLine();
 | 
						|
    pfList.print();
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
runWeakBayesBall (ParfactorList& pfList, const Grounds& query)
 | 
						|
{
 | 
						|
  std::queue<PrvGroup> todo; // groups to process
 | 
						|
  std::set<PrvGroup> done;   // processed or in queue
 | 
						|
  for (size_t i = 0; i < query.size(); i++) {
 | 
						|
    ParfactorList::iterator it = pfList.begin();
 | 
						|
    while (it != pfList.end()) {
 | 
						|
      PrvGroup group = (*it)->findGroup (query[i]);
 | 
						|
      if (group != std::numeric_limits<PrvGroup>::max()) {
 | 
						|
        todo.push (group);
 | 
						|
        done.insert (group);
 | 
						|
        break;
 | 
						|
      }
 | 
						|
      ++ it;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  std::set<Parfactor*> requiredPfs;
 | 
						|
  while (todo.empty() == false) {
 | 
						|
    PrvGroup group = todo.front();
 | 
						|
    ParfactorList::iterator it = pfList.begin();
 | 
						|
    while (it != pfList.end()) {
 | 
						|
      if (Util::contains (requiredPfs, *it) == false &&
 | 
						|
          (*it)->containsGroup (group)) {
 | 
						|
        std::vector<PrvGroup> groups = (*it)->getAllGroups();
 | 
						|
        for (size_t i = 0; i < groups.size(); i++) {
 | 
						|
          if (Util::contains (done, groups[i]) == false) {
 | 
						|
            todo.push (groups[i]);
 | 
						|
            done.insert (groups[i]);
 | 
						|
          }
 | 
						|
        }
 | 
						|
        requiredPfs.insert (*it);
 | 
						|
      }
 | 
						|
      ++ it;
 | 
						|
    }
 | 
						|
    todo.pop();
 | 
						|
  }
 | 
						|
 | 
						|
  ParfactorList::iterator it = pfList.begin();
 | 
						|
  bool foundNotRequired = false;
 | 
						|
  while (it != pfList.end()) {
 | 
						|
    if (Util::contains (requiredPfs, *it) == false) {
 | 
						|
      if (Globals::verbosity > 2) {
 | 
						|
        if (foundNotRequired == false) {
 | 
						|
          Util::printHeader ("PARFACTORS TO DISCARD");
 | 
						|
          foundNotRequired = true;
 | 
						|
        }
 | 
						|
        (*it)->print();
 | 
						|
      }
 | 
						|
      it = pfList.removeAndDelete (it);
 | 
						|
    } else {
 | 
						|
      ++ it;
 | 
						|
    }
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
void
 | 
						|
absorveEvidence (ParfactorList& pfList, ObservedFormulas& obsFormulas)
 | 
						|
{
 | 
						|
  for (size_t i = 0; i < obsFormulas.size(); i++) {
 | 
						|
    Parfactors newPfs;
 | 
						|
    ParfactorList::iterator it  = pfList.begin();
 | 
						|
    while (it != pfList.end()) {
 | 
						|
      Parfactor* pf = *it;
 | 
						|
      it = pfList.remove (it);
 | 
						|
      Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
 | 
						|
      if (absorvedPfs.empty() == false) {
 | 
						|
        if (absorvedPfs.size() == 1 && !absorvedPfs[0]) {
 | 
						|
          // just remove pf;
 | 
						|
        } else {
 | 
						|
          Util::addToVector (newPfs, absorvedPfs);
 | 
						|
        }
 | 
						|
        delete pf;
 | 
						|
      } else {
 | 
						|
        it = pfList.insertShattered (it, pf);
 | 
						|
        ++ it;
 | 
						|
      }
 | 
						|
    }
 | 
						|
    pfList.add (newPfs);
 | 
						|
  }
 | 
						|
  if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
 | 
						|
    Util::printAsteriskLine();
 | 
						|
    std::cout << "AFTER EVIDENCE ABSORVED" << std::endl;
 | 
						|
    for (size_t i = 0; i < obsFormulas.size(); i++) {
 | 
						|
      std::cout << " -> " << obsFormulas[i] << std::endl;
 | 
						|
    }
 | 
						|
    Util::printAsteriskLine();
 | 
						|
    pfList.print();
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
Parfactors
 | 
						|
countNormalize (Parfactor* g, const LogVarSet& set)
 | 
						|
{
 | 
						|
  Parfactors normPfs;
 | 
						|
  if (set.empty()) {
 | 
						|
    normPfs.push_back (new Parfactor (*g));
 | 
						|
  } else {
 | 
						|
    ConstraintTrees normCts = g->constr()->countNormalize (set);
 | 
						|
    for (size_t i = 0; i < normCts.size(); i++) {
 | 
						|
      normPfs.push_back (new Parfactor (g, normCts[i]));
 | 
						|
    }
 | 
						|
  }
 | 
						|
  return normPfs;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
Parfactor
 | 
						|
calcGroundMultiplication (Parfactor pf)
 | 
						|
{
 | 
						|
  LogVarSet lvs = pf.constr()->logVarSet();
 | 
						|
  lvs -= pf.constr()->singletons();
 | 
						|
  Parfactors newPfs = {new Parfactor (pf)};
 | 
						|
  for (size_t i = 0; i < lvs.size(); i++) {
 | 
						|
    Parfactors pfs = newPfs;
 | 
						|
    newPfs.clear();
 | 
						|
    for (size_t j = 0; j < pfs.size(); j++) {
 | 
						|
      bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
 | 
						|
      if (countedLv) {
 | 
						|
        pfs[j]->fullExpand (lvs[i]);
 | 
						|
        newPfs.push_back (pfs[j]);
 | 
						|
      } else {
 | 
						|
        ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
 | 
						|
        for (size_t k = 0; k < cts.size(); k++) {
 | 
						|
          newPfs.push_back (new Parfactor (pfs[j], cts[k]));
 | 
						|
        }
 | 
						|
        delete pfs[j];
 | 
						|
      }
 | 
						|
    }
 | 
						|
  }
 | 
						|
  ParfactorList pfList (newPfs);
 | 
						|
  Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
 | 
						|
  for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
 | 
						|
     groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
 | 
						|
  }
 | 
						|
  return Parfactor (*groundShatteredPfs[0]);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
namespace {
 | 
						|
 | 
						|
Parfactors
 | 
						|
absorve (ObservedFormula& obsFormula, Parfactor* g)
 | 
						|
{
 | 
						|
  Parfactors absorvedPfs;
 | 
						|
  const ProbFormulas& formulas = g->arguments();
 | 
						|
  for (size_t i = 0; i < formulas.size(); i++) {
 | 
						|
    if (obsFormula.functor() == formulas[i].functor() &&
 | 
						|
        obsFormula.arity()   == formulas[i].arity()) {
 | 
						|
 | 
						|
      if (obsFormula.isAtom()) {
 | 
						|
        if (formulas.size() > 1) {
 | 
						|
          g->absorveEvidence (formulas[i], obsFormula.evidence());
 | 
						|
        } else {
 | 
						|
          // hack to erase parfactor g
 | 
						|
          absorvedPfs.push_back (0);
 | 
						|
        }
 | 
						|
        break;
 | 
						|
      }
 | 
						|
 | 
						|
      g->constr()->moveToTop (formulas[i].logVars());
 | 
						|
      std::pair<ConstraintTree*, ConstraintTree*> res;
 | 
						|
      res = g->constr()->split (
 | 
						|
          formulas[i].logVars(),
 | 
						|
          &(obsFormula.constr()),
 | 
						|
          obsFormula.constr().logVars());
 | 
						|
      ConstraintTree* commCt = res.first;
 | 
						|
      ConstraintTree* exclCt = res.second;
 | 
						|
 | 
						|
      if (commCt->empty() == false) {
 | 
						|
        if (formulas.size() > 1) {
 | 
						|
          LogVarSet excl = g->exclusiveLogVars (i);
 | 
						|
          Parfactor tempPf (g, commCt);
 | 
						|
          Parfactors countNormPfs = LiftedOperations::countNormalize (
 | 
						|
              &tempPf, excl);
 | 
						|
          for (size_t j = 0; j < countNormPfs.size(); j++) {
 | 
						|
            countNormPfs[j]->absorveEvidence (
 | 
						|
                formulas[i], obsFormula.evidence());
 | 
						|
            absorvedPfs.push_back (countNormPfs[j]);
 | 
						|
          }
 | 
						|
        } else {
 | 
						|
          delete commCt;
 | 
						|
        }
 | 
						|
        if (exclCt->empty() == false) {
 | 
						|
          absorvedPfs.push_back (new Parfactor (g, exclCt));
 | 
						|
        } else {
 | 
						|
          delete exclCt;
 | 
						|
        }
 | 
						|
        if (absorvedPfs.empty()) {
 | 
						|
          // hack to erase parfactor g
 | 
						|
          absorvedPfs.push_back (0);
 | 
						|
        }
 | 
						|
        break;
 | 
						|
      } else {
 | 
						|
        delete commCt;
 | 
						|
        delete exclCt;
 | 
						|
      }
 | 
						|
    }
 | 
						|
  }
 | 
						|
  return absorvedPfs;
 | 
						|
}
 | 
						|
 | 
						|
}
 | 
						|
 | 
						|
}  // namespace LiftedOperations
 | 
						|
 | 
						|
}  // namespace Horus
 | 
						|
 |