diff --git a/packages/CLPBN/horus/LiftedCircuit.cpp b/packages/CLPBN/horus/LiftedCircuit.cpp index 14203066a..26989dbf4 100644 --- a/packages/CLPBN/horus/LiftedCircuit.cpp +++ b/packages/CLPBN/horus/LiftedCircuit.cpp @@ -23,6 +23,10 @@ AndNode::weight (void) const +stack> SetOrNode::nrGrsStack; + + + double SetOrNode::weight (void) const { @@ -583,82 +587,195 @@ LiftedCircuit::shatterCountedLogVarsAux ( -TinySet +LogVarTypes +unionTypes (const LogVarTypes& types1, const LogVarTypes& types2) +{ + if (types1.empty()) { + return types2; + } + if (types2.empty()) { + return types1; + } + assert (types1.size() == types2.size()); + LogVarTypes res; + for (size_t i = 0; i < types1.size(); i++) { + if (types1[i] == LogVarType::POS_LV + && types2[i] == LogVarType::POS_LV) { + res.push_back (LogVarType::POS_LV); + } else if (types1[i] == LogVarType::NEG_LV + && types2[i] == LogVarType::NEG_LV) { + res.push_back (LogVarType::NEG_LV); + } else { + res.push_back (LogVarType::FULL_LV); + } + } + return res; +} + + + +vector +getAllPossibleTypes (unsigned nrLogVars) +{ + if (nrLogVars == 0) { + return {}; + } + if (nrLogVars == 1) { + return {{LogVarType::POS_LV},{LogVarType::NEG_LV}}; + } + vector res; + Indexer indexer (vector (nrLogVars, 2)); + while (indexer.valid()) { + LogVarTypes types; + for (size_t i = 0; i < nrLogVars; i++) { + if (indexer[i] == 0) { + types.push_back (LogVarType::POS_LV); + } else { + types.push_back (LogVarType::NEG_LV); + } + } + res.push_back (types); + ++ indexer; + } + return res; +} + + + +bool +containsTypes (const LogVarTypes& typesA, const LogVarTypes& typesB) +{ + for (size_t i = 0; i < typesA.size(); i++) { + if (typesA[i] == LogVarType::FULL_LV) { + + } else if (typesA[i] == LogVarType::POS_LV + && typesB[i] == LogVarType::POS_LV) { + + } else if (typesA[i] == LogVarType::NEG_LV + && typesB[i] == LogVarType::NEG_LV) { + + } else { + return false; + } + } + return true; +} + + + +LitLvTypesSet LiftedCircuit::smoothCircuit (CircuitNode* node) { assert (node != 0); - TinySet propagatingLids; + LitLvTypesSet propagLits; switch (getCircuitNodeType (node)) { case CircuitNodeType::OR_NODE: { OrNode* casted = dynamic_cast(node); - TinySet lids1 = smoothCircuit (*casted->leftBranch()); - TinySet lids2 = smoothCircuit (*casted->rightBranch()); - TinySet missingLeft = lids2 - lids1; - TinySet missingRight = lids1 - lids2; + LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch()); + LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch()); + LitLvTypesSet missingLeft = lids2 - lids1; + LitLvTypesSet missingRight = lids1 - lids2; createSmoothNode (missingLeft, casted->leftBranch()); createSmoothNode (missingRight, casted->rightBranch()); - propagatingLids |= lids1; - propagatingLids |= lids2; + propagLits |= lids1; + propagLits |= lids2; break; } case CircuitNodeType::AND_NODE: { AndNode* casted = dynamic_cast(node); - TinySet lids1 = smoothCircuit (*casted->leftBranch()); - TinySet lids2 = smoothCircuit (*casted->rightBranch()); - propagatingLids |= lids1; - propagatingLids |= lids2; + LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch()); + LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch()); + propagLits |= lids1; + propagLits |= lids2; break; } case CircuitNodeType::SET_OR_NODE: { - // TODO + SetOrNode* casted = dynamic_cast(node); + propagLits = smoothCircuit (*casted->follow()); + TinySet> litSet; + for (size_t i = 0; i < propagLits.size(); i++) { + litSet.insert (make_pair (propagLits[i].lid(), + propagLits[i].logVarTypes().size())); + } + LitLvTypesSet missingLids; + for (size_t i = 0; i < litSet.size(); i++) { + vector allTypes = getAllPossibleTypes (litSet[i].second); + for (size_t j = 0; j < allTypes.size(); j++) { + bool typeFound = false; + for (size_t k = 0; k < propagLits.size(); k++) { + if (litSet[i].first == propagLits[k].lid() + && containsTypes (allTypes[j], propagLits[k].logVarTypes())) { + typeFound = true; + break; + } + } + if (typeFound == false) { + missingLids.insert (LiteralLvTypes (litSet[i].first, allTypes[j])); + } + } + } + createSmoothNode (missingLids, casted->follow()); + // TODO change propagLits to full lvs break; } case CircuitNodeType::SET_AND_NODE: { SetAndNode* casted = dynamic_cast(node); - propagatingLids = smoothCircuit (*casted->follow()); + propagLits = smoothCircuit (*casted->follow()); break; } case CircuitNodeType::INC_EXC_NODE: { IncExcNode* casted = dynamic_cast(node); - TinySet lids1 = smoothCircuit (*casted->plus1Branch()); - TinySet lids2 = smoothCircuit (*casted->plus2Branch()); - TinySet missingPlus1 = lids2 - lids1; - TinySet missingPlus2 = lids1 - lids2; + LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch()); + LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch()); + LitLvTypesSet missingPlus1 = lids2 - lids1; + LitLvTypesSet missingPlus2 = lids1 - lids2; createSmoothNode (missingPlus1, casted->plus1Branch()); createSmoothNode (missingPlus2, casted->plus2Branch()); - propagatingLids |= lids1; - propagatingLids |= lids2; + propagLits |= lids1; + propagLits |= lids2; break; } case CircuitNodeType::LEAF_NODE: { - propagatingLids.insert (node->clauses()[0].literals()[0].lid()); + propagLits.insert (LiteralLvTypes ( + node->clauses()[0].literals()[0].lid(), + node->clauses()[0].logVarTypes(0))); } default: break; } - return propagatingLids; + return propagLits; } void LiftedCircuit::createSmoothNode ( - const TinySet& missingLids, + const LitLvTypesSet& missingLits, CircuitNode** prev) { - if (missingLids.empty() == false) { + if (missingLits.empty() == false) { Clauses clauses; - for (size_t i = 0; i < missingLids.size(); i++) { - Clause c = lwcnf_->createClauseForLiteral (missingLids[i]); + for (size_t i = 0; i < missingLits.size(); i++) { + LiteralId lid = missingLits[i].lid(); + const LogVarTypes& types = missingLits[i].logVarTypes(); + Clause c = lwcnf_->createClauseForLiteral (lid); + for (size_t j = 0; j < types.size(); j++) { + LogVar X = c.literals().front().logVars()[j]; + if (types[j] == LogVarType::POS_LV) { + c.addPositiveCountedLogVar (X); + } else if (types[j] == LogVarType::NEG_LV) { + c.addNegativeCountedLogVar (X); + } + } c.addAndNegateLiteral (c.literals()[0]); clauses.push_back (c); } diff --git a/packages/CLPBN/horus/LiftedCircuit.h b/packages/CLPBN/horus/LiftedCircuit.h index 3688229cc..accc52f72 100644 --- a/packages/CLPBN/horus/LiftedCircuit.h +++ b/packages/CLPBN/horus/LiftedCircuit.h @@ -234,9 +234,9 @@ class LiftedCircuit bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2); - TinySet smoothCircuit (CircuitNode* node); + LitLvTypesSet smoothCircuit (CircuitNode* node); - void createSmoothNode (const TinySet& lids, + void createSmoothNode (const LitLvTypesSet& lids, CircuitNode** prev); CircuitNodeType getCircuitNodeType (const CircuitNode* node) const; diff --git a/packages/CLPBN/horus/LiftedWCNF.cpp b/packages/CLPBN/horus/LiftedWCNF.cpp index e9003405c..290bdf274 100644 --- a/packages/CLPBN/horus/LiftedWCNF.cpp +++ b/packages/CLPBN/horus/LiftedWCNF.cpp @@ -303,13 +303,13 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList) vector> names = {{"p1","p1"},{"p2","p2"}}; Clause c1 (names); - c1.addLiteral (Literal (0, LogVars()={0})); - c1.addAndNegateLiteral (Literal (1, {0,1})); + c1.addLiteral (Literal (0, LogVars()={0}, 3.0)); + c1.addAndNegateLiteral (Literal (1, {0,1}, 1.0)); clauses_.push_back(c1); Clause c2 (names); - c2.addLiteral (Literal (0, LogVars()={0})); - c2.addAndNegateLiteral (Literal (1, {1,0})); + c2.addLiteral (Literal (0, LogVars()={0}, 2.0)); + c2.addAndNegateLiteral (Literal (1, {1,0}, 5.0)); clauses_.push_back(c2); cout << "FORMULA INDICATORS:" << endl; diff --git a/packages/CLPBN/horus/LiftedWCNF.h b/packages/CLPBN/horus/LiftedWCNF.h index ac8a8758f..43a3afa30 100644 --- a/packages/CLPBN/horus/LiftedWCNF.h +++ b/packages/CLPBN/horus/LiftedWCNF.h @@ -146,6 +146,38 @@ typedef vector Clauses; +class LiteralLvTypes +{ + public: + struct CompareLiteralLvTypes + { + bool operator() ( + const LiteralLvTypes& types1, + const LiteralLvTypes& types2) const + { + if (types1.lid_ < types2.lid_) { + return true; + } + return types1.lvTypes_ < types2.lvTypes_; + } + }; + + LiteralLvTypes (LiteralId lid, const LogVarTypes& lvTypes) : + lid_(lid), lvTypes_(lvTypes) { } + + LiteralId lid (void) const { return lid_; } + + const LogVarTypes& logVarTypes (void) const { return lvTypes_; } + + private: + LiteralId lid_; + LogVarTypes lvTypes_; +}; + +typedef TinySet LitLvTypesSet; + + + class LiftedWCNF { public: