diff --git a/packages/CLPBN/horus/LiftedCircuit.cpp b/packages/CLPBN/horus/LiftedCircuit.cpp index fa572db5e..14203066a 100644 --- a/packages/CLPBN/horus/LiftedCircuit.cpp +++ b/packages/CLPBN/horus/LiftedCircuit.cpp @@ -26,13 +26,21 @@ AndNode::weight (void) const double SetOrNode::weight (void) const { - // TODO - return 0.0; + double weightSum = LogAware::addIdenty(); + for (unsigned i = 0; i < nrGroundings_ + 1; i++) { + nrGrsStack.push (make_pair (i, nrGroundings_ - i)); + if (Globals::logDomain) { + double w = std::log (Util::nrCombinations (nrGroundings_, i)); + weightSum = Util::logSum (weightSum, w + follow_->weight()); + } else { + weightSum += Util::nrCombinations (nrGroundings_, i) * follow_->weight(); + } + } + return weightSum; } - double SetAndNode::weight (void) const { @@ -67,14 +75,22 @@ LeafNode::weight (void) const assert (clauses()[0].isUnit()); Clause c = clauses()[0]; double weight = c.literals()[0].weight(); - LogVarSet lvs = c.constr().logVarSet() - c.ipgLogVars(); + LogVarSet lvs = c.constr().logVarSet(); + lvs -= c.ipgLogVars(); + lvs -= c.positiveCountedLogVars(); + lvs -= c.negativeCountedLogVars(); unsigned nrGroundings = 1; if (lvs.empty() == false) { ConstraintTree ct = c.constr(); ct.project (lvs); nrGroundings = ct.size(); } - assert (nrGroundings != 0); + // TODO this only works for one counted log var + if (c.positiveCountedLogVars().empty() == false) { + nrGroundings *= SetOrNode::nrPositives(); + } else if (c.negativeCountedLogVars().empty() == false) { + nrGroundings *= SetOrNode::nrNegatives(); + } return Globals::logDomain ? weight * nrGroundings : std::pow (weight, nrGroundings); @@ -85,6 +101,7 @@ LeafNode::weight (void) const double SmoothNode::weight (void) const { + // TODO and what happens if smoothing contains ipg or counted lvs ? Clauses cs = clauses(); double totalWeight = LogAware::multIdenty(); for (size_t i = 0; i < cs.size(); i++) { @@ -223,12 +240,11 @@ LiftedCircuit::tryUnitPropagation ( CircuitNode** follow, Clauses& clauses) { - cout << "ALL CLAUSES:" << endl; - Clause::printClauses (clauses); - + // cout << "ALL CLAUSES:" << endl; + // Clause::printClauses (clauses); for (size_t i = 0; i < clauses.size(); i++) { if (clauses[i].isUnit()) { - cout << clauses[i] << " is unit!" << endl; + // cout << clauses[i] << " is unit!" << endl; Clauses newClauses; for (size_t j = 0; j < clauses.size(); j++) { if (i != j) { @@ -453,7 +469,9 @@ LiftedCircuit::tryAtomCounting ( for (size_t j = 0; j < literals.size(); j++) { if (literals[j].logVars().size() == 1) { // TODO check if not already in ipg and countedlvs - SetOrNode* setOrNode = new SetOrNode (clauses); + unsigned nrGroundings = clauses[i].constr().projectedCopy ( + literals[j].logVars()).size(); + SetOrNode* setOrNode = new SetOrNode (nrGroundings, clauses); Clause c1 (clauses[i].constr().projectedCopy (literals[j].logVars())); Clause c2 (clauses[i].constr().projectedCopy (literals[j].logVars())); c1.addLiteral (literals[j]); diff --git a/packages/CLPBN/horus/LiftedCircuit.h b/packages/CLPBN/horus/LiftedCircuit.h index a687a202f..3688229cc 100644 --- a/packages/CLPBN/horus/LiftedCircuit.h +++ b/packages/CLPBN/horus/LiftedCircuit.h @@ -1,6 +1,8 @@ #ifndef HORUS_LIFTEDCIRCUIT_H #define HORUS_LIFTEDCIRCUIT_H +#include + #include "LiftedWCNF.h" @@ -45,11 +47,11 @@ class OrNode : public CircuitNode OrNode (const Clauses& clauses, string explanation = "") : CircuitNode (clauses, explanation), leftBranch_(0), rightBranch_(0) { } - - double weight (void) const; CircuitNode** leftBranch (void) { return &leftBranch_; } CircuitNode** rightBranch (void) { return &rightBranch_; } + + double weight (void) const; private: CircuitNode* leftBranch_; CircuitNode* rightBranch_; @@ -78,11 +80,12 @@ class AndNode : public CircuitNode string explanation = "") : CircuitNode ({}, explanation), leftBranch_(leftBranch), rightBranch_(rightBranch) { } - - double weight (void) const; CircuitNode** leftBranch (void) { return &leftBranch_; } CircuitNode** rightBranch (void) { return &rightBranch_; } + + double weight (void) const; + private: CircuitNode* leftBranch_; CircuitNode* rightBranch_; @@ -93,14 +96,23 @@ class AndNode : public CircuitNode class SetOrNode : public CircuitNode { public: - SetOrNode (const Clauses& clauses, string explanation = "") - : CircuitNode (clauses, explanation), follow_(0) { } - - double weight (void) const; - + SetOrNode (unsigned nrGroundings, const Clauses& clauses) + : CircuitNode (clauses, " AC"), follow_(0), + nrGroundings_(nrGroundings) { } + CircuitNode** follow (void) { return &follow_; } + + static unsigned nrPositives (void) { return nrGrsStack.top().first; } + + static unsigned nrNegatives (void) { return nrGrsStack.top().second; } + + double weight (void) const; + private: - CircuitNode* follow_; + CircuitNode* follow_; + unsigned nrGroundings_; + + static stack> nrGrsStack; }; @@ -109,15 +121,16 @@ class SetAndNode : public CircuitNode { public: SetAndNode (unsigned nrGroundings, const Clauses& clauses) - : CircuitNode (clauses, " IPG"), nrGroundings_(nrGroundings), - follow_(0) { } + : CircuitNode (clauses, " IPG"), follow_(0), + nrGroundings_(nrGroundings) { } + + CircuitNode** follow (void) { return &follow_; } double weight (void) const; - - CircuitNode** follow (void) { return &follow_; } + private: - unsigned nrGroundings_; CircuitNode* follow_; + unsigned nrGroundings_; }; @@ -129,11 +142,11 @@ class IncExcNode : public CircuitNode : CircuitNode (clauses), plus1Branch_(0), plus2Branch_(0), minusBranch_(0) { } - double weight (void) const; - CircuitNode** plus1Branch (void) { return &plus1Branch_; } CircuitNode** plus2Branch (void) { return &plus2Branch_; } CircuitNode** minusBranch (void) { return &minusBranch_; } + + double weight (void) const; private: CircuitNode* plus1Branch_; diff --git a/packages/CLPBN/horus/LiftedWCNF.cpp b/packages/CLPBN/horus/LiftedWCNF.cpp index 527035d08..e9003405c 100644 --- a/packages/CLPBN/horus/LiftedWCNF.cpp +++ b/packages/CLPBN/horus/LiftedWCNF.cpp @@ -119,7 +119,7 @@ Clause::removeLiterals (LiteralId lid) } } - + void Clause::removePositiveLiterals ( @@ -137,7 +137,7 @@ Clause::removePositiveLiterals ( } } } - + void diff --git a/packages/CLPBN/horus/LiftedWCNF.h b/packages/CLPBN/horus/LiftedWCNF.h index 260de8170..ac8a8758f 100644 --- a/packages/CLPBN/horus/LiftedWCNF.h +++ b/packages/CLPBN/horus/LiftedWCNF.h @@ -25,36 +25,36 @@ class Literal public: Literal (LiteralId lid, double w = -1.0) : lid_(lid), weight_(w), negated_(false) { } - + Literal (LiteralId lid, const LogVars& lvs, double w = -1.0) : lid_(lid), logVars_(lvs), weight_(w), negated_(false) { } - + Literal (const Literal& lit, bool negated) : lid_(lit.lid_), logVars_(lit.logVars_), weight_(lit.weight_), negated_(negated) { } LiteralId lid (void) const { return lid_; } - + LogVars logVars (void) const { return logVars_; } - + LogVarSet logVarSet (void) const { return LogVarSet (logVars_); } // FIXME this is not log aware :( double weight (void) const { return weight_ < 0.0 ? 1.0 : weight_; } - + void negate (void) { negated_ = !negated_; } bool isPositive (void) const { return negated_ == false; } - + bool isNegative (void) const { return negated_; } - + bool isGround (ConstraintTree constr, LogVarSet ipgLogVars) const; - + string toString (LogVarSet ipgLogVars = LogVarSet(), LogVarSet posCountedLvs = LogVarSet(), LogVarSet negCountedLvs = LogVarSet()) const; - + friend std::ostream& operator<< (ostream &os, const Literal& lit); - + private: LiteralId lid_; LogVars logVars_; @@ -81,7 +81,7 @@ class Clause literals_.push_back (l); literals_.back().negate(); } - + const vector& literals (void) const { return literals_; } const ConstraintTree& constr (void) const { return constr_; } @@ -98,6 +98,10 @@ class Clause void addNegativeCountedLogVar (LogVar X) { negCountedLvs_.insert (X); } + LogVarSet positiveCountedLogVars (void) const { return posCountedLvs_; } + + LogVarSet negativeCountedLogVars (void) const { return negCountedLvs_; } + bool containsLiteral (LiteralId lid) const; bool containsPositiveLiteral (LiteralId lid, const LogVarTypes&) const; @@ -115,26 +119,26 @@ class Clause bool isPositiveCountedLogVar (LogVar X) const; bool isNegativeCountedLogVar (LogVar X) const; - + TinySet lidSet (void) const; LogVarSet ipgCandidates (void) const; LogVarTypes logVarTypes (size_t litIdx) const; - + void removeLiteral (size_t litIdx); - + static void printClauses (const vector& clauses); friend std::ostream& operator<< (ostream &os, const Clause& clause); - + private: LogVarSet getLogVarSetExcluding (size_t idx) const; - + vector literals_; LogVarSet ipgLogVars_; LogVarSet posCountedLvs_; - LogVarSet negCountedLvs_; + LogVarSet negCountedLvs_; ConstraintTree constr_; }; @@ -146,19 +150,19 @@ class LiftedWCNF { public: LiftedWCNF (const ParfactorList& pfList); - + ~LiftedWCNF (void); - + const Clauses& clauses (void) const { return clauses_; } - + Clause createClauseForLiteral (LiteralId lid) const; - + void printFormulaIndicators (void) const; - + void printWeights (void) const; - + void printClauses (void) const; - + private: void addIndicatorClauses (const ParfactorList& pfList); @@ -171,11 +175,11 @@ class LiftedWCNF } Clauses clauses_; - + unordered_map> map_; - + const ParfactorList& pfList_; - + LiteralId freeLiteralId_; };