more work to support inference with lifted knowledge compilation
This commit is contained in:
parent
c53220aa61
commit
8ab622e0aa
@ -190,7 +190,7 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
||||
Clauses clauses = lwcnf->clauses();
|
||||
compile (&root_, clauses);
|
||||
exportToGraphViz("circuit.dot");
|
||||
smoothCircuit();
|
||||
smoothCircuit (root_);
|
||||
exportToGraphViz("circuit.smooth.dot");
|
||||
cout << "--------------------------------------------------" << endl;
|
||||
cout << "--------------------------------------------------" << endl;
|
||||
@ -199,14 +199,6 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedCircuit::smoothCircuit (void)
|
||||
{
|
||||
smoothCircuit (root_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
LiftedCircuit::getWeightedModelCount (void) const
|
||||
{
|
||||
|
@ -25,15 +25,15 @@ class CircuitNode
|
||||
public:
|
||||
CircuitNode (const Clauses& clauses, string explanation = "")
|
||||
: clauses_(clauses), explanation_(explanation) { }
|
||||
|
||||
|
||||
const Clauses& clauses (void) const { return clauses_; }
|
||||
|
||||
|
||||
Clauses clauses (void) { return clauses_; }
|
||||
|
||||
|
||||
virtual double weight (void) const = 0;
|
||||
|
||||
|
||||
string explanation (void) const { return explanation_; }
|
||||
|
||||
|
||||
private:
|
||||
Clauses clauses_;
|
||||
string explanation_;
|
||||
@ -50,7 +50,7 @@ class OrNode : public CircuitNode
|
||||
|
||||
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
@ -66,7 +66,7 @@ class AndNode : public CircuitNode
|
||||
AndNode (const Clauses& clauses, string explanation = "")
|
||||
: CircuitNode (clauses, explanation),
|
||||
leftBranch_(0), rightBranch_(0) { }
|
||||
|
||||
|
||||
AndNode (
|
||||
const Clauses& clauses,
|
||||
CircuitNode* leftBranch,
|
||||
@ -74,7 +74,7 @@ class AndNode : public CircuitNode
|
||||
string explanation = "")
|
||||
: CircuitNode (clauses, explanation),
|
||||
leftBranch_(leftBranch), rightBranch_(rightBranch) { }
|
||||
|
||||
|
||||
AndNode (
|
||||
CircuitNode* leftBranch,
|
||||
CircuitNode* rightBranch,
|
||||
@ -84,9 +84,9 @@ class AndNode : public CircuitNode
|
||||
|
||||
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
|
||||
private:
|
||||
CircuitNode* leftBranch_;
|
||||
CircuitNode* rightBranch_;
|
||||
@ -106,7 +106,7 @@ class SetOrNode : public CircuitNode
|
||||
static unsigned nrPositives (void) { return nrGrsStack.top().first; }
|
||||
|
||||
static unsigned nrNegatives (void) { return nrGrsStack.top().second; }
|
||||
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
@ -126,7 +126,7 @@ class SetAndNode : public CircuitNode
|
||||
nrGroundings_(nrGroundings) { }
|
||||
|
||||
CircuitNode** follow (void) { return &follow_; }
|
||||
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
@ -146,7 +146,7 @@ class IncExcNode : public CircuitNode
|
||||
CircuitNode** plus1Branch (void) { return &plus1Branch_; }
|
||||
CircuitNode** plus2Branch (void) { return &plus2Branch_; }
|
||||
CircuitNode** minusBranch (void) { return &minusBranch_; }
|
||||
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
@ -164,7 +164,7 @@ class LeafNode : public CircuitNode
|
||||
: CircuitNode (Clauses() = {clause}), lwcnf_(lwcnf) { }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
|
||||
private:
|
||||
const LiftedWCNF& lwcnf_;
|
||||
};
|
||||
@ -189,7 +189,7 @@ class TrueNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
TrueNode (void) : CircuitNode ({}) { }
|
||||
|
||||
|
||||
double weight (void) const;
|
||||
};
|
||||
|
||||
@ -200,7 +200,7 @@ class CompilationFailedNode : public CircuitNode
|
||||
public:
|
||||
CompilationFailedNode (const Clauses& clauses)
|
||||
: CircuitNode (clauses) { }
|
||||
|
||||
|
||||
double weight (void) const;
|
||||
};
|
||||
|
||||
@ -210,62 +210,60 @@ class LiftedCircuit
|
||||
{
|
||||
public:
|
||||
LiftedCircuit (const LiftedWCNF* lwcnf);
|
||||
|
||||
void smoothCircuit (void);
|
||||
|
||||
|
||||
double getWeightedModelCount (void) const;
|
||||
|
||||
|
||||
void exportToGraphViz (const char*);
|
||||
|
||||
|
||||
private:
|
||||
|
||||
void compile (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
|
||||
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
|
||||
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
|
||||
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
|
||||
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
|
||||
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
|
||||
LogVars& rootLogVars);
|
||||
|
||||
|
||||
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
|
||||
bool tryGrounding (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
|
||||
void shatterCountedLogVars (Clauses& clauses);
|
||||
|
||||
|
||||
bool shatterCountedLogVarsAux (Clauses& clauses);
|
||||
|
||||
bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2);
|
||||
|
||||
|
||||
bool independentClause (Clause& clause, Clauses& otherClauses) const;
|
||||
|
||||
|
||||
bool independentLiteral (const Literal& lit,
|
||||
const Literals& otherLits) const;
|
||||
|
||||
LitLvTypesSet smoothCircuit (CircuitNode* node);
|
||||
|
||||
|
||||
void createSmoothNode (const LitLvTypesSet& lids,
|
||||
CircuitNode** prev);
|
||||
|
||||
|
||||
vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
|
||||
|
||||
|
||||
bool containsTypes (const LogVarTypes& typesA,
|
||||
const LogVarTypes& typesB) const;
|
||||
|
||||
|
||||
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
|
||||
|
||||
|
||||
void exportToGraphViz (CircuitNode* node, ofstream&);
|
||||
|
||||
|
||||
void printClauses (const CircuitNode* node, ofstream&,
|
||||
string extraOptions = "");
|
||||
|
||||
|
||||
string escapeNode (const CircuitNode* node) const;
|
||||
|
||||
CircuitNode* root_;
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include "LiftedKc.h"
|
||||
#include "LiftedWCNF.h"
|
||||
#include "LiftedCircuit.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
LiftedKc::LiftedKc (const ParfactorList& pfList)
|
||||
@ -19,9 +20,46 @@ LiftedKc::~LiftedKc (void)
|
||||
|
||||
|
||||
Params
|
||||
LiftedKc::solveQuery (const Grounds&)
|
||||
LiftedKc::solveQuery (const Grounds& query)
|
||||
{
|
||||
return Params();
|
||||
vector<PrvGroup> groups;
|
||||
Ranges ranges;
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
size_t idx = (*it)->indexOfGround (query[i]);
|
||||
if (idx != (*it)->nrArguments()) {
|
||||
groups.push_back ((*it)->argument (idx).group());
|
||||
ranges.push_back ((*it)->range (idx));
|
||||
break;
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
cout << "groups: " << groups << endl;
|
||||
cout << "ranges: " << ranges << endl;
|
||||
Params params;
|
||||
Indexer indexer (ranges);
|
||||
while (indexer.valid()) {
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
|
||||
for (size_t j = 0; j < litIds.size(); j++) {
|
||||
if (indexer[i] == j) {
|
||||
lwcnf_->addWeight (litIds[j], 1.0, 1.0); // TODO not log aware
|
||||
} else {
|
||||
lwcnf_->addWeight (litIds[j], 0.0, 1.0); // TODO not log aware
|
||||
}
|
||||
}
|
||||
}
|
||||
// cout << "new weights ----- ----- -----" << endl;
|
||||
// lwcnf_->printWeights();
|
||||
// circuit_->exportToGraphViz ("ccircuit.dot");
|
||||
params.push_back (circuit_->getWeightedModelCount());
|
||||
++ indexer;
|
||||
}
|
||||
cout << "params: " << params << endl;
|
||||
LogAware::normalize (params);
|
||||
return params;
|
||||
}
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ class LiftedKc
|
||||
private:
|
||||
LiftedWCNF* lwcnf_;
|
||||
LiftedCircuit* circuit_;
|
||||
|
||||
|
||||
const ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
@ -352,8 +352,8 @@ Clause::getLogVarSetExcluding (size_t idx) const
|
||||
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
: freeLiteralId_(0), pfList_(pfList)
|
||||
{
|
||||
//addIndicatorClauses (pfList);
|
||||
//addParameterClauses (pfList);
|
||||
addIndicatorClauses (pfList);
|
||||
addParameterClauses (pfList);
|
||||
|
||||
/*
|
||||
vector<vector<string>> names = {
|
||||
@ -377,25 +377,25 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
|
||||
freeLiteralId_ = 2;
|
||||
*/
|
||||
|
||||
Literal lit1 (0, {0});
|
||||
Literal lit2 (1, {0});
|
||||
Literal lit3 (2, {1});
|
||||
Literal lit4 (3, {1});
|
||||
|
||||
vector<vector<string>> names = {{"p1","p2"},{"p3","p4"}};
|
||||
Clause c1 (names);
|
||||
c1.addLiteral (lit1);
|
||||
c1.addLiteral (lit2);
|
||||
c1.addLiteral (lit3);
|
||||
c1.addLiteral (lit4);
|
||||
//c1.addPosCountedLogVar (0);
|
||||
clauses_.push_back (c1);
|
||||
|
||||
Clause c2 (names);
|
||||
c2.addLiteral (lit1);
|
||||
c2.addLiteral (lit3);
|
||||
c2.addNegCountedLogVar (0);
|
||||
//Literal lit1 (0, {0});
|
||||
//Literal lit2 (1, {0});
|
||||
//Literal lit3 (2, {1});
|
||||
//Literal lit4 (3, {1});
|
||||
|
||||
//vector<vector<string>> names = {{"p1","p2"},{"p3","p4"}};
|
||||
//Clause c1 (names);
|
||||
//c1.addLiteral (lit1);
|
||||
//c1.addLiteral (lit2);
|
||||
//c1.addLiteral (lit3);
|
||||
//c1.addLiteral (lit4);
|
||||
//c1.addPosCountedLogVar (0);
|
||||
//clauses_.push_back (c1);
|
||||
|
||||
//Clause c2 (names);
|
||||
//c2.addLiteral (lit1);
|
||||
//c2.addLiteral (lit3);
|
||||
//c2.addNegCountedLogVar (0);
|
||||
//clauses_.push_back (c2);
|
||||
/*
|
||||
Clause c3;
|
||||
@ -408,17 +408,18 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
c4.addLiteral (lit3);
|
||||
clauses_.push_back (c4);
|
||||
*/
|
||||
freeLiteralId_ = 4;
|
||||
//freeLiteralId_ = 4;
|
||||
|
||||
cout << "FORMULA INDICATORS:" << endl;
|
||||
// printFormulaIndicators();
|
||||
printFormulaIndicators();
|
||||
cout << endl;
|
||||
|
||||
cout << "WEIGHTS:" << endl;
|
||||
printWeights();
|
||||
cout << endl;
|
||||
|
||||
cout << "CLAUSES:" << endl;
|
||||
printClauses();
|
||||
// abort();
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
@ -431,6 +432,14 @@ LiftedWCNF::~LiftedWCNF (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
|
||||
{
|
||||
weights_[lid] = make_pair (posW, negW);
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
LiftedWCNF::posWeight (LiteralId lid) const
|
||||
{
|
||||
@ -451,6 +460,15 @@ LiftedWCNF::negWeight (LiteralId lid) const
|
||||
|
||||
|
||||
|
||||
vector<LiteralId>
|
||||
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
|
||||
{
|
||||
assert (Util::contains (map_, prvGroup));
|
||||
return map_[prvGroup];
|
||||
}
|
||||
|
||||
|
||||
|
||||
Clause
|
||||
LiftedWCNF::createClause (LiteralId lid) const
|
||||
{
|
||||
@ -481,14 +499,6 @@ LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range)
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
|
||||
{
|
||||
weights_[lid] = make_pair (posW, negW);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
|
||||
{
|
||||
|
@ -188,13 +188,17 @@ class LiftedWCNF
|
||||
LiftedWCNF (const ParfactorList& pfList);
|
||||
|
||||
~LiftedWCNF (void);
|
||||
|
||||
|
||||
const Clauses& clauses (void) const { return clauses_; }
|
||||
|
||||
|
||||
void addWeight (LiteralId lid, double posW, double negW);
|
||||
|
||||
double posWeight (LiteralId lid) const;
|
||||
|
||||
double negWeight (LiteralId lid) const;
|
||||
|
||||
vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
|
||||
|
||||
Clause createClause (LiteralId lid) const;
|
||||
|
||||
void printFormulaIndicators (void) const;
|
||||
@ -204,10 +208,8 @@ class LiftedWCNF
|
||||
void printClauses (void) const;
|
||||
|
||||
private:
|
||||
|
||||
|
||||
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
|
||||
|
||||
void addWeight (LiteralId lid, double posW, double negW);
|
||||
|
||||
void addIndicatorClauses (const ParfactorList& pfList);
|
||||
|
||||
|
Reference in New Issue
Block a user