diff --git a/packages/CLPBN/horus/ConstraintTree.cpp b/packages/CLPBN/horus/ConstraintTree.cpp index 59705e1e1..bfabc982c 100644 --- a/packages/CLPBN/horus/ConstraintTree.cpp +++ b/packages/CLPBN/horus/ConstraintTree.cpp @@ -112,6 +112,8 @@ CTNode::copySubtree (const CTNode* root1) const CTNode* n1 = stack.back().first; CTNode* n2 = stack.back().second; stack.pop_back(); + // cout << "n2 childs: " << n2->childs(); + // cout << "n1 childs: " << n1->childs(); n2->childs().reserve (n1->nrChilds()); stack.reserve (n1->nrChilds()); for (CTChilds::const_iterator chIt = n1->childs().begin(); @@ -185,6 +187,28 @@ ConstraintTree::ConstraintTree ( +ConstraintTree::ConstraintTree (vector> names) +{ + assert (names.empty() == false); + assert (names.front().empty() == false); + unsigned nrLvs = names[0].size(); + for (size_t i = 0; i < nrLvs; i++) { + logVars_.push_back (LogVar (i)); + } + root_ = new CTNode (0, 0); + logVarSet_ = LogVarSet (logVars_); + for (size_t i = 0; i < names.size(); i++) { + Tuple t; + for (size_t j = 0; j < names[i].size(); j++) { + assert (names[i].size() == nrLvs); + t.push_back (LiftedUtils::getSymbol (names[i][j])); + } + addTuple (t); + } +} + + + ConstraintTree::ConstraintTree (const ConstraintTree& ct) { *this = ct; diff --git a/packages/CLPBN/horus/ConstraintTree.h b/packages/CLPBN/horus/ConstraintTree.h index 071a96a5e..0b48c3650 100644 --- a/packages/CLPBN/horus/ConstraintTree.h +++ b/packages/CLPBN/horus/ConstraintTree.h @@ -108,6 +108,8 @@ class ConstraintTree ConstraintTree (const LogVars&); ConstraintTree (const LogVars&, const Tuples&); + + ConstraintTree (vector> names); ConstraintTree (const ConstraintTree&); diff --git a/packages/CLPBN/horus/LiftedCircuit.cpp b/packages/CLPBN/horus/LiftedCircuit.cpp index e31f13c79..baa058170 100644 --- a/packages/CLPBN/horus/LiftedCircuit.cpp +++ b/packages/CLPBN/horus/LiftedCircuit.cpp @@ -115,7 +115,7 @@ CompilationFailedNode::weight (void) const { // we should not perform model counting // in compilation failed nodes - abort(); + // abort(); return 0.0; } @@ -125,19 +125,11 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf) : lwcnf_(lwcnf) { root_ = 0; - Clauses ccc = lwcnf->clauses(); - //ccc.erase (ccc.begin() + 5, ccc.end()); - //Clause c2 = ccc.front(); - //c2.removeLiteralByIndex (1); - //ccc.push_back (c2); - - //compile (&root_, lwcnf->clauses()); - Clauses cccc = {ccc[6],ccc[4]}; - cccc.front().removeLiteral (2); - compile (&root_, cccc); + Clauses clauses = lwcnf->clauses(); + compile (&root_, clauses); exportToGraphViz("circuit.dot"); smoothCircuit(); - exportToGraphViz("smooth.dot"); + exportToGraphViz("circuit.smooth.dot"); cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl; } @@ -188,18 +180,7 @@ LiftedCircuit::compile ( } if (clauses.size() == 1 && clauses[0].isUnit()) { - static int count = 0; count ++; *follow = new LeafNode (clauses[0]); - if (count == 1) { - // Clause c (new ConstraintTree({})); - // c.addLiteral (Literal (100,{})); - // *follow = new LeafNode (c); - } - if (count == 2) { - // Clause c (new ConstraintTree({})); - // c.addLiteral (Literal (101,{})); - // *follow = new LeafNode (c); - } return; } @@ -219,7 +200,11 @@ LiftedCircuit::compile ( return; } - if (tryIndepPartialGrounding (follow, clauses)) { + //if (tryIndepPartialGrounding (follow, clauses)) { + // return; + //} + + if (tryAtomCounting (follow, clauses)) { return; } @@ -229,7 +214,81 @@ LiftedCircuit::compile ( // assert (false); *follow = new CompilationFailedNode (clauses); +} + +void +LiftedCircuit::propagate ( + const Clause& c, + const Clause& unitClause, + Clauses& newClauses) +{ +/* + Literals literals = c.literals(); + for (size_t i = 0; i < literals.size(); i++) { + if (literals_[i].lid() == lid && literals[i].isPositive()) { + + return true; + } + } +*/ +} + + +bool +shatterCountedLogVars (Clauses& clauses, size_t idx1, size_t idx2) +{ + Literals lits1 = clauses[idx1].literals(); + Literals lits2 = clauses[idx2].literals(); + for (size_t i = 0; i < lits1.size(); i++) { + for (size_t j = 0; j < lits2.size(); j++) { + if (lits1[i].lid() == lits2[j].lid()) { + LogVars lvs1 = lits1[i].logVars(); + LogVars lvs2 = lits2[j].logVars(); + for (size_t k = 0; k < lvs1.size(); k++) { + if (clauses[idx1].isCountedLogVar (lvs1[k]) + && clauses[idx2].isCountedLogVar (lvs2[k]) == false) { + clauses.push_back (clauses[idx2]); + clauses[idx2].addPositiveCountedLogVar (lvs2[k]); + clauses.back().addNegativeCountedLogVar (lvs2[k]); + return true; + } + if (clauses[idx2].isCountedLogVar (lvs2[k]) + && clauses[idx1].isCountedLogVar (lvs1[k]) == false) { + clauses.push_back (clauses[idx1]); + clauses[idx1].addPositiveCountedLogVar (lvs1[k]); + clauses.back().addNegativeCountedLogVar (lvs1[k]); + return true; + } + } + } + } + } + return false; +} + + + +bool +shatterCountedLogVarsAux (Clauses& clauses) +{ + for (size_t i = 0; i < clauses.size() - 1; i++) { + for (size_t j = i + 1; j < clauses.size(); j++) { + bool splitedSome = shatterCountedLogVars (clauses, i, j); + if (splitedSome) { + return true; + } + } + } + return false; +} + + + +void +shatterCountedLogVars (Clauses& clauses) +{ + while (shatterCountedLogVarsAux (clauses)) ; } @@ -239,22 +298,29 @@ LiftedCircuit::tryUnitPropagation ( CircuitNode** follow, Clauses& clauses) { + cout << "ALL CLAUSES:" << endl; + Clause::printClauses (clauses); + for (size_t i = 0; i < clauses.size(); i++) { if (clauses[i].isUnit()) { + cout << clauses[i] << " is unit!" << endl; Clauses newClauses; for (size_t j = 0; j < clauses.size(); j++) { if (i != j) { LiteralId lid = clauses[i].literals()[0].lid(); + LogVarTypes types = clauses[i].logVarTypes (0); if (clauses[i].literals()[0].isPositive()) { - if (clauses[j].containsPositiveLiteral (lid) == false) { + if (clauses[j].containsPositiveLiteral (lid, types) == false) { Clause newClause = clauses[j]; - newClause.removeNegativeLiterals (lid); + cout << "removing negative literals on " << newClause << endl; + newClause.removeNegativeLiterals (lid, types); newClauses.push_back (newClause); } } else if (clauses[i].literals()[0].isNegative()) { - if (clauses[j].containsNegativeLiteral (lid) == false) { + if (clauses[j].containsNegativeLiteral (lid, types) == false) { Clause newClause = clauses[j]; - newClause.removePositiveLiterals (lid); + cout << "removing negative literals on " << newClause << endl; + newClause.removePositiveLiterals (lid, types); newClauses.push_back (newClause); } } @@ -264,6 +330,7 @@ LiftedCircuit::tryUnitPropagation ( explanation << " UP on" << clauses[i].literals()[0]; AndNode* andNode = new AndNode (clauses, explanation.str()); Clauses leftClauses = {clauses[i]}; + cout << "new clauses: " << newClauses << endl; compile (andNode->leftBranch(), leftClauses); compile (andNode->rightBranch(), newClauses); (*follow) = andNode; @@ -454,6 +521,37 @@ LiftedCircuit::tryIndepPartialGroundingAux ( +bool +LiftedCircuit::tryAtomCounting ( + CircuitNode** follow, + Clauses& clauses) +{ + for (size_t i = 0; i < clauses.size(); i++) { + Literals literals = clauses[i].literals(); + for (size_t j = 0; j < literals.size(); j++) { + if (literals[j].logVars().size() == 1) { + // TODO check if not already in ipg and countedlvs + SetOrNode* setOrNode = new SetOrNode (clauses); + Clause c1 (clauses[i].constr().projectedCopy (literals[j].logVars())); + Clause c2 (clauses[i].constr().projectedCopy (literals[j].logVars())); + c1.addLiteral (literals[j]); + c2.addAndNegateLiteral (literals[j]); + c1.addPositiveCountedLogVar (literals[j].logVars().front()); + c2.addNegativeCountedLogVar (literals[j].logVars().front()); + clauses.push_back (c1); + clauses.push_back (c2); + shatterCountedLogVars (clauses); + compile (setOrNode->follow(), clauses); + *follow = setOrNode; + return true; + } + } + } + return false; +} + + + bool LiftedCircuit::tryGrounding ( CircuitNode**, @@ -638,11 +736,11 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) exportToGraphViz (*casted->rightBranch(), os); break; } - + case AND_NODE: { AndNode* casted = dynamic_cast(node); printClauses (casted, os); - + os << auxNode << " [label=\"∧\"]" << endl; os << escapeNode (node) << " -> " << auxNode; os << " [label=\"" << node->explanation() << "\"]" ; @@ -662,34 +760,47 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) exportToGraphViz (*casted->rightBranch(), os); break; } - + case SET_OR_NODE: { - // TODO - assert (false); - } - - case SET_AND_NODE: { - SetAndNode* casted = dynamic_cast(node); + SetOrNode* casted = dynamic_cast(node); printClauses (casted, os); - - os << auxNode << " [label=\"∧(X)\"]" << endl; + + os << auxNode << " [label=\"∨(X)\"]" << endl; os << escapeNode (node) << " -> " << auxNode; os << " [label=\"" << node->explanation() << "\"]" ; os << endl; - + os << auxNode << " -> " ; os << escapeNode (*casted->follow()); os << " [label=\" " << (*casted->follow())->weight() << "\"]" ; os << endl; - + exportToGraphViz (*casted->follow(), os); break; } - + + case SET_AND_NODE: { + SetAndNode* casted = dynamic_cast(node); + printClauses (casted, os); + + os << auxNode << " [label=\"∧(X)\"]" << endl; + os << escapeNode (node) << " -> " << auxNode; + os << " [label=\"" << node->explanation() << "\"]" ; + os << endl; + + os << auxNode << " -> " ; + os << escapeNode (*casted->follow()); + os << " [label=\" " << (*casted->follow())->weight() << "\"]" ; + os << endl; + + exportToGraphViz (*casted->follow(), os); + break; + } + case INC_EXC_NODE: { IncExcNode* casted = dynamic_cast(node); printClauses (casted, os); - + os << auxNode << " [label=\"IncExc\"]" << endl; os << escapeNode (node) << " -> " << auxNode; os << " [label=\"" << node->explanation() << "\"]" ; @@ -699,7 +810,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) os << escapeNode (*casted->plus1Branch()); os << " [label=\" " << (*casted->plus1Branch())->weight() << "\"]" ; os << endl; - + os << auxNode << " -> " ; os << escapeNode (*casted->plus2Branch()); os << " [label=\" " << (*casted->plus2Branch())->weight() << "\"]" ; @@ -709,35 +820,35 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) os << escapeNode (*casted->minusBranch()) << endl; os << " [label=\" " << (*casted->minusBranch())->weight() << "\"]" ; os << endl; - + exportToGraphViz (*casted->plus1Branch(), os); exportToGraphViz (*casted->plus2Branch(), os); exportToGraphViz (*casted->minusBranch(), os); break; } - + case LEAF_NODE: { printClauses (node, os, "style=filled,fillcolor=palegreen,"); break; } - + case SMOOTH_NODE: { printClauses (node, os, "style=filled,fillcolor=lightblue,"); break; } - + case TRUE_NODE: { os << escapeNode (node); os << " [shape=box,label=\"⊤\"]" ; os << endl; break; } - + case COMPILATION_FAILED_NODE: { printClauses (node, os, "style=filled,fillcolor=salmon,"); break; } - + default: assert (false); } diff --git a/packages/CLPBN/horus/LiftedCircuit.h b/packages/CLPBN/horus/LiftedCircuit.h index f7db3d334..a89821c84 100644 --- a/packages/CLPBN/horus/LiftedCircuit.h +++ b/packages/CLPBN/horus/LiftedCircuit.h @@ -186,7 +186,7 @@ class CompilationFailedNode : public CircuitNode class LiftedCircuit { public: - LiftedCircuit (const LiftedWCNF* lwcnf); + LiftedCircuit (const LiftedWCNF* lwcnf); void smoothCircuit (void); @@ -205,8 +205,11 @@ class LiftedCircuit bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses); bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct, vector& indices); + bool tryAtomCounting (CircuitNode** follow, Clauses& clauses); bool tryGrounding (CircuitNode** follow, Clauses& clauses); - + + void propagate (const Clause& c, const Clause& uc, Clauses& newClauses); + TinySet smoothCircuit (CircuitNode* node); void createSmoothNode (const TinySet& lids, diff --git a/packages/CLPBN/horus/LiftedWCNF.cpp b/packages/CLPBN/horus/LiftedWCNF.cpp index d69f4bd01..29affb2ac 100644 --- a/packages/CLPBN/horus/LiftedWCNF.cpp +++ b/packages/CLPBN/horus/LiftedWCNF.cpp @@ -17,7 +17,10 @@ Literal::isGround (ConstraintTree constr, LogVarSet ipgLogVars) const string -Literal::toString (LogVarSet ipgLogVars) const +Literal::toString ( + LogVarSet ipgLogVars, + LogVarSet posCountedLvs, + LogVarSet negCountedLvs) const { stringstream ss; negated_ ? ss << "¬" : ss << "" ; @@ -27,10 +30,14 @@ Literal::toString (LogVarSet ipgLogVars) const ss << "(" ; for (size_t i = 0; i < logVars_.size(); i++) { if (i != 0) ss << ","; - if (ipgLogVars.contains (logVars_[i])) { + if (posCountedLvs.contains (logVars_[i])) { + ss << "+" << logVars_[i]; + } else if (negCountedLvs.contains (logVars_[i])) { + ss << "-" << logVars_[i]; + } else if (ipgLogVars.contains (logVars_[i])) { LogVar X = logVars_[i]; const string labels[] = { - "a", "b", "c", "d", "e", "f", + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "m" }; (X >= 12) ? ss << "x_" << X : ss << labels[X]; } else { @@ -66,10 +73,14 @@ Clause::containsLiteral (LiteralId lid) const bool -Clause::containsPositiveLiteral (LiteralId lid) const +Clause::containsPositiveLiteral ( + LiteralId lid, + const LogVarTypes& types) const { for (size_t i = 0; i < literals_.size(); i++) { - if (literals_[i].lid() == lid && literals_[i].isPositive()) { + if (literals_[i].lid() == lid + && literals_[i].isPositive() + && logVarTypes (i) == types) { return true; } } @@ -79,10 +90,14 @@ Clause::containsPositiveLiteral (LiteralId lid) const bool -Clause::containsNegativeLiteral (LiteralId lid) const +Clause::containsNegativeLiteral ( + LiteralId lid, + const LogVarTypes& types) const { for (size_t i = 0; i < literals_.size(); i++) { - if (literals_[i].lid() == lid && literals_[i].isNegative()) { + if (literals_[i].lid() == lid + && literals_[i].isNegative() + && logVarTypes (i) == types) { return true; } } @@ -107,11 +122,15 @@ Clause::removeLiterals (LiteralId lid) void -Clause::removePositiveLiterals (LiteralId lid) +Clause::removePositiveLiterals ( + LiteralId lid, + const LogVarTypes& types) { size_t i = 0; while (i != literals_.size()) { - if (literals_[i].lid() == lid && literals_[i].isPositive()) { + if (literals_[i].lid() == lid + && literals_[i].isPositive() + && logVarTypes (i) == types) { removeLiteral (i); } else { i ++; @@ -122,11 +141,15 @@ Clause::removePositiveLiterals (LiteralId lid) void -Clause::removeNegativeLiterals (LiteralId lid) +Clause::removeNegativeLiterals ( + LiteralId lid, + const LogVarTypes& types) { size_t i = 0; while (i != literals_.size()) { - if (literals_[i].lid() == lid && literals_[i].isNegative()) { + if (literals_[i].lid() == lid + && literals_[i].isNegative() + && logVarTypes (i) == types) { removeLiteral (i); } else { i ++; @@ -171,6 +194,34 @@ Clause::lidSet (void) const +bool +Clause::isCountedLogVar (LogVar X) const +{ + assert (constr_.logVarSet().contains (X)); + return posCountedLvs_.contains (X) + || negCountedLvs_.contains (X); +} + + + +bool +Clause::isPositiveCountedLogVar (LogVar X) const +{ + assert (constr_.logVarSet().contains (X)); + return posCountedLvs_.contains (X); +} + + + +bool +Clause::isNegativeCountedLogVar (LogVar X) const +{ + assert (constr_.logVarSet().contains (X)); + return negCountedLvs_.contains (X); +} + + + void Clause::removeLiteral (size_t idx) { @@ -182,6 +233,35 @@ Clause::removeLiteral (size_t idx) +LogVarTypes +Clause::logVarTypes (size_t litIdx) const +{ + LogVarTypes types; + const LogVars lvs = literals_[litIdx].logVars(); + for (size_t i = 0; i < lvs.size(); i++) { + if (posCountedLvs_.contains (lvs[i])) { + types.push_back (LogVarType::POS_LV); + } else if (negCountedLvs_.contains (lvs[i])) { + types.push_back (LogVarType::NEG_LV); + } else { + types.push_back (LogVarType::FULL_LV); + } + } + return types; +} + + + +void +Clause::printClauses (const Clauses& clauses) +{ + for (size_t i = 0; i < clauses.size(); i++) { + cout << clauses[i] << endl; + } +} + + + LogVarSet Clause::getLogVarSetExcluding (size_t idx) const { @@ -200,7 +280,8 @@ std::ostream& operator<< (ostream &os, const Clause& clause) { for (unsigned i = 0; i < clause.literals_.size(); i++) { if (i != 0) os << " v " ; - os << clause.literals_[i].toString (clause.ipgLogVars_); + os << clause.literals_[i].toString (clause.ipgLogVars_, + clause.posCountedLvs_, clause.negCountedLvs_); } if (clause.constr_.empty() == false) { ConstraintTree copy (clause.constr_); @@ -215,10 +296,23 @@ std::ostream& operator<< (ostream &os, const Clause& clause) LiftedWCNF::LiftedWCNF (const ParfactorList& pfList) : pfList_(pfList), freeLiteralId_(0) { - addIndicatorClauses (pfList); - addParameterClauses (pfList); + //addIndicatorClauses (pfList); + //addParameterClauses (pfList); + + vector> names = {{"p1","p1"},{"p2","p2"}}; + + Clause c1 (names); + c1.addLiteral (Literal (0, LogVars()={0})); + c1.addAndNegateLiteral (Literal (1, {0,1})); + clauses_.push_back(c1); + + Clause c2 (names); + c2.addLiteral (Literal (0, LogVars()={0})); + c2.addAndNegateLiteral (Literal (1, {1,0})); + clauses_.push_back(c2); + cout << "FORMULA INDICATORS:" << endl; - printFormulaIndicators(); + // printFormulaIndicators(); cout << endl; cout << "WEIGHTS:" << endl; printWeights(); diff --git a/packages/CLPBN/horus/LiftedWCNF.h b/packages/CLPBN/horus/LiftedWCNF.h index 89ebec029..079c07a11 100644 --- a/packages/CLPBN/horus/LiftedWCNF.h +++ b/packages/CLPBN/horus/LiftedWCNF.h @@ -10,6 +10,16 @@ typedef long LiteralId; class ConstraintTree; +enum LogVarType +{ + FULL_LV, + POS_LV, + NEG_LV +}; + + +typedef vector LogVarTypes; + class Literal { public: @@ -39,7 +49,9 @@ class Literal bool isGround (ConstraintTree constr, LogVarSet ipgLogVars) const; - string toString (LogVarSet ipgLogVars = LogVarSet()) const; + string toString (LogVarSet ipgLogVars = LogVarSet(), + LogVarSet posCountedLvs = LogVarSet(), + LogVarSet negCountedLvs = LogVarSet()) const; friend std::ostream& operator<< (ostream &os, const Literal& lit); @@ -58,6 +70,8 @@ class Clause public: Clause (const ConstraintTree& ct) : constr_(ct) { } + Clause (vector> names) : constr_(ConstraintTree (names)) { } + void addLiteral (const Literal& l) { literals_.push_back (l); } void addAndNegateLiteral (const Literal& l) @@ -68,15 +82,15 @@ class Clause bool containsLiteral (LiteralId lid) const; - bool containsPositiveLiteral (LiteralId lid) const; + bool containsPositiveLiteral (LiteralId lid, const LogVarTypes&) const; - bool containsNegativeLiteral (LiteralId lid) const; + bool containsNegativeLiteral (LiteralId lid, const LogVarTypes&) const; void removeLiterals (LiteralId lid); - void removePositiveLiterals (LiteralId lid); + void removePositiveLiterals (LiteralId lid, const LogVarTypes&); - void removeNegativeLiterals (LiteralId lid); + void removeNegativeLiterals (LiteralId lid, const LogVarTypes&); const vector& literals (void) const { return literals_; } @@ -92,20 +106,36 @@ class Clause void addIpgLogVar (LogVar X) { ipgLogVars_.insert (X); } + void addPositiveCountedLogVar (LogVar X) { posCountedLvs_.insert (X); } + + void addNegativeCountedLogVar (LogVar X) { negCountedLvs_.insert (X); } + LogVarSet ipgCandidates (void) const; TinySet lidSet (void) const; + bool isCountedLogVar (LogVar X) const; + + bool isPositiveCountedLogVar (LogVar X) const; + + bool isNegativeCountedLogVar (LogVar X) const; + friend std::ostream& operator<< (ostream &os, const Clause& clause); void removeLiteral (size_t idx); + LogVarTypes logVarTypes (size_t litIdx) const; + + static void printClauses (const vector& clauses); + private: LogVarSet getLogVarSetExcluding (size_t idx) const; vector literals_; LogVarSet ipgLogVars_; + LogVarSet posCountedLvs_; + LogVarSet negCountedLvs_; ConstraintTree constr_; };