diff --git a/packages/CLPBN/horus/LiftedCircuit.cpp b/packages/CLPBN/horus/LiftedCircuit.cpp index cb7ba75a4..ab6ffac91 100644 --- a/packages/CLPBN/horus/LiftedCircuit.cpp +++ b/packages/CLPBN/horus/LiftedCircuit.cpp @@ -190,7 +190,7 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf) Clauses clauses = lwcnf->clauses(); compile (&root_, clauses); exportToGraphViz("circuit.dot"); - smoothCircuit(); + smoothCircuit (root_); exportToGraphViz("circuit.smooth.dot"); cout << "--------------------------------------------------" << endl; cout << "--------------------------------------------------" << endl; @@ -199,14 +199,6 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf) -void -LiftedCircuit::smoothCircuit (void) -{ - smoothCircuit (root_); -} - - - double LiftedCircuit::getWeightedModelCount (void) const { diff --git a/packages/CLPBN/horus/LiftedCircuit.h b/packages/CLPBN/horus/LiftedCircuit.h index 073de64fe..b040af3c5 100644 --- a/packages/CLPBN/horus/LiftedCircuit.h +++ b/packages/CLPBN/horus/LiftedCircuit.h @@ -25,15 +25,15 @@ class CircuitNode public: CircuitNode (const Clauses& clauses, string explanation = "") : clauses_(clauses), explanation_(explanation) { } - + const Clauses& clauses (void) const { return clauses_; } - + Clauses clauses (void) { return clauses_; } - + virtual double weight (void) const = 0; - + string explanation (void) const { return explanation_; } - + private: Clauses clauses_; string explanation_; @@ -50,7 +50,7 @@ class OrNode : public CircuitNode CircuitNode** leftBranch (void) { return &leftBranch_; } CircuitNode** rightBranch (void) { return &rightBranch_; } - + double weight (void) const; private: @@ -66,7 +66,7 @@ class AndNode : public CircuitNode AndNode (const Clauses& clauses, string explanation = "") : CircuitNode (clauses, explanation), leftBranch_(0), rightBranch_(0) { } - + AndNode ( const Clauses& clauses, CircuitNode* leftBranch, @@ -74,7 +74,7 @@ class AndNode : public CircuitNode string explanation = "") : CircuitNode (clauses, explanation), leftBranch_(leftBranch), rightBranch_(rightBranch) { } - + AndNode ( CircuitNode* leftBranch, CircuitNode* rightBranch, @@ -84,9 +84,9 @@ class AndNode : public CircuitNode CircuitNode** leftBranch (void) { return &leftBranch_; } CircuitNode** rightBranch (void) { return &rightBranch_; } - + double weight (void) const; - + private: CircuitNode* leftBranch_; CircuitNode* rightBranch_; @@ -106,7 +106,7 @@ class SetOrNode : public CircuitNode static unsigned nrPositives (void) { return nrGrsStack.top().first; } static unsigned nrNegatives (void) { return nrGrsStack.top().second; } - + double weight (void) const; private: @@ -126,7 +126,7 @@ class SetAndNode : public CircuitNode nrGroundings_(nrGroundings) { } CircuitNode** follow (void) { return &follow_; } - + double weight (void) const; private: @@ -146,7 +146,7 @@ class IncExcNode : public CircuitNode CircuitNode** plus1Branch (void) { return &plus1Branch_; } CircuitNode** plus2Branch (void) { return &plus2Branch_; } CircuitNode** minusBranch (void) { return &minusBranch_; } - + double weight (void) const; private: @@ -164,7 +164,7 @@ class LeafNode : public CircuitNode : CircuitNode (Clauses() = {clause}), lwcnf_(lwcnf) { } double weight (void) const; - + private: const LiftedWCNF& lwcnf_; }; @@ -189,7 +189,7 @@ class TrueNode : public CircuitNode { public: TrueNode (void) : CircuitNode ({}) { } - + double weight (void) const; }; @@ -200,7 +200,7 @@ class CompilationFailedNode : public CircuitNode public: CompilationFailedNode (const Clauses& clauses) : CircuitNode (clauses) { } - + double weight (void) const; }; @@ -210,62 +210,60 @@ class LiftedCircuit { public: LiftedCircuit (const LiftedWCNF* lwcnf); - - void smoothCircuit (void); - + double getWeightedModelCount (void) const; - + void exportToGraphViz (const char*); - + private: void compile (CircuitNode** follow, Clauses& clauses); bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses); - + bool tryIndependence (CircuitNode** follow, Clauses& clauses); - + bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses); - + bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses); - + bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses); - + bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct, LogVars& rootLogVars); - + bool tryAtomCounting (CircuitNode** follow, Clauses& clauses); - + bool tryGrounding (CircuitNode** follow, Clauses& clauses); - + void shatterCountedLogVars (Clauses& clauses); - + bool shatterCountedLogVarsAux (Clauses& clauses); bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2); - + bool independentClause (Clause& clause, Clauses& otherClauses) const; - + bool independentLiteral (const Literal& lit, const Literals& otherLits) const; LitLvTypesSet smoothCircuit (CircuitNode* node); - + void createSmoothNode (const LitLvTypesSet& lids, CircuitNode** prev); - + vector getAllPossibleTypes (unsigned nrLogVars) const; - + bool containsTypes (const LogVarTypes& typesA, const LogVarTypes& typesB) const; - + CircuitNodeType getCircuitNodeType (const CircuitNode* node) const; - + void exportToGraphViz (CircuitNode* node, ofstream&); - + void printClauses (const CircuitNode* node, ofstream&, string extraOptions = ""); - + string escapeNode (const CircuitNode* node) const; CircuitNode* root_; diff --git a/packages/CLPBN/horus/LiftedKc.cpp b/packages/CLPBN/horus/LiftedKc.cpp index 7566a6d90..fb036df33 100644 --- a/packages/CLPBN/horus/LiftedKc.cpp +++ b/packages/CLPBN/horus/LiftedKc.cpp @@ -1,6 +1,7 @@ #include "LiftedKc.h" #include "LiftedWCNF.h" #include "LiftedCircuit.h" +#include "Indexer.h" LiftedKc::LiftedKc (const ParfactorList& pfList) @@ -19,9 +20,46 @@ LiftedKc::~LiftedKc (void) Params -LiftedKc::solveQuery (const Grounds&) +LiftedKc::solveQuery (const Grounds& query) { - return Params(); + vector groups; + Ranges ranges; + for (size_t i = 0; i < query.size(); i++) { + ParfactorList::const_iterator it = pfList_.begin(); + while (it != pfList_.end()) { + size_t idx = (*it)->indexOfGround (query[i]); + if (idx != (*it)->nrArguments()) { + groups.push_back ((*it)->argument (idx).group()); + ranges.push_back ((*it)->range (idx)); + break; + } + ++ it; + } + } + cout << "groups: " << groups << endl; + cout << "ranges: " << ranges << endl; + Params params; + Indexer indexer (ranges); + while (indexer.valid()) { + for (size_t i = 0; i < groups.size(); i++) { + vector litIds = lwcnf_->prvGroupLiterals (groups[i]); + for (size_t j = 0; j < litIds.size(); j++) { + if (indexer[i] == j) { + lwcnf_->addWeight (litIds[j], 1.0, 1.0); // TODO not log aware + } else { + lwcnf_->addWeight (litIds[j], 0.0, 1.0); // TODO not log aware + } + } + } + // cout << "new weights ----- ----- -----" << endl; + // lwcnf_->printWeights(); + // circuit_->exportToGraphViz ("ccircuit.dot"); + params.push_back (circuit_->getWeightedModelCount()); + ++ indexer; + } + cout << "params: " << params << endl; + LogAware::normalize (params); + return params; } diff --git a/packages/CLPBN/horus/LiftedKc.h b/packages/CLPBN/horus/LiftedKc.h index fbef33e0a..eb0074213 100644 --- a/packages/CLPBN/horus/LiftedKc.h +++ b/packages/CLPBN/horus/LiftedKc.h @@ -20,7 +20,7 @@ class LiftedKc private: LiftedWCNF* lwcnf_; LiftedCircuit* circuit_; - + const ParfactorList& pfList_; }; diff --git a/packages/CLPBN/horus/LiftedWCNF.cpp b/packages/CLPBN/horus/LiftedWCNF.cpp index 31641187e..7bf249923 100644 --- a/packages/CLPBN/horus/LiftedWCNF.cpp +++ b/packages/CLPBN/horus/LiftedWCNF.cpp @@ -352,8 +352,8 @@ Clause::getLogVarSetExcluding (size_t idx) const LiftedWCNF::LiftedWCNF (const ParfactorList& pfList) : freeLiteralId_(0), pfList_(pfList) { - //addIndicatorClauses (pfList); - //addParameterClauses (pfList); + addIndicatorClauses (pfList); + addParameterClauses (pfList); /* vector> names = { @@ -377,25 +377,25 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList) freeLiteralId_ = 2; */ - - Literal lit1 (0, {0}); - Literal lit2 (1, {0}); - Literal lit3 (2, {1}); - Literal lit4 (3, {1}); - - vector> names = {{"p1","p2"},{"p3","p4"}}; - Clause c1 (names); - c1.addLiteral (lit1); - c1.addLiteral (lit2); - c1.addLiteral (lit3); - c1.addLiteral (lit4); - //c1.addPosCountedLogVar (0); - clauses_.push_back (c1); - Clause c2 (names); - c2.addLiteral (lit1); - c2.addLiteral (lit3); - c2.addNegCountedLogVar (0); + //Literal lit1 (0, {0}); + //Literal lit2 (1, {0}); + //Literal lit3 (2, {1}); + //Literal lit4 (3, {1}); + + //vector> names = {{"p1","p2"},{"p3","p4"}}; + //Clause c1 (names); + //c1.addLiteral (lit1); + //c1.addLiteral (lit2); + //c1.addLiteral (lit3); + //c1.addLiteral (lit4); + //c1.addPosCountedLogVar (0); + //clauses_.push_back (c1); + + //Clause c2 (names); + //c2.addLiteral (lit1); + //c2.addLiteral (lit3); + //c2.addNegCountedLogVar (0); //clauses_.push_back (c2); /* Clause c3; @@ -408,17 +408,18 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList) c4.addLiteral (lit3); clauses_.push_back (c4); */ - freeLiteralId_ = 4; + //freeLiteralId_ = 4; cout << "FORMULA INDICATORS:" << endl; - // printFormulaIndicators(); + printFormulaIndicators(); cout << endl; + cout << "WEIGHTS:" << endl; printWeights(); cout << endl; + cout << "CLAUSES:" << endl; printClauses(); - // abort(); cout << endl; } @@ -431,6 +432,14 @@ LiftedWCNF::~LiftedWCNF (void) +void +LiftedWCNF::addWeight (LiteralId lid, double posW, double negW) +{ + weights_[lid] = make_pair (posW, negW); +} + + + double LiftedWCNF::posWeight (LiteralId lid) const { @@ -451,6 +460,15 @@ LiftedWCNF::negWeight (LiteralId lid) const +vector +LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup) +{ + assert (Util::contains (map_, prvGroup)); + return map_[prvGroup]; +} + + + Clause LiftedWCNF::createClause (LiteralId lid) const { @@ -481,14 +499,6 @@ LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range) -void -LiftedWCNF::addWeight (LiteralId lid, double posW, double negW) -{ - weights_[lid] = make_pair (posW, negW); -} - - - void LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList) { diff --git a/packages/CLPBN/horus/LiftedWCNF.h b/packages/CLPBN/horus/LiftedWCNF.h index 263593cbf..4163b60d8 100644 --- a/packages/CLPBN/horus/LiftedWCNF.h +++ b/packages/CLPBN/horus/LiftedWCNF.h @@ -188,13 +188,17 @@ class LiftedWCNF LiftedWCNF (const ParfactorList& pfList); ~LiftedWCNF (void); - + const Clauses& clauses (void) const { return clauses_; } - + + void addWeight (LiteralId lid, double posW, double negW); + double posWeight (LiteralId lid) const; double negWeight (LiteralId lid) const; + vector prvGroupLiterals (PrvGroup prvGroup); + Clause createClause (LiteralId lid) const; void printFormulaIndicators (void) const; @@ -204,10 +208,8 @@ class LiftedWCNF void printClauses (void) const; private: - + LiteralId getLiteralId (PrvGroup prvGroup, unsigned range); - - void addWeight (LiteralId lid, double posW, double negW); void addIndicatorClauses (const ParfactorList& pfList);