more work to support inference with lifted knowledge compilation

This commit is contained in:
Tiago Gomes 2012-11-09 18:42:21 +00:00
parent c53220aa61
commit 8ab622e0aa
6 changed files with 127 additions and 87 deletions

View File

@ -190,7 +190,7 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
Clauses clauses = lwcnf->clauses(); Clauses clauses = lwcnf->clauses();
compile (&root_, clauses); compile (&root_, clauses);
exportToGraphViz("circuit.dot"); exportToGraphViz("circuit.dot");
smoothCircuit(); smoothCircuit (root_);
exportToGraphViz("circuit.smooth.dot"); exportToGraphViz("circuit.smooth.dot");
cout << "--------------------------------------------------" << endl; cout << "--------------------------------------------------" << endl;
cout << "--------------------------------------------------" << endl; cout << "--------------------------------------------------" << endl;
@ -199,14 +199,6 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
void
LiftedCircuit::smoothCircuit (void)
{
smoothCircuit (root_);
}
double double
LiftedCircuit::getWeightedModelCount (void) const LiftedCircuit::getWeightedModelCount (void) const
{ {

View File

@ -25,15 +25,15 @@ class CircuitNode
public: public:
CircuitNode (const Clauses& clauses, string explanation = "") CircuitNode (const Clauses& clauses, string explanation = "")
: clauses_(clauses), explanation_(explanation) { } : clauses_(clauses), explanation_(explanation) { }
const Clauses& clauses (void) const { return clauses_; } const Clauses& clauses (void) const { return clauses_; }
Clauses clauses (void) { return clauses_; } Clauses clauses (void) { return clauses_; }
virtual double weight (void) const = 0; virtual double weight (void) const = 0;
string explanation (void) const { return explanation_; } string explanation (void) const { return explanation_; }
private: private:
Clauses clauses_; Clauses clauses_;
string explanation_; string explanation_;
@ -50,7 +50,7 @@ class OrNode : public CircuitNode
CircuitNode** leftBranch (void) { return &leftBranch_; } CircuitNode** leftBranch (void) { return &leftBranch_; }
CircuitNode** rightBranch (void) { return &rightBranch_; } CircuitNode** rightBranch (void) { return &rightBranch_; }
double weight (void) const; double weight (void) const;
private: private:
@ -66,7 +66,7 @@ class AndNode : public CircuitNode
AndNode (const Clauses& clauses, string explanation = "") AndNode (const Clauses& clauses, string explanation = "")
: CircuitNode (clauses, explanation), : CircuitNode (clauses, explanation),
leftBranch_(0), rightBranch_(0) { } leftBranch_(0), rightBranch_(0) { }
AndNode ( AndNode (
const Clauses& clauses, const Clauses& clauses,
CircuitNode* leftBranch, CircuitNode* leftBranch,
@ -74,7 +74,7 @@ class AndNode : public CircuitNode
string explanation = "") string explanation = "")
: CircuitNode (clauses, explanation), : CircuitNode (clauses, explanation),
leftBranch_(leftBranch), rightBranch_(rightBranch) { } leftBranch_(leftBranch), rightBranch_(rightBranch) { }
AndNode ( AndNode (
CircuitNode* leftBranch, CircuitNode* leftBranch,
CircuitNode* rightBranch, CircuitNode* rightBranch,
@ -84,9 +84,9 @@ class AndNode : public CircuitNode
CircuitNode** leftBranch (void) { return &leftBranch_; } CircuitNode** leftBranch (void) { return &leftBranch_; }
CircuitNode** rightBranch (void) { return &rightBranch_; } CircuitNode** rightBranch (void) { return &rightBranch_; }
double weight (void) const; double weight (void) const;
private: private:
CircuitNode* leftBranch_; CircuitNode* leftBranch_;
CircuitNode* rightBranch_; CircuitNode* rightBranch_;
@ -106,7 +106,7 @@ class SetOrNode : public CircuitNode
static unsigned nrPositives (void) { return nrGrsStack.top().first; } static unsigned nrPositives (void) { return nrGrsStack.top().first; }
static unsigned nrNegatives (void) { return nrGrsStack.top().second; } static unsigned nrNegatives (void) { return nrGrsStack.top().second; }
double weight (void) const; double weight (void) const;
private: private:
@ -126,7 +126,7 @@ class SetAndNode : public CircuitNode
nrGroundings_(nrGroundings) { } nrGroundings_(nrGroundings) { }
CircuitNode** follow (void) { return &follow_; } CircuitNode** follow (void) { return &follow_; }
double weight (void) const; double weight (void) const;
private: private:
@ -146,7 +146,7 @@ class IncExcNode : public CircuitNode
CircuitNode** plus1Branch (void) { return &plus1Branch_; } CircuitNode** plus1Branch (void) { return &plus1Branch_; }
CircuitNode** plus2Branch (void) { return &plus2Branch_; } CircuitNode** plus2Branch (void) { return &plus2Branch_; }
CircuitNode** minusBranch (void) { return &minusBranch_; } CircuitNode** minusBranch (void) { return &minusBranch_; }
double weight (void) const; double weight (void) const;
private: private:
@ -164,7 +164,7 @@ class LeafNode : public CircuitNode
: CircuitNode (Clauses() = {clause}), lwcnf_(lwcnf) { } : CircuitNode (Clauses() = {clause}), lwcnf_(lwcnf) { }
double weight (void) const; double weight (void) const;
private: private:
const LiftedWCNF& lwcnf_; const LiftedWCNF& lwcnf_;
}; };
@ -189,7 +189,7 @@ class TrueNode : public CircuitNode
{ {
public: public:
TrueNode (void) : CircuitNode ({}) { } TrueNode (void) : CircuitNode ({}) { }
double weight (void) const; double weight (void) const;
}; };
@ -200,7 +200,7 @@ class CompilationFailedNode : public CircuitNode
public: public:
CompilationFailedNode (const Clauses& clauses) CompilationFailedNode (const Clauses& clauses)
: CircuitNode (clauses) { } : CircuitNode (clauses) { }
double weight (void) const; double weight (void) const;
}; };
@ -210,62 +210,60 @@ class LiftedCircuit
{ {
public: public:
LiftedCircuit (const LiftedWCNF* lwcnf); LiftedCircuit (const LiftedWCNF* lwcnf);
void smoothCircuit (void);
double getWeightedModelCount (void) const; double getWeightedModelCount (void) const;
void exportToGraphViz (const char*); void exportToGraphViz (const char*);
private: private:
void compile (CircuitNode** follow, Clauses& clauses); void compile (CircuitNode** follow, Clauses& clauses);
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses); bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
bool tryIndependence (CircuitNode** follow, Clauses& clauses); bool tryIndependence (CircuitNode** follow, Clauses& clauses);
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses); bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses); bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses); bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct, bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
LogVars& rootLogVars); LogVars& rootLogVars);
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses); bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
bool tryGrounding (CircuitNode** follow, Clauses& clauses); bool tryGrounding (CircuitNode** follow, Clauses& clauses);
void shatterCountedLogVars (Clauses& clauses); void shatterCountedLogVars (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses); bool shatterCountedLogVarsAux (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2); bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2);
bool independentClause (Clause& clause, Clauses& otherClauses) const; bool independentClause (Clause& clause, Clauses& otherClauses) const;
bool independentLiteral (const Literal& lit, bool independentLiteral (const Literal& lit,
const Literals& otherLits) const; const Literals& otherLits) const;
LitLvTypesSet smoothCircuit (CircuitNode* node); LitLvTypesSet smoothCircuit (CircuitNode* node);
void createSmoothNode (const LitLvTypesSet& lids, void createSmoothNode (const LitLvTypesSet& lids,
CircuitNode** prev); CircuitNode** prev);
vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const; vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
bool containsTypes (const LogVarTypes& typesA, bool containsTypes (const LogVarTypes& typesA,
const LogVarTypes& typesB) const; const LogVarTypes& typesB) const;
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const; CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
void exportToGraphViz (CircuitNode* node, ofstream&); void exportToGraphViz (CircuitNode* node, ofstream&);
void printClauses (const CircuitNode* node, ofstream&, void printClauses (const CircuitNode* node, ofstream&,
string extraOptions = ""); string extraOptions = "");
string escapeNode (const CircuitNode* node) const; string escapeNode (const CircuitNode* node) const;
CircuitNode* root_; CircuitNode* root_;

