initial support for weighted model countign

This commit is contained in:
Tiago Gomes 2012-10-25 12:22:52 +01:00
parent eac6b954a8
commit 68ef63207f
4 changed files with 266 additions and 96 deletions

View File

@ -3,22 +3,92 @@
#include "LiftedCircuit.h" #include "LiftedCircuit.h"
double
OrNode::weight (void) const
{
double lw = leftBranch_->weight();
double rw = rightBranch_->weight();
return Globals::logDomain ? Util::logSum (lw, rw) : lw + rw;
}
double
AndNode::weight (void) const
{
double lw = leftBranch_->weight();
double rw = rightBranch_->weight();
return Globals::logDomain ? lw + rw : lw * rw;
}
double
LeafNode::weight (void) const
{
assert (clauses().size() == 1);
assert (clauses()[0].isUnit());
Clause c = clauses()[0];
double weight = c.literals()[0].weight();
unsigned nrGroundings = c.constr()->size();
assert (nrGroundings != 0);
double www = Globals::logDomain
? weight * nrGroundings
: std::pow (weight, nrGroundings);
cout << "leaf w: " << www << endl;
return Globals::logDomain
? weight * nrGroundings
: std::pow (weight, nrGroundings);
}
double
SmoothNode::weight (void) const
{
Clauses cs = clauses();
double totalWeight = LogAware::multIdenty();
for (size_t i = 0; i < cs.size(); i++) {
double posWeight = cs[i].literals()[0].weight();
double negWeight = cs[i].literals()[1].weight();
unsigned nrGroundings = cs[i].constr()->size();
if (Globals::logDomain) {
totalWeight += (Util::logSum (posWeight, negWeight) * nrGroundings);
} else {
totalWeight *= std::pow (posWeight + negWeight, nrGroundings);
}
}
return totalWeight;
}
double
TrueNode::weight (void) const
{
return LogAware::multIdenty();
}
LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf) LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
: lwcnf_(lwcnf) : lwcnf_(lwcnf)
{ {
root_ = 0; root_ = 0;
Clauses ccc = lwcnf->clauses(); Clauses ccc = lwcnf->clauses();
ccc.erase (ccc.begin() + 5, ccc.end()); //ccc.erase (ccc.begin() + 5, ccc.end());
//Clause c2 = ccc.front(); //Clause c2 = ccc.front();
//c2.removeLiteralByIndex (1); //c2.removeLiteralByIndex (1);
//ccc.push_back (c2); //ccc.push_back (c2);
//compile (&root_, lwcnf->clauses()); //compile (&root_, lwcnf->clauses());
compile (&root_, ccc); compile (&root_, ccc);
cout << "done compiling..." << endl;
exportToGraphViz("circuit.dot"); exportToGraphViz("circuit.dot");
smoothCircuit(); smoothCircuit();
exportToGraphViz("smooth.dot"); exportToGraphViz("smooth.dot");
cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl;
} }
@ -31,6 +101,14 @@ LiftedCircuit::smoothCircuit (void)
double
LiftedCircuit::getWeightedModelCount (void) const
{
return root_->weight();
}
void void
LiftedCircuit::exportToGraphViz (const char* fileName) LiftedCircuit::exportToGraphViz (const char* fileName)
{ {
@ -63,14 +141,14 @@ LiftedCircuit::compile (
static int count = 0; count ++; static int count = 0; count ++;
*follow = new LeafNode (clauses[0]); *follow = new LeafNode (clauses[0]);
if (count == 1) { if (count == 1) {
Clause c (new ConstraintTree({})); // Clause c (new ConstraintTree({}));
c.addLiteral (Literal (100,{})); // c.addLiteral (Literal (100,{}));
*follow = new LeafNode (c); // *follow = new LeafNode (c);
} }
if (count == 2) { if (count == 2) {
Clause c (new ConstraintTree({})); // Clause c (new ConstraintTree({}));
c.addLiteral (Literal (101,{})); // c.addLiteral (Literal (101,{}));
*follow = new LeafNode (c); // *follow = new LeafNode (c);
} }
return; return;
} }
@ -326,10 +404,13 @@ void
LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
{ {
assert (node != 0); assert (node != 0);
static unsigned nrAndNodes = 0; static unsigned nrAndNodes = 0;
static unsigned nrOrNodes = 0; static unsigned nrOrNodes = 0;
if (dynamic_cast<OrNode*>(node) != 0) { switch (getCircuitNodeType (node)) {
case OR_NODE: {
OrNode* casted = dynamic_cast<OrNode*>(node); OrNode* casted = dynamic_cast<OrNode*>(node);
const Clauses& clauses = node->clauses(); const Clauses& clauses = node->clauses();
if (clauses.empty() == false) { if (clauses.empty() == false) {
@ -351,7 +432,10 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
nrOrNodes ++; nrOrNodes ++;
exportToGraphViz (*casted->leftBranch(), os); exportToGraphViz (*casted->leftBranch(), os);
exportToGraphViz (*casted->rightBranch(), os); exportToGraphViz (*casted->rightBranch(), os);
} else if (dynamic_cast<AndNode*>(node) != 0) { break;
}
case AND_NODE: {
AndNode* casted = dynamic_cast<AndNode*>(node); AndNode* casted = dynamic_cast<AndNode*>(node);
const Clauses& clauses = node->clauses(); const Clauses& clauses = node->clauses();
os << escapeNode (node) << " [shape=box,label=\"" ; os << escapeNode (node) << " [shape=box,label=\"" ;
@ -371,13 +455,19 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
nrAndNodes ++; nrAndNodes ++;
exportToGraphViz (*casted->leftBranch(), os); exportToGraphViz (*casted->leftBranch(), os);
exportToGraphViz (*casted->rightBranch(), os); exportToGraphViz (*casted->rightBranch(), os);
} else if (dynamic_cast<LeafNode*>(node) != 0) { break;
}
case LEAF_NODE: {
os << escapeNode (node); os << escapeNode (node);
os << " [shape=box,label=\"" ; os << " [shape=box,label=\"" ;
os << node->clauses()[0]; os << node->clauses()[0];
os << "\"]" ; os << "\"]" ;
os << endl; os << endl;
} else if (dynamic_cast<SmoothNode*>(node) != 0) { break;
}
case SMOOTH_NODE: {
os << escapeNode (node); os << escapeNode (node);
os << " [shape=box,style=filled,fillcolor=chartreuse,label=\"" ; os << " [shape=box,style=filled,fillcolor=chartreuse,label=\"" ;
const Clauses& clauses = node->clauses(); const Clauses& clauses = node->clauses();
@ -387,11 +477,17 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
} }
os << "\"]" ; os << "\"]" ;
os << endl; os << endl;
} else if (dynamic_cast<TrueNode*>(node) != 0) { break;
}
case TRUE_NODE: {
os << escapeNode (node); os << escapeNode (node);
os << " [shape=box,label=\"\"]" ; os << " [shape=box,label=\"\"]" ;
os << endl; os << endl;
} else if (dynamic_cast<FailNode*>(node) != 0) { break;
}
case FAIL_NODE: {
os << escapeNode (node); os << escapeNode (node);
os << " [shape=box,style=filled,fillcolor=brown1,label=\"" ; os << " [shape=box,style=filled,fillcolor=brown1,label=\"" ;
const Clauses& clauses = node->clauses(); const Clauses& clauses = node->clauses();
@ -401,7 +497,10 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
} }
os << "\"]" ; os << "\"]" ;
os << endl; os << endl;
} else { break;
}
default:
assert (false); assert (false);
} }
} }

