diff --git a/packages/CLPBN/horus/LiftedCircuit.cpp b/packages/CLPBN/horus/LiftedCircuit.cpp index a805f806c..4332e5584 100644 --- a/packages/CLPBN/horus/LiftedCircuit.cpp +++ b/packages/CLPBN/horus/LiftedCircuit.cpp @@ -8,24 +8,32 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf) { root_ = 0; Clauses ccc = lwcnf->clauses(); - //ccc.erase (ccc.begin() + 3, ccc.end()); + ccc.erase (ccc.begin() + 5, ccc.end()); //Clause c2 = ccc.front(); //c2.removeLiteralByIndex (1); //ccc.push_back (c2); - //compile (&root_, lwcnf->clauses()); compile (&root_, ccc); cout << "done compiling..." << endl; - printToDot(); + exportToGraphViz("circuit.dot"); + smoothCircuit(); + exportToGraphViz("smooth.dot"); } void -LiftedCircuit::printToDot (void) +LiftedCircuit::smoothCircuit (void) +{ + smoothCircuit (root_); +} + + + +void +LiftedCircuit::exportToGraphViz (const char* fileName) { - const char* fileName = "circuit.dot" ; ofstream out (fileName); if (!out.is_open()) { cerr << "error: cannot open file to write at " ; @@ -34,7 +42,7 @@ LiftedCircuit::printToDot (void) } out << "digraph {" << endl; out << "ranksep=1" << endl; - printToDot (root_, out); + exportToGraphViz (root_, out); out << "}" << endl; out.close(); } @@ -52,7 +60,18 @@ LiftedCircuit::compile ( } if (clauses.size() == 1 && clauses[0].isUnit()) { + static int count = 0; count ++; *follow = new LeafNode (clauses[0]); + if (count == 1) { + Clause c (new ConstraintTree({})); + c.addLiteral (Literal (100,{})); + *follow = new LeafNode (c); + } + if (count == 2) { + Clause c (new ConstraintTree({})); + c.addLiteral (Literal (101,{})); + *follow = new LeafNode (c); + } return; } @@ -64,7 +83,7 @@ LiftedCircuit::compile ( return; } - if (tryShannonDecomposition (follow, clauses)) { + if (tryShannonDecomp (follow, clauses)) { return; } @@ -104,8 +123,8 @@ LiftedCircuit::tryUnitPropagation ( stringstream explanation; explanation << " UP of " << clauses[i]; AndNode* andNode = new AndNode (clauses, explanation.str()); - compile (andNode->leftFollow(), {clauses[i]}); - compile (andNode->rightFollow(), newClauses); + compile (andNode->leftBranch(), {clauses[i]}); + compile (andNode->rightBranch(), newClauses); (*follow) = andNode; return true; } @@ -139,8 +158,8 @@ LiftedCircuit::tryIndependence ( stringstream explanation; explanation << " independence" ; AndNode* andNode = new AndNode (clauses, explanation.str()); - compile (andNode->leftFollow(), {clauses[i]}); - compile (andNode->rightFollow(), newClauses); + compile (andNode->leftBranch(), {clauses[i]}); + compile (andNode->rightBranch(), newClauses); (*follow) = andNode; return true; } @@ -151,7 +170,7 @@ LiftedCircuit::tryIndependence ( bool -LiftedCircuit::tryShannonDecomposition ( +LiftedCircuit::tryShannonDecomp ( CircuitNode** follow, const Clauses& clauses) { @@ -175,8 +194,8 @@ LiftedCircuit::tryShannonDecomposition ( explanation << " SD on " << literals[j]; OrNode* orNode = new OrNode (clauses, explanation.str()); (*follow) = orNode; - compile (orNode->leftFollow(), leftClauses); - compile (orNode->rightFollow(), rightClauses); + compile (orNode->leftBranch(), leftClauses); + compile (orNode->rightBranch(), rightClauses); return true; } } @@ -186,6 +205,113 @@ LiftedCircuit::tryShannonDecomposition ( +TinySet +LiftedCircuit::smoothCircuit (CircuitNode* node) +{ + assert (node != 0); + TinySet propagatingLids; + + switch (getCircuitNodeType (node)) { + + case CircuitNodeType::OR_NODE: { + OrNode* casted = dynamic_cast(node); + TinySet lids1 = smoothCircuit (*casted->leftBranch()); + TinySet lids2 = smoothCircuit (*casted->rightBranch()); + TinySet missingLeft = lids2 - lids1; + TinySet missingRight = lids1 - lids2; + if (missingLeft.empty() == false) { + Clauses clauses; + for (size_t i = 0; i < missingLeft.size(); i++) { + Clause c = lwcnf_->createClauseForLiteral (missingLeft[i]); + c.addAndNegateLiteral (c.literals()[0]); + clauses.push_back (c); + } + SmoothNode* smoothNode = new SmoothNode (clauses); + CircuitNode** prev = casted->leftBranch(); + string explanation = " smoothing" ; + AndNode* andNode = new AndNode ((*prev)->clauses(), smoothNode, *prev, explanation); + *prev = andNode; + } + if (missingRight.empty() == false) { + Clauses clauses; + for (size_t i = 0; i < missingRight.size(); i++) { + Clause c = lwcnf_->createClauseForLiteral (missingRight[i]); + c.addAndNegateLiteral (c.literals()[0]); + clauses.push_back (c); + } + SmoothNode* smoothNode = new SmoothNode (clauses); + CircuitNode** prev = casted->rightBranch(); + string explanation = " smoothing" ; + AndNode* andNode = new AndNode ((*prev)->clauses(), smoothNode, *prev, explanation); + *prev = andNode; + } + propagatingLids |= lids1; + propagatingLids |= lids2; + break; + } + + case CircuitNodeType::AND_NODE: { + AndNode* casted = dynamic_cast(node); + TinySet lids1 = smoothCircuit (*casted->leftBranch()); + TinySet lids2 = smoothCircuit (*casted->rightBranch()); + propagatingLids |= lids1; + propagatingLids |= lids2; + break; + } + + case CircuitNodeType::SET_OR_NODE: { + // TODO + } + + case CircuitNodeType::SET_AND_NODE: { + // TODO + } + + case CircuitNodeType::INC_EXC_NODE: { + // TODO + } + + case CircuitNodeType::LEAF_NODE: { + propagatingLids.insert (node->clauses()[0].literals()[0].lid()); + } + + // case CircuitNodeType::SMOOTH_NODE: + // case CircuitNodeType::TRUE_NODE: + // case CircuitNodeType::FAIL_NODE: + + default: + break; + } + + return propagatingLids; +} + + + +CircuitNodeType +LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const +{ + CircuitNodeType type; + if (dynamic_cast(node) != 0) { + type = CircuitNodeType::OR_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::AND_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::LEAF_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::SMOOTH_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::TRUE_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::FAIL_NODE; + } else { + assert (false); + } + return type; +} + + + string LiftedCircuit::escapeNode (const CircuitNode* node) const { @@ -197,19 +323,41 @@ LiftedCircuit::escapeNode (const CircuitNode* node) const void -LiftedCircuit::printToDot (CircuitNode* node, ofstream& os) +LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) { assert (node != 0); static unsigned nrAndNodes = 0; static unsigned nrOrNodes = 0; - if (dynamic_cast(node) != 0) { - nrAndNodes ++; - AndNode* casted = dynamic_cast(node); - os << escapeNode (node) << " [shape=box,label=\"" ; + if (dynamic_cast(node) != 0) { + OrNode* casted = dynamic_cast(node); const Clauses& clauses = node->clauses(); + if (clauses.empty() == false) { + os << escapeNode (node) << " [shape=box,label=\"" ; for (size_t i = 0; i < clauses.size(); i++) { - os << clauses[i] << "\\n" ; + if (i != 0) os << "\\n" ; + os << clauses[i]; + } + os << "\"]" ; + os << endl; + } + os << "or" << nrOrNodes << " [label=\"∨\"]" << endl; + os << '"' << node << '"' << " -> " << "or" << nrOrNodes; + os << " [label=\"" << node->explanation() << "\"]" << endl; + os << "or" << nrOrNodes << " -> " ; + os << escapeNode (*casted->leftBranch()) << endl; + os << "or" << nrOrNodes << " -> " ; + os << escapeNode (*casted->rightBranch()) << endl; + nrOrNodes ++; + exportToGraphViz (*casted->leftBranch(), os); + exportToGraphViz (*casted->rightBranch(), os); + } else if (dynamic_cast(node) != 0) { + AndNode* casted = dynamic_cast(node); + const Clauses& clauses = node->clauses(); + os << escapeNode (node) << " [shape=box,label=\"" ; + for (size_t i = 0; i < clauses.size(); i++) { + if (i != 0) os << "\\n" ; + os << clauses[i]; } os << "\"]" ; os << endl; @@ -217,33 +365,26 @@ LiftedCircuit::printToDot (CircuitNode* node, ofstream& os) os << '"' << node << '"' << " -> " << "and" << nrAndNodes; os << " [label=\"" << node->explanation() << "\"]" << endl; os << "and" << nrAndNodes << " -> " ; - os << escapeNode (*casted->leftFollow()) << endl; + os << escapeNode (*casted->leftBranch()) << endl; os << "and" << nrAndNodes << " -> " ; - os << escapeNode (*casted->rightFollow()) << endl; - printToDot (*casted->leftFollow(), os); - printToDot (*casted->rightFollow(), os); - } else if (dynamic_cast(node) != 0) { - nrOrNodes ++; - OrNode* casted = dynamic_cast(node); - os << escapeNode (node) << " [shape=box,label=\"" ; - const Clauses& clauses = node->clauses(); - for (size_t i = 0; i < clauses.size(); i++) { - os << clauses[i] << "\\n" ; - } + os << escapeNode (*casted->rightBranch()) << endl; + nrAndNodes ++; + exportToGraphViz (*casted->leftBranch(), os); + exportToGraphViz (*casted->rightBranch(), os); + } else if (dynamic_cast(node) != 0) { + os << escapeNode (node); + os << " [shape=box,label=\"" ; + os << node->clauses()[0]; os << "\"]" ; os << endl; - os << "or" << nrOrNodes << " [label=\"∨\"]" << endl; - os << '"' << node << '"' << " -> " << "or" << nrOrNodes; - os << " [label=\"" << node->explanation() << "\"]" << endl; - os << "or" << nrOrNodes << " -> " ; - os << escapeNode (*casted->leftFollow()) << endl; - os << "or" << nrOrNodes << " -> " ; - os << escapeNode (*casted->rightFollow()) << endl; - printToDot (*casted->leftFollow(), os); - printToDot (*casted->rightFollow(), os); - } else if (dynamic_cast(node) != 0) { - os << escapeNode (node) << " [shape=box,label=\"" ; - os << node->clauses()[0] << "\\n" ; + } else if (dynamic_cast(node) != 0) { + os << escapeNode (node); + os << " [shape=box,style=filled,fillcolor=chartreuse,label=\"" ; + const Clauses& clauses = node->clauses(); + for (size_t i = 0; i < clauses.size(); i++) { + if (i != 0) os << "\\n" ; + os << clauses[i]; + } os << "\"]" ; os << endl; } else if (dynamic_cast(node) != 0) { @@ -252,15 +393,16 @@ LiftedCircuit::printToDot (CircuitNode* node, ofstream& os) os << endl; } else if (dynamic_cast(node) != 0) { os << escapeNode (node); - os << " [shape=box,style=filled,fillcolor=red,label=\"" ; + os << " [shape=box,style=filled,fillcolor=brown1,label=\"" ; const Clauses& clauses = node->clauses(); for (size_t i = 0; i < clauses.size(); i++) { - os << clauses[i] << "\\n" ; + if (i != 0) os << "\\n" ; + os << clauses[i]; } os << "\"]" ; os << endl; } else { - cout << "something really failled" << endl; + assert (false); } } diff --git a/packages/CLPBN/horus/LiftedCircuit.h b/packages/CLPBN/horus/LiftedCircuit.h index dfd74d82d..e386e78d5 100644 --- a/packages/CLPBN/horus/LiftedCircuit.h +++ b/packages/CLPBN/horus/LiftedCircuit.h @@ -4,6 +4,20 @@ #include "LiftedWCNF.h" +enum CircuitNodeType { + OR_NODE, + AND_NODE, + SET_OR_NODE, + SET_AND_NODE, + INC_EXC_NODE, + LEAF_NODE, + SMOOTH_NODE, + TRUE_NODE, + FAIL_NODE +}; + + + class CircuitNode { public: @@ -23,34 +37,49 @@ class CircuitNode -class AndNode : public CircuitNode -{ - public: - AndNode (const Clauses& clauses, string explanation = "") - : CircuitNode (clauses, explanation), - leftFollow_(0), rightFollow_(0) { } - - CircuitNode** leftFollow (void) { return &leftFollow_; } - CircuitNode** rightFollow (void) { return &rightFollow_; } - private: - CircuitNode* leftFollow_; - CircuitNode* rightFollow_; -}; - - - class OrNode : public CircuitNode { public: OrNode (const Clauses& clauses, string explanation = "") : CircuitNode (clauses, explanation), - leftFollow_(0), rightFollow_(0) { } + leftBranch_(0), rightBranch_(0) { } - CircuitNode** leftFollow (void) { return &leftFollow_; } - CircuitNode** rightFollow (void) { return &rightFollow_; } + CircuitNode** leftBranch (void) { return &leftBranch_; } + CircuitNode** rightBranch (void) { return &rightBranch_; } private: - CircuitNode* leftFollow_; - CircuitNode* rightFollow_; + CircuitNode* leftBranch_; + CircuitNode* rightBranch_; +}; + + + +class AndNode : public CircuitNode +{ + public: + AndNode (const Clauses& clauses, string explanation = "") + : CircuitNode (clauses, explanation), + leftBranch_(0), rightBranch_(0) { } + + AndNode ( + const Clauses& clauses, + CircuitNode* leftBranch, + CircuitNode* rightBranch, + string explanation = "") + : CircuitNode (clauses, explanation), + leftBranch_(leftBranch), rightBranch_(rightBranch) { } + + AndNode ( + CircuitNode* leftBranch, + CircuitNode* rightBranch, + string explanation = "") + : CircuitNode ({}, explanation), + leftBranch_(leftBranch), rightBranch_(rightBranch) { } + + CircuitNode** leftBranch (void) { return &leftBranch_; } + CircuitNode** rightBranch (void) { return &rightBranch_; } + private: + CircuitNode* leftBranch_; + CircuitNode* rightBranch_; }; @@ -88,7 +117,14 @@ class LeafNode : public CircuitNode { public: LeafNode (const Clause& clause) : CircuitNode ({clause}) { } - private: +}; + + + +class SmoothNode : public CircuitNode +{ + public: + SmoothNode (const Clauses& clauses) : CircuitNode (clauses) { } }; @@ -97,7 +133,6 @@ class TrueNode : public CircuitNode { public: TrueNode () : CircuitNode ({}) { } - private: }; @@ -107,7 +142,6 @@ class FailNode : public CircuitNode { public: FailNode (const Clauses& clauses) : CircuitNode (clauses) { } - private: }; @@ -117,21 +151,29 @@ class LiftedCircuit public: LiftedCircuit (const LiftedWCNF* lwcnf); - void printToDot (void); + void smoothCircuit (void); + + void exportToGraphViz (const char*); private: void compile (CircuitNode** follow, const Clauses& clauses); bool tryUnitPropagation (CircuitNode** follow, const Clauses& clauses); - bool tryIndependence (CircuitNode** follow, const Clauses& clauses); - bool tryShannonDecomposition (CircuitNode** follow, const Clauses& clauses); - - string escapeNode (const CircuitNode*) const; - void printToDot (CircuitNode* node, ofstream&); + bool tryIndependence (CircuitNode** follow, const Clauses& clauses); + bool tryShannonDecomp (CircuitNode** follow, const Clauses& clauses); + + TinySet smoothCircuit (CircuitNode* node); + + CircuitNodeType getCircuitNodeType (const CircuitNode* node) const; + + string escapeNode (const CircuitNode* node) const; + + void exportToGraphViz (CircuitNode* node, ofstream&); CircuitNode* root_; const LiftedWCNF* lwcnf_; }; #endif // HORUS_LIFTEDCIRCUIT_H + diff --git a/packages/CLPBN/horus/LiftedWCNF.cpp b/packages/CLPBN/horus/LiftedWCNF.cpp index 2e67c38eb..ac08bbbe0 100644 --- a/packages/CLPBN/horus/LiftedWCNF.cpp +++ b/packages/CLPBN/horus/LiftedWCNF.cpp @@ -150,7 +150,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList) { addIndicatorClauses (pfList); addParameterClauses (pfList); - printFormulasToIndicators(); + printFormulaIndicators(); printClauses(); } @@ -163,6 +163,31 @@ LiftedWCNF::~LiftedWCNF (void) +Clause +LiftedWCNF::createClauseForLiteral (LiteralId lid) const +{ + for (size_t i = 0; i < clauses_.size(); i++) { + const Literals& literals = clauses_[i].literals(); + for (size_t j = 0; j < literals.size(); j++) { + if (literals[j].lid() == lid) { + ConstraintTree* ct = new ConstraintTree (*clauses_[i].constr()); + ct->project (literals[j].logVars()); + Clause clause (ct); + clause.addLiteral (literals[j]); + return clause; + } + } + } + // FIXME + Clause c (new ConstraintTree({})); + c.addLiteral (Literal (lid,{})); + return c; + //assert (false); + //return Clause (0); +} + + + void LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList) { @@ -245,7 +270,7 @@ LiftedWCNF::printClauses (void) const void -LiftedWCNF::printFormulasToIndicators (void) const +LiftedWCNF::printFormulaIndicators (void) const { set allGroups; ParfactorList::const_iterator it = pfList_.begin(); diff --git a/packages/CLPBN/horus/LiftedWCNF.h b/packages/CLPBN/horus/LiftedWCNF.h index e829a1fba..f71582b42 100644 --- a/packages/CLPBN/horus/LiftedWCNF.h +++ b/packages/CLPBN/horus/LiftedWCNF.h @@ -20,6 +20,8 @@ class Literal lid_(lit.lid_), logVars_(lit.logVars_), weight_(lit.weight_), negated_(negated) { } LiteralId lid (void) const { return lid_; } + + LogVars logVars (void) const { return logVars_; } double weight (void) const { return weight_; } @@ -102,9 +104,11 @@ class LiftedWCNF const Clauses& clauses (void) const { return clauses_; } + Clause createClauseForLiteral (LiteralId lid) const; + void printClauses (void) const; - void printFormulasToIndicators (void) const; + void printFormulaIndicators (void) const; private: void addIndicatorClauses (const ParfactorList& pfList); @@ -119,7 +123,7 @@ class LiftedWCNF Clauses clauses_; - unordered_map> map_; + unordered_map> map_; const ParfactorList& pfList_; @@ -127,3 +131,4 @@ class LiftedWCNF }; #endif // HORUS_LIFTEDWCNF_H +