initial support for weighted model countign
This commit is contained in:
parent
eac6b954a8
commit
68ef63207f
@ -3,22 +3,92 @@
|
|||||||
#include "LiftedCircuit.h"
|
#include "LiftedCircuit.h"
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
OrNode::weight (void) const
|
||||||
|
{
|
||||||
|
double lw = leftBranch_->weight();
|
||||||
|
double rw = rightBranch_->weight();
|
||||||
|
return Globals::logDomain ? Util::logSum (lw, rw) : lw + rw;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
AndNode::weight (void) const
|
||||||
|
{
|
||||||
|
double lw = leftBranch_->weight();
|
||||||
|
double rw = rightBranch_->weight();
|
||||||
|
return Globals::logDomain ? lw + rw : lw * rw;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
LeafNode::weight (void) const
|
||||||
|
{
|
||||||
|
assert (clauses().size() == 1);
|
||||||
|
assert (clauses()[0].isUnit());
|
||||||
|
Clause c = clauses()[0];
|
||||||
|
double weight = c.literals()[0].weight();
|
||||||
|
unsigned nrGroundings = c.constr()->size();
|
||||||
|
assert (nrGroundings != 0);
|
||||||
|
double www = Globals::logDomain
|
||||||
|
? weight * nrGroundings
|
||||||
|
: std::pow (weight, nrGroundings);
|
||||||
|
|
||||||
|
cout << "leaf w: " << www << endl;
|
||||||
|
|
||||||
|
return Globals::logDomain
|
||||||
|
? weight * nrGroundings
|
||||||
|
: std::pow (weight, nrGroundings);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
SmoothNode::weight (void) const
|
||||||
|
{
|
||||||
|
Clauses cs = clauses();
|
||||||
|
double totalWeight = LogAware::multIdenty();
|
||||||
|
for (size_t i = 0; i < cs.size(); i++) {
|
||||||
|
double posWeight = cs[i].literals()[0].weight();
|
||||||
|
double negWeight = cs[i].literals()[1].weight();
|
||||||
|
unsigned nrGroundings = cs[i].constr()->size();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
totalWeight += (Util::logSum (posWeight, negWeight) * nrGroundings);
|
||||||
|
} else {
|
||||||
|
totalWeight *= std::pow (posWeight + negWeight, nrGroundings);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return totalWeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
TrueNode::weight (void) const
|
||||||
|
{
|
||||||
|
return LogAware::multIdenty();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
||||||
: lwcnf_(lwcnf)
|
: lwcnf_(lwcnf)
|
||||||
{
|
{
|
||||||
root_ = 0;
|
root_ = 0;
|
||||||
Clauses ccc = lwcnf->clauses();
|
Clauses ccc = lwcnf->clauses();
|
||||||
ccc.erase (ccc.begin() + 5, ccc.end());
|
//ccc.erase (ccc.begin() + 5, ccc.end());
|
||||||
//Clause c2 = ccc.front();
|
//Clause c2 = ccc.front();
|
||||||
//c2.removeLiteralByIndex (1);
|
//c2.removeLiteralByIndex (1);
|
||||||
//ccc.push_back (c2);
|
//ccc.push_back (c2);
|
||||||
|
|
||||||
//compile (&root_, lwcnf->clauses());
|
//compile (&root_, lwcnf->clauses());
|
||||||
compile (&root_, ccc);
|
compile (&root_, ccc);
|
||||||
cout << "done compiling..." << endl;
|
|
||||||
exportToGraphViz("circuit.dot");
|
exportToGraphViz("circuit.dot");
|
||||||
smoothCircuit();
|
smoothCircuit();
|
||||||
exportToGraphViz("smooth.dot");
|
exportToGraphViz("smooth.dot");
|
||||||
|
cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -31,6 +101,14 @@ LiftedCircuit::smoothCircuit (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
LiftedCircuit::getWeightedModelCount (void) const
|
||||||
|
{
|
||||||
|
return root_->weight();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
LiftedCircuit::exportToGraphViz (const char* fileName)
|
LiftedCircuit::exportToGraphViz (const char* fileName)
|
||||||
{
|
{
|
||||||
@ -63,14 +141,14 @@ LiftedCircuit::compile (
|
|||||||
static int count = 0; count ++;
|
static int count = 0; count ++;
|
||||||
*follow = new LeafNode (clauses[0]);
|
*follow = new LeafNode (clauses[0]);
|
||||||
if (count == 1) {
|
if (count == 1) {
|
||||||
Clause c (new ConstraintTree({}));
|
// Clause c (new ConstraintTree({}));
|
||||||
c.addLiteral (Literal (100,{}));
|
// c.addLiteral (Literal (100,{}));
|
||||||
*follow = new LeafNode (c);
|
// *follow = new LeafNode (c);
|
||||||
}
|
}
|
||||||
if (count == 2) {
|
if (count == 2) {
|
||||||
Clause c (new ConstraintTree({}));
|
// Clause c (new ConstraintTree({}));
|
||||||
c.addLiteral (Literal (101,{}));
|
// c.addLiteral (Literal (101,{}));
|
||||||
*follow = new LeafNode (c);
|
// *follow = new LeafNode (c);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -326,10 +404,13 @@ void
|
|||||||
LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
||||||
{
|
{
|
||||||
assert (node != 0);
|
assert (node != 0);
|
||||||
|
|
||||||
static unsigned nrAndNodes = 0;
|
static unsigned nrAndNodes = 0;
|
||||||
static unsigned nrOrNodes = 0;
|
static unsigned nrOrNodes = 0;
|
||||||
|
|
||||||
if (dynamic_cast<OrNode*>(node) != 0) {
|
switch (getCircuitNodeType (node)) {
|
||||||
|
|
||||||
|
case OR_NODE: {
|
||||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||||
const Clauses& clauses = node->clauses();
|
const Clauses& clauses = node->clauses();
|
||||||
if (clauses.empty() == false) {
|
if (clauses.empty() == false) {
|
||||||
@ -351,7 +432,10 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
|||||||
nrOrNodes ++;
|
nrOrNodes ++;
|
||||||
exportToGraphViz (*casted->leftBranch(), os);
|
exportToGraphViz (*casted->leftBranch(), os);
|
||||||
exportToGraphViz (*casted->rightBranch(), os);
|
exportToGraphViz (*casted->rightBranch(), os);
|
||||||
} else if (dynamic_cast<AndNode*>(node) != 0) {
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case AND_NODE: {
|
||||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||||
const Clauses& clauses = node->clauses();
|
const Clauses& clauses = node->clauses();
|
||||||
os << escapeNode (node) << " [shape=box,label=\"" ;
|
os << escapeNode (node) << " [shape=box,label=\"" ;
|
||||||
@ -371,13 +455,19 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
|||||||
nrAndNodes ++;
|
nrAndNodes ++;
|
||||||
exportToGraphViz (*casted->leftBranch(), os);
|
exportToGraphViz (*casted->leftBranch(), os);
|
||||||
exportToGraphViz (*casted->rightBranch(), os);
|
exportToGraphViz (*casted->rightBranch(), os);
|
||||||
} else if (dynamic_cast<LeafNode*>(node) != 0) {
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case LEAF_NODE: {
|
||||||
os << escapeNode (node);
|
os << escapeNode (node);
|
||||||
os << " [shape=box,label=\"" ;
|
os << " [shape=box,label=\"" ;
|
||||||
os << node->clauses()[0];
|
os << node->clauses()[0];
|
||||||
os << "\"]" ;
|
os << "\"]" ;
|
||||||
os << endl;
|
os << endl;
|
||||||
} else if (dynamic_cast<SmoothNode*>(node) != 0) {
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case SMOOTH_NODE: {
|
||||||
os << escapeNode (node);
|
os << escapeNode (node);
|
||||||
os << " [shape=box,style=filled,fillcolor=chartreuse,label=\"" ;
|
os << " [shape=box,style=filled,fillcolor=chartreuse,label=\"" ;
|
||||||
const Clauses& clauses = node->clauses();
|
const Clauses& clauses = node->clauses();
|
||||||
@ -387,11 +477,17 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
|||||||
}
|
}
|
||||||
os << "\"]" ;
|
os << "\"]" ;
|
||||||
os << endl;
|
os << endl;
|
||||||
} else if (dynamic_cast<TrueNode*>(node) != 0) {
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case TRUE_NODE: {
|
||||||
os << escapeNode (node);
|
os << escapeNode (node);
|
||||||
os << " [shape=box,label=\"⊤\"]" ;
|
os << " [shape=box,label=\"⊤\"]" ;
|
||||||
os << endl;
|
os << endl;
|
||||||
} else if (dynamic_cast<FailNode*>(node) != 0) {
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case FAIL_NODE: {
|
||||||
os << escapeNode (node);
|
os << escapeNode (node);
|
||||||
os << " [shape=box,style=filled,fillcolor=brown1,label=\"" ;
|
os << " [shape=box,style=filled,fillcolor=brown1,label=\"" ;
|
||||||
const Clauses& clauses = node->clauses();
|
const Clauses& clauses = node->clauses();
|
||||||
@ -401,7 +497,10 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
|||||||
}
|
}
|
||||||
os << "\"]" ;
|
os << "\"]" ;
|
||||||
os << endl;
|
os << endl;
|
||||||
} else {
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
assert (false);
|
assert (false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -24,7 +24,9 @@ class CircuitNode
|
|||||||
CircuitNode (const Clauses& clauses, string explanation = "")
|
CircuitNode (const Clauses& clauses, string explanation = "")
|
||||||
: clauses_(clauses), explanation_(explanation) { }
|
: clauses_(clauses), explanation_(explanation) { }
|
||||||
|
|
||||||
const Clauses& clauses (void) { return clauses_; }
|
const Clauses& clauses (void) const { return clauses_; }
|
||||||
|
|
||||||
|
Clauses clauses (void) { return clauses_; }
|
||||||
|
|
||||||
virtual double weight (void) const { return 0; }
|
virtual double weight (void) const { return 0; }
|
||||||
|
|
||||||
@ -44,6 +46,8 @@ class OrNode : public CircuitNode
|
|||||||
: CircuitNode (clauses, explanation),
|
: CircuitNode (clauses, explanation),
|
||||||
leftBranch_(0), rightBranch_(0) { }
|
leftBranch_(0), rightBranch_(0) { }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
|
|
||||||
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||||
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||||
private:
|
private:
|
||||||
@ -75,6 +79,8 @@ class AndNode : public CircuitNode
|
|||||||
: CircuitNode ({}, explanation),
|
: CircuitNode ({}, explanation),
|
||||||
leftBranch_(leftBranch), rightBranch_(rightBranch) { }
|
leftBranch_(leftBranch), rightBranch_(rightBranch) { }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
|
|
||||||
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||||
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||||
private:
|
private:
|
||||||
@ -117,6 +123,8 @@ class LeafNode : public CircuitNode
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
LeafNode (const Clause& clause) : CircuitNode ({clause}) { }
|
LeafNode (const Clause& clause) : CircuitNode ({clause}) { }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -125,6 +133,8 @@ class SmoothNode : public CircuitNode
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
SmoothNode (const Clauses& clauses) : CircuitNode (clauses) { }
|
SmoothNode (const Clauses& clauses) : CircuitNode (clauses) { }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -133,6 +143,8 @@ class TrueNode : public CircuitNode
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
TrueNode () : CircuitNode ({}) { }
|
TrueNode () : CircuitNode ({}) { }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -153,6 +165,8 @@ class LiftedCircuit
|
|||||||
|
|
||||||
void smoothCircuit (void);
|
void smoothCircuit (void);
|
||||||
|
|
||||||
|
double getWeightedModelCount (void) const;
|
||||||
|
|
||||||
void exportToGraphViz (const char*);
|
void exportToGraphViz (const char*);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -150,8 +150,15 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
|||||||
{
|
{
|
||||||
addIndicatorClauses (pfList);
|
addIndicatorClauses (pfList);
|
||||||
addParameterClauses (pfList);
|
addParameterClauses (pfList);
|
||||||
|
cout << "FORMULA INDICATORS:" << endl;
|
||||||
printFormulaIndicators();
|
printFormulaIndicators();
|
||||||
|
cout << endl;
|
||||||
|
cout << "WEIGHTS:" << endl;
|
||||||
|
printWeights();
|
||||||
|
cout << endl;
|
||||||
|
cout << "CLAUSES:" << endl;
|
||||||
printClauses();
|
printClauses();
|
||||||
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -237,6 +244,8 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
|||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
LiteralId paramVarLid = freeLiteralId_;
|
LiteralId paramVarLid = freeLiteralId_;
|
||||||
|
|
||||||
|
double weight = (**it)[indexer];
|
||||||
|
|
||||||
Clause clause1 ((*it)->constr());
|
Clause clause1 ((*it)->constr());
|
||||||
|
|
||||||
for (unsigned i = 0; i < groups.size(); i++) {
|
for (unsigned i = 0; i < groups.size(); i++) {
|
||||||
@ -245,11 +254,11 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
|||||||
clause1.addAndNegateLiteral (Literal (lid, (*it)->argument(i).logVars()));
|
clause1.addAndNegateLiteral (Literal (lid, (*it)->argument(i).logVars()));
|
||||||
|
|
||||||
Clause tempClause ((*it)->constr());
|
Clause tempClause ((*it)->constr());
|
||||||
tempClause.addAndNegateLiteral (Literal (paramVarLid, LogVars(), 1.0));
|
tempClause.addAndNegateLiteral (Literal (paramVarLid, 1.0));
|
||||||
tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars()));
|
tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars()));
|
||||||
clauses_.push_back (tempClause);
|
clauses_.push_back (tempClause);
|
||||||
}
|
}
|
||||||
clause1.addLiteral (Literal (paramVarLid, LogVars(), 1.0));
|
clause1.addLiteral (Literal (paramVarLid, weight));
|
||||||
clauses_.push_back (clause1);
|
clauses_.push_back (clause1);
|
||||||
freeLiteralId_ ++;
|
freeLiteralId_ ++;
|
||||||
++ indexer;
|
++ indexer;
|
||||||
@ -259,16 +268,6 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
LiftedWCNF::printClauses (void) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < clauses_.size(); i++) {
|
|
||||||
cout << clauses_[i] << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
LiftedWCNF::printFormulaIndicators (void) const
|
LiftedWCNF::printFormulaIndicators (void) const
|
||||||
{
|
{
|
||||||
@ -293,3 +292,55 @@ LiftedWCNF::printFormulaIndicators (void) const
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedWCNF::printWeights (void) const
|
||||||
|
{
|
||||||
|
for (LiteralId i = 0; i < freeLiteralId_; i++) {
|
||||||
|
|
||||||
|
bool found = false;
|
||||||
|
for (size_t j = 0; j < clauses_.size(); j++) {
|
||||||
|
Literals literals = clauses_[j].literals();
|
||||||
|
for (size_t k = 0; k < literals.size(); k++) {
|
||||||
|
if (literals[k].lid() == i && literals[k].isPositive()) {
|
||||||
|
cout << "weight(" << literals[k] << ") = " << literals[k].weight();
|
||||||
|
cout << endl;
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (found == true) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
found = false;
|
||||||
|
for (size_t j = 0; j < clauses_.size(); j++) {
|
||||||
|
Literals literals = clauses_[j].literals();
|
||||||
|
for (size_t k = 0; k < literals.size(); k++) {
|
||||||
|
if (literals[k].lid() == i && literals[k].isNegative()) {
|
||||||
|
cout << "weight(" << literals[k] << ") = " << literals[k].weight();
|
||||||
|
cout << endl;
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (found == true) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedWCNF::printClauses (void) const
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < clauses_.size(); i++) {
|
||||||
|
cout << clauses_[i] << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -13,6 +13,9 @@ class ConstraintTree;
|
|||||||
class Literal
|
class Literal
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
Literal (LiteralId lid, double w = -1.0) :
|
||||||
|
lid_(lid), weight_(w), negated_(false) { }
|
||||||
|
|
||||||
Literal (LiteralId lid, const LogVars& lvs, double w = -1.0) :
|
Literal (LiteralId lid, const LogVars& lvs, double w = -1.0) :
|
||||||
lid_(lid), logVars_(lvs), weight_(w), negated_(false) { }
|
lid_(lid), logVars_(lvs), weight_(w), negated_(false) { }
|
||||||
|
|
||||||
@ -23,7 +26,8 @@ class Literal
|
|||||||
|
|
||||||
LogVars logVars (void) const { return logVars_; }
|
LogVars logVars (void) const { return logVars_; }
|
||||||
|
|
||||||
double weight (void) const { return weight_; }
|
// FIXME not log aware
|
||||||
|
double weight (void) const { return weight_ < 0.0 ? 1.0 : weight_; }
|
||||||
|
|
||||||
void negate (void) { negated_ = !negated_; }
|
void negate (void) { negated_ = !negated_; }
|
||||||
|
|
||||||
@ -106,10 +110,12 @@ class LiftedWCNF
|
|||||||
|
|
||||||
Clause createClauseForLiteral (LiteralId lid) const;
|
Clause createClauseForLiteral (LiteralId lid) const;
|
||||||
|
|
||||||
void printClauses (void) const;
|
|
||||||
|
|
||||||
void printFormulaIndicators (void) const;
|
void printFormulaIndicators (void) const;
|
||||||
|
|
||||||
|
void printWeights (void) const;
|
||||||
|
|
||||||
|
void printClauses (void) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void addIndicatorClauses (const ParfactorList& pfList);
|
void addIndicatorClauses (const ParfactorList& pfList);
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user