View File

@ -24,7 +24,9 @@ class CircuitNode
CircuitNode (const Clauses& clauses, string explanation = "") CircuitNode (const Clauses& clauses, string explanation = "")
: clauses_(clauses), explanation_(explanation) { } : clauses_(clauses), explanation_(explanation) { }
const Clauses& clauses (void) { return clauses_; } const Clauses& clauses (void) const { return clauses_; }
Clauses clauses (void) { return clauses_; }
virtual double weight (void) const { return 0; } virtual double weight (void) const { return 0; }
@ -44,6 +46,8 @@ class OrNode : public CircuitNode
: CircuitNode (clauses, explanation), : CircuitNode (clauses, explanation),
leftBranch_(0), rightBranch_(0) { } leftBranch_(0), rightBranch_(0) { }
double weight (void) const;
CircuitNode** leftBranch (void) { return &leftBranch_; } CircuitNode** leftBranch (void) { return &leftBranch_; }
CircuitNode** rightBranch (void) { return &rightBranch_; } CircuitNode** rightBranch (void) { return &rightBranch_; }
private: private:
@ -75,6 +79,8 @@ class AndNode : public CircuitNode
: CircuitNode ({}, explanation), : CircuitNode ({}, explanation),
leftBranch_(leftBranch), rightBranch_(rightBranch) { } leftBranch_(leftBranch), rightBranch_(rightBranch) { }
double weight (void) const;
CircuitNode** leftBranch (void) { return &leftBranch_; } CircuitNode** leftBranch (void) { return &leftBranch_; }
CircuitNode** rightBranch (void) { return &rightBranch_; } CircuitNode** rightBranch (void) { return &rightBranch_; }
private: private:
@ -117,6 +123,8 @@ class LeafNode : public CircuitNode
{ {
public: public:
LeafNode (const Clause& clause) : CircuitNode ({clause}) { } LeafNode (const Clause& clause) : CircuitNode ({clause}) { }
double weight (void) const;
}; };
@ -125,6 +133,8 @@ class SmoothNode : public CircuitNode
{ {
public: public:
SmoothNode (const Clauses& clauses) : CircuitNode (clauses) { } SmoothNode (const Clauses& clauses) : CircuitNode (clauses) { }
double weight (void) const;
}; };
@ -133,6 +143,8 @@ class TrueNode : public CircuitNode
{ {
public: public:
TrueNode () : CircuitNode ({}) { } TrueNode () : CircuitNode ({}) { }
double weight (void) const;
}; };
@ -153,6 +165,8 @@ class LiftedCircuit
void smoothCircuit (void); void smoothCircuit (void);
double getWeightedModelCount (void) const;
void exportToGraphViz (const char*); void exportToGraphViz (const char*);
private: private:

