From 57339760b9615eda045303816f47e74ac2bd037c Mon Sep 17 00:00:00 2001 From: Tiago Gomes Date: Thu, 20 Dec 2012 21:11:51 +0000 Subject: [PATCH] Merge LiftedKc and LiftedCircuit in one file --- packages/CLPBN/horus/LiftedCircuit.cpp | 1232 ----------------------- packages/CLPBN/horus/LiftedCircuit.h | 279 ------ packages/CLPBN/horus/LiftedKc.cpp | 1234 +++++++++++++++++++++++- packages/CLPBN/horus/LiftedKc.h | 274 +++++- packages/CLPBN/horus/Makefile.in | 7 +- 5 files changed, 1506 insertions(+), 1520 deletions(-) delete mode 100644 packages/CLPBN/horus/LiftedCircuit.cpp delete mode 100644 packages/CLPBN/horus/LiftedCircuit.h diff --git a/packages/CLPBN/horus/LiftedCircuit.cpp b/packages/CLPBN/horus/LiftedCircuit.cpp deleted file mode 100644 index 863f86f29..000000000 --- a/packages/CLPBN/horus/LiftedCircuit.cpp +++ /dev/null @@ -1,1232 +0,0 @@ -#include - -#include "LiftedCircuit.h" - - -OrNode::~OrNode (void) -{ - delete leftBranch_; - delete rightBranch_; -} - - - -double -OrNode::weight (void) const -{ - double lw = leftBranch_->weight(); - double rw = rightBranch_->weight(); - return Globals::logDomain ? Util::logSum (lw, rw) : lw + rw; -} - - - -AndNode::~AndNode (void) -{ - delete leftBranch_; - delete rightBranch_; -} - - - -double -AndNode::weight (void) const -{ - double lw = leftBranch_->weight(); - double rw = rightBranch_->weight(); - return Globals::logDomain ? lw + rw : lw * rw; -} - - - -int SetOrNode::nrPos_ = -1; -int SetOrNode::nrNeg_ = -1; - - - -SetOrNode::~SetOrNode (void) -{ - delete follow_; -} - - - -double -SetOrNode::weight (void) const -{ - double weightSum = LogAware::addIdenty(); - for (unsigned i = 0; i < nrGroundings_ + 1; i++) { - nrPos_ = nrGroundings_ - i; - nrNeg_ = i; - if (Globals::logDomain) { - double nrCombs = Util::nrCombinations (nrGroundings_, i); - double w = follow_->weight(); - weightSum = Util::logSum (weightSum, std::log (nrCombs) + w); - } else { - double w = follow_->weight(); - weightSum += Util::nrCombinations (nrGroundings_, i) * w; - } - } - nrPos_ = -1; - nrNeg_ = -1; - return weightSum; -} - - - -SetAndNode::~SetAndNode (void) -{ - delete follow_; -} - - - -double -SetAndNode::weight (void) const -{ - return LogAware::pow (follow_->weight(), nrGroundings_); -} - - - -IncExcNode::~IncExcNode (void) -{ - delete plus1Branch_; - delete plus2Branch_; - delete minusBranch_; -} - - - -double -IncExcNode::weight (void) const -{ - double w = 0.0; - if (Globals::logDomain) { - w = Util::logSum (plus1Branch_->weight(), plus2Branch_->weight()); - w = std::log (std::exp (w) - std::exp (minusBranch_->weight())); - } else { - w = plus1Branch_->weight() + plus2Branch_->weight(); - w -= minusBranch_->weight(); - } - return w; -} - - - -LeafNode::~LeafNode (void) -{ - delete clause_; -} - - - -double -LeafNode::weight (void) const -{ - assert (clause_->isUnit()); - if (clause_->posCountedLogVars().empty() == false - || clause_->negCountedLogVars().empty() == false) { - if (SetOrNode::isSet() == false) { - // return a NaN if we have a SetOrNode - // ancester that is not set. This can only - // happen when calculating the weights - // for the edge labels in graphviz - return 0.0 / 0.0; - } - } - double weight = clause_->literals()[0].isPositive() - ? lwcnf_.posWeight (clause_->literals().front().lid()) - : lwcnf_.negWeight (clause_->literals().front().lid()); - LogVarSet lvs = clause_->constr().logVarSet(); - lvs -= clause_->ipgLogVars(); - lvs -= clause_->posCountedLogVars(); - lvs -= clause_->negCountedLogVars(); - unsigned nrGroundings = 1; - if (lvs.empty() == false) { - nrGroundings = clause_->constr().projectedCopy (lvs).size(); - } - if (clause_->posCountedLogVars().empty() == false) { - nrGroundings *= std::pow (SetOrNode::nrPositives(), - clause_->nrPosCountedLogVars()); - } - if (clause_->negCountedLogVars().empty() == false) { - nrGroundings *= std::pow (SetOrNode::nrNegatives(), - clause_->nrNegCountedLogVars()); - } - return LogAware::pow (weight, nrGroundings); -} - - - -SmoothNode::~SmoothNode (void) -{ - Clause::deleteClauses (clauses_); -} - - - -double -SmoothNode::weight (void) const -{ - Clauses cs = clauses(); - double totalWeight = LogAware::multIdenty(); - for (size_t i = 0; i < cs.size(); i++) { - double posWeight = lwcnf_.posWeight (cs[i]->literals()[0].lid()); - double negWeight = lwcnf_.negWeight (cs[i]->literals()[0].lid()); - LogVarSet lvs = cs[i]->constr().logVarSet(); - lvs -= cs[i]->ipgLogVars(); - lvs -= cs[i]->posCountedLogVars(); - lvs -= cs[i]->negCountedLogVars(); - unsigned nrGroundings = 1; - if (lvs.empty() == false) { - nrGroundings = cs[i]->constr().projectedCopy (lvs).size(); - } - if (cs[i]->posCountedLogVars().empty() == false) { - nrGroundings *= std::pow (SetOrNode::nrPositives(), - cs[i]->nrPosCountedLogVars()); - } - if (cs[i]->negCountedLogVars().empty() == false) { - nrGroundings *= std::pow (SetOrNode::nrNegatives(), - cs[i]->nrNegCountedLogVars()); - } - if (Globals::logDomain) { - totalWeight += Util::logSum (posWeight, negWeight) * nrGroundings; - } else { - totalWeight *= std::pow (posWeight + negWeight, nrGroundings); - } - } - return totalWeight; -} - - - -double -TrueNode::weight (void) const -{ - return LogAware::multIdenty(); -} - - - -double -CompilationFailedNode::weight (void) const -{ - // weighted model counting in compilation - // failed nodes should give NaN - return 0.0 / 0.0; -} - - - -LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf) - : lwcnf_(lwcnf) -{ - root_ = 0; - compilationSucceeded_ = true; - Clauses clauses = Clause::copyClauses (lwcnf->clauses()); - compile (&root_, clauses); - if (compilationSucceeded_) { - smoothCircuit (root_); - } - if (Globals::verbosity > 1) { - if (compilationSucceeded_) { - double wmc = LogAware::exp (getWeightedModelCount()); - cout << "Weighted model count = " << wmc << endl << endl; - } - cout << "Exporting circuit to graphviz (circuit.dot)..." ; - cout << endl << endl; - exportToGraphViz ("circuit.dot"); - } -} - - - -LiftedCircuit::~LiftedCircuit (void) -{ - delete root_; - unordered_map::iterator it; - it = originClausesMap_.begin(); - while (it != originClausesMap_.end()) { - Clause::deleteClauses (it->second); - ++ it; - } -} - - - -bool -LiftedCircuit::isCompilationSucceeded (void) const -{ - return compilationSucceeded_; -} - - - -double -LiftedCircuit::getWeightedModelCount (void) const -{ - assert (compilationSucceeded_); - return root_->weight(); -} - - - -void -LiftedCircuit::exportToGraphViz (const char* fileName) -{ - ofstream out (fileName); - if (!out.is_open()) { - cerr << "Error: couldn't open file '" << fileName << "'." ; - return; - } - out << "digraph {" << endl; - out << "ranksep=1" << endl; - exportToGraphViz (root_, out); - out << "}" << endl; - out.close(); -} - - - -void -LiftedCircuit::compile ( - CircuitNode** follow, - Clauses& clauses) -{ - if (compilationSucceeded_ == false - && Globals::verbosity <= 1) { - return; - } - - if (clauses.empty()) { - *follow = new TrueNode(); - return; - } - - if (clauses.size() == 1 && clauses[0]->isUnit()) { - *follow = new LeafNode (clauses[0], *lwcnf_); - return; - } - - if (tryUnitPropagation (follow, clauses)) { - return; - } - - if (tryIndependence (follow, clauses)) { - return; - } - - if (tryShannonDecomp (follow, clauses)) { - return; - } - - if (tryInclusionExclusion (follow, clauses)) { - return; - } - - if (tryIndepPartialGrounding (follow, clauses)) { - return; - } - - if (tryAtomCounting (follow, clauses)) { - return; - } - - *follow = new CompilationFailedNode(); - if (Globals::verbosity > 1) { - originClausesMap_[*follow] = clauses; - explanationMap_[*follow] = "" ; - } - compilationSucceeded_ = false; -} - - - -bool -LiftedCircuit::tryUnitPropagation ( - CircuitNode** follow, - Clauses& clauses) -{ - if (Globals::verbosity > 1) { - backupClauses_ = Clause::copyClauses (clauses); - } - for (size_t i = 0; i < clauses.size(); i++) { - if (clauses[i]->isUnit()) { - Clauses propagClauses; - for (size_t j = 0; j < clauses.size(); j++) { - if (i != j) { - LiteralId lid = clauses[i]->literals()[0].lid(); - LogVarTypes types = clauses[i]->logVarTypes (0); - if (clauses[i]->literals()[0].isPositive()) { - if (clauses[j]->containsPositiveLiteral (lid, types) == false) { - clauses[j]->removeNegativeLiterals (lid, types); - if (clauses[j]->nrLiterals() > 0) { - propagClauses.push_back (clauses[j]); - } else { - delete clauses[j]; - } - } else { - delete clauses[j]; - } - } else if (clauses[i]->literals()[0].isNegative()) { - if (clauses[j]->containsNegativeLiteral (lid, types) == false) { - clauses[j]->removePositiveLiterals (lid, types); - if (clauses[j]->nrLiterals() > 0) { - propagClauses.push_back (clauses[j]); - } else { - delete clauses[j]; - } - } else { - delete clauses[j]; - } - } - } - } - - AndNode* andNode = new AndNode(); - if (Globals::verbosity > 1) { - originClausesMap_[andNode] = backupClauses_; - stringstream explanation; - explanation << " UP on " << clauses[i]->literals()[0]; - explanationMap_[andNode] = explanation.str(); - } - - Clauses unitClause = { clauses[i] }; - compile (andNode->leftBranch(), unitClause); - compile (andNode->rightBranch(), propagClauses); - (*follow) = andNode; - return true; - } - } - if (Globals::verbosity > 1) { - Clause::deleteClauses (backupClauses_); - } - return false; -} - - - -bool -LiftedCircuit::tryIndependence ( - CircuitNode** follow, - Clauses& clauses) -{ - if (clauses.size() == 1) { - return false; - } - if (Globals::verbosity > 1) { - backupClauses_ = Clause::copyClauses (clauses); - } - Clauses depClauses = { clauses[0] }; - Clauses indepClauses (clauses.begin() + 1, clauses.end()); - bool finish = false; - while (finish == false) { - finish = true; - for (size_t i = 0; i < indepClauses.size(); i++) { - if (independentClause (*indepClauses[i], depClauses) == false) { - depClauses.push_back (indepClauses[i]); - indepClauses.erase (indepClauses.begin() + i); - finish = false; - break; - } - } - } - if (indepClauses.empty() == false) { - AndNode* andNode = new AndNode (); - if (Globals::verbosity > 1) { - originClausesMap_[andNode] = backupClauses_; - explanationMap_[andNode] = " Independence" ; - } - compile (andNode->leftBranch(), depClauses); - compile (andNode->rightBranch(), indepClauses); - (*follow) = andNode; - return true; - } - if (Globals::verbosity > 1) { - Clause::deleteClauses (backupClauses_); - } - return false; -} - - - -bool -LiftedCircuit::tryShannonDecomp ( - CircuitNode** follow, - Clauses& clauses) -{ - if (Globals::verbosity > 1) { - backupClauses_ = Clause::copyClauses (clauses); - } - for (size_t i = 0; i < clauses.size(); i++) { - const Literals& literals = clauses[i]->literals(); - for (size_t j = 0; j < literals.size(); j++) { - if (literals[j].isGround ( - clauses[i]->constr(), clauses[i]->ipgLogVars())) { - - Clause* c1 = lwcnf_->createClause (literals[j].lid()); - Clause* c2 = new Clause (*c1); - c2->literals().front().complement(); - - Clauses otherClauses = Clause::copyClauses (clauses); - clauses.push_back (c1); - otherClauses.push_back (c2); - - OrNode* orNode = new OrNode(); - if (Globals::verbosity > 1) { - originClausesMap_[orNode] = backupClauses_; - stringstream explanation; - explanation << " SD on " << literals[j]; - explanationMap_[orNode] = explanation.str(); - } - - compile (orNode->leftBranch(), clauses); - compile (orNode->rightBranch(), otherClauses); - (*follow) = orNode; - return true; - } - } - } - if (Globals::verbosity > 1) { - Clause::deleteClauses (backupClauses_); - } - return false; -} - - - -bool -LiftedCircuit::tryInclusionExclusion ( - CircuitNode** follow, - Clauses& clauses) -{ - if (Globals::verbosity > 1) { - backupClauses_ = Clause::copyClauses (clauses); - } - for (size_t i = 0; i < clauses.size(); i++) { - Literals depLits = { clauses[i]->literals().front() }; - Literals indepLits (clauses[i]->literals().begin() + 1, - clauses[i]->literals().end()); - bool finish = false; - while (finish == false) { - finish = true; - for (size_t j = 0; j < indepLits.size(); j++) { - if (independentLiteral (indepLits[j], depLits) == false) { - depLits.push_back (indepLits[j]); - indepLits.erase (indepLits.begin() + j); - finish = false; - break; - } - } - } - if (indepLits.empty() == false) { - LogVarSet lvs1; - for (size_t j = 0; j < depLits.size(); j++) { - lvs1 |= depLits[j].logVarSet(); - } - if (clauses[i]->constr().isCountNormalized (lvs1) == false) { - break; - } - LogVarSet lvs2; - for (size_t j = 0; j < indepLits.size(); j++) { - lvs2 |= indepLits[j].logVarSet(); - } - if (clauses[i]->constr().isCountNormalized (lvs2) == false) { - break; - } - Clause* c1 = new Clause (clauses[i]->constr().projectedCopy (lvs1)); - for (size_t j = 0; j < depLits.size(); j++) { - c1->addLiteral (depLits[j]); - } - Clause* c2 = new Clause (clauses[i]->constr().projectedCopy (lvs2)); - for (size_t j = 0; j < indepLits.size(); j++) { - c2->addLiteral (indepLits[j]); - } - - clauses.erase (clauses.begin() + i); - Clauses plus1Clauses = Clause::copyClauses (clauses); - Clauses plus2Clauses = Clause::copyClauses (clauses); - - plus1Clauses.push_back (c1); - plus2Clauses.push_back (c2); - clauses.push_back (c1); - clauses.push_back (c2); - - IncExcNode* ieNode = new IncExcNode(); - if (Globals::verbosity > 1) { - originClausesMap_[ieNode] = backupClauses_; - stringstream explanation; - explanation << " IncExc on clause nº " << i + 1; - explanationMap_[ieNode] = explanation.str(); - } - compile (ieNode->plus1Branch(), plus1Clauses); - compile (ieNode->plus2Branch(), plus2Clauses); - compile (ieNode->minusBranch(), clauses); - *follow = ieNode; - return true; - } - } - if (Globals::verbosity > 1) { - Clause::deleteClauses (backupClauses_); - } - return false; -} - - - -bool -LiftedCircuit::tryIndepPartialGrounding ( - CircuitNode** follow, - Clauses& clauses) -{ - // assumes that all literals have logical variables - // else, shannon decomp was possible - if (Globals::verbosity > 1) { - backupClauses_ = Clause::copyClauses (clauses); - } - LogVars rootLogVars; - LogVarSet lvs = clauses[0]->ipgCandidates(); - for (size_t i = 0; i < lvs.size(); i++) { - rootLogVars.clear(); - rootLogVars.push_back (lvs[i]); - ConstraintTree ct = clauses[0]->constr().projectedCopy ({lvs[i]}); - if (tryIndepPartialGroundingAux (clauses, ct, rootLogVars)) { - for (size_t j = 0; j < clauses.size(); j++) { - clauses[j]->addIpgLogVar (rootLogVars[j]); - } - SetAndNode* setAndNode = new SetAndNode (ct.size()); - if (Globals::verbosity > 1) { - originClausesMap_[setAndNode] = backupClauses_; - explanationMap_[setAndNode] = " IPG" ; - } - *follow = setAndNode; - compile (setAndNode->follow(), clauses); - return true; - } - } - if (Globals::verbosity > 1) { - Clause::deleteClauses (backupClauses_); - } - return false; -} - - - -bool -LiftedCircuit::tryIndepPartialGroundingAux ( - Clauses& clauses, - ConstraintTree& ct, - LogVars& rootLogVars) -{ - for (size_t i = 1; i < clauses.size(); i++) { - LogVarSet lvs = clauses[i]->ipgCandidates(); - for (size_t j = 0; j < lvs.size(); j++) { - ConstraintTree ct2 = clauses[i]->constr().projectedCopy ({lvs[j]}); - if (ct.tupleSet() == ct2.tupleSet()) { - rootLogVars.push_back (lvs[j]); - break; - } - } - if (rootLogVars.size() != i + 1) { - return false; - } - } - // verifies if the IPG logical vars appear in the same positions - unordered_map positions; - for (size_t i = 0; i < clauses.size(); i++) { - const Literals& literals = clauses[i]->literals(); - for (size_t j = 0; j < literals.size(); j++) { - size_t idx = literals[j].indexOfLogVar (rootLogVars[i]); - assert (idx != literals[j].nrLogVars()); - unordered_map::iterator it; - it = positions.find (literals[j].lid()); - if (it != positions.end()) { - if (it->second != idx) { - return false; - } - } else { - positions[literals[j].lid()] = idx; - } - } - } - return true; -} - - - -bool -LiftedCircuit::tryAtomCounting ( - CircuitNode** follow, - Clauses& clauses) -{ - for (size_t i = 0; i < clauses.size(); i++) { - if (clauses[i]->nrPosCountedLogVars() > 0 - || clauses[i]->nrNegCountedLogVars() > 0) { - // only allow one atom counting node per branch - return false; - } - } - if (Globals::verbosity > 1) { - backupClauses_ = Clause::copyClauses (clauses); - } - for (size_t i = 0; i < clauses.size(); i++) { - Literals literals = clauses[i]->literals(); - for (size_t j = 0; j < literals.size(); j++) { - if (literals[j].nrLogVars() == 1 - && ! clauses[i]->isIpgLogVar (literals[j].logVars().front()) - && ! clauses[i]->isCountedLogVar (literals[j].logVars().front())) { - unsigned nrGroundings = clauses[i]->constr().projectedCopy ( - literals[j].logVars()).size(); - SetOrNode* setOrNode = new SetOrNode (nrGroundings); - if (Globals::verbosity > 1) { - originClausesMap_[setOrNode] = backupClauses_; - explanationMap_[setOrNode] = " AC" ; - } - Clause* c1 = new Clause ( - clauses[i]->constr().projectedCopy (literals[j].logVars())); - Clause* c2 = new Clause ( - clauses[i]->constr().projectedCopy (literals[j].logVars())); - c1->addLiteral (literals[j]); - c2->addLiteralComplemented (literals[j]); - c1->addPosCountedLogVar (literals[j].logVars().front()); - c2->addNegCountedLogVar (literals[j].logVars().front()); - clauses.push_back (c1); - clauses.push_back (c2); - shatterCountedLogVars (clauses); - compile (setOrNode->follow(), clauses); - *follow = setOrNode; - return true; - } - } - } - if (Globals::verbosity > 1) { - Clause::deleteClauses (backupClauses_); - } - return false; -} - - - -void -LiftedCircuit::shatterCountedLogVars (Clauses& clauses) -{ - while (shatterCountedLogVarsAux (clauses)) ; -} - - - -bool -LiftedCircuit::shatterCountedLogVarsAux (Clauses& clauses) -{ - for (size_t i = 0; i < clauses.size() - 1; i++) { - for (size_t j = i + 1; j < clauses.size(); j++) { - bool splitedSome = shatterCountedLogVarsAux (clauses, i, j); - if (splitedSome) { - return true; - } - } - } - return false; -} - - - -bool -LiftedCircuit::shatterCountedLogVarsAux ( - Clauses& clauses, - size_t idx1, - size_t idx2) -{ - Literals lits1 = clauses[idx1]->literals(); - Literals lits2 = clauses[idx2]->literals(); - for (size_t i = 0; i < lits1.size(); i++) { - for (size_t j = 0; j < lits2.size(); j++) { - if (lits1[i].lid() == lits2[j].lid()) { - LogVars lvs1 = lits1[i].logVars(); - LogVars lvs2 = lits2[j].logVars(); - for (size_t k = 0; k < lvs1.size(); k++) { - if (clauses[idx1]->isCountedLogVar (lvs1[k]) - && clauses[idx2]->isCountedLogVar (lvs2[k]) == false) { - clauses.push_back (new Clause (*clauses[idx2])); - clauses[idx2]->addPosCountedLogVar (lvs2[k]); - clauses.back()->addNegCountedLogVar (lvs2[k]); - return true; - } - if (clauses[idx2]->isCountedLogVar (lvs2[k]) - && clauses[idx1]->isCountedLogVar (lvs1[k]) == false) { - clauses.push_back (new Clause (*clauses[idx1])); - clauses[idx1]->addPosCountedLogVar (lvs1[k]); - clauses.back()->addNegCountedLogVar (lvs1[k]); - return true; - } - } - } - } - } - return false; -} - - - -bool -LiftedCircuit::independentClause ( - Clause& clause, - Clauses& otherClauses) const -{ - for (size_t i = 0; i < otherClauses.size(); i++) { - if (Clause::independentClauses (clause, *otherClauses[i]) == false) { - return false; - } - } - return true; -} - - - -bool -LiftedCircuit::independentLiteral ( - const Literal& lit, - const Literals& otherLits) const -{ - for (size_t i = 0; i < otherLits.size(); i++) { - if (lit.lid() == otherLits[i].lid() - || (lit.logVarSet() & otherLits[i].logVarSet()).empty() == false) { - return false; - } - } - return true; -} - - - -LitLvTypesSet -LiftedCircuit::smoothCircuit (CircuitNode* node) -{ - assert (node != 0); - LitLvTypesSet propagLits; - - switch (getCircuitNodeType (node)) { - - case CircuitNodeType::OR_NODE: { - OrNode* casted = dynamic_cast(node); - LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch()); - LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch()); - LitLvTypesSet missingLeft = lids2 - lids1; - LitLvTypesSet missingRight = lids1 - lids2; - createSmoothNode (missingLeft, casted->leftBranch()); - createSmoothNode (missingRight, casted->rightBranch()); - propagLits |= lids1; - propagLits |= lids2; - break; - } - - case CircuitNodeType::AND_NODE: { - AndNode* casted = dynamic_cast(node); - LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch()); - LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch()); - propagLits |= lids1; - propagLits |= lids2; - break; - } - - case CircuitNodeType::SET_OR_NODE: { - SetOrNode* casted = dynamic_cast(node); - propagLits = smoothCircuit (*casted->follow()); - TinySet> litSet; - for (size_t i = 0; i < propagLits.size(); i++) { - litSet.insert (make_pair (propagLits[i].lid(), - propagLits[i].logVarTypes().size())); - } - LitLvTypesSet missingLids; - for (size_t i = 0; i < litSet.size(); i++) { - vector allTypes = getAllPossibleTypes (litSet[i].second); - for (size_t j = 0; j < allTypes.size(); j++) { - bool typeFound = false; - for (size_t k = 0; k < propagLits.size(); k++) { - if (litSet[i].first == propagLits[k].lid() - && containsTypes (propagLits[k].logVarTypes(), allTypes[j])) { - typeFound = true; - break; - } - } - if (typeFound == false) { - missingLids.insert (LitLvTypes (litSet[i].first, allTypes[j])); - } - } - } - createSmoothNode (missingLids, casted->follow()); - // setAllFullLogVars() can cause repeated elements in - // the set. Fix this by reconstructing the set again - LitLvTypesSet copy = propagLits; - propagLits.clear(); - for (size_t i = 0; i < copy.size(); i++) { - copy[i].setAllFullLogVars(); - propagLits.insert (copy[i]); - } - break; - } - - case CircuitNodeType::SET_AND_NODE: { - SetAndNode* casted = dynamic_cast(node); - propagLits = smoothCircuit (*casted->follow()); - break; - } - - case CircuitNodeType::INC_EXC_NODE: { - IncExcNode* casted = dynamic_cast(node); - LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch()); - LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch()); - LitLvTypesSet missingPlus1 = lids2 - lids1; - LitLvTypesSet missingPlus2 = lids1 - lids2; - createSmoothNode (missingPlus1, casted->plus1Branch()); - createSmoothNode (missingPlus2, casted->plus2Branch()); - propagLits |= lids1; - propagLits |= lids2; - break; - } - - case CircuitNodeType::LEAF_NODE: { - LeafNode* casted = dynamic_cast(node); - propagLits.insert (LitLvTypes ( - casted->clause()->literals()[0].lid(), - casted->clause()->logVarTypes(0))); - } - - default: - break; - } - - return propagLits; -} - - - -void -LiftedCircuit::createSmoothNode ( - const LitLvTypesSet& missingLits, - CircuitNode** prev) -{ - if (missingLits.empty() == false) { - if (Globals::verbosity > 1) { - unordered_map::iterator it; - it = originClausesMap_.find (*prev); - if (it != originClausesMap_.end()) { - backupClauses_ = it->second; - } else { - backupClauses_ = Clause::copyClauses ( - {((dynamic_cast(*prev))->clause())}); - } - } - Clauses clauses; - for (size_t i = 0; i < missingLits.size(); i++) { - LiteralId lid = missingLits[i].lid(); - const LogVarTypes& types = missingLits[i].logVarTypes(); - Clause* c = lwcnf_->createClause (lid); - for (size_t j = 0; j < types.size(); j++) { - LogVar X = c->literals().front().logVars()[j]; - if (types[j] == LogVarType::POS_LV) { - c->addPosCountedLogVar (X); - } else if (types[j] == LogVarType::NEG_LV) { - c->addNegCountedLogVar (X); - } - } - c->addLiteralComplemented (c->literals()[0]); - clauses.push_back (c); - } - SmoothNode* smoothNode = new SmoothNode (clauses, *lwcnf_); - *prev = new AndNode (smoothNode, *prev); - if (Globals::verbosity > 1) { - originClausesMap_[*prev] = backupClauses_; - explanationMap_[*prev] = " Smoothing" ; - } - } -} - - - -vector -LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const -{ - if (nrLogVars == 0) { - return {}; - } - if (nrLogVars == 1) { - return {{LogVarType::POS_LV},{LogVarType::NEG_LV}}; - } - vector res; - Ranges ranges (nrLogVars, 2); - Indexer indexer (ranges); - while (indexer.valid()) { - LogVarTypes types; - for (size_t i = 0; i < nrLogVars; i++) { - if (indexer[i] == 0) { - types.push_back (LogVarType::POS_LV); - } else { - types.push_back (LogVarType::NEG_LV); - } - } - res.push_back (types); - ++ indexer; - } - return res; -} - - - -bool -LiftedCircuit::containsTypes ( - const LogVarTypes& typesA, - const LogVarTypes& typesB) const -{ - for (size_t i = 0; i < typesA.size(); i++) { - if (typesA[i] == LogVarType::FULL_LV) { - - } else if (typesA[i] == LogVarType::POS_LV - && typesB[i] == LogVarType::POS_LV) { - - } else if (typesA[i] == LogVarType::NEG_LV - && typesB[i] == LogVarType::NEG_LV) { - - } else { - return false; - } - } - return true; -} - - - -CircuitNodeType -LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const -{ - CircuitNodeType type; - if (dynamic_cast(node) != 0) { - type = CircuitNodeType::OR_NODE; - } else if (dynamic_cast(node) != 0) { - type = CircuitNodeType::AND_NODE; - } else if (dynamic_cast(node) != 0) { - type = CircuitNodeType::SET_OR_NODE; - } else if (dynamic_cast(node) != 0) { - type = CircuitNodeType::SET_AND_NODE; - } else if (dynamic_cast(node) != 0) { - type = CircuitNodeType::INC_EXC_NODE; - } else if (dynamic_cast(node) != 0) { - type = CircuitNodeType::LEAF_NODE; - } else if (dynamic_cast(node) != 0) { - type = CircuitNodeType::SMOOTH_NODE; - } else if (dynamic_cast(node) != 0) { - type = CircuitNodeType::TRUE_NODE; - } else if (dynamic_cast(node) != 0) { - type = CircuitNodeType::COMPILATION_FAILED_NODE; - } else { - assert (false); - } - return type; -} - - - -void -LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) -{ - assert (node != 0); - - static unsigned nrAuxNodes = 0; - stringstream ss; - ss << "n" << nrAuxNodes; - string auxNode = ss.str(); - nrAuxNodes ++; - string opStyle = "shape=circle,width=0.7,margin=\"0.0,0.0\"," ; - - switch (getCircuitNodeType (node)) { - - case OR_NODE: { - OrNode* casted = dynamic_cast(node); - printClauses (casted, os); - - os << auxNode << " [" << opStyle << "label=\"∨\"]" << endl; - os << escapeNode (node) << " -> " << auxNode; - os << " [label=\"" << getExplanationString (node) << "\"]" ; - os << endl; - - os << auxNode << " -> " ; - os << escapeNode (*casted->leftBranch()); - os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ; - os << endl; - - os << auxNode << " -> " ; - os << escapeNode (*casted->rightBranch()); - os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ; - os << endl; - - exportToGraphViz (*casted->leftBranch(), os); - exportToGraphViz (*casted->rightBranch(), os); - break; - } - - case AND_NODE: { - AndNode* casted = dynamic_cast(node); - printClauses (casted, os); - - os << auxNode << " [" << opStyle << "label=\"∧\"]" << endl; - os << escapeNode (node) << " -> " << auxNode; - os << " [label=\"" << getExplanationString (node) << "\"]" ; - os << endl; - - os << auxNode << " -> " ; - os << escapeNode (*casted->leftBranch()); - os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ; - os << endl; - - os << auxNode << " -> " ; - os << escapeNode (*casted->rightBranch()) << endl; - os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ; - os << endl; - - exportToGraphViz (*casted->leftBranch(), os); - exportToGraphViz (*casted->rightBranch(), os); - break; - } - - case SET_OR_NODE: { - SetOrNode* casted = dynamic_cast(node); - printClauses (casted, os); - - os << auxNode << " [" << opStyle << "label=\"∨(X)\"]" << endl; - os << escapeNode (node) << " -> " << auxNode; - os << " [label=\"" << getExplanationString (node) << "\"]" ; - os << endl; - - os << auxNode << " -> " ; - os << escapeNode (*casted->follow()); - os << " [label=\" " << (*casted->follow())->weight() << "\"]" ; - os << endl; - - exportToGraphViz (*casted->follow(), os); - break; - } - - case SET_AND_NODE: { - SetAndNode* casted = dynamic_cast(node); - printClauses (casted, os); - - os << auxNode << " [" << opStyle << "label=\"∧(X)\"]" << endl; - os << escapeNode (node) << " -> " << auxNode; - os << " [label=\"" << getExplanationString (node) << "\"]" ; - os << endl; - - os << auxNode << " -> " ; - os << escapeNode (*casted->follow()); - os << " [label=\" " << (*casted->follow())->weight() << "\"]" ; - os << endl; - - exportToGraphViz (*casted->follow(), os); - break; - } - - case INC_EXC_NODE: { - IncExcNode* casted = dynamic_cast(node); - printClauses (casted, os); - - os << auxNode << " [" << opStyle << "label=\"+ - +\"]" ; - os << endl; - os << escapeNode (node) << " -> " << auxNode; - os << " [label=\"" << getExplanationString (node) << "\"]" ; - os << endl; - - os << auxNode << " -> " ; - os << escapeNode (*casted->plus1Branch()); - os << " [label=\" " << (*casted->plus1Branch())->weight() << "\"]" ; - os << endl; - - os << auxNode << " -> " ; - os << escapeNode (*casted->minusBranch()) << endl; - os << " [label=\" " << (*casted->minusBranch())->weight() << "\"]" ; - os << endl; - - os << auxNode << " -> " ; - os << escapeNode (*casted->plus2Branch()); - os << " [label=\" " << (*casted->plus2Branch())->weight() << "\"]" ; - os << endl; - - exportToGraphViz (*casted->plus1Branch(), os); - exportToGraphViz (*casted->plus2Branch(), os); - exportToGraphViz (*casted->minusBranch(), os); - break; - } - - case LEAF_NODE: { - printClauses (node, os, "style=filled,fillcolor=palegreen,"); - break; - } - - case SMOOTH_NODE: { - printClauses (node, os, "style=filled,fillcolor=lightblue,"); - break; - } - - case TRUE_NODE: { - os << escapeNode (node); - os << " [shape=box,label=\"⊤\"]" ; - os << endl; - break; - } - - case COMPILATION_FAILED_NODE: { - printClauses (node, os, "style=filled,fillcolor=salmon,"); - break; - } - - default: - assert (false); - } -} - - - -string -LiftedCircuit::escapeNode (const CircuitNode* node) const -{ - stringstream ss; - ss << "\"" << node << "\"" ; - return ss.str(); -} - - - -string -LiftedCircuit::getExplanationString (CircuitNode* node) -{ - return Util::contains (explanationMap_, node) - ? explanationMap_[node] - : "" ; -} - - - -void -LiftedCircuit::printClauses ( - CircuitNode* node, - ofstream& os, - string extraOptions) -{ - Clauses clauses; - if (Util::contains (originClausesMap_, node)) { - clauses = originClausesMap_[node]; - } else if (getCircuitNodeType (node) == CircuitNodeType::LEAF_NODE) { - clauses = { (dynamic_cast(node))->clause() } ; - } else if (getCircuitNodeType (node) == CircuitNodeType::SMOOTH_NODE) { - clauses = (dynamic_cast(node))->clauses(); - } - assert (clauses.empty() == false); - os << escapeNode (node); - os << " [shape=box," << extraOptions << "label=\"" ; - for (size_t i = 0; i < clauses.size(); i++) { - if (i != 0) os << "\\n" ; - os << *clauses[i]; - } - os << "\"]" ; - os << endl; -} - diff --git a/packages/CLPBN/horus/LiftedCircuit.h b/packages/CLPBN/horus/LiftedCircuit.h deleted file mode 100644 index e3883211b..000000000 --- a/packages/CLPBN/horus/LiftedCircuit.h +++ /dev/null @@ -1,279 +0,0 @@ -#ifndef HORUS_LIFTEDCIRCUIT_H -#define HORUS_LIFTEDCIRCUIT_H - -#include - -#include "LiftedWCNF.h" - - -enum CircuitNodeType { - OR_NODE, - AND_NODE, - SET_OR_NODE, - SET_AND_NODE, - INC_EXC_NODE, - LEAF_NODE, - SMOOTH_NODE, - TRUE_NODE, - COMPILATION_FAILED_NODE -}; - - - -class CircuitNode -{ - public: - CircuitNode (void) { } - - virtual ~CircuitNode (void) { } - - virtual double weight (void) const = 0; -}; - - - -class OrNode : public CircuitNode -{ - public: - OrNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { } - - ~OrNode (void); - - CircuitNode** leftBranch (void) { return &leftBranch_; } - CircuitNode** rightBranch (void) { return &rightBranch_; } - - double weight (void) const; - - private: - CircuitNode* leftBranch_; - CircuitNode* rightBranch_; -}; - - - -class AndNode : public CircuitNode -{ - public: - AndNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { } - - AndNode (CircuitNode* leftBranch, CircuitNode* rightBranch) - : CircuitNode(), leftBranch_(leftBranch), rightBranch_(rightBranch) { } - - ~AndNode (void); - - CircuitNode** leftBranch (void) { return &leftBranch_; } - CircuitNode** rightBranch (void) { return &rightBranch_; } - - double weight (void) const; - - private: - CircuitNode* leftBranch_; - CircuitNode* rightBranch_; -}; - - - -class SetOrNode : public CircuitNode -{ - public: - SetOrNode (unsigned nrGroundings) - : CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { } - - ~SetOrNode (void); - - CircuitNode** follow (void) { return &follow_; } - - static unsigned nrPositives (void) { return nrPos_; } - - static unsigned nrNegatives (void) { return nrNeg_; } - - static bool isSet (void) { return nrPos_ >= 0; } - - double weight (void) const; - - private: - CircuitNode* follow_; - unsigned nrGroundings_; - static int nrPos_; - static int nrNeg_; -}; - - - -class SetAndNode : public CircuitNode -{ - public: - SetAndNode (unsigned nrGroundings) - : CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { } - - ~SetAndNode (void); - - CircuitNode** follow (void) { return &follow_; } - - double weight (void) const; - - private: - CircuitNode* follow_; - unsigned nrGroundings_; -}; - - - -class IncExcNode : public CircuitNode -{ - public: - IncExcNode (void) - : CircuitNode(), plus1Branch_(0), plus2Branch_(0), minusBranch_(0) { } - - ~IncExcNode (void); - - CircuitNode** plus1Branch (void) { return &plus1Branch_; } - CircuitNode** plus2Branch (void) { return &plus2Branch_; } - CircuitNode** minusBranch (void) { return &minusBranch_; } - - double weight (void) const; - - private: - CircuitNode* plus1Branch_; - CircuitNode* plus2Branch_; - CircuitNode* minusBranch_; -}; - - - -class LeafNode : public CircuitNode -{ - public: - LeafNode (Clause* clause, const LiftedWCNF& lwcnf) - : CircuitNode(), clause_(clause), lwcnf_(lwcnf) { } - - ~LeafNode (void); - - const Clause* clause (void) const { return clause_; } - - Clause* clause (void) { return clause_; } - - double weight (void) const; - - private: - Clause* clause_; - const LiftedWCNF& lwcnf_; -}; - - - -class SmoothNode : public CircuitNode -{ - public: - SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf) - : CircuitNode(), clauses_(clauses), lwcnf_(lwcnf) { } - - ~SmoothNode (void); - - const Clauses& clauses (void) const { return clauses_; } - - Clauses clauses (void) { return clauses_; } - - double weight (void) const; - - private: - Clauses clauses_; - const LiftedWCNF& lwcnf_; -}; - - - -class TrueNode : public CircuitNode -{ - public: - TrueNode (void) : CircuitNode() { } - - double weight (void) const; -}; - - - -class CompilationFailedNode : public CircuitNode -{ - public: - CompilationFailedNode (void) : CircuitNode() { } - - double weight (void) const; -}; - - - -class LiftedCircuit -{ - public: - LiftedCircuit (const LiftedWCNF* lwcnf); - - ~LiftedCircuit (void); - - bool isCompilationSucceeded (void) const; - - double getWeightedModelCount (void) const; - - void exportToGraphViz (const char*); - - private: - - void compile (CircuitNode** follow, Clauses& clauses); - - bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses); - - bool tryIndependence (CircuitNode** follow, Clauses& clauses); - - bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses); - - bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses); - - bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses); - - bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct, - LogVars& rootLogVars); - - bool tryAtomCounting (CircuitNode** follow, Clauses& clauses); - - void shatterCountedLogVars (Clauses& clauses); - - bool shatterCountedLogVarsAux (Clauses& clauses); - - bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2); - - bool independentClause (Clause& clause, Clauses& otherClauses) const; - - bool independentLiteral (const Literal& lit, - const Literals& otherLits) const; - - LitLvTypesSet smoothCircuit (CircuitNode* node); - - void createSmoothNode (const LitLvTypesSet& lids, - CircuitNode** prev); - - vector getAllPossibleTypes (unsigned nrLogVars) const; - - bool containsTypes (const LogVarTypes& typesA, - const LogVarTypes& typesB) const; - - CircuitNodeType getCircuitNodeType (const CircuitNode* node) const; - - void exportToGraphViz (CircuitNode* node, ofstream&); - - void printClauses (CircuitNode* node, ofstream&, - string extraOptions = ""); - - string escapeNode (const CircuitNode* node) const; - - string getExplanationString (CircuitNode* node); - - CircuitNode* root_; - const LiftedWCNF* lwcnf_; - bool compilationSucceeded_; - Clauses backupClauses_; - unordered_map originClausesMap_; - unordered_map explanationMap_; -}; - -#endif // HORUS_LIFTEDCIRCUIT_H - diff --git a/packages/CLPBN/horus/LiftedKc.cpp b/packages/CLPBN/horus/LiftedKc.cpp index 678bacbec..45848ab70 100644 --- a/packages/CLPBN/horus/LiftedKc.cpp +++ b/packages/CLPBN/horus/LiftedKc.cpp @@ -1,10 +1,1240 @@ +#include + #include "LiftedKc.h" -#include "LiftedWCNF.h" -#include "LiftedCircuit.h" #include "LiftedOperations.h" #include "Indexer.h" + +OrNode::~OrNode (void) +{ + delete leftBranch_; + delete rightBranch_; +} + + + +double +OrNode::weight (void) const +{ + double lw = leftBranch_->weight(); + double rw = rightBranch_->weight(); + return Globals::logDomain ? Util::logSum (lw, rw) : lw + rw; +} + + + +AndNode::~AndNode (void) +{ + delete leftBranch_; + delete rightBranch_; +} + + + +double +AndNode::weight (void) const +{ + double lw = leftBranch_->weight(); + double rw = rightBranch_->weight(); + return Globals::logDomain ? lw + rw : lw * rw; +} + + + +int SetOrNode::nrPos_ = -1; +int SetOrNode::nrNeg_ = -1; + + + +SetOrNode::~SetOrNode (void) +{ + delete follow_; +} + + + +double +SetOrNode::weight (void) const +{ + double weightSum = LogAware::addIdenty(); + for (unsigned i = 0; i < nrGroundings_ + 1; i++) { + nrPos_ = nrGroundings_ - i; + nrNeg_ = i; + if (Globals::logDomain) { + double nrCombs = Util::nrCombinations (nrGroundings_, i); + double w = follow_->weight(); + weightSum = Util::logSum (weightSum, std::log (nrCombs) + w); + } else { + double w = follow_->weight(); + weightSum += Util::nrCombinations (nrGroundings_, i) * w; + } + } + nrPos_ = -1; + nrNeg_ = -1; + return weightSum; +} + + + +SetAndNode::~SetAndNode (void) +{ + delete follow_; +} + + + +double +SetAndNode::weight (void) const +{ + return LogAware::pow (follow_->weight(), nrGroundings_); +} + + + +IncExcNode::~IncExcNode (void) +{ + delete plus1Branch_; + delete plus2Branch_; + delete minusBranch_; +} + + + +double +IncExcNode::weight (void) const +{ + double w = 0.0; + if (Globals::logDomain) { + w = Util::logSum (plus1Branch_->weight(), plus2Branch_->weight()); + w = std::log (std::exp (w) - std::exp (minusBranch_->weight())); + } else { + w = plus1Branch_->weight() + plus2Branch_->weight(); + w -= minusBranch_->weight(); + } + return w; +} + + + +LeafNode::~LeafNode (void) +{ + delete clause_; +} + + + +double +LeafNode::weight (void) const +{ + assert (clause_->isUnit()); + if (clause_->posCountedLogVars().empty() == false + || clause_->negCountedLogVars().empty() == false) { + if (SetOrNode::isSet() == false) { + // return a NaN if we have a SetOrNode + // ancester that is not set. This can only + // happen when calculating the weights + // for the edge labels in graphviz + return 0.0 / 0.0; + } + } + double weight = clause_->literals()[0].isPositive() + ? lwcnf_.posWeight (clause_->literals().front().lid()) + : lwcnf_.negWeight (clause_->literals().front().lid()); + LogVarSet lvs = clause_->constr().logVarSet(); + lvs -= clause_->ipgLogVars(); + lvs -= clause_->posCountedLogVars(); + lvs -= clause_->negCountedLogVars(); + unsigned nrGroundings = 1; + if (lvs.empty() == false) { + nrGroundings = clause_->constr().projectedCopy (lvs).size(); + } + if (clause_->posCountedLogVars().empty() == false) { + nrGroundings *= std::pow (SetOrNode::nrPositives(), + clause_->nrPosCountedLogVars()); + } + if (clause_->negCountedLogVars().empty() == false) { + nrGroundings *= std::pow (SetOrNode::nrNegatives(), + clause_->nrNegCountedLogVars()); + } + return LogAware::pow (weight, nrGroundings); +} + + + +SmoothNode::~SmoothNode (void) +{ + Clause::deleteClauses (clauses_); +} + + + +double +SmoothNode::weight (void) const +{ + Clauses cs = clauses(); + double totalWeight = LogAware::multIdenty(); + for (size_t i = 0; i < cs.size(); i++) { + double posWeight = lwcnf_.posWeight (cs[i]->literals()[0].lid()); + double negWeight = lwcnf_.negWeight (cs[i]->literals()[0].lid()); + LogVarSet lvs = cs[i]->constr().logVarSet(); + lvs -= cs[i]->ipgLogVars(); + lvs -= cs[i]->posCountedLogVars(); + lvs -= cs[i]->negCountedLogVars(); + unsigned nrGroundings = 1; + if (lvs.empty() == false) { + nrGroundings = cs[i]->constr().projectedCopy (lvs).size(); + } + if (cs[i]->posCountedLogVars().empty() == false) { + nrGroundings *= std::pow (SetOrNode::nrPositives(), + cs[i]->nrPosCountedLogVars()); + } + if (cs[i]->negCountedLogVars().empty() == false) { + nrGroundings *= std::pow (SetOrNode::nrNegatives(), + cs[i]->nrNegCountedLogVars()); + } + if (Globals::logDomain) { + totalWeight += Util::logSum (posWeight, negWeight) * nrGroundings; + } else { + totalWeight *= std::pow (posWeight + negWeight, nrGroundings); + } + } + return totalWeight; +} + + + +double +TrueNode::weight (void) const +{ + return LogAware::multIdenty(); +} + + + +double +CompilationFailedNode::weight (void) const +{ + // weighted model counting in compilation + // failed nodes should give NaN + return 0.0 / 0.0; +} + + + +LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf) + : lwcnf_(lwcnf) +{ + root_ = 0; + compilationSucceeded_ = true; + Clauses clauses = Clause::copyClauses (lwcnf->clauses()); + compile (&root_, clauses); + if (compilationSucceeded_) { + smoothCircuit (root_); + } + if (Globals::verbosity > 1) { + if (compilationSucceeded_) { + double wmc = LogAware::exp (getWeightedModelCount()); + cout << "Weighted model count = " << wmc << endl << endl; + } + cout << "Exporting circuit to graphviz (circuit.dot)..." ; + cout << endl << endl; + exportToGraphViz ("circuit.dot"); + } +} + + + +LiftedCircuit::~LiftedCircuit (void) +{ + delete root_; + unordered_map::iterator it; + it = originClausesMap_.begin(); + while (it != originClausesMap_.end()) { + Clause::deleteClauses (it->second); + ++ it; + } +} + + + +bool +LiftedCircuit::isCompilationSucceeded (void) const +{ + return compilationSucceeded_; +} + + + +double +LiftedCircuit::getWeightedModelCount (void) const +{ + assert (compilationSucceeded_); + return root_->weight(); +} + + + +void +LiftedCircuit::exportToGraphViz (const char* fileName) +{ + ofstream out (fileName); + if (!out.is_open()) { + cerr << "Error: couldn't open file '" << fileName << "'." ; + return; + } + out << "digraph {" << endl; + out << "ranksep=1" << endl; + exportToGraphViz (root_, out); + out << "}" << endl; + out.close(); +} + + + +void +LiftedCircuit::compile ( + CircuitNode** follow, + Clauses& clauses) +{ + if (compilationSucceeded_ == false + && Globals::verbosity <= 1) { + return; + } + + if (clauses.empty()) { + *follow = new TrueNode(); + return; + } + + if (clauses.size() == 1 && clauses[0]->isUnit()) { + *follow = new LeafNode (clauses[0], *lwcnf_); + return; + } + + if (tryUnitPropagation (follow, clauses)) { + return; + } + + if (tryIndependence (follow, clauses)) { + return; + } + + if (tryShannonDecomp (follow, clauses)) { + return; + } + + if (tryInclusionExclusion (follow, clauses)) { + return; + } + + if (tryIndepPartialGrounding (follow, clauses)) { + return; + } + + if (tryAtomCounting (follow, clauses)) { + return; + } + + *follow = new CompilationFailedNode(); + if (Globals::verbosity > 1) { + originClausesMap_[*follow] = clauses; + explanationMap_[*follow] = "" ; + } + compilationSucceeded_ = false; +} + + + +bool +LiftedCircuit::tryUnitPropagation ( + CircuitNode** follow, + Clauses& clauses) +{ + if (Globals::verbosity > 1) { + backupClauses_ = Clause::copyClauses (clauses); + } + for (size_t i = 0; i < clauses.size(); i++) { + if (clauses[i]->isUnit()) { + Clauses propagClauses; + for (size_t j = 0; j < clauses.size(); j++) { + if (i != j) { + LiteralId lid = clauses[i]->literals()[0].lid(); + LogVarTypes types = clauses[i]->logVarTypes (0); + if (clauses[i]->literals()[0].isPositive()) { + if (clauses[j]->containsPositiveLiteral (lid, types) == false) { + clauses[j]->removeNegativeLiterals (lid, types); + if (clauses[j]->nrLiterals() > 0) { + propagClauses.push_back (clauses[j]); + } else { + delete clauses[j]; + } + } else { + delete clauses[j]; + } + } else if (clauses[i]->literals()[0].isNegative()) { + if (clauses[j]->containsNegativeLiteral (lid, types) == false) { + clauses[j]->removePositiveLiterals (lid, types); + if (clauses[j]->nrLiterals() > 0) { + propagClauses.push_back (clauses[j]); + } else { + delete clauses[j]; + } + } else { + delete clauses[j]; + } + } + } + } + + AndNode* andNode = new AndNode(); + if (Globals::verbosity > 1) { + originClausesMap_[andNode] = backupClauses_; + stringstream explanation; + explanation << " UP on " << clauses[i]->literals()[0]; + explanationMap_[andNode] = explanation.str(); + } + + Clauses unitClause = { clauses[i] }; + compile (andNode->leftBranch(), unitClause); + compile (andNode->rightBranch(), propagClauses); + (*follow) = andNode; + return true; + } + } + if (Globals::verbosity > 1) { + Clause::deleteClauses (backupClauses_); + } + return false; +} + + + +bool +LiftedCircuit::tryIndependence ( + CircuitNode** follow, + Clauses& clauses) +{ + if (clauses.size() == 1) { + return false; + } + if (Globals::verbosity > 1) { + backupClauses_ = Clause::copyClauses (clauses); + } + Clauses depClauses = { clauses[0] }; + Clauses indepClauses (clauses.begin() + 1, clauses.end()); + bool finish = false; + while (finish == false) { + finish = true; + for (size_t i = 0; i < indepClauses.size(); i++) { + if (independentClause (*indepClauses[i], depClauses) == false) { + depClauses.push_back (indepClauses[i]); + indepClauses.erase (indepClauses.begin() + i); + finish = false; + break; + } + } + } + if (indepClauses.empty() == false) { + AndNode* andNode = new AndNode (); + if (Globals::verbosity > 1) { + originClausesMap_[andNode] = backupClauses_; + explanationMap_[andNode] = " Independence" ; + } + compile (andNode->leftBranch(), depClauses); + compile (andNode->rightBranch(), indepClauses); + (*follow) = andNode; + return true; + } + if (Globals::verbosity > 1) { + Clause::deleteClauses (backupClauses_); + } + return false; +} + + + +bool +LiftedCircuit::tryShannonDecomp ( + CircuitNode** follow, + Clauses& clauses) +{ + if (Globals::verbosity > 1) { + backupClauses_ = Clause::copyClauses (clauses); + } + for (size_t i = 0; i < clauses.size(); i++) { + const Literals& literals = clauses[i]->literals(); + for (size_t j = 0; j < literals.size(); j++) { + if (literals[j].isGround ( + clauses[i]->constr(), clauses[i]->ipgLogVars())) { + + Clause* c1 = lwcnf_->createClause (literals[j].lid()); + Clause* c2 = new Clause (*c1); + c2->literals().front().complement(); + + Clauses otherClauses = Clause::copyClauses (clauses); + clauses.push_back (c1); + otherClauses.push_back (c2); + + OrNode* orNode = new OrNode(); + if (Globals::verbosity > 1) { + originClausesMap_[orNode] = backupClauses_; + stringstream explanation; + explanation << " SD on " << literals[j]; + explanationMap_[orNode] = explanation.str(); + } + + compile (orNode->leftBranch(), clauses); + compile (orNode->rightBranch(), otherClauses); + (*follow) = orNode; + return true; + } + } + } + if (Globals::verbosity > 1) { + Clause::deleteClauses (backupClauses_); + } + return false; +} + + + +bool +LiftedCircuit::tryInclusionExclusion ( + CircuitNode** follow, + Clauses& clauses) +{ + if (Globals::verbosity > 1) { + backupClauses_ = Clause::copyClauses (clauses); + } + for (size_t i = 0; i < clauses.size(); i++) { + Literals depLits = { clauses[i]->literals().front() }; + Literals indepLits (clauses[i]->literals().begin() + 1, + clauses[i]->literals().end()); + bool finish = false; + while (finish == false) { + finish = true; + for (size_t j = 0; j < indepLits.size(); j++) { + if (independentLiteral (indepLits[j], depLits) == false) { + depLits.push_back (indepLits[j]); + indepLits.erase (indepLits.begin() + j); + finish = false; + break; + } + } + } + if (indepLits.empty() == false) { + LogVarSet lvs1; + for (size_t j = 0; j < depLits.size(); j++) { + lvs1 |= depLits[j].logVarSet(); + } + if (clauses[i]->constr().isCountNormalized (lvs1) == false) { + break; + } + LogVarSet lvs2; + for (size_t j = 0; j < indepLits.size(); j++) { + lvs2 |= indepLits[j].logVarSet(); + } + if (clauses[i]->constr().isCountNormalized (lvs2) == false) { + break; + } + Clause* c1 = new Clause (clauses[i]->constr().projectedCopy (lvs1)); + for (size_t j = 0; j < depLits.size(); j++) { + c1->addLiteral (depLits[j]); + } + Clause* c2 = new Clause (clauses[i]->constr().projectedCopy (lvs2)); + for (size_t j = 0; j < indepLits.size(); j++) { + c2->addLiteral (indepLits[j]); + } + + clauses.erase (clauses.begin() + i); + Clauses plus1Clauses = Clause::copyClauses (clauses); + Clauses plus2Clauses = Clause::copyClauses (clauses); + + plus1Clauses.push_back (c1); + plus2Clauses.push_back (c2); + clauses.push_back (c1); + clauses.push_back (c2); + + IncExcNode* ieNode = new IncExcNode(); + if (Globals::verbosity > 1) { + originClausesMap_[ieNode] = backupClauses_; + stringstream explanation; + explanation << " IncExc on clause nº " << i + 1; + explanationMap_[ieNode] = explanation.str(); + } + compile (ieNode->plus1Branch(), plus1Clauses); + compile (ieNode->plus2Branch(), plus2Clauses); + compile (ieNode->minusBranch(), clauses); + *follow = ieNode; + return true; + } + } + if (Globals::verbosity > 1) { + Clause::deleteClauses (backupClauses_); + } + return false; +} + + + +bool +LiftedCircuit::tryIndepPartialGrounding ( + CircuitNode** follow, + Clauses& clauses) +{ + // assumes that all literals have logical variables + // else, shannon decomp was possible + if (Globals::verbosity > 1) { + backupClauses_ = Clause::copyClauses (clauses); + } + LogVars rootLogVars; + LogVarSet lvs = clauses[0]->ipgCandidates(); + for (size_t i = 0; i < lvs.size(); i++) { + rootLogVars.clear(); + rootLogVars.push_back (lvs[i]); + ConstraintTree ct = clauses[0]->constr().projectedCopy ({lvs[i]}); + if (tryIndepPartialGroundingAux (clauses, ct, rootLogVars)) { + for (size_t j = 0; j < clauses.size(); j++) { + clauses[j]->addIpgLogVar (rootLogVars[j]); + } + SetAndNode* setAndNode = new SetAndNode (ct.size()); + if (Globals::verbosity > 1) { + originClausesMap_[setAndNode] = backupClauses_; + explanationMap_[setAndNode] = " IPG" ; + } + *follow = setAndNode; + compile (setAndNode->follow(), clauses); + return true; + } + } + if (Globals::verbosity > 1) { + Clause::deleteClauses (backupClauses_); + } + return false; +} + + + +bool +LiftedCircuit::tryIndepPartialGroundingAux ( + Clauses& clauses, + ConstraintTree& ct, + LogVars& rootLogVars) +{ + for (size_t i = 1; i < clauses.size(); i++) { + LogVarSet lvs = clauses[i]->ipgCandidates(); + for (size_t j = 0; j < lvs.size(); j++) { + ConstraintTree ct2 = clauses[i]->constr().projectedCopy ({lvs[j]}); + if (ct.tupleSet() == ct2.tupleSet()) { + rootLogVars.push_back (lvs[j]); + break; + } + } + if (rootLogVars.size() != i + 1) { + return false; + } + } + // verifies if the IPG logical vars appear in the same positions + unordered_map positions; + for (size_t i = 0; i < clauses.size(); i++) { + const Literals& literals = clauses[i]->literals(); + for (size_t j = 0; j < literals.size(); j++) { + size_t idx = literals[j].indexOfLogVar (rootLogVars[i]); + assert (idx != literals[j].nrLogVars()); + unordered_map::iterator it; + it = positions.find (literals[j].lid()); + if (it != positions.end()) { + if (it->second != idx) { + return false; + } + } else { + positions[literals[j].lid()] = idx; + } + } + } + return true; +} + + + +bool +LiftedCircuit::tryAtomCounting ( + CircuitNode** follow, + Clauses& clauses) +{ + for (size_t i = 0; i < clauses.size(); i++) { + if (clauses[i]->nrPosCountedLogVars() > 0 + || clauses[i]->nrNegCountedLogVars() > 0) { + // only allow one atom counting node per branch + return false; + } + } + if (Globals::verbosity > 1) { + backupClauses_ = Clause::copyClauses (clauses); + } + for (size_t i = 0; i < clauses.size(); i++) { + Literals literals = clauses[i]->literals(); + for (size_t j = 0; j < literals.size(); j++) { + if (literals[j].nrLogVars() == 1 + && ! clauses[i]->isIpgLogVar (literals[j].logVars().front()) + && ! clauses[i]->isCountedLogVar (literals[j].logVars().front())) { + unsigned nrGroundings = clauses[i]->constr().projectedCopy ( + literals[j].logVars()).size(); + SetOrNode* setOrNode = new SetOrNode (nrGroundings); + if (Globals::verbosity > 1) { + originClausesMap_[setOrNode] = backupClauses_; + explanationMap_[setOrNode] = " AC" ; + } + Clause* c1 = new Clause ( + clauses[i]->constr().projectedCopy (literals[j].logVars())); + Clause* c2 = new Clause ( + clauses[i]->constr().projectedCopy (literals[j].logVars())); + c1->addLiteral (literals[j]); + c2->addLiteralComplemented (literals[j]); + c1->addPosCountedLogVar (literals[j].logVars().front()); + c2->addNegCountedLogVar (literals[j].logVars().front()); + clauses.push_back (c1); + clauses.push_back (c2); + shatterCountedLogVars (clauses); + compile (setOrNode->follow(), clauses); + *follow = setOrNode; + return true; + } + } + } + if (Globals::verbosity > 1) { + Clause::deleteClauses (backupClauses_); + } + return false; +} + + + +void +LiftedCircuit::shatterCountedLogVars (Clauses& clauses) +{ + while (shatterCountedLogVarsAux (clauses)) ; +} + + + +bool +LiftedCircuit::shatterCountedLogVarsAux (Clauses& clauses) +{ + for (size_t i = 0; i < clauses.size() - 1; i++) { + for (size_t j = i + 1; j < clauses.size(); j++) { + bool splitedSome = shatterCountedLogVarsAux (clauses, i, j); + if (splitedSome) { + return true; + } + } + } + return false; +} + + + +bool +LiftedCircuit::shatterCountedLogVarsAux ( + Clauses& clauses, + size_t idx1, + size_t idx2) +{ + Literals lits1 = clauses[idx1]->literals(); + Literals lits2 = clauses[idx2]->literals(); + for (size_t i = 0; i < lits1.size(); i++) { + for (size_t j = 0; j < lits2.size(); j++) { + if (lits1[i].lid() == lits2[j].lid()) { + LogVars lvs1 = lits1[i].logVars(); + LogVars lvs2 = lits2[j].logVars(); + for (size_t k = 0; k < lvs1.size(); k++) { + if (clauses[idx1]->isCountedLogVar (lvs1[k]) + && clauses[idx2]->isCountedLogVar (lvs2[k]) == false) { + clauses.push_back (new Clause (*clauses[idx2])); + clauses[idx2]->addPosCountedLogVar (lvs2[k]); + clauses.back()->addNegCountedLogVar (lvs2[k]); + return true; + } + if (clauses[idx2]->isCountedLogVar (lvs2[k]) + && clauses[idx1]->isCountedLogVar (lvs1[k]) == false) { + clauses.push_back (new Clause (*clauses[idx1])); + clauses[idx1]->addPosCountedLogVar (lvs1[k]); + clauses.back()->addNegCountedLogVar (lvs1[k]); + return true; + } + } + } + } + } + return false; +} + + + +bool +LiftedCircuit::independentClause ( + Clause& clause, + Clauses& otherClauses) const +{ + for (size_t i = 0; i < otherClauses.size(); i++) { + if (Clause::independentClauses (clause, *otherClauses[i]) == false) { + return false; + } + } + return true; +} + + + +bool +LiftedCircuit::independentLiteral ( + const Literal& lit, + const Literals& otherLits) const +{ + for (size_t i = 0; i < otherLits.size(); i++) { + if (lit.lid() == otherLits[i].lid() + || (lit.logVarSet() & otherLits[i].logVarSet()).empty() == false) { + return false; + } + } + return true; +} + + + +LitLvTypesSet +LiftedCircuit::smoothCircuit (CircuitNode* node) +{ + assert (node != 0); + LitLvTypesSet propagLits; + + switch (getCircuitNodeType (node)) { + + case CircuitNodeType::OR_NODE: { + OrNode* casted = dynamic_cast(node); + LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch()); + LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch()); + LitLvTypesSet missingLeft = lids2 - lids1; + LitLvTypesSet missingRight = lids1 - lids2; + createSmoothNode (missingLeft, casted->leftBranch()); + createSmoothNode (missingRight, casted->rightBranch()); + propagLits |= lids1; + propagLits |= lids2; + break; + } + + case CircuitNodeType::AND_NODE: { + AndNode* casted = dynamic_cast(node); + LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch()); + LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch()); + propagLits |= lids1; + propagLits |= lids2; + break; + } + + case CircuitNodeType::SET_OR_NODE: { + SetOrNode* casted = dynamic_cast(node); + propagLits = smoothCircuit (*casted->follow()); + TinySet> litSet; + for (size_t i = 0; i < propagLits.size(); i++) { + litSet.insert (make_pair (propagLits[i].lid(), + propagLits[i].logVarTypes().size())); + } + LitLvTypesSet missingLids; + for (size_t i = 0; i < litSet.size(); i++) { + vector allTypes = getAllPossibleTypes (litSet[i].second); + for (size_t j = 0; j < allTypes.size(); j++) { + bool typeFound = false; + for (size_t k = 0; k < propagLits.size(); k++) { + if (litSet[i].first == propagLits[k].lid() + && containsTypes (propagLits[k].logVarTypes(), allTypes[j])) { + typeFound = true; + break; + } + } + if (typeFound == false) { + missingLids.insert (LitLvTypes (litSet[i].first, allTypes[j])); + } + } + } + createSmoothNode (missingLids, casted->follow()); + // setAllFullLogVars() can cause repeated elements in + // the set. Fix this by reconstructing the set again + LitLvTypesSet copy = propagLits; + propagLits.clear(); + for (size_t i = 0; i < copy.size(); i++) { + copy[i].setAllFullLogVars(); + propagLits.insert (copy[i]); + } + break; + } + + case CircuitNodeType::SET_AND_NODE: { + SetAndNode* casted = dynamic_cast(node); + propagLits = smoothCircuit (*casted->follow()); + break; + } + + case CircuitNodeType::INC_EXC_NODE: { + IncExcNode* casted = dynamic_cast(node); + LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch()); + LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch()); + LitLvTypesSet missingPlus1 = lids2 - lids1; + LitLvTypesSet missingPlus2 = lids1 - lids2; + createSmoothNode (missingPlus1, casted->plus1Branch()); + createSmoothNode (missingPlus2, casted->plus2Branch()); + propagLits |= lids1; + propagLits |= lids2; + break; + } + + case CircuitNodeType::LEAF_NODE: { + LeafNode* casted = dynamic_cast(node); + propagLits.insert (LitLvTypes ( + casted->clause()->literals()[0].lid(), + casted->clause()->logVarTypes(0))); + } + + default: + break; + } + + return propagLits; +} + + + +void +LiftedCircuit::createSmoothNode ( + const LitLvTypesSet& missingLits, + CircuitNode** prev) +{ + if (missingLits.empty() == false) { + if (Globals::verbosity > 1) { + unordered_map::iterator it; + it = originClausesMap_.find (*prev); + if (it != originClausesMap_.end()) { + backupClauses_ = it->second; + } else { + backupClauses_ = Clause::copyClauses ( + {((dynamic_cast(*prev))->clause())}); + } + } + Clauses clauses; + for (size_t i = 0; i < missingLits.size(); i++) { + LiteralId lid = missingLits[i].lid(); + const LogVarTypes& types = missingLits[i].logVarTypes(); + Clause* c = lwcnf_->createClause (lid); + for (size_t j = 0; j < types.size(); j++) { + LogVar X = c->literals().front().logVars()[j]; + if (types[j] == LogVarType::POS_LV) { + c->addPosCountedLogVar (X); + } else if (types[j] == LogVarType::NEG_LV) { + c->addNegCountedLogVar (X); + } + } + c->addLiteralComplemented (c->literals()[0]); + clauses.push_back (c); + } + SmoothNode* smoothNode = new SmoothNode (clauses, *lwcnf_); + *prev = new AndNode (smoothNode, *prev); + if (Globals::verbosity > 1) { + originClausesMap_[*prev] = backupClauses_; + explanationMap_[*prev] = " Smoothing" ; + } + } +} + + + +vector +LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const +{ + if (nrLogVars == 0) { + return {}; + } + if (nrLogVars == 1) { + return {{LogVarType::POS_LV},{LogVarType::NEG_LV}}; + } + vector res; + Ranges ranges (nrLogVars, 2); + Indexer indexer (ranges); + while (indexer.valid()) { + LogVarTypes types; + for (size_t i = 0; i < nrLogVars; i++) { + if (indexer[i] == 0) { + types.push_back (LogVarType::POS_LV); + } else { + types.push_back (LogVarType::NEG_LV); + } + } + res.push_back (types); + ++ indexer; + } + return res; +} + + + +bool +LiftedCircuit::containsTypes ( + const LogVarTypes& typesA, + const LogVarTypes& typesB) const +{ + for (size_t i = 0; i < typesA.size(); i++) { + if (typesA[i] == LogVarType::FULL_LV) { + + } else if (typesA[i] == LogVarType::POS_LV + && typesB[i] == LogVarType::POS_LV) { + + } else if (typesA[i] == LogVarType::NEG_LV + && typesB[i] == LogVarType::NEG_LV) { + + } else { + return false; + } + } + return true; +} + + + +CircuitNodeType +LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const +{ + CircuitNodeType type; + if (dynamic_cast(node) != 0) { + type = CircuitNodeType::OR_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::AND_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::SET_OR_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::SET_AND_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::INC_EXC_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::LEAF_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::SMOOTH_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::TRUE_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::COMPILATION_FAILED_NODE; + } else { + assert (false); + } + return type; +} + + + +void +LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) +{ + assert (node != 0); + + static unsigned nrAuxNodes = 0; + stringstream ss; + ss << "n" << nrAuxNodes; + string auxNode = ss.str(); + nrAuxNodes ++; + string opStyle = "shape=circle,width=0.7,margin=\"0.0,0.0\"," ; + + switch (getCircuitNodeType (node)) { + + case OR_NODE: { + OrNode* casted = dynamic_cast(node); + printClauses (casted, os); + + os << auxNode << " [" << opStyle << "label=\"∨\"]" << endl; + os << escapeNode (node) << " -> " << auxNode; + os << " [label=\"" << getExplanationString (node) << "\"]" ; + os << endl; + + os << auxNode << " -> " ; + os << escapeNode (*casted->leftBranch()); + os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ; + os << endl; + + os << auxNode << " -> " ; + os << escapeNode (*casted->rightBranch()); + os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ; + os << endl; + + exportToGraphViz (*casted->leftBranch(), os); + exportToGraphViz (*casted->rightBranch(), os); + break; + } + + case AND_NODE: { + AndNode* casted = dynamic_cast(node); + printClauses (casted, os); + + os << auxNode << " [" << opStyle << "label=\"∧\"]" << endl; + os << escapeNode (node) << " -> " << auxNode; + os << " [label=\"" << getExplanationString (node) << "\"]" ; + os << endl; + + os << auxNode << " -> " ; + os << escapeNode (*casted->leftBranch()); + os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ; + os << endl; + + os << auxNode << " -> " ; + os << escapeNode (*casted->rightBranch()) << endl; + os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ; + os << endl; + + exportToGraphViz (*casted->leftBranch(), os); + exportToGraphViz (*casted->rightBranch(), os); + break; + } + + case SET_OR_NODE: { + SetOrNode* casted = dynamic_cast(node); + printClauses (casted, os); + + os << auxNode << " [" << opStyle << "label=\"∨(X)\"]" << endl; + os << escapeNode (node) << " -> " << auxNode; + os << " [label=\"" << getExplanationString (node) << "\"]" ; + os << endl; + + os << auxNode << " -> " ; + os << escapeNode (*casted->follow()); + os << " [label=\" " << (*casted->follow())->weight() << "\"]" ; + os << endl; + + exportToGraphViz (*casted->follow(), os); + break; + } + + case SET_AND_NODE: { + SetAndNode* casted = dynamic_cast(node); + printClauses (casted, os); + + os << auxNode << " [" << opStyle << "label=\"∧(X)\"]" << endl; + os << escapeNode (node) << " -> " << auxNode; + os << " [label=\"" << getExplanationString (node) << "\"]" ; + os << endl; + + os << auxNode << " -> " ; + os << escapeNode (*casted->follow()); + os << " [label=\" " << (*casted->follow())->weight() << "\"]" ; + os << endl; + + exportToGraphViz (*casted->follow(), os); + break; + } + + case INC_EXC_NODE: { + IncExcNode* casted = dynamic_cast(node); + printClauses (casted, os); + + os << auxNode << " [" << opStyle << "label=\"+ - +\"]" ; + os << endl; + os << escapeNode (node) << " -> " << auxNode; + os << " [label=\"" << getExplanationString (node) << "\"]" ; + os << endl; + + os << auxNode << " -> " ; + os << escapeNode (*casted->plus1Branch()); + os << " [label=\" " << (*casted->plus1Branch())->weight() << "\"]" ; + os << endl; + + os << auxNode << " -> " ; + os << escapeNode (*casted->minusBranch()) << endl; + os << " [label=\" " << (*casted->minusBranch())->weight() << "\"]" ; + os << endl; + + os << auxNode << " -> " ; + os << escapeNode (*casted->plus2Branch()); + os << " [label=\" " << (*casted->plus2Branch())->weight() << "\"]" ; + os << endl; + + exportToGraphViz (*casted->plus1Branch(), os); + exportToGraphViz (*casted->plus2Branch(), os); + exportToGraphViz (*casted->minusBranch(), os); + break; + } + + case LEAF_NODE: { + printClauses (node, os, "style=filled,fillcolor=palegreen,"); + break; + } + + case SMOOTH_NODE: { + printClauses (node, os, "style=filled,fillcolor=lightblue,"); + break; + } + + case TRUE_NODE: { + os << escapeNode (node); + os << " [shape=box,label=\"⊤\"]" ; + os << endl; + break; + } + + case COMPILATION_FAILED_NODE: { + printClauses (node, os, "style=filled,fillcolor=salmon,"); + break; + } + + default: + assert (false); + } +} + + + +string +LiftedCircuit::escapeNode (const CircuitNode* node) const +{ + stringstream ss; + ss << "\"" << node << "\"" ; + return ss.str(); +} + + + +string +LiftedCircuit::getExplanationString (CircuitNode* node) +{ + return Util::contains (explanationMap_, node) + ? explanationMap_[node] + : "" ; +} + + + +void +LiftedCircuit::printClauses ( + CircuitNode* node, + ofstream& os, + string extraOptions) +{ + Clauses clauses; + if (Util::contains (originClausesMap_, node)) { + clauses = originClausesMap_[node]; + } else if (getCircuitNodeType (node) == CircuitNodeType::LEAF_NODE) { + clauses = { (dynamic_cast(node))->clause() } ; + } else if (getCircuitNodeType (node) == CircuitNodeType::SMOOTH_NODE) { + clauses = (dynamic_cast(node))->clauses(); + } + assert (clauses.empty() == false); + os << escapeNode (node); + os << " [shape=box," << extraOptions << "label=\"" ; + for (size_t i = 0; i < clauses.size(); i++) { + if (i != 0) os << "\\n" ; + os << *clauses[i]; + } + os << "\"]" ; + os << endl; +} + + + LiftedKc::~LiftedKc (void) { delete lwcnf_; diff --git a/packages/CLPBN/horus/LiftedKc.h b/packages/CLPBN/horus/LiftedKc.h index cba6499e1..a4cd2dbeb 100644 --- a/packages/CLPBN/horus/LiftedKc.h +++ b/packages/CLPBN/horus/LiftedKc.h @@ -1,11 +1,281 @@ #ifndef HORUS_LIFTEDKC_H #define HORUS_LIFTEDKC_H + +#include "LiftedWCNF.h" #include "LiftedSolver.h" #include "ParfactorList.h" -class LiftedWCNF; -class LiftedCircuit; + +enum CircuitNodeType { + OR_NODE, + AND_NODE, + SET_OR_NODE, + SET_AND_NODE, + INC_EXC_NODE, + LEAF_NODE, + SMOOTH_NODE, + TRUE_NODE, + COMPILATION_FAILED_NODE +}; + + + +class CircuitNode +{ + public: + CircuitNode (void) { } + + virtual ~CircuitNode (void) { } + + virtual double weight (void) const = 0; +}; + + + +class OrNode : public CircuitNode +{ + public: + OrNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { } + + ~OrNode (void); + + CircuitNode** leftBranch (void) { return &leftBranch_; } + CircuitNode** rightBranch (void) { return &rightBranch_; } + + double weight (void) const; + + private: + CircuitNode* leftBranch_; + CircuitNode* rightBranch_; +}; + + + +class AndNode : public CircuitNode +{ + public: + AndNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { } + + AndNode (CircuitNode* leftBranch, CircuitNode* rightBranch) + : CircuitNode(), leftBranch_(leftBranch), rightBranch_(rightBranch) { } + + ~AndNode (void); + + CircuitNode** leftBranch (void) { return &leftBranch_; } + CircuitNode** rightBranch (void) { return &rightBranch_; } + + double weight (void) const; + + private: + CircuitNode* leftBranch_; + CircuitNode* rightBranch_; +}; + + + +class SetOrNode : public CircuitNode +{ + public: + SetOrNode (unsigned nrGroundings) + : CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { } + + ~SetOrNode (void); + + CircuitNode** follow (void) { return &follow_; } + + static unsigned nrPositives (void) { return nrPos_; } + + static unsigned nrNegatives (void) { return nrNeg_; } + + static bool isSet (void) { return nrPos_ >= 0; } + + double weight (void) const; + + private: + CircuitNode* follow_; + unsigned nrGroundings_; + static int nrPos_; + static int nrNeg_; +}; + + + +class SetAndNode : public CircuitNode +{ + public: + SetAndNode (unsigned nrGroundings) + : CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { } + + ~SetAndNode (void); + + CircuitNode** follow (void) { return &follow_; } + + double weight (void) const; + + private: + CircuitNode* follow_; + unsigned nrGroundings_; +}; + + + +class IncExcNode : public CircuitNode +{ + public: + IncExcNode (void) + : CircuitNode(), plus1Branch_(0), plus2Branch_(0), minusBranch_(0) { } + + ~IncExcNode (void); + + CircuitNode** plus1Branch (void) { return &plus1Branch_; } + CircuitNode** plus2Branch (void) { return &plus2Branch_; } + CircuitNode** minusBranch (void) { return &minusBranch_; } + + double weight (void) const; + + private: + CircuitNode* plus1Branch_; + CircuitNode* plus2Branch_; + CircuitNode* minusBranch_; +}; + + + +class LeafNode : public CircuitNode +{ + public: + LeafNode (Clause* clause, const LiftedWCNF& lwcnf) + : CircuitNode(), clause_(clause), lwcnf_(lwcnf) { } + + ~LeafNode (void); + + const Clause* clause (void) const { return clause_; } + + Clause* clause (void) { return clause_; } + + double weight (void) const; + + private: + Clause* clause_; + const LiftedWCNF& lwcnf_; +}; + + + +class SmoothNode : public CircuitNode +{ + public: + SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf) + : CircuitNode(), clauses_(clauses), lwcnf_(lwcnf) { } + + ~SmoothNode (void); + + const Clauses& clauses (void) const { return clauses_; } + + Clauses clauses (void) { return clauses_; } + + double weight (void) const; + + private: + Clauses clauses_; + const LiftedWCNF& lwcnf_; +}; + + + +class TrueNode : public CircuitNode +{ + public: + TrueNode (void) : CircuitNode() { } + + double weight (void) const; +}; + + + +class CompilationFailedNode : public CircuitNode +{ + public: + CompilationFailedNode (void) : CircuitNode() { } + + double weight (void) const; +}; + + + +class LiftedCircuit +{ + public: + LiftedCircuit (const LiftedWCNF* lwcnf); + + ~LiftedCircuit (void); + + bool isCompilationSucceeded (void) const; + + double getWeightedModelCount (void) const; + + void exportToGraphViz (const char*); + + private: + + void compile (CircuitNode** follow, Clauses& clauses); + + bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses); + + bool tryIndependence (CircuitNode** follow, Clauses& clauses); + + bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses); + + bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses); + + bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses); + + bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct, + LogVars& rootLogVars); + + bool tryAtomCounting (CircuitNode** follow, Clauses& clauses); + + void shatterCountedLogVars (Clauses& clauses); + + bool shatterCountedLogVarsAux (Clauses& clauses); + + bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2); + + bool independentClause (Clause& clause, Clauses& otherClauses) const; + + bool independentLiteral (const Literal& lit, + const Literals& otherLits) const; + + LitLvTypesSet smoothCircuit (CircuitNode* node); + + void createSmoothNode (const LitLvTypesSet& lids, + CircuitNode** prev); + + vector getAllPossibleTypes (unsigned nrLogVars) const; + + bool containsTypes (const LogVarTypes& typesA, + const LogVarTypes& typesB) const; + + CircuitNodeType getCircuitNodeType (const CircuitNode* node) const; + + void exportToGraphViz (CircuitNode* node, ofstream&); + + void printClauses (CircuitNode* node, ofstream&, + string extraOptions = ""); + + string escapeNode (const CircuitNode* node) const; + + string getExplanationString (CircuitNode* node); + + CircuitNode* root_; + const LiftedWCNF* lwcnf_; + bool compilationSucceeded_; + Clauses backupClauses_; + unordered_map originClausesMap_; + unordered_map explanationMap_; +}; + class LiftedKc : public LiftedSolver diff --git a/packages/CLPBN/horus/Makefile.in b/packages/CLPBN/horus/Makefile.in index d19803ee7..24e7d0b87 100644 --- a/packages/CLPBN/horus/Makefile.in +++ b/packages/CLPBN/horus/Makefile.in @@ -23,10 +23,10 @@ CC=@CC@ CXX=@CXX@ # normal -#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG +CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG # debug -CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra +#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra # @@ -57,7 +57,6 @@ HEADERS = \ $(srcdir)/Horus.h \ $(srcdir)/Indexer.h \ $(srcdir)/LiftedBp.h \ - $(srcdir)/LiftedCircuit.h \ $(srcdir)/LiftedKc.h \ $(srcdir)/LiftedOperations.h \ $(srcdir)/LiftedSolver.h \ @@ -87,7 +86,6 @@ CPP_SOURCES = \ $(srcdir)/HorusCli.cpp \ $(srcdir)/HorusYap.cpp \ $(srcdir)/LiftedBp.cpp \ - $(srcdir)/LiftedCircuit.cpp \ $(srcdir)/LiftedKc.cpp \ $(srcdir)/LiftedOperations.cpp \ $(srcdir)/LiftedUtils.cpp \ @@ -114,7 +112,6 @@ OBJS = \ Histogram.o \ HorusYap.o \ LiftedBp.o \ - LiftedCircuit.o \ LiftedKc.o \ LiftedOperations.o \ LiftedUtils.o \