diff --git a/packages/CLPBN/horus/LiftedCircuit.cpp b/packages/CLPBN/horus/LiftedCircuit.cpp index 0b1ebcdf1..9586c7bbb 100644 --- a/packages/CLPBN/horus/LiftedCircuit.cpp +++ b/packages/CLPBN/horus/LiftedCircuit.cpp @@ -8,7 +8,6 @@ OrNode::weight (void) const { double lw = leftBranch_->weight(); double rw = rightBranch_->weight(); - cout << ">>>> OR NODE res = " << lw << " + " << rw << endl; return Globals::logDomain ? Util::logSum (lw, rw) : lw + rw; } @@ -28,7 +27,6 @@ double SetOrNode::weight (void) const { // TODO - assert (false); return 0.0; } @@ -38,11 +36,10 @@ SetOrNode::weight (void) const double SetAndNode::weight (void) const { - unsigned nrGroundings = 2; // FIXME double w = follow_->weight(); return Globals::logDomain - ? w * nrGroundings - : std::pow (w, nrGroundings); + ? w * nrGroundings_ + : std::pow (w, nrGroundings_); } @@ -108,7 +105,9 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf) //ccc.push_back (c2); //compile (&root_, lwcnf->clauses()); - compile (&root_, ccc); + Clauses cccc = {ccc[6],ccc[4]}; + cccc.front().removeLiteral (2); + compile (&root_, cccc); exportToGraphViz("circuit.dot"); smoothCircuit(); exportToGraphViz("smooth.dot"); @@ -189,6 +188,10 @@ LiftedCircuit::compile ( return; } + if (tryInclusionExclusion (follow, clauses)) { + return; + } + if (tryIndepPartialGrounding (follow, clauses)) { return; } @@ -309,9 +312,9 @@ LiftedCircuit::tryShannonDecomp ( stringstream explanation; explanation << " SD on " << literals[j]; OrNode* orNode = new OrNode (clauses, explanation.str()); - (*follow) = orNode; compile (orNode->leftBranch(), leftClauses); compile (orNode->rightBranch(), rightClauses); + (*follow) = orNode; return true; } } @@ -321,6 +324,56 @@ LiftedCircuit::tryShannonDecomp ( +bool +LiftedCircuit::tryInclusionExclusion ( + CircuitNode** follow, + Clauses& clauses) +{ + for (size_t i = 0; i < clauses.size(); i++) { + const Literals& literals = clauses[i].literals(); + for (size_t j = 0; j < literals.size(); j++) { + bool indep = true; + for (size_t k = 0; k < literals.size(); k++) { + LogVarSet intersect = literals[j].logVarSet() + & literals[k].logVarSet(); + if (j != k && intersect.empty() == false) { + indep = false; + break; + } + } + if (indep) { + // TODO i am almost sure that this will + // have to be count normalized too! + ConstraintTree really = clauses[i].constr(); + Clause c1 (really.projectedCopy ( + literals[j].logVars())); + c1.addLiteral (literals[j]); + Clause c2 = clauses[i]; + c2.removeLiteral (j); + Clauses plus1Clauses = clauses; + Clauses plus2Clauses = clauses; + Clauses minusClauses = clauses; + plus1Clauses.erase (plus1Clauses.begin() + i); + plus2Clauses.erase (plus2Clauses.begin() + i); + minusClauses.erase (minusClauses.begin() + i); + plus1Clauses.push_back (c1); + plus2Clauses.push_back (c2); + minusClauses.push_back (c1); + minusClauses.push_back (c2); + IncExcNode* ieNode = new IncExcNode (clauses); + compile (ieNode->plus1Branch(), plus1Clauses); + compile (ieNode->plus2Branch(), plus2Clauses); + compile (ieNode->minusBranch(), minusClauses); + *follow = ieNode; + return true; + } + } + } + return false; +} + + + bool LiftedCircuit::tryIndepPartialGrounding ( CircuitNode** follow, @@ -328,7 +381,6 @@ LiftedCircuit::tryIndepPartialGrounding ( { // assumes that all literals have logical variables // else, shannon decomp was possible - vector lvIndices; LogVarSet lvs = clauses[0].ipgCandidates(); for (size_t i = 0; i < lvs.size(); i++) { @@ -503,9 +555,13 @@ LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const } else if (dynamic_cast(node) != 0) { type = CircuitNodeType::AND_NODE; } else if (dynamic_cast(node) != 0) { + // TODO + assert (false); type = CircuitNodeType::SET_OR_NODE; } else if (dynamic_cast(node) != 0) { type = CircuitNodeType::SET_AND_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::INC_EXC_NODE; } else if (dynamic_cast(node) != 0) { type = CircuitNodeType::LEAF_NODE; } else if (dynamic_cast(node) != 0) { @@ -611,7 +667,8 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) } case SET_OR_NODE: { - assert (false); // not yet implemented + // TODO + assert (false); } case SET_AND_NODE: { @@ -639,6 +696,43 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) break; } + case INC_EXC_NODE: { + IncExcNode* casted = dynamic_cast(node); + const Clauses& clauses = node->clauses(); + os << escapeNode (node) << " [shape=box,label=\"" ; + for (size_t i = 0; i < clauses.size(); i++) { + if (i != 0) os << "\\n" ; + os << clauses[i]; + } + os << "\"]" ; + os << endl; + + os << auxNode << " [label=\"IncExc\"]" << endl; + os << escapeNode (node) << " -> " << auxNode; + os << " [label=\"" << node->explanation() << "\"]" ; + os << endl; + + os << auxNode << " -> " ; + os << escapeNode (*casted->plus1Branch()); + os << " [label=\" " << (*casted->plus1Branch())->weight() << "\"]" ; + os << endl; + + os << auxNode << " -> " ; + os << escapeNode (*casted->plus2Branch()); + os << " [label=\" " << (*casted->plus2Branch())->weight() << "\"]" ; + os << endl; + + os << auxNode << " -> " ; + os << escapeNode (*casted->minusBranch()) << endl; + os << " [label=\" " << (*casted->minusBranch())->weight() << "\"]" ; + os << endl; + + exportToGraphViz (*casted->plus1Branch(), os); + exportToGraphViz (*casted->plus2Branch(), os); + exportToGraphViz (*casted->minusBranch(), os); + break; + } + case LEAF_NODE: { os << escapeNode (node); os << " [shape=box,label=\"" ; diff --git a/packages/CLPBN/horus/LiftedCircuit.h b/packages/CLPBN/horus/LiftedCircuit.h index 0db93679c..98ce61086 100644 --- a/packages/CLPBN/horus/LiftedCircuit.h +++ b/packages/CLPBN/horus/LiftedCircuit.h @@ -109,7 +109,7 @@ class SetAndNode : public CircuitNode { public: SetAndNode (unsigned nrGroundings, const Clauses& clauses) - : CircuitNode (clauses, "IPG"), nrGroundings_(nrGroundings), + : CircuitNode (clauses, " IPG"), nrGroundings_(nrGroundings), follow_(0) { } double weight (void) const; @@ -122,13 +122,20 @@ class SetAndNode : public CircuitNode -class IncExclNode : public CircuitNode +class IncExcNode : public CircuitNode { public: + IncExcNode (const Clauses& clauses) + : CircuitNode (clauses), plus1Branch_(0), + plus2Branch_(0), minusBranch_(0) { } + + CircuitNode** plus1Branch (void) { return &plus1Branch_; } + CircuitNode** plus2Branch (void) { return &plus2Branch_; } + CircuitNode** minusBranch (void) { return &minusBranch_; } private: - CircuitNode* xFollow_; - CircuitNode* yFollow_; - CircuitNode* zFollow_; + CircuitNode* plus1Branch_; + CircuitNode* plus2Branch_; + CircuitNode* minusBranch_; }; @@ -156,7 +163,7 @@ class SmoothNode : public CircuitNode class TrueNode : public CircuitNode { public: - TrueNode () : CircuitNode ({}) { } + TrueNode (void) : CircuitNode ({}) { } double weight (void) const; }; @@ -190,6 +197,7 @@ class LiftedCircuit bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses); bool tryIndependence (CircuitNode** follow, Clauses& clauses); bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses); + bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses); bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses); bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct, vector& indices); diff --git a/packages/CLPBN/horus/LiftedWCNF.cpp b/packages/CLPBN/horus/LiftedWCNF.cpp index 657016988..d69f4bd01 100644 --- a/packages/CLPBN/horus/LiftedWCNF.cpp +++ b/packages/CLPBN/horus/LiftedWCNF.cpp @@ -203,7 +203,7 @@ std::ostream& operator<< (ostream &os, const Clause& clause) os << clause.literals_[i].toString (clause.ipgLogVars_); } if (clause.constr_.empty() == false) { - ConstraintTree copy = clause.constr_; + ConstraintTree copy (clause.constr_); copy.moveToTop (copy.logVarSet().elements()); os << " | " << copy.tupleSet(); } diff --git a/packages/CLPBN/horus/LiftedWCNF.h b/packages/CLPBN/horus/LiftedWCNF.h index 564ec0fd4..89ebec029 100644 --- a/packages/CLPBN/horus/LiftedWCNF.h +++ b/packages/CLPBN/horus/LiftedWCNF.h @@ -25,8 +25,10 @@ class Literal LiteralId lid (void) const { return lid_; } LogVars logVars (void) const { return logVars_; } + + LogVarSet logVarSet (void) const { return LogVarSet (logVars_); } - // FIXME not log aware :( + // FIXME this is not log aware :( double weight (void) const { return weight_ < 0.0 ? 1.0 : weight_; } void negate (void) { negated_ = !negated_; } @@ -96,8 +98,9 @@ class Clause friend std::ostream& operator<< (ostream &os, const Clause& clause); - private: void removeLiteral (size_t idx); + + private: LogVarSet getLogVarSetExcluding (size_t idx) const;