From 68ef63207fc4fd324a7b9ee4e3cb646d39005fa5 Mon Sep 17 00:00:00 2001 From: Tiago Gomes Date: Thu, 25 Oct 2012 12:22:52 +0100 Subject: [PATCH] initial support for weighted model countign --- packages/CLPBN/horus/LiftedCircuit.cpp | 255 +++++++++++++++++-------- packages/CLPBN/horus/LiftedCircuit.h | 18 +- packages/CLPBN/horus/LiftedWCNF.cpp | 75 ++++++-- packages/CLPBN/horus/LiftedWCNF.h | 14 +- 4 files changed, 266 insertions(+), 96 deletions(-) diff --git a/packages/CLPBN/horus/LiftedCircuit.cpp b/packages/CLPBN/horus/LiftedCircuit.cpp index 4332e5584..f3b27d4b9 100644 --- a/packages/CLPBN/horus/LiftedCircuit.cpp +++ b/packages/CLPBN/horus/LiftedCircuit.cpp @@ -3,22 +3,92 @@ #include "LiftedCircuit.h" +double +OrNode::weight (void) const +{ + double lw = leftBranch_->weight(); + double rw = rightBranch_->weight(); + return Globals::logDomain ? Util::logSum (lw, rw) : lw + rw; +} + + + +double +AndNode::weight (void) const +{ + double lw = leftBranch_->weight(); + double rw = rightBranch_->weight(); + return Globals::logDomain ? lw + rw : lw * rw; +} + + + +double +LeafNode::weight (void) const +{ + assert (clauses().size() == 1); + assert (clauses()[0].isUnit()); + Clause c = clauses()[0]; + double weight = c.literals()[0].weight(); + unsigned nrGroundings = c.constr()->size(); + assert (nrGroundings != 0); + double www = Globals::logDomain + ? weight * nrGroundings + : std::pow (weight, nrGroundings); + + cout << "leaf w: " << www << endl; + + return Globals::logDomain + ? weight * nrGroundings + : std::pow (weight, nrGroundings); +} + + + +double +SmoothNode::weight (void) const +{ + Clauses cs = clauses(); + double totalWeight = LogAware::multIdenty(); + for (size_t i = 0; i < cs.size(); i++) { + double posWeight = cs[i].literals()[0].weight(); + double negWeight = cs[i].literals()[1].weight(); + unsigned nrGroundings = cs[i].constr()->size(); + if (Globals::logDomain) { + totalWeight += (Util::logSum (posWeight, negWeight) * nrGroundings); + } else { + totalWeight *= std::pow (posWeight + negWeight, nrGroundings); + } + } + return totalWeight; +} + + + +double +TrueNode::weight (void) const +{ + return LogAware::multIdenty(); +} + + + LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf) : lwcnf_(lwcnf) { root_ = 0; Clauses ccc = lwcnf->clauses(); - ccc.erase (ccc.begin() + 5, ccc.end()); + //ccc.erase (ccc.begin() + 5, ccc.end()); //Clause c2 = ccc.front(); //c2.removeLiteralByIndex (1); //ccc.push_back (c2); //compile (&root_, lwcnf->clauses()); compile (&root_, ccc); - cout << "done compiling..." << endl; exportToGraphViz("circuit.dot"); smoothCircuit(); exportToGraphViz("smooth.dot"); + cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl; } @@ -31,6 +101,14 @@ LiftedCircuit::smoothCircuit (void) +double +LiftedCircuit::getWeightedModelCount (void) const +{ + return root_->weight(); +} + + + void LiftedCircuit::exportToGraphViz (const char* fileName) { @@ -63,14 +141,14 @@ LiftedCircuit::compile ( static int count = 0; count ++; *follow = new LeafNode (clauses[0]); if (count == 1) { - Clause c (new ConstraintTree({})); - c.addLiteral (Literal (100,{})); - *follow = new LeafNode (c); + // Clause c (new ConstraintTree({})); + // c.addLiteral (Literal (100,{})); + // *follow = new LeafNode (c); } if (count == 2) { - Clause c (new ConstraintTree({})); - c.addLiteral (Literal (101,{})); - *follow = new LeafNode (c); + // Clause c (new ConstraintTree({})); + // c.addLiteral (Literal (101,{})); + // *follow = new LeafNode (c); } return; } @@ -326,83 +404,104 @@ void LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) { assert (node != 0); + static unsigned nrAndNodes = 0; static unsigned nrOrNodes = 0; - if (dynamic_cast(node) != 0) { - OrNode* casted = dynamic_cast(node); - const Clauses& clauses = node->clauses(); - if (clauses.empty() == false) { - os << escapeNode (node) << " [shape=box,label=\"" ; - for (size_t i = 0; i < clauses.size(); i++) { - if (i != 0) os << "\\n" ; - os << clauses[i]; + switch (getCircuitNodeType (node)) { + + case OR_NODE: { + OrNode* casted = dynamic_cast(node); + const Clauses& clauses = node->clauses(); + if (clauses.empty() == false) { + os << escapeNode (node) << " [shape=box,label=\"" ; + for (size_t i = 0; i < clauses.size(); i++) { + if (i != 0) os << "\\n" ; + os << clauses[i]; + } + os << "\"]" ; + os << endl; + } + os << "or" << nrOrNodes << " [label=\"∨\"]" << endl; + os << '"' << node << '"' << " -> " << "or" << nrOrNodes; + os << " [label=\"" << node->explanation() << "\"]" << endl; + os << "or" << nrOrNodes << " -> " ; + os << escapeNode (*casted->leftBranch()) << endl; + os << "or" << nrOrNodes << " -> " ; + os << escapeNode (*casted->rightBranch()) << endl; + nrOrNodes ++; + exportToGraphViz (*casted->leftBranch(), os); + exportToGraphViz (*casted->rightBranch(), os); + break; } - os << "\"]" ; - os << endl; + + case AND_NODE: { + AndNode* casted = dynamic_cast(node); + const Clauses& clauses = node->clauses(); + os << escapeNode (node) << " [shape=box,label=\"" ; + for (size_t i = 0; i < clauses.size(); i++) { + if (i != 0) os << "\\n" ; + os << clauses[i]; + } + os << "\"]" ; + os << endl; + os << "and" << nrAndNodes << " [label=\"∧\"]" << endl; + os << '"' << node << '"' << " -> " << "and" << nrAndNodes; + os << " [label=\"" << node->explanation() << "\"]" << endl; + os << "and" << nrAndNodes << " -> " ; + os << escapeNode (*casted->leftBranch()) << endl; + os << "and" << nrAndNodes << " -> " ; + os << escapeNode (*casted->rightBranch()) << endl; + nrAndNodes ++; + exportToGraphViz (*casted->leftBranch(), os); + exportToGraphViz (*casted->rightBranch(), os); + break; } - os << "or" << nrOrNodes << " [label=\"∨\"]" << endl; - os << '"' << node << '"' << " -> " << "or" << nrOrNodes; - os << " [label=\"" << node->explanation() << "\"]" << endl; - os << "or" << nrOrNodes << " -> " ; - os << escapeNode (*casted->leftBranch()) << endl; - os << "or" << nrOrNodes << " -> " ; - os << escapeNode (*casted->rightBranch()) << endl; - nrOrNodes ++; - exportToGraphViz (*casted->leftBranch(), os); - exportToGraphViz (*casted->rightBranch(), os); - } else if (dynamic_cast(node) != 0) { - AndNode* casted = dynamic_cast(node); - const Clauses& clauses = node->clauses(); - os << escapeNode (node) << " [shape=box,label=\"" ; - for (size_t i = 0; i < clauses.size(); i++) { - if (i != 0) os << "\\n" ; - os << clauses[i]; + + case LEAF_NODE: { + os << escapeNode (node); + os << " [shape=box,label=\"" ; + os << node->clauses()[0]; + os << "\"]" ; + os << endl; + break; } - os << "\"]" ; - os << endl; - os << "and" << nrAndNodes << " [label=\"∧\"]" << endl; - os << '"' << node << '"' << " -> " << "and" << nrAndNodes; - os << " [label=\"" << node->explanation() << "\"]" << endl; - os << "and" << nrAndNodes << " -> " ; - os << escapeNode (*casted->leftBranch()) << endl; - os << "and" << nrAndNodes << " -> " ; - os << escapeNode (*casted->rightBranch()) << endl; - nrAndNodes ++; - exportToGraphViz (*casted->leftBranch(), os); - exportToGraphViz (*casted->rightBranch(), os); - } else if (dynamic_cast(node) != 0) { - os << escapeNode (node); - os << " [shape=box,label=\"" ; - os << node->clauses()[0]; - os << "\"]" ; - os << endl; - } else if (dynamic_cast(node) != 0) { - os << escapeNode (node); - os << " [shape=box,style=filled,fillcolor=chartreuse,label=\"" ; - const Clauses& clauses = node->clauses(); - for (size_t i = 0; i < clauses.size(); i++) { - if (i != 0) os << "\\n" ; - os << clauses[i]; + + case SMOOTH_NODE: { + os << escapeNode (node); + os << " [shape=box,style=filled,fillcolor=chartreuse,label=\"" ; + const Clauses& clauses = node->clauses(); + for (size_t i = 0; i < clauses.size(); i++) { + if (i != 0) os << "\\n" ; + os << clauses[i]; + } + os << "\"]" ; + os << endl; + break; } - os << "\"]" ; - os << endl; - } else if (dynamic_cast(node) != 0) { - os << escapeNode (node); - os << " [shape=box,label=\"⊤\"]" ; - os << endl; - } else if (dynamic_cast(node) != 0) { - os << escapeNode (node); - os << " [shape=box,style=filled,fillcolor=brown1,label=\"" ; - const Clauses& clauses = node->clauses(); - for (size_t i = 0; i < clauses.size(); i++) { - if (i != 0) os << "\\n" ; - os << clauses[i]; + + case TRUE_NODE: { + os << escapeNode (node); + os << " [shape=box,label=\"⊤\"]" ; + os << endl; + break; } - os << "\"]" ; - os << endl; - } else { - assert (false); - } + + case FAIL_NODE: { + os << escapeNode (node); + os << " [shape=box,style=filled,fillcolor=brown1,label=\"" ; + const Clauses& clauses = node->clauses(); + for (size_t i = 0; i < clauses.size(); i++) { + if (i != 0) os << "\\n" ; + os << clauses[i]; + } + os << "\"]" ; + os << endl; + break; + } + + default: + assert (false); + } } diff --git a/packages/CLPBN/horus/LiftedCircuit.h b/packages/CLPBN/horus/LiftedCircuit.h index e386e78d5..839d14a70 100644 --- a/packages/CLPBN/horus/LiftedCircuit.h +++ b/packages/CLPBN/horus/LiftedCircuit.h @@ -24,7 +24,9 @@ class CircuitNode CircuitNode (const Clauses& clauses, string explanation = "") : clauses_(clauses), explanation_(explanation) { } - const Clauses& clauses (void) { return clauses_; } + const Clauses& clauses (void) const { return clauses_; } + + Clauses clauses (void) { return clauses_; } virtual double weight (void) const { return 0; } @@ -43,6 +45,8 @@ class OrNode : public CircuitNode OrNode (const Clauses& clauses, string explanation = "") : CircuitNode (clauses, explanation), leftBranch_(0), rightBranch_(0) { } + + double weight (void) const; CircuitNode** leftBranch (void) { return &leftBranch_; } CircuitNode** rightBranch (void) { return &rightBranch_; } @@ -74,6 +78,8 @@ class AndNode : public CircuitNode string explanation = "") : CircuitNode ({}, explanation), leftBranch_(leftBranch), rightBranch_(rightBranch) { } + + double weight (void) const; CircuitNode** leftBranch (void) { return &leftBranch_; } CircuitNode** rightBranch (void) { return &rightBranch_; } @@ -117,6 +123,8 @@ class LeafNode : public CircuitNode { public: LeafNode (const Clause& clause) : CircuitNode ({clause}) { } + + double weight (void) const; }; @@ -125,6 +133,8 @@ class SmoothNode : public CircuitNode { public: SmoothNode (const Clauses& clauses) : CircuitNode (clauses) { } + + double weight (void) const; }; @@ -133,6 +143,8 @@ class TrueNode : public CircuitNode { public: TrueNode () : CircuitNode ({}) { } + + double weight (void) const; }; @@ -153,8 +165,10 @@ class LiftedCircuit void smoothCircuit (void); + double getWeightedModelCount (void) const; + void exportToGraphViz (const char*); - + private: void compile (CircuitNode** follow, const Clauses& clauses); diff --git a/packages/CLPBN/horus/LiftedWCNF.cpp b/packages/CLPBN/horus/LiftedWCNF.cpp index ac08bbbe0..df2efd59e 100644 --- a/packages/CLPBN/horus/LiftedWCNF.cpp +++ b/packages/CLPBN/horus/LiftedWCNF.cpp @@ -150,8 +150,15 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList) { addIndicatorClauses (pfList); addParameterClauses (pfList); + cout << "FORMULA INDICATORS:" << endl; printFormulaIndicators(); + cout << endl; + cout << "WEIGHTS:" << endl; + printWeights(); + cout << endl; + cout << "CLAUSES:" << endl; printClauses(); + cout << endl; } @@ -237,6 +244,8 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList) while (indexer.valid()) { LiteralId paramVarLid = freeLiteralId_; + double weight = (**it)[indexer]; + Clause clause1 ((*it)->constr()); for (unsigned i = 0; i < groups.size(); i++) { @@ -245,11 +254,11 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList) clause1.addAndNegateLiteral (Literal (lid, (*it)->argument(i).logVars())); Clause tempClause ((*it)->constr()); - tempClause.addAndNegateLiteral (Literal (paramVarLid, LogVars(), 1.0)); + tempClause.addAndNegateLiteral (Literal (paramVarLid, 1.0)); tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars())); clauses_.push_back (tempClause); } - clause1.addLiteral (Literal (paramVarLid, LogVars(), 1.0)); + clause1.addLiteral (Literal (paramVarLid, weight)); clauses_.push_back (clause1); freeLiteralId_ ++; ++ indexer; @@ -259,16 +268,6 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList) } -void -LiftedWCNF::printClauses (void) const -{ - for (unsigned i = 0; i < clauses_.size(); i++) { - cout << clauses_[i] << endl; - } -} - - - void LiftedWCNF::printFormulaIndicators (void) const { @@ -293,3 +292,55 @@ LiftedWCNF::printFormulaIndicators (void) const } } + + +void +LiftedWCNF::printWeights (void) const +{ + for (LiteralId i = 0; i < freeLiteralId_; i++) { + + bool found = false; + for (size_t j = 0; j < clauses_.size(); j++) { + Literals literals = clauses_[j].literals(); + for (size_t k = 0; k < literals.size(); k++) { + if (literals[k].lid() == i && literals[k].isPositive()) { + cout << "weight(" << literals[k] << ") = " << literals[k].weight(); + cout << endl; + found = true; + break; + } + } + if (found == true) { + break; + } + } + + found = false; + for (size_t j = 0; j < clauses_.size(); j++) { + Literals literals = clauses_[j].literals(); + for (size_t k = 0; k < literals.size(); k++) { + if (literals[k].lid() == i && literals[k].isNegative()) { + cout << "weight(" << literals[k] << ") = " << literals[k].weight(); + cout << endl; + found = true; + break; + } + } + if (found == true) { + break; + } + } + + } +} + + + +void +LiftedWCNF::printClauses (void) const +{ + for (unsigned i = 0; i < clauses_.size(); i++) { + cout << clauses_[i] << endl; + } +} + diff --git a/packages/CLPBN/horus/LiftedWCNF.h b/packages/CLPBN/horus/LiftedWCNF.h index f71582b42..78a75db4e 100644 --- a/packages/CLPBN/horus/LiftedWCNF.h +++ b/packages/CLPBN/horus/LiftedWCNF.h @@ -13,8 +13,11 @@ class ConstraintTree; class Literal { public: + Literal (LiteralId lid, double w = -1.0) : + lid_(lid), weight_(w), negated_(false) { } + Literal (LiteralId lid, const LogVars& lvs, double w = -1.0) : - lid_(lid), logVars_(lvs), weight_(w), negated_(false) { } + lid_(lid), logVars_(lvs), weight_(w), negated_(false) { } Literal (const Literal& lit, bool negated) : lid_(lit.lid_), logVars_(lit.logVars_), weight_(lit.weight_), negated_(negated) { } @@ -23,7 +26,8 @@ class Literal LogVars logVars (void) const { return logVars_; } - double weight (void) const { return weight_; } + // FIXME not log aware + double weight (void) const { return weight_ < 0.0 ? 1.0 : weight_; } void negate (void) { negated_ = !negated_; } @@ -106,9 +110,11 @@ class LiftedWCNF Clause createClauseForLiteral (LiteralId lid) const; - void printClauses (void) const; - void printFormulaIndicators (void) const; + + void printWeights (void) const; + + void printClauses (void) const; private: void addIndicatorClauses (const ParfactorList& pfList);