View File

@ -1,6 +1,7 @@
#include "LiftedKc.h" #include "LiftedKc.h"
#include "LiftedWCNF.h" #include "LiftedWCNF.h"
#include "LiftedCircuit.h" #include "LiftedCircuit.h"
#include "Indexer.h"
LiftedKc::LiftedKc (const ParfactorList& pfList) LiftedKc::LiftedKc (const ParfactorList& pfList)
@ -19,9 +20,46 @@ LiftedKc::~LiftedKc (void)
Params Params
LiftedKc::solveQuery (const Grounds&) LiftedKc::solveQuery (const Grounds& query)
{ {
return Params(); vector<PrvGroup> groups;
Ranges ranges;
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin();
while (it != pfList_.end()) {
size_t idx = (*it)->indexOfGround (query[i]);
if (idx != (*it)->nrArguments()) {
groups.push_back ((*it)->argument (idx).group());
ranges.push_back ((*it)->range (idx));
break;
}
++ it;
}
}
cout << "groups: " << groups << endl;
cout << "ranges: " << ranges << endl;
Params params;
Indexer indexer (ranges);
while (indexer.valid()) {
for (size_t i = 0; i < groups.size(); i++) {
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
for (size_t j = 0; j < litIds.size(); j++) {
if (indexer[i] == j) {
lwcnf_->addWeight (litIds[j], 1.0, 1.0); // TODO not log aware
} else {
lwcnf_->addWeight (litIds[j], 0.0, 1.0); // TODO not log aware
}
}
}
// cout << "new weights ----- ----- -----" << endl;
// lwcnf_->printWeights();
// circuit_->exportToGraphViz ("ccircuit.dot");
params.push_back (circuit_->getWeightedModelCount());
++ indexer;
}
cout << "params: " << params << endl;
LogAware::normalize (params);
return params;
} }

View File

@ -20,7 +20,7 @@ class LiftedKc
private: private:
LiftedWCNF* lwcnf_; LiftedWCNF* lwcnf_;
LiftedCircuit* circuit_; LiftedCircuit* circuit_;
const ParfactorList& pfList_; const ParfactorList& pfList_;
}; };

View File

@ -352,8 +352,8 @@ Clause::getLogVarSetExcluding (size_t idx) const
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList) LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
: freeLiteralId_(0), pfList_(pfList) : freeLiteralId_(0), pfList_(pfList)
{ {
//addIndicatorClauses (pfList); addIndicatorClauses (pfList);
//addParameterClauses (pfList); addParameterClauses (pfList);
/* /*
vector<vector<string>> names = { vector<vector<string>> names = {
@ -377,25 +377,25 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
freeLiteralId_ = 2; freeLiteralId_ = 2;
*/ */
Literal lit1 (0, {0});
Literal lit2 (1, {0});
Literal lit3 (2, {1});
Literal lit4 (3, {1});
vector<vector<string>> names = {{"p1","p2"},{"p3","p4"}};
Clause c1 (names);
c1.addLiteral (lit1);
c1.addLiteral (lit2);
c1.addLiteral (lit3);
c1.addLiteral (lit4);
//c1.addPosCountedLogVar (0);
clauses_.push_back (c1);
Clause c2 (names); //Literal lit1 (0, {0});
c2.addLiteral (lit1); //Literal lit2 (1, {0});
c2.addLiteral (lit3); //Literal lit3 (2, {1});
c2.addNegCountedLogVar (0); //Literal lit4 (3, {1});
//vector<vector<string>> names = {{"p1","p2"},{"p3","p4"}};
//Clause c1 (names);
//c1.addLiteral (lit1);
//c1.addLiteral (lit2);
//c1.addLiteral (lit3);
//c1.addLiteral (lit4);
//c1.addPosCountedLogVar (0);
//clauses_.push_back (c1);
//Clause c2 (names);
//c2.addLiteral (lit1);
//c2.addLiteral (lit3);
//c2.addNegCountedLogVar (0);
//clauses_.push_back (c2); //clauses_.push_back (c2);
/* /*
Clause c3; Clause c3;
@ -408,17 +408,18 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
c4.addLiteral (lit3); c4.addLiteral (lit3);
clauses_.push_back (c4); clauses_.push_back (c4);
*/ */
freeLiteralId_ = 4; //freeLiteralId_ = 4;
cout << "FORMULA INDICATORS:" << endl; cout << "FORMULA INDICATORS:" << endl;
// printFormulaIndicators(); printFormulaIndicators();
cout << endl; cout << endl;
cout << "WEIGHTS:" << endl; cout << "WEIGHTS:" << endl;
printWeights(); printWeights();
cout << endl; cout << endl;
cout << "CLAUSES:" << endl; cout << "CLAUSES:" << endl;
printClauses(); printClauses();
// abort();
cout << endl; cout << endl;
} }
@ -431,6 +432,14 @@ LiftedWCNF::~LiftedWCNF (void)
void
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
{
weights_[lid] = make_pair (posW, negW);
}
double double
LiftedWCNF::posWeight (LiteralId lid) const LiftedWCNF::posWeight (LiteralId lid) const
{ {
@ -451,6 +460,15 @@ LiftedWCNF::negWeight (LiteralId lid) const
vector<LiteralId>
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
{
assert (Util::contains (map_, prvGroup));
return map_[prvGroup];
}
Clause Clause
LiftedWCNF::createClause (LiteralId lid) const LiftedWCNF::createClause (LiteralId lid) const
{ {
@ -481,14 +499,6 @@ LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range)
void
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
{
weights_[lid] = make_pair (posW, negW);
}
void void
LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList) LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
{ {

View File

@ -188,13 +188,17 @@ class LiftedWCNF
LiftedWCNF (const ParfactorList& pfList); LiftedWCNF (const ParfactorList& pfList);
~LiftedWCNF (void); ~LiftedWCNF (void);
const Clauses& clauses (void) const { return clauses_; } const Clauses& clauses (void) const { return clauses_; }
void addWeight (LiteralId lid, double posW, double negW);
double posWeight (LiteralId lid) const; double posWeight (LiteralId lid) const;
double negWeight (LiteralId lid) const; double negWeight (LiteralId lid) const;
vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
Clause createClause (LiteralId lid) const; Clause createClause (LiteralId lid) const;
void printFormulaIndicators (void) const; void printFormulaIndicators (void) const;
@ -204,10 +208,8 @@ class LiftedWCNF
void printClauses (void) const; void printClauses (void) const;
private: private:
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range); LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
void addWeight (LiteralId lid, double posW, double negW);
void addIndicatorClauses (const ParfactorList& pfList); void addIndicatorClauses (const ParfactorList& pfList);