View File

@ -150,8 +150,15 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
{ {
addIndicatorClauses (pfList); addIndicatorClauses (pfList);
addParameterClauses (pfList); addParameterClauses (pfList);
cout << "FORMULA INDICATORS:" << endl;
printFormulaIndicators(); printFormulaIndicators();
cout << endl;
cout << "WEIGHTS:" << endl;
printWeights();
cout << endl;
cout << "CLAUSES:" << endl;
printClauses(); printClauses();
cout << endl;
} }
@ -237,6 +244,8 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
while (indexer.valid()) { while (indexer.valid()) {
LiteralId paramVarLid = freeLiteralId_; LiteralId paramVarLid = freeLiteralId_;
double weight = (**it)[indexer];
Clause clause1 ((*it)->constr()); Clause clause1 ((*it)->constr());
for (unsigned i = 0; i < groups.size(); i++) { for (unsigned i = 0; i < groups.size(); i++) {
@ -245,11 +254,11 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
clause1.addAndNegateLiteral (Literal (lid, (*it)->argument(i).logVars())); clause1.addAndNegateLiteral (Literal (lid, (*it)->argument(i).logVars()));
Clause tempClause ((*it)->constr()); Clause tempClause ((*it)->constr());
tempClause.addAndNegateLiteral (Literal (paramVarLid, LogVars(), 1.0)); tempClause.addAndNegateLiteral (Literal (paramVarLid, 1.0));
tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars())); tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars()));
clauses_.push_back (tempClause); clauses_.push_back (tempClause);
} }
clause1.addLiteral (Literal (paramVarLid, LogVars(), 1.0)); clause1.addLiteral (Literal (paramVarLid, weight));
clauses_.push_back (clause1); clauses_.push_back (clause1);
freeLiteralId_ ++; freeLiteralId_ ++;
++ indexer; ++ indexer;
@ -259,16 +268,6 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
} }
void
LiftedWCNF::printClauses (void) const
{
for (unsigned i = 0; i < clauses_.size(); i++) {
cout << clauses_[i] << endl;
}
}
void void
LiftedWCNF::printFormulaIndicators (void) const LiftedWCNF::printFormulaIndicators (void) const
{ {
@ -293,3 +292,55 @@ LiftedWCNF::printFormulaIndicators (void) const
} }
} }
void
LiftedWCNF::printWeights (void) const
{
for (LiteralId i = 0; i < freeLiteralId_; i++) {
bool found = false;
for (size_t j = 0; j < clauses_.size(); j++) {
Literals literals = clauses_[j].literals();
for (size_t k = 0; k < literals.size(); k++) {
if (literals[k].lid() == i && literals[k].isPositive()) {
cout << "weight(" << literals[k] << ") = " << literals[k].weight();
cout << endl;
found = true;
break;
}
}
if (found == true) {
break;
}
}
found = false;
for (size_t j = 0; j < clauses_.size(); j++) {
Literals literals = clauses_[j].literals();
for (size_t k = 0; k < literals.size(); k++) {
if (literals[k].lid() == i && literals[k].isNegative()) {
cout << "weight(" << literals[k] << ") = " << literals[k].weight();
cout << endl;
found = true;
break;
}
}
if (found == true) {
break;
}
}
}
}
void
LiftedWCNF::printClauses (void) const
{
for (unsigned i = 0; i < clauses_.size(); i++) {
cout << clauses_[i] << endl;
}
}

View File

@ -13,6 +13,9 @@ class ConstraintTree;
class Literal class Literal
{ {
public: public:
Literal (LiteralId lid, double w = -1.0) :
lid_(lid), weight_(w), negated_(false) { }
Literal (LiteralId lid, const LogVars& lvs, double w = -1.0) : Literal (LiteralId lid, const LogVars& lvs, double w = -1.0) :
lid_(lid), logVars_(lvs), weight_(w), negated_(false) { } lid_(lid), logVars_(lvs), weight_(w), negated_(false) { }
@ -23,7 +26,8 @@ class Literal
LogVars logVars (void) const { return logVars_; } LogVars logVars (void) const { return logVars_; }
double weight (void) const { return weight_; } // FIXME not log aware
double weight (void) const { return weight_ < 0.0 ? 1.0 : weight_; }
void negate (void) { negated_ = !negated_; } void negate (void) { negated_ = !negated_; }
@ -106,10 +110,12 @@ class LiftedWCNF
Clause createClauseForLiteral (LiteralId lid) const; Clause createClauseForLiteral (LiteralId lid) const;
void printClauses (void) const;
void printFormulaIndicators (void) const; void printFormulaIndicators (void) const;
void printWeights (void) const;
void printClauses (void) const;
private: private:
void addIndicatorClauses (const ParfactorList& pfList); void addIndicatorClauses (const ParfactorList& pfList);