From d074ca9a8fb36922aaeb03333a1bc2c705191c27 Mon Sep 17 00:00:00 2001 From: Tiago Gomes Date: Sat, 27 Oct 2012 00:13:11 +0100 Subject: [PATCH] add initial independent partial grounding support --- packages/CLPBN/horus/LiftedCircuit.cpp | 165 ++++++++++++++++++++++--- packages/CLPBN/horus/LiftedCircuit.h | 28 +++-- packages/CLPBN/horus/LiftedWCNF.cpp | 121 ++++++++++++------ packages/CLPBN/horus/LiftedWCNF.h | 23 +++- 4 files changed, 273 insertions(+), 64 deletions(-) diff --git a/packages/CLPBN/horus/LiftedCircuit.cpp b/packages/CLPBN/horus/LiftedCircuit.cpp index f3b27d4b9..8b6214428 100644 --- a/packages/CLPBN/horus/LiftedCircuit.cpp +++ b/packages/CLPBN/horus/LiftedCircuit.cpp @@ -23,6 +23,28 @@ AndNode::weight (void) const +double +SetOrNode::weight (void) const +{ + // TODO + assert (false); + return 0.0; +} + + + + +double +SetAndNode::weight (void) const +{ + unsigned nrGroundings = 2; // FIXME + return Globals::logDomain + ? follow_->weight() * nrGroundings + : std::pow (follow_->weight(), nrGroundings); +} + + + double LeafNode::weight (void) const { @@ -30,13 +52,13 @@ LeafNode::weight (void) const assert (clauses()[0].isUnit()); Clause c = clauses()[0]; double weight = c.literals()[0].weight(); - unsigned nrGroundings = c.constr()->size(); + unsigned nrGroundings = c.constr().size(); assert (nrGroundings != 0); double www = Globals::logDomain ? weight * nrGroundings : std::pow (weight, nrGroundings); - cout << "leaf w: " << www << endl; + cout << "leaf weight(" << clauses()[0].literals()[0] << "): " << www << endl; return Globals::logDomain ? weight * nrGroundings @@ -53,7 +75,7 @@ SmoothNode::weight (void) const for (size_t i = 0; i < cs.size(); i++) { double posWeight = cs[i].literals()[0].weight(); double negWeight = cs[i].literals()[1].weight(); - unsigned nrGroundings = cs[i].constr()->size(); + unsigned nrGroundings = cs[i].constr().size(); if (Globals::logDomain) { totalWeight += (Util::logSum (posWeight, negWeight) * nrGroundings); } else { @@ -130,7 +152,7 @@ LiftedCircuit::exportToGraphViz (const char* fileName) void LiftedCircuit::compile ( CircuitNode** follow, - const Clauses& clauses) + Clauses& clauses) { if (clauses.empty()) { *follow = new TrueNode (); @@ -165,6 +187,15 @@ LiftedCircuit::compile ( return; } + if (tryIndepPartialGrounding (follow, clauses)) { + return; + } + + if (tryGrounding (follow, clauses)) { + return; + } + + // assert (false); *follow = new FailNode (clauses); } @@ -174,7 +205,7 @@ LiftedCircuit::compile ( bool LiftedCircuit::tryUnitPropagation ( CircuitNode** follow, - const Clauses& clauses) + Clauses& clauses) { for (size_t i = 0; i < clauses.size(); i++) { if (clauses[i].isUnit()) { @@ -185,11 +216,14 @@ LiftedCircuit::tryUnitPropagation ( if (clauses[i].literals()[0].isPositive()) { if (clauses[j].containsPositiveLiteral (lid) == false) { Clause newClause = clauses[j]; + //cout << "new j : " << clauses[j] << endl; + //cout << "new clause: " << newClause << endl; + //cout << "clvs: " << clauses[j].constr()->logVars() << endl; newClause.removeNegativeLiterals (lid); newClauses.push_back (newClause); } - } - if (clauses[i].literals()[0].isNegative()) { + } else if (clauses[i].literals()[0].isNegative()) { + //cout << "unit prop of = " << clauses[i].literals()[0] << endl; if (clauses[j].containsNegativeLiteral (lid) == false) { Clause newClause = clauses[j]; newClause.removePositiveLiterals (lid); @@ -201,7 +235,8 @@ LiftedCircuit::tryUnitPropagation ( stringstream explanation; explanation << " UP of " << clauses[i]; AndNode* andNode = new AndNode (clauses, explanation.str()); - compile (andNode->leftBranch(), {clauses[i]}); + Clauses leftClauses = {clauses[i]}; + compile (andNode->leftBranch(), leftClauses); compile (andNode->rightBranch(), newClauses); (*follow) = andNode; return true; @@ -215,7 +250,7 @@ LiftedCircuit::tryUnitPropagation ( bool LiftedCircuit::tryIndependence ( CircuitNode** follow, - const Clauses& clauses) + Clauses& clauses) { if (clauses.size() == 1) { return false; @@ -236,7 +271,8 @@ LiftedCircuit::tryIndependence ( stringstream explanation; explanation << " independence" ; AndNode* andNode = new AndNode (clauses, explanation.str()); - compile (andNode->leftBranch(), {clauses[i]}); + Clauses indepClause = {clauses[i]}; + compile (andNode->leftBranch(), indepClause); compile (andNode->rightBranch(), newClauses); (*follow) = andNode; return true; @@ -250,16 +286,16 @@ LiftedCircuit::tryIndependence ( bool LiftedCircuit::tryShannonDecomp ( CircuitNode** follow, - const Clauses& clauses) + Clauses& clauses) { for (size_t i = 0; i < clauses.size(); i++) { const Literals& literals = clauses[i].literals(); for (size_t j = 0; j < literals.size(); j++) { - if (literals[j].isGround (clauses[i].constr())) { + if (literals[j].isGround (clauses[i].constr(),clauses[i].ipgLogVars())) { Literal posLit (literals[j], false); Literal negLit (literals[j], true); - ConstraintTree* ct1 = new ConstraintTree (*clauses[i].constr()); - ConstraintTree* ct2 = new ConstraintTree (*clauses[i].constr()); + ConstraintTree ct1 = clauses[i].constr(); + ConstraintTree ct2 = clauses[i].constr(); Clause c1 (ct1); Clause c2 (ct2); c1.addLiteral (posLit); @@ -283,6 +319,72 @@ LiftedCircuit::tryShannonDecomp ( +bool +LiftedCircuit::tryIndepPartialGrounding ( + CircuitNode** follow, + Clauses& clauses) +{ + // assumes that all literals have logical variables + LogVar X = clauses[0].constr().logVars()[0]; + ConstraintTree ct = clauses[0].constr(); + + // FIXME this is so weak ... + ct.project ({X}); + for (size_t i = 0; i < clauses.size(); i++) { + if (clauses[i].constr().logVars().size() == 1) { + if (ct.tupleSet() != clauses[i].constr().tupleSet()) { + return false; + } + } else { + return false; + } + } + + // FIXME this is so broken ... + Clauses newClauses = clauses; + for (size_t i = 0; i < clauses.size(); i++) { + newClauses[i].addIpgLogVar (clauses[i].constr().logVars()[0]); + } + + string explanation = " IPG" ; + SetAndNode* node = new SetAndNode (clauses, explanation); + *follow = node; + compile (node->follow(), newClauses); + return true; +} + + + +bool +LiftedCircuit::tryGrounding ( + CircuitNode** follow, + Clauses& clauses) +{ + return false; + /* + size_t bestClauseIdx = 0; + size_t bestLogVarIdx = 0; + unsigned minNrSymbols = Util::maxUnsigned(); + for (size_t i = 0; i < clauses.size(); i++) { + LogVarSet lvs = clauses[i].constr().logVars(); + ConstraintTree ct = clauses[i].constr(); + for (unsigned j = 0; j < lvs.size(); j++) { + unsigned nrSymbols = ct.nrSymbols (lvs[j]); + if (nrSymbols < minNrSymbols) { + minNrSymbols = nrSymbols; + bestClauseIdx = i; + bestLogVarIdx = j; + } + } + } + LogVar bestLogVar = clauses[bestClauseIdx].constr().logVars()[bestLogVarIdx]; + ConstraintTrees cts = clauses[bestClauseIdx].constr().ground (bestLogVar); + return true; + */ +} + + + TinySet LiftedCircuit::smoothCircuit (CircuitNode* node) { @@ -374,6 +476,10 @@ LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const type = CircuitNodeType::OR_NODE; } else if (dynamic_cast(node) != 0) { type = CircuitNodeType::AND_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::SET_OR_NODE; + } else if (dynamic_cast(node) != 0) { + type = CircuitNodeType::SET_AND_NODE; } else if (dynamic_cast(node) != 0) { type = CircuitNodeType::LEAF_NODE; } else if (dynamic_cast(node) != 0) { @@ -404,9 +510,11 @@ void LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) { assert (node != 0); - + + static unsigned nrOrNodes = 0; static unsigned nrAndNodes = 0; - static unsigned nrOrNodes = 0; + static unsigned nrSetOrNodes = 0; + static unsigned nrSetAndNodes = 0; switch (getCircuitNodeType (node)) { @@ -458,6 +566,31 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) break; } + case SET_OR_NODE: { + nrSetOrNodes ++; + assert (false); // not yet implemented + } + + case SET_AND_NODE: { + SetAndNode* casted = dynamic_cast(node); + const Clauses& clauses = node->clauses(); + os << escapeNode (node) << " [shape=box,label=\"" ; + for (size_t i = 0; i < clauses.size(); i++) { + if (i != 0) os << "\\n" ; + os << clauses[i]; + } + os << "\"]" ; + os << endl; + os << "setand" << nrSetAndNodes << " [label=\"∧(X)\"]" << endl; + os << '"' << node << '"' << " -> " << "setand" << nrSetAndNodes; + os << " [label=\"" << node->explanation() << "\"]" << endl; + os << "setand" << nrSetAndNodes << " -> " ; + os << escapeNode (*casted->follow()) << endl; + nrSetAndNodes ++; + exportToGraphViz (*casted->follow(), os); + break; + } + case LEAF_NODE: { os << escapeNode (node); os << " [shape=box,label=\"" ; diff --git a/packages/CLPBN/horus/LiftedCircuit.h b/packages/CLPBN/horus/LiftedCircuit.h index 839d14a70..435a26517 100644 --- a/packages/CLPBN/horus/LiftedCircuit.h +++ b/packages/CLPBN/horus/LiftedCircuit.h @@ -46,7 +46,7 @@ class OrNode : public CircuitNode : CircuitNode (clauses, explanation), leftBranch_(0), rightBranch_(0) { } - double weight (void) const; + double weight (void) const; CircuitNode** leftBranch (void) { return &leftBranch_; } CircuitNode** rightBranch (void) { return &rightBranch_; } @@ -90,18 +90,30 @@ class AndNode : public CircuitNode -class SetAndNode : public CircuitNode +class SetOrNode : public CircuitNode { public: + SetOrNode (const Clauses& clauses, string explanation = "") + : CircuitNode (clauses, explanation), follow_(0) { } + + double weight (void) const; + + CircuitNode** follow (void) { return &follow_; } private: CircuitNode* follow_; }; -class SetOrNode : public CircuitNode +class SetAndNode : public CircuitNode { public: + SetAndNode (const Clauses& clauses, string explanation = "") + : CircuitNode (clauses, explanation), follow_(0) { } + + double weight (void) const; + + CircuitNode** follow (void) { return &follow_; } private: CircuitNode* follow_; }; @@ -171,11 +183,13 @@ class LiftedCircuit private: - void compile (CircuitNode** follow, const Clauses& clauses); + void compile (CircuitNode** follow, Clauses& clauses); - bool tryUnitPropagation (CircuitNode** follow, const Clauses& clauses); - bool tryIndependence (CircuitNode** follow, const Clauses& clauses); - bool tryShannonDecomp (CircuitNode** follow, const Clauses& clauses); + bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses); + bool tryIndependence (CircuitNode** follow, Clauses& clauses); + bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses); + bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses); + bool tryGrounding (CircuitNode** follow, Clauses& clauses); TinySet smoothCircuit (CircuitNode* node); diff --git a/packages/CLPBN/horus/LiftedWCNF.cpp b/packages/CLPBN/horus/LiftedWCNF.cpp index df2efd59e..4ad651621 100644 --- a/packages/CLPBN/horus/LiftedWCNF.cpp +++ b/packages/CLPBN/horus/LiftedWCNF.cpp @@ -4,30 +4,49 @@ bool -Literal::isGround (ConstraintTree* constr) const +Literal::isGround (ConstraintTree constr, LogVarSet ipgLogVars) const { if (logVars_.size() == 0) { return true; } - LogVarSet singletons = constr->singletons(); - return singletons.contains (logVars_); + LogVarSet lvs (logVars_); + lvs -= ipgLogVars; + return constr.singletons().contains (lvs); +} + + + +string +Literal::toString (LogVarSet ipgLogVars) const +{ + stringstream ss; + negated_ ? ss << "¬" : ss << "" ; + weight_ < 0.0 ? ss << "λ" : ss << "Θ" ; + ss << lid_ ; + if (logVars_.empty() == false) { + ss << "(" ; + for (size_t i = 0; i < logVars_.size(); i++) { + if (i != 0) ss << ","; + if (ipgLogVars.contains (logVars_[i])) { + LogVar X = logVars_[i]; + const string labels[] = { + "a", "b", "c", "d", "e", "f", + "g", "h", "i", "j", "k", "m" }; + (X >= 12) ? ss << "x_" << X : ss << labels[X]; + } else { + ss << logVars_[i]; + } + } + ss << ")" ; + } + return ss.str(); } std::ostream& operator<< (ostream &os, const Literal& lit) { - lit.negated_ ? os << "¬" : os << "" ; - lit.weight_ < 0.0 ? os << "λ" : os << "Θ" ; - os << lit.lid_ ; - if (lit.logVars_.empty() == false) { - os << "(" ; - for (size_t i = 0; i < lit.logVars_.size(); i++) { - if (i != 0) os << ","; - os << lit.logVars_[i]; - } - os << ")" ; - } + os << lit.toString(); return os; } @@ -78,7 +97,7 @@ Clause::removeLiterals (LiteralId lid) size_t i = 0; while (i != literals_.size()) { if (literals_[i].lid() == lid) { - literals_.erase (literals_.begin() + i); + removeLiteral (i); } else { i ++; } @@ -93,7 +112,7 @@ Clause::removePositiveLiterals (LiteralId lid) size_t i = 0; while (i != literals_.size()) { if (literals_[i].lid() == lid && literals_[i].isPositive()) { - literals_.erase (literals_.begin() + i); + removeLiteral (i); } else { i ++; } @@ -108,7 +127,7 @@ Clause::removeNegativeLiterals (LiteralId lid) size_t i = 0; while (i != literals_.size()) { if (literals_[i].lid() == lid && literals_[i].isNegative()) { - literals_.erase (literals_.begin() + i); + removeLiteral (i); } else { i ++; } @@ -129,14 +148,39 @@ Clause::lidSet (void) const +void +Clause::removeLiteral (size_t idx) +{ + LogVarSet lvs (literals_[idx].logVars()); + lvs -= getLogVarSetExcluding (idx); + constr_.remove (lvs); + literals_.erase (literals_.begin() + idx); +} + + + +LogVarSet +Clause::getLogVarSetExcluding (size_t idx) const +{ + LogVarSet lvs; + for (size_t i = 0; i < literals_.size(); i++) { + if (i != idx) { + lvs |= literals_[i].logVars(); + } + } + return lvs; +} + + + std::ostream& operator<< (ostream &os, const Clause& clause) { for (unsigned i = 0; i < clause.literals_.size(); i++) { if (i != 0) os << " v " ; - os << clause.literals_[i]; + os << clause.literals_[i].toString (clause.ipgLogVars_); } - if (clause.ct_->empty() == false) { - ConstraintTree copy (*clause.ct_); + if (clause.constr_.empty() == false) { + ConstraintTree copy = clause.constr_; copy.moveToTop (copy.logVarSet().elements()); os << " | " << copy.tupleSet(); } @@ -177,8 +221,8 @@ LiftedWCNF::createClauseForLiteral (LiteralId lid) const const Literals& literals = clauses_[i].literals(); for (size_t j = 0; j < literals.size(); j++) { if (literals[j].lid() == lid) { - ConstraintTree* ct = new ConstraintTree (*clauses_[i].constr()); - ct->project (literals[j].logVars()); + ConstraintTree ct = clauses_[i].constr(); + ct.project (literals[j].logVars()); Clause clause (ct); clause.addLiteral (literals[j]); return clause; @@ -186,7 +230,7 @@ LiftedWCNF::createClauseForLiteral (LiteralId lid) const } } // FIXME - Clause c (new ConstraintTree({})); + Clause c (ConstraintTree({})); c.addLiteral (Literal (lid,{})); return c; //assert (false); @@ -205,8 +249,8 @@ LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList) for (size_t i = 0; i < formulas.size(); i++) { if (Util::contains (allGroups, formulas[i].group()) == false) { allGroups.insert (formulas[i].group()); - ConstraintTree* tempConstr = new ConstraintTree (*(*it)->constr()); - tempConstr->project (formulas[i].logVars()); + ConstraintTree tempConstr = *(*it)->constr(); + tempConstr.project (formulas[i].logVars()); Clause clause (tempConstr); vector lids; for (size_t j = 0; j < formulas[i].range(); j++) { @@ -217,8 +261,8 @@ LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList) clauses_.push_back (clause); for (size_t j = 0; j < formulas[i].range() - 1; j++) { for (size_t k = j + 1; k < formulas[i].range(); k++) { - ConstraintTree* tempConstr2 = new ConstraintTree (*(*it)->constr()); - tempConstr2->project (formulas[i].logVars()); + ConstraintTree tempConstr2 = *(*it)->constr(); + tempConstr2.project (formulas[i].logVars()); Clause clause2 (tempConstr2); clause2.addAndNegateLiteral (Literal (clause.literals()[j])); clause2.addAndNegateLiteral (Literal (clause.literals()[k])); @@ -243,22 +287,27 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList) vector groups = (*it)->getAllGroups(); while (indexer.valid()) { LiteralId paramVarLid = freeLiteralId_; - + // λu1 ∧ ... ∧ λun ∧ λxi <=> θxi|u1,...,un + // + // ¬λu1 ... ¬λun v θxi|u1,...,un -> clause1 + // ¬θxi|u1,...,un v λu1 -> tempClause + // ¬θxi|u1,...,un v λu2 -> tempClause double weight = (**it)[indexer]; - Clause clause1 ((*it)->constr()); + Clause clause1 (*(*it)->constr()); for (unsigned i = 0; i < groups.size(); i++) { LiteralId lid = getLiteralId (groups[i], indexer[i]); clause1.addAndNegateLiteral (Literal (lid, (*it)->argument(i).logVars())); - Clause tempClause ((*it)->constr()); - tempClause.addAndNegateLiteral (Literal (paramVarLid, 1.0)); + ConstraintTree ct = *(*it)->constr(); + Clause tempClause (ct); + tempClause.addAndNegateLiteral (Literal (paramVarLid, (*it)->constr()->logVars(), 1.0)); tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars())); clauses_.push_back (tempClause); } - clause1.addLiteral (Literal (paramVarLid, weight)); + clause1.addLiteral (Literal (paramVarLid, (*it)->constr()->logVars(),weight)); clauses_.push_back (clause1); freeLiteralId_ ++; ++ indexer; @@ -279,7 +328,7 @@ LiftedWCNF::printFormulaIndicators (void) const if (Util::contains (allGroups, formulas[i].group()) == false) { allGroups.insert (formulas[i].group()); cout << formulas[i] << " | " ; - ConstraintTree tempCt (*(*it)->constr()); + ConstraintTree tempCt = *(*it)->constr(); tempCt.project (formulas[i].logVars()); cout << tempCt.tupleSet(); cout << " indicators => " ; @@ -304,7 +353,8 @@ LiftedWCNF::printWeights (void) const Literals literals = clauses_[j].literals(); for (size_t k = 0; k < literals.size(); k++) { if (literals[k].lid() == i && literals[k].isPositive()) { - cout << "weight(" << literals[k] << ") = " << literals[k].weight(); + cout << "weight(" << literals[k] << ") = " ; + cout << literals[k].weight(); cout << endl; found = true; break; @@ -320,7 +370,8 @@ LiftedWCNF::printWeights (void) const Literals literals = clauses_[j].literals(); for (size_t k = 0; k < literals.size(); k++) { if (literals[k].lid() == i && literals[k].isNegative()) { - cout << "weight(" << literals[k] << ") = " << literals[k].weight(); + cout << "weight(" << literals[k] << ") = " ; + cout << literals[k].weight(); cout << endl; found = true; break; diff --git a/packages/CLPBN/horus/LiftedWCNF.h b/packages/CLPBN/horus/LiftedWCNF.h index 78a75db4e..155ec48e8 100644 --- a/packages/CLPBN/horus/LiftedWCNF.h +++ b/packages/CLPBN/horus/LiftedWCNF.h @@ -26,7 +26,7 @@ class Literal LogVars logVars (void) const { return logVars_; } - // FIXME not log aware + // FIXME not log aware :( double weight (void) const { return weight_ < 0.0 ? 1.0 : weight_; } void negate (void) { negated_ = !negated_; } @@ -35,7 +35,9 @@ class Literal bool isNegative (void) const { return negated_; } - bool isGround (ConstraintTree* constr) const; + bool isGround (ConstraintTree constr, LogVarSet ipgLogVars) const; + + string toString (LogVarSet ipgLogVars = LogVarSet()) const; friend std::ostream& operator<< (ostream &os, const Literal& lit); @@ -52,7 +54,7 @@ typedef vector Literals; class Clause { public: - Clause (ConstraintTree* ct) : ct_(ct) { } + Clause (const ConstraintTree& ct) : constr_(ct) { } void addLiteral (const Literal& l) { literals_.push_back (l); } @@ -78,19 +80,28 @@ class Clause void removeLiteralByIndex (size_t idx) { literals_.erase (literals_.begin() + idx); } - ConstraintTree* constr (void) const { return ct_; } + const ConstraintTree& constr (void) const { return constr_; } - ConstraintTree* constr (void) { return ct_; } + ConstraintTree constr (void) { return constr_; } bool isUnit (void) const { return literals_.size() == 1; } + LogVarSet ipgLogVars (void) const { return ipgLogVars_; } + + void addIpgLogVar (LogVar X) { ipgLogVars_.insert (X); } + TinySet lidSet (void) const; friend std::ostream& operator<< (ostream &os, const Clause& clause); private: + void removeLiteral (size_t idx); + + LogVarSet getLogVarSetExcluding (size_t idx) const; + vector literals_; - ConstraintTree* ct_; + LogVarSet ipgLogVars_; + ConstraintTree constr_; };