fix weighted model counting in atom counting nodes
This commit is contained in:
		| @@ -32,12 +32,20 @@ SetOrNode::weight (void) const | |||||||
| { | { | ||||||
|   double weightSum = LogAware::addIdenty(); |   double weightSum = LogAware::addIdenty(); | ||||||
|   for (unsigned i = 0; i < nrGroundings_ + 1; i++) { |   for (unsigned i = 0; i < nrGroundings_ + 1; i++) { | ||||||
|     nrGrsStack.push (make_pair (i, nrGroundings_ - i)); |     nrGrsStack.push (make_pair (nrGroundings_ - i, i)); | ||||||
|     if (Globals::logDomain) { |     if (Globals::logDomain) { | ||||||
|       double w = std::log (Util::nrCombinations (nrGroundings_, i)); |       double w = std::log (Util::nrCombinations (nrGroundings_, i)); | ||||||
|       weightSum = Util::logSum (weightSum, w + follow_->weight()); |       weightSum = Util::logSum (weightSum, w + follow_->weight()); | ||||||
|     } else { |     } else { | ||||||
|       weightSum += Util::nrCombinations (nrGroundings_, i) * follow_->weight(); |       cout << endl; | ||||||
|  |       cout << "nr groundings = " << nrGroundings_ << endl; | ||||||
|  |       cout << "nr positives  = " << nrPositives() << endl; | ||||||
|  |       cout << "nr negatives  = " << nrNegatives() << endl;       | ||||||
|  |       cout << "i             = " << i << endl; | ||||||
|  |       cout << "nr combos     = " << Util::nrCombinations (nrGroundings_, i) << endl; | ||||||
|  |       double w = follow_->weight(); | ||||||
|  |       cout << "weight        = " << w << endl; | ||||||
|  |       weightSum += Util::nrCombinations (nrGroundings_, i) * w; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   return weightSum; |   return weightSum; | ||||||
| @@ -78,7 +86,9 @@ LeafNode::weight (void) const | |||||||
|   assert (clauses().size() == 1); |   assert (clauses().size() == 1); | ||||||
|   assert (clauses()[0].isUnit()); |   assert (clauses()[0].isUnit()); | ||||||
|   Clause c = clauses()[0]; |   Clause c = clauses()[0]; | ||||||
|   double weight = c.literals()[0].weight(); |   double weight = c.literals()[0].isPositive() | ||||||
|  |       ? lwcnf_.posWeight (c.literals().front().lid()) | ||||||
|  |       : lwcnf_.negWeight (c.literals().front().lid()); | ||||||
|   LogVarSet lvs = c.constr().logVarSet(); |   LogVarSet lvs = c.constr().logVarSet(); | ||||||
|   lvs -= c.ipgLogVars(); |   lvs -= c.ipgLogVars(); | ||||||
|   lvs -= c.positiveCountedLogVars(); |   lvs -= c.positiveCountedLogVars(); | ||||||
| @@ -90,11 +100,20 @@ LeafNode::weight (void) const | |||||||
|     nrGroundings = ct.size(); |     nrGroundings = ct.size(); | ||||||
|   } |   } | ||||||
|   // TODO this only works for one counted log var |   // TODO this only works for one counted log var | ||||||
|  |   cout << "calc weight for " << clauses().front() << endl; | ||||||
|   if (c.positiveCountedLogVars().empty() == false) { |   if (c.positiveCountedLogVars().empty() == false) { | ||||||
|     nrGroundings *= SetOrNode::nrPositives(); |     cout << "  -> nr pos = " << SetOrNode::nrPositives() << endl; | ||||||
|   } else if (c.negativeCountedLogVars().empty() == false) { |     nrGroundings *= std::pow (SetOrNode::nrPositives(), | ||||||
|     nrGroundings *= SetOrNode::nrNegatives();     |         c.nrPositiveCountedLogVars()); | ||||||
|   } |   } | ||||||
|  |   if (c.negativeCountedLogVars().empty() == false) { | ||||||
|  |     cout << "  -> nr neg = " << SetOrNode::nrNegatives() << endl; | ||||||
|  |     nrGroundings *= std::pow (SetOrNode::nrNegatives(), | ||||||
|  |         c.nrNegativeCountedLogVars()); | ||||||
|  |   } | ||||||
|  |   cout << "  -> nr groundings = " << nrGroundings << endl; | ||||||
|  |   cout << "  -> lit weight    = " << weight << endl; | ||||||
|  |   cout << "  -> ret weight    = " << std::pow (weight, nrGroundings) << endl;   | ||||||
|   return Globals::logDomain  |   return Globals::logDomain  | ||||||
|       ? weight * nrGroundings |       ? weight * nrGroundings | ||||||
|       : std::pow (weight, nrGroundings); |       : std::pow (weight, nrGroundings); | ||||||
| @@ -109,14 +128,38 @@ SmoothNode::weight (void) const | |||||||
|   Clauses cs = clauses(); |   Clauses cs = clauses(); | ||||||
|   double totalWeight = LogAware::multIdenty(); |   double totalWeight = LogAware::multIdenty(); | ||||||
|   for (size_t i = 0; i < cs.size(); i++) { |   for (size_t i = 0; i < cs.size(); i++) { | ||||||
|     double posWeight = cs[i].literals()[0].weight(); |     double posWeight = lwcnf_.posWeight (cs[i].literals()[0].lid()); | ||||||
|     double negWeight = cs[i].literals()[1].weight(); |     double negWeight = lwcnf_.negWeight (cs[i].literals()[0].lid()); | ||||||
|     unsigned nrGroundings = cs[i].constr().size(); |     LogVarSet lvs = cs[i].constr().logVarSet(); | ||||||
|  |     lvs -= cs[i].ipgLogVars(); | ||||||
|  |     lvs -= cs[i].positiveCountedLogVars(); | ||||||
|  |     lvs -= cs[i].negativeCountedLogVars(); | ||||||
|  |     unsigned nrGroundings = 1; | ||||||
|  |     if (lvs.empty() == false) { | ||||||
|  |       ConstraintTree ct = cs[i].constr(); | ||||||
|  |       ct.project (lvs); | ||||||
|  |       nrGroundings = ct.size(); | ||||||
|  |     } | ||||||
|  |     cout << "calc smooth weight for " << cs[i] << endl; | ||||||
|  |     if (cs[i].positiveCountedLogVars().empty() == false) { | ||||||
|  |       cout << "  -> nr pos = " << SetOrNode::nrPositives() << endl; | ||||||
|  |       nrGroundings *= std::pow (SetOrNode::nrPositives(),  | ||||||
|  |           cs[i].nrPositiveCountedLogVars()); | ||||||
|  |     } | ||||||
|  |     if (cs[i].negativeCountedLogVars().empty() == false) { | ||||||
|  |       cout << "  -> nr neg = " << SetOrNode::nrNegatives() << endl; | ||||||
|  |       nrGroundings *= std::pow (SetOrNode::nrNegatives(), | ||||||
|  |           cs[i].nrNegativeCountedLogVars());       | ||||||
|  |     } | ||||||
|  |     cout << "  -> pos+neg = " << posWeight + negWeight << endl; | ||||||
|  |     cout << "  -> nrgroun = " << nrGroundings << endl;     | ||||||
|     if (Globals::logDomain) { |     if (Globals::logDomain) { | ||||||
|  |       // TODO i think i have to do log on nrGrounginds here! | ||||||
|       totalWeight += (Util::logSum (posWeight, negWeight) * nrGroundings); |       totalWeight += (Util::logSum (posWeight, negWeight) * nrGroundings); | ||||||
|     } else { |     } else { | ||||||
|       totalWeight *= std::pow (posWeight + negWeight, nrGroundings); |       totalWeight *= std::pow (posWeight + negWeight, nrGroundings); | ||||||
|     } |     } | ||||||
|  |     cout << "  -> smooth weight  = " << totalWeight << endl; | ||||||
|   } |   } | ||||||
|   return totalWeight; |   return totalWeight; | ||||||
| } | } | ||||||
| @@ -151,6 +194,8 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf) | |||||||
|   exportToGraphViz("circuit.dot"); |   exportToGraphViz("circuit.dot"); | ||||||
|   smoothCircuit(); |   smoothCircuit(); | ||||||
|   exportToGraphViz("circuit.smooth.dot"); |   exportToGraphViz("circuit.smooth.dot"); | ||||||
|  |   cout << "--------------------------------------------------" << endl; | ||||||
|  |   cout << "--------------------------------------------------" << endl;   | ||||||
|   cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl; |   cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl; | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -201,7 +246,7 @@ LiftedCircuit::compile ( | |||||||
|   } |   } | ||||||
|    |    | ||||||
|   if (clauses.size() == 1 && clauses[0].isUnit()) { |   if (clauses.size() == 1 && clauses[0].isUnit()) { | ||||||
|     *follow = new LeafNode (clauses[0]); |     *follow = new LeafNode (clauses[0], *lwcnf_); | ||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @@ -587,33 +632,6 @@ LiftedCircuit::shatterCountedLogVarsAux ( | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| LogVarTypes |  | ||||||
| unionTypes (const LogVarTypes& types1, const LogVarTypes& types2) |  | ||||||
| { |  | ||||||
|   if (types1.empty()) { |  | ||||||
|     return types2; |  | ||||||
|   } |  | ||||||
|   if (types2.empty()) { |  | ||||||
|     return types1; |  | ||||||
|   }   |  | ||||||
|   assert (types1.size() == types2.size()); |  | ||||||
|   LogVarTypes res;   |  | ||||||
|   for (size_t i = 0; i < types1.size(); i++) { |  | ||||||
|     if (types1[i] == LogVarType::POS_LV |  | ||||||
|         && types2[i] == LogVarType::POS_LV) { |  | ||||||
|       res.push_back (LogVarType::POS_LV);         |  | ||||||
|     } else if (types1[i] == LogVarType::NEG_LV |  | ||||||
|         && types2[i] == LogVarType::NEG_LV) { |  | ||||||
|       res.push_back (LogVarType::NEG_LV); |  | ||||||
|     } else { |  | ||||||
|       res.push_back (LogVarType::FULL_LV); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|   return res; |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| vector<LogVarTypes> | vector<LogVarTypes> | ||||||
| getAllPossibleTypes (unsigned nrLogVars) | getAllPossibleTypes (unsigned nrLogVars) | ||||||
| { | { | ||||||
| @@ -779,7 +797,7 @@ LiftedCircuit::createSmoothNode ( | |||||||
|       c.addAndNegateLiteral (c.literals()[0]); |       c.addAndNegateLiteral (c.literals()[0]); | ||||||
|       clauses.push_back (c); |       clauses.push_back (c); | ||||||
|     } |     } | ||||||
|     SmoothNode* smoothNode = new SmoothNode (clauses); |     SmoothNode* smoothNode = new SmoothNode (clauses, *lwcnf_); | ||||||
|     *prev = new AndNode ((*prev)->clauses(), smoothNode, |     *prev = new AndNode ((*prev)->clauses(), smoothNode, | ||||||
|         *prev, " Smoothing"); |         *prev, " Smoothing"); | ||||||
|   } |   } | ||||||
|   | |||||||
| @@ -159,9 +159,13 @@ class IncExcNode : public CircuitNode | |||||||
| class LeafNode : public CircuitNode | class LeafNode : public CircuitNode | ||||||
| { | { | ||||||
|   public: |   public: | ||||||
|     LeafNode (const Clause& clause) : CircuitNode ({clause}) { } |     LeafNode (const Clause& clause, const LiftedWCNF& lwcnf) | ||||||
|  |         : CircuitNode (Clauses() = {clause}), lwcnf_(lwcnf) { } | ||||||
|  |  | ||||||
|     double weight (void) const; |     double weight (void) const; | ||||||
|  |      | ||||||
|  |   private: | ||||||
|  |     const LiftedWCNF&  lwcnf_; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -169,9 +173,13 @@ class LeafNode : public CircuitNode | |||||||
| class SmoothNode : public CircuitNode | class SmoothNode : public CircuitNode | ||||||
| { | { | ||||||
|   public: |   public: | ||||||
|     SmoothNode (const Clauses& clauses) : CircuitNode (clauses) { } |     SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf) | ||||||
|  |         : CircuitNode (clauses), lwcnf_(lwcnf) { } | ||||||
|  |  | ||||||
|     double weight (void) const; |     double weight (void) const; | ||||||
|  |  | ||||||
|  |   private: | ||||||
|  |     const LiftedWCNF&  lwcnf_; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| #include "Indexer.h" | #include "Indexer.h" | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| bool | bool | ||||||
| Literal::isGround (ConstraintTree constr, LogVarSet ipgLogVars) const | Literal::isGround (ConstraintTree constr, LogVarSet ipgLogVars) const | ||||||
| { | { | ||||||
| @@ -24,7 +25,12 @@ Literal::toString ( | |||||||
| { | { | ||||||
|   stringstream ss; |   stringstream ss; | ||||||
|   negated_ ? ss << "¬" : ss << "" ; |   negated_ ? ss << "¬" : ss << "" ; | ||||||
|   weight_ < 0.0 ? ss << "λ" : ss << "Θ" ; |   // if (negated_ == false) { | ||||||
|  |   //   posWeight_ < 0.0 ? ss << "λ" : ss << "Θ" ; | ||||||
|  |   // } else { | ||||||
|  |   //   negWeight_ < 0.0 ? ss << "λ" : ss << "Θ" ;   | ||||||
|  |   // } | ||||||
|  |   ss << "λ" ; | ||||||
|   ss << lid_ ; |   ss << lid_ ; | ||||||
|   if (logVars_.empty() == false) { |   if (logVars_.empty() == false) { | ||||||
|     ss << "(" ; |     ss << "(" ; | ||||||
| @@ -300,18 +306,44 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList) | |||||||
|   //addIndicatorClauses (pfList); |   //addIndicatorClauses (pfList); | ||||||
|   //addParameterClauses (pfList); |   //addParameterClauses (pfList); | ||||||
|  |  | ||||||
|   vector<vector<string>> names = {{"p1","p1"},{"p2","p2"}}; |   vector<vector<string>> names = { | ||||||
|  | /* | ||||||
|  |       {"p1","p1"}, | ||||||
|  |       {"p1","p2"}, | ||||||
|  |       {"p2","p1"}, | ||||||
|  |       {"p2","p2"}, | ||||||
|  |       {"p1","p3"}, | ||||||
|  |       {"p2","p3"}, | ||||||
|  |       {"p3","p3"}, | ||||||
|  |       {"p3","p2"}, | ||||||
|  |       {"p3","p1"} | ||||||
|  | */ | ||||||
|  |       {"p1","p1"}, | ||||||
|  |       {"p1","p2"}, | ||||||
|  |       {"p1","p3"},       | ||||||
|  |       {"p2","p1"}, | ||||||
|  |       {"p2","p2"}, | ||||||
|  |       {"p2","p3"}, | ||||||
|  |       {"p3","p1"}, | ||||||
|  |       {"p3","p2"}, | ||||||
|  |       {"p3","p3"} | ||||||
|  |   }; | ||||||
|  |  | ||||||
|   Clause c1 (names); |   Clause c1 (names); | ||||||
|   c1.addLiteral (Literal (0, LogVars()={0}, 3.0)); |   c1.addLiteral (Literal (0, LogVars() = {0})); | ||||||
|   c1.addAndNegateLiteral (Literal (1, {0,1}, 1.0)); |   c1.addAndNegateLiteral (Literal (1, {0,1})); | ||||||
|   clauses_.push_back(c1); |   clauses_.push_back(c1); | ||||||
|      |      | ||||||
|   Clause c2 (names); |   Clause c2 (names); | ||||||
|   c2.addLiteral (Literal (0, LogVars()={0}, 2.0)); |   c2.addLiteral (Literal (0, LogVars()={0})); | ||||||
|   c2.addAndNegateLiteral (Literal (1, {1,0}, 5.0)); |   c2.addAndNegateLiteral (Literal (1, {1,0})); | ||||||
|   clauses_.push_back(c2); |   clauses_.push_back(c2); | ||||||
|    |    | ||||||
|  |   addWeight (0, 3.0, 4.0); | ||||||
|  |   addWeight (1, 2.0, 5.0); | ||||||
|  |    | ||||||
|  |   freeLiteralId_ = 2; | ||||||
|  |    | ||||||
|   cout << "FORMULA INDICATORS:" << endl; |   cout << "FORMULA INDICATORS:" << endl; | ||||||
|   // printFormulaIndicators(); |   // printFormulaIndicators(); | ||||||
|   cout << endl; |   cout << endl; | ||||||
| @@ -320,6 +352,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList) | |||||||
|   cout << endl; |   cout << endl; | ||||||
|   cout << "CLAUSES:" << endl; |   cout << "CLAUSES:" << endl; | ||||||
|   printClauses(); |   printClauses(); | ||||||
|  |   // abort(); | ||||||
|   cout << endl; |   cout << endl; | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -349,7 +382,7 @@ LiftedWCNF::createClauseForLiteral (LiteralId lid) const | |||||||
|   } |   } | ||||||
|   // FIXME |   // FIXME | ||||||
|   Clause c (ConstraintTree({})); |   Clause c (ConstraintTree({})); | ||||||
|   c.addLiteral (Literal (lid,{})); |   c.addLiteral (Literal (lid,LogVars() = {})); | ||||||
|   return c; |   return c; | ||||||
|   //assert (false); |   //assert (false); | ||||||
|   //return Clause (0); |   //return Clause (0); | ||||||
| @@ -410,22 +443,25 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList) | |||||||
|       // ¬λu1 ... ¬λun v θxi|u1,...,un  -> clause1 |       // ¬λu1 ... ¬λun v θxi|u1,...,un  -> clause1 | ||||||
|       // ¬θxi|u1,...,un v λu1           -> tempClause |       // ¬θxi|u1,...,un v λu1           -> tempClause | ||||||
|       // ¬θxi|u1,...,un v λu2           -> tempClause |       // ¬θxi|u1,...,un v λu2           -> tempClause | ||||||
|       double weight = (**it)[indexer]; |       double posWeight = (**it)[indexer]; | ||||||
|  |       addWeight (paramVarLid, posWeight, 1.0); | ||||||
|        |        | ||||||
|       Clause clause1 (*(*it)->constr()); |       Clause clause1 (*(*it)->constr()); | ||||||
|  |  | ||||||
|       for (unsigned i = 0; i < groups.size(); i++) { |       for (unsigned i = 0; i < groups.size(); i++) { | ||||||
|         LiteralId lid = getLiteralId (groups[i], indexer[i]); |         LiteralId lid = getLiteralId (groups[i], indexer[i]); | ||||||
|  |  | ||||||
|         clause1.addAndNegateLiteral (Literal (lid, (*it)->argument(i).logVars())); |         clause1.addAndNegateLiteral ( | ||||||
|  |             Literal (lid, (*it)->argument(i).logVars())); | ||||||
|  |  | ||||||
|         ConstraintTree ct = *(*it)->constr(); |         ConstraintTree ct = *(*it)->constr(); | ||||||
|         Clause tempClause (ct); |         Clause tempClause (ct); | ||||||
|         tempClause.addAndNegateLiteral (Literal (paramVarLid, (*it)->constr()->logVars(), 1.0)); |         tempClause.addAndNegateLiteral (Literal ( | ||||||
|  |             paramVarLid, (*it)->constr()->logVars())); | ||||||
|         tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars())); |         tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars())); | ||||||
|         clauses_.push_back (tempClause);         |         clauses_.push_back (tempClause);         | ||||||
|       } |       } | ||||||
|       clause1.addLiteral (Literal (paramVarLid, (*it)->constr()->logVars(),weight)); |       clause1.addLiteral (Literal (paramVarLid, (*it)->constr()->logVars())); | ||||||
|       clauses_.push_back (clause1); |       clauses_.push_back (clause1); | ||||||
|       freeLiteralId_ ++; |       freeLiteralId_ ++; | ||||||
|       ++ indexer; |       ++ indexer; | ||||||
| @@ -464,43 +500,14 @@ LiftedWCNF::printFormulaIndicators (void) const | |||||||
| void | void | ||||||
| LiftedWCNF::printWeights (void) const | LiftedWCNF::printWeights (void) const | ||||||
| { | { | ||||||
|    for (LiteralId i = 0; i < freeLiteralId_; i++) { |   unordered_map<LiteralId, std::pair<double,double>>::const_iterator it; | ||||||
|  |   it = weights_.begin(); | ||||||
|      bool found = false; |   while (it != weights_.end()) { | ||||||
|      for (size_t j = 0; j < clauses_.size(); j++) { |     cout << "λ" << it->first << " weights: " ;      | ||||||
|        Literals literals = clauses_[j].literals(); |     cout << it->second.first << " " << it->second.second; | ||||||
|        for (size_t k = 0; k < literals.size(); k++) { |     cout << endl; | ||||||
|          if (literals[k].lid() == i && literals[k].isPositive()) { |     ++ it; | ||||||
|            cout << "weight(" << literals[k] << ") = " ; |   } | ||||||
|            cout << literals[k].weight(); |  | ||||||
|            cout << endl; |  | ||||||
|            found = true; |  | ||||||
|            break; |  | ||||||
|          } |  | ||||||
|        } |  | ||||||
|        if (found == true) { |  | ||||||
|          break; |  | ||||||
|        } |  | ||||||
|      } |  | ||||||
|       |  | ||||||
|      found = false; |  | ||||||
|      for (size_t j = 0; j < clauses_.size(); j++) { |  | ||||||
|        Literals literals = clauses_[j].literals(); |  | ||||||
|        for (size_t k = 0; k < literals.size(); k++) { |  | ||||||
|          if (literals[k].lid() == i && literals[k].isNegative()) { |  | ||||||
|            cout << "weight(" << literals[k] << ") = " ; |  | ||||||
|            cout << literals[k].weight(); |  | ||||||
|            cout << endl; |  | ||||||
|            found = true; |  | ||||||
|            break; |  | ||||||
|          } |  | ||||||
|        } |  | ||||||
|        if (found == true) { |  | ||||||
|          break; |  | ||||||
|        } |  | ||||||
|      } |  | ||||||
|       |  | ||||||
|    } |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -23,14 +23,11 @@ typedef vector<LogVarType> LogVarTypes; | |||||||
| class Literal | class Literal | ||||||
| { | { | ||||||
|   public: |   public: | ||||||
|     Literal (LiteralId lid, double w = -1.0) : |     Literal (LiteralId lid, const LogVars& lvs) : | ||||||
|        lid_(lid), weight_(w), negated_(false) { } |        lid_(lid), logVars_(lvs), negated_(false) { } | ||||||
|  |  | ||||||
|     Literal (LiteralId lid, const LogVars& lvs, double w = -1.0) : |  | ||||||
|        lid_(lid), logVars_(lvs), weight_(w), negated_(false) { } |  | ||||||
|  |  | ||||||
|     Literal (const Literal& lit, bool negated) : |     Literal (const Literal& lit, bool negated) : | ||||||
|        lid_(lit.lid_), logVars_(lit.logVars_), weight_(lit.weight_), negated_(negated) { } |        lid_(lit.lid_), logVars_(lit.logVars_), negated_(negated) { } | ||||||
|  |  | ||||||
|     LiteralId lid (void) const { return lid_; } |     LiteralId lid (void) const { return lid_; } | ||||||
|  |  | ||||||
| @@ -38,9 +35,6 @@ class Literal | |||||||
|  |  | ||||||
|     LogVarSet logVarSet (void) const { return LogVarSet (logVars_); } |     LogVarSet logVarSet (void) const { return LogVarSet (logVars_); } | ||||||
|  |  | ||||||
|     // FIXME this is not log aware :( |  | ||||||
|     double weight (void) const { return weight_ < 0.0 ? 1.0 : weight_; } |  | ||||||
|  |  | ||||||
|     void negate (void) { negated_ = !negated_; } |     void negate (void) { negated_ = !negated_; } | ||||||
|  |  | ||||||
|     bool isPositive (void) const { return negated_ == false; } |     bool isPositive (void) const { return negated_ == false; } | ||||||
| @@ -58,7 +52,6 @@ class Literal | |||||||
|   private: |   private: | ||||||
|     LiteralId    lid_; |     LiteralId    lid_; | ||||||
|     LogVars      logVars_; |     LogVars      logVars_; | ||||||
|     double       weight_; |  | ||||||
|     bool         negated_; |     bool         negated_; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| @@ -102,6 +95,10 @@ class Clause | |||||||
|  |  | ||||||
|     LogVarSet negativeCountedLogVars (void) const { return negCountedLvs_; } |     LogVarSet negativeCountedLogVars (void) const { return negCountedLvs_; } | ||||||
|  |  | ||||||
|  |     unsigned nrPositiveCountedLogVars (void) const { return posCountedLvs_.size(); } | ||||||
|  |  | ||||||
|  |     unsigned nrNegativeCountedLogVars (void) const { return negCountedLvs_.size(); } | ||||||
|  |  | ||||||
|     bool containsLiteral (LiteralId lid) const; |     bool containsLiteral (LiteralId lid) const; | ||||||
|  |  | ||||||
|     bool containsPositiveLiteral (LiteralId lid, const LogVarTypes&) const; |     bool containsPositiveLiteral (LiteralId lid, const LogVarTypes&) const; | ||||||
| @@ -185,20 +182,31 @@ class LiftedWCNF | |||||||
|  |  | ||||||
|    ~LiftedWCNF (void); |    ~LiftedWCNF (void); | ||||||
|     |     | ||||||
|    const Clauses& clauses (void) const { return clauses_; } |     const Clauses& clauses (void) const { return clauses_; } | ||||||
|     |     | ||||||
|    Clause createClauseForLiteral (LiteralId lid) const; |     double posWeight (LiteralId lid) const | ||||||
|  |     { | ||||||
|  |       unordered_map<LiteralId, std::pair<double,double>>::const_iterator it; | ||||||
|  |       it = weights_.find (lid); | ||||||
|  |       return it != weights_.end() ? it->second.first : 1.0; | ||||||
|  |     } | ||||||
|  |  | ||||||
|    void printFormulaIndicators (void) const; |     double negWeight (LiteralId lid) const | ||||||
|  |     { | ||||||
|  |       unordered_map<LiteralId, std::pair<double,double>>::const_iterator it; | ||||||
|  |       it = weights_.find (lid); | ||||||
|  |       return it != weights_.end() ? it->second.second : 1.0; | ||||||
|  |     } | ||||||
|  |  | ||||||
|    void printWeights (void) const; |     Clause createClauseForLiteral (LiteralId lid) const; | ||||||
|  |  | ||||||
|    void printClauses (void) const; |     void printFormulaIndicators (void) const; | ||||||
|  |  | ||||||
|  |     void printWeights (void) const; | ||||||
|  |  | ||||||
|  |     void printClauses (void) const; | ||||||
|  |  | ||||||
|   private: |   private: | ||||||
|     void addIndicatorClauses (const ParfactorList& pfList); |  | ||||||
|  |  | ||||||
|     void addParameterClauses (const ParfactorList& pfList); |  | ||||||
|    |    | ||||||
|     LiteralId getLiteralId (PrvGroup prvGroup, unsigned range) |     LiteralId getLiteralId (PrvGroup prvGroup, unsigned range) | ||||||
|     { |     { | ||||||
| @@ -206,10 +214,21 @@ class LiftedWCNF | |||||||
|       return map_[prvGroup][range]; |       return map_[prvGroup][range]; | ||||||
|     } |     } | ||||||
|    |    | ||||||
|  |     void addWeight (LiteralId lid, double posW, double negW) | ||||||
|  |     { | ||||||
|  |       weights_[lid] = make_pair (posW, negW); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     void addIndicatorClauses (const ParfactorList& pfList); | ||||||
|  |  | ||||||
|  |     void addParameterClauses (const ParfactorList& pfList); | ||||||
|  |  | ||||||
|     Clauses clauses_; |     Clauses clauses_; | ||||||
|  |  | ||||||
|     unordered_map<PrvGroup, vector<LiteralId>> map_; |     unordered_map<PrvGroup, vector<LiteralId>> map_; | ||||||
|  |  | ||||||
|  |     unordered_map<LiteralId, std::pair<double,double>> weights_; | ||||||
|  |  | ||||||
|     const ParfactorList& pfList_; |     const ParfactorList& pfList_; | ||||||
|  |  | ||||||
|     LiteralId freeLiteralId_; |     LiteralId freeLiteralId_; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user