initial support for weighted model countign
This commit is contained in:
parent
eac6b954a8
commit
68ef63207f
@ -3,22 +3,92 @@
|
||||
#include "LiftedCircuit.h"
|
||||
|
||||
|
||||
double
|
||||
OrNode::weight (void) const
|
||||
{
|
||||
double lw = leftBranch_->weight();
|
||||
double rw = rightBranch_->weight();
|
||||
return Globals::logDomain ? Util::logSum (lw, rw) : lw + rw;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
AndNode::weight (void) const
|
||||
{
|
||||
double lw = leftBranch_->weight();
|
||||
double rw = rightBranch_->weight();
|
||||
return Globals::logDomain ? lw + rw : lw * rw;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
LeafNode::weight (void) const
|
||||
{
|
||||
assert (clauses().size() == 1);
|
||||
assert (clauses()[0].isUnit());
|
||||
Clause c = clauses()[0];
|
||||
double weight = c.literals()[0].weight();
|
||||
unsigned nrGroundings = c.constr()->size();
|
||||
assert (nrGroundings != 0);
|
||||
double www = Globals::logDomain
|
||||
? weight * nrGroundings
|
||||
: std::pow (weight, nrGroundings);
|
||||
|
||||
cout << "leaf w: " << www << endl;
|
||||
|
||||
return Globals::logDomain
|
||||
? weight * nrGroundings
|
||||
: std::pow (weight, nrGroundings);
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
SmoothNode::weight (void) const
|
||||
{
|
||||
Clauses cs = clauses();
|
||||
double totalWeight = LogAware::multIdenty();
|
||||
for (size_t i = 0; i < cs.size(); i++) {
|
||||
double posWeight = cs[i].literals()[0].weight();
|
||||
double negWeight = cs[i].literals()[1].weight();
|
||||
unsigned nrGroundings = cs[i].constr()->size();
|
||||
if (Globals::logDomain) {
|
||||
totalWeight += (Util::logSum (posWeight, negWeight) * nrGroundings);
|
||||
} else {
|
||||
totalWeight *= std::pow (posWeight + negWeight, nrGroundings);
|
||||
}
|
||||
}
|
||||
return totalWeight;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
TrueNode::weight (void) const
|
||||
{
|
||||
return LogAware::multIdenty();
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
||||
: lwcnf_(lwcnf)
|
||||
{
|
||||
root_ = 0;
|
||||
Clauses ccc = lwcnf->clauses();
|
||||
ccc.erase (ccc.begin() + 5, ccc.end());
|
||||
//ccc.erase (ccc.begin() + 5, ccc.end());
|
||||
//Clause c2 = ccc.front();
|
||||
//c2.removeLiteralByIndex (1);
|
||||
//ccc.push_back (c2);
|
||||
|
||||
//compile (&root_, lwcnf->clauses());
|
||||
compile (&root_, ccc);
|
||||
cout << "done compiling..." << endl;
|
||||
exportToGraphViz("circuit.dot");
|
||||
smoothCircuit();
|
||||
exportToGraphViz("smooth.dot");
|
||||
cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl;
|
||||
}
|
||||
|
||||
|
||||
@ -31,6 +101,14 @@ LiftedCircuit::smoothCircuit (void)
|
||||
|
||||
|
||||
|
||||
double
|
||||
LiftedCircuit::getWeightedModelCount (void) const
|
||||
{
|
||||
return root_->weight();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedCircuit::exportToGraphViz (const char* fileName)
|
||||
{
|
||||
@ -63,14 +141,14 @@ LiftedCircuit::compile (
|
||||
static int count = 0; count ++;
|
||||
*follow = new LeafNode (clauses[0]);
|
||||
if (count == 1) {
|
||||
Clause c (new ConstraintTree({}));
|
||||
c.addLiteral (Literal (100,{}));
|
||||
*follow = new LeafNode (c);
|
||||
// Clause c (new ConstraintTree({}));
|
||||
// c.addLiteral (Literal (100,{}));
|
||||
// *follow = new LeafNode (c);
|
||||
}
|
||||
if (count == 2) {
|
||||
Clause c (new ConstraintTree({}));
|
||||
c.addLiteral (Literal (101,{}));
|
||||
*follow = new LeafNode (c);
|
||||
// Clause c (new ConstraintTree({}));
|
||||
// c.addLiteral (Literal (101,{}));
|
||||
// *follow = new LeafNode (c);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -326,83 +404,104 @@ void
|
||||
LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
||||
{
|
||||
assert (node != 0);
|
||||
|
||||
static unsigned nrAndNodes = 0;
|
||||
static unsigned nrOrNodes = 0;
|
||||
|
||||
if (dynamic_cast<OrNode*>(node) != 0) {
|
||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||
const Clauses& clauses = node->clauses();
|
||||
if (clauses.empty() == false) {
|
||||
os << escapeNode (node) << " [shape=box,label=\"" ;
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
switch (getCircuitNodeType (node)) {
|
||||
|
||||
case OR_NODE: {
|
||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||
const Clauses& clauses = node->clauses();
|
||||
if (clauses.empty() == false) {
|
||||
os << escapeNode (node) << " [shape=box,label=\"" ;
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
}
|
||||
os << "or" << nrOrNodes << " [label=\"∨\"]" << endl;
|
||||
os << '"' << node << '"' << " -> " << "or" << nrOrNodes;
|
||||
os << " [label=\"" << node->explanation() << "\"]" << endl;
|
||||
os << "or" << nrOrNodes << " -> " ;
|
||||
os << escapeNode (*casted->leftBranch()) << endl;
|
||||
os << "or" << nrOrNodes << " -> " ;
|
||||
os << escapeNode (*casted->rightBranch()) << endl;
|
||||
nrOrNodes ++;
|
||||
exportToGraphViz (*casted->leftBranch(), os);
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
break;
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
|
||||
case AND_NODE: {
|
||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||
const Clauses& clauses = node->clauses();
|
||||
os << escapeNode (node) << " [shape=box,label=\"" ;
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
os << "and" << nrAndNodes << " [label=\"∧\"]" << endl;
|
||||
os << '"' << node << '"' << " -> " << "and" << nrAndNodes;
|
||||
os << " [label=\"" << node->explanation() << "\"]" << endl;
|
||||
os << "and" << nrAndNodes << " -> " ;
|
||||
os << escapeNode (*casted->leftBranch()) << endl;
|
||||
os << "and" << nrAndNodes << " -> " ;
|
||||
os << escapeNode (*casted->rightBranch()) << endl;
|
||||
nrAndNodes ++;
|
||||
exportToGraphViz (*casted->leftBranch(), os);
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
break;
|
||||
}
|
||||
os << "or" << nrOrNodes << " [label=\"∨\"]" << endl;
|
||||
os << '"' << node << '"' << " -> " << "or" << nrOrNodes;
|
||||
os << " [label=\"" << node->explanation() << "\"]" << endl;
|
||||
os << "or" << nrOrNodes << " -> " ;
|
||||
os << escapeNode (*casted->leftBranch()) << endl;
|
||||
os << "or" << nrOrNodes << " -> " ;
|
||||
os << escapeNode (*casted->rightBranch()) << endl;
|
||||
nrOrNodes ++;
|
||||
exportToGraphViz (*casted->leftBranch(), os);
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
} else if (dynamic_cast<AndNode*>(node) != 0) {
|
||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||
const Clauses& clauses = node->clauses();
|
||||
os << escapeNode (node) << " [shape=box,label=\"" ;
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
|
||||
case LEAF_NODE: {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,label=\"" ;
|
||||
os << node->clauses()[0];
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
break;
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
os << "and" << nrAndNodes << " [label=\"∧\"]" << endl;
|
||||
os << '"' << node << '"' << " -> " << "and" << nrAndNodes;
|
||||
os << " [label=\"" << node->explanation() << "\"]" << endl;
|
||||
os << "and" << nrAndNodes << " -> " ;
|
||||
os << escapeNode (*casted->leftBranch()) << endl;
|
||||
os << "and" << nrAndNodes << " -> " ;
|
||||
os << escapeNode (*casted->rightBranch()) << endl;
|
||||
nrAndNodes ++;
|
||||
exportToGraphViz (*casted->leftBranch(), os);
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
} else if (dynamic_cast<LeafNode*>(node) != 0) {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,label=\"" ;
|
||||
os << node->clauses()[0];
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
} else if (dynamic_cast<SmoothNode*>(node) != 0) {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,style=filled,fillcolor=chartreuse,label=\"" ;
|
||||
const Clauses& clauses = node->clauses();
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
|
||||
case SMOOTH_NODE: {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,style=filled,fillcolor=chartreuse,label=\"" ;
|
||||
const Clauses& clauses = node->clauses();
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
break;
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
} else if (dynamic_cast<TrueNode*>(node) != 0) {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,label=\"⊤\"]" ;
|
||||
os << endl;
|
||||
} else if (dynamic_cast<FailNode*>(node) != 0) {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,style=filled,fillcolor=brown1,label=\"" ;
|
||||
const Clauses& clauses = node->clauses();
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
|
||||
case TRUE_NODE: {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,label=\"⊤\"]" ;
|
||||
os << endl;
|
||||
break;
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
} else {
|
||||
assert (false);
|
||||
}
|
||||
|
||||
case FAIL_NODE: {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,style=filled,fillcolor=brown1,label=\"" ;
|
||||
const Clauses& clauses = node->clauses();
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,9 @@ class CircuitNode
|
||||
CircuitNode (const Clauses& clauses, string explanation = "")
|
||||
: clauses_(clauses), explanation_(explanation) { }
|
||||
|
||||
const Clauses& clauses (void) { return clauses_; }
|
||||
const Clauses& clauses (void) const { return clauses_; }
|
||||
|
||||
Clauses clauses (void) { return clauses_; }
|
||||
|
||||
virtual double weight (void) const { return 0; }
|
||||
|
||||
@ -43,6 +45,8 @@ class OrNode : public CircuitNode
|
||||
OrNode (const Clauses& clauses, string explanation = "")
|
||||
: CircuitNode (clauses, explanation),
|
||||
leftBranch_(0), rightBranch_(0) { }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||
@ -74,6 +78,8 @@ class AndNode : public CircuitNode
|
||||
string explanation = "")
|
||||
: CircuitNode ({}, explanation),
|
||||
leftBranch_(leftBranch), rightBranch_(rightBranch) { }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||
@ -117,6 +123,8 @@ class LeafNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
LeafNode (const Clause& clause) : CircuitNode ({clause}) { }
|
||||
|
||||
double weight (void) const;
|
||||
};
|
||||
|
||||
|
||||
@ -125,6 +133,8 @@ class SmoothNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
SmoothNode (const Clauses& clauses) : CircuitNode (clauses) { }
|
||||
|
||||
double weight (void) const;
|
||||
};
|
||||
|
||||
|
||||
@ -133,6 +143,8 @@ class TrueNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
TrueNode () : CircuitNode ({}) { }
|
||||
|
||||
double weight (void) const;
|
||||
};
|
||||
|
||||
|
||||
@ -153,8 +165,10 @@ class LiftedCircuit
|
||||
|
||||
void smoothCircuit (void);
|
||||
|
||||
double getWeightedModelCount (void) const;
|
||||
|
||||
void exportToGraphViz (const char*);
|
||||
|
||||
|
||||
private:
|
||||
|
||||
void compile (CircuitNode** follow, const Clauses& clauses);
|
||||
|
@ -150,8 +150,15 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
{
|
||||
addIndicatorClauses (pfList);
|
||||
addParameterClauses (pfList);
|
||||
cout << "FORMULA INDICATORS:" << endl;
|
||||
printFormulaIndicators();
|
||||
cout << endl;
|
||||
cout << "WEIGHTS:" << endl;
|
||||
printWeights();
|
||||
cout << endl;
|
||||
cout << "CLAUSES:" << endl;
|
||||
printClauses();
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
||||
@ -237,6 +244,8 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
||||
while (indexer.valid()) {
|
||||
LiteralId paramVarLid = freeLiteralId_;
|
||||
|
||||
double weight = (**it)[indexer];
|
||||
|
||||
Clause clause1 ((*it)->constr());
|
||||
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
@ -245,11 +254,11 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
||||
clause1.addAndNegateLiteral (Literal (lid, (*it)->argument(i).logVars()));
|
||||
|
||||
Clause tempClause ((*it)->constr());
|
||||
tempClause.addAndNegateLiteral (Literal (paramVarLid, LogVars(), 1.0));
|
||||
tempClause.addAndNegateLiteral (Literal (paramVarLid, 1.0));
|
||||
tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars()));
|
||||
clauses_.push_back (tempClause);
|
||||
}
|
||||
clause1.addLiteral (Literal (paramVarLid, LogVars(), 1.0));
|
||||
clause1.addLiteral (Literal (paramVarLid, weight));
|
||||
clauses_.push_back (clause1);
|
||||
freeLiteralId_ ++;
|
||||
++ indexer;
|
||||
@ -259,16 +268,6 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::printClauses (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < clauses_.size(); i++) {
|
||||
cout << clauses_[i] << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::printFormulaIndicators (void) const
|
||||
{
|
||||
@ -293,3 +292,55 @@ LiftedWCNF::printFormulaIndicators (void) const
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::printWeights (void) const
|
||||
{
|
||||
for (LiteralId i = 0; i < freeLiteralId_; i++) {
|
||||
|
||||
bool found = false;
|
||||
for (size_t j = 0; j < clauses_.size(); j++) {
|
||||
Literals literals = clauses_[j].literals();
|
||||
for (size_t k = 0; k < literals.size(); k++) {
|
||||
if (literals[k].lid() == i && literals[k].isPositive()) {
|
||||
cout << "weight(" << literals[k] << ") = " << literals[k].weight();
|
||||
cout << endl;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found == true) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
found = false;
|
||||
for (size_t j = 0; j < clauses_.size(); j++) {
|
||||
Literals literals = clauses_[j].literals();
|
||||
for (size_t k = 0; k < literals.size(); k++) {
|
||||
if (literals[k].lid() == i && literals[k].isNegative()) {
|
||||
cout << "weight(" << literals[k] << ") = " << literals[k].weight();
|
||||
cout << endl;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found == true) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::printClauses (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < clauses_.size(); i++) {
|
||||
cout << clauses_[i] << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -13,8 +13,11 @@ class ConstraintTree;
|
||||
class Literal
|
||||
{
|
||||
public:
|
||||
Literal (LiteralId lid, double w = -1.0) :
|
||||
lid_(lid), weight_(w), negated_(false) { }
|
||||
|
||||
Literal (LiteralId lid, const LogVars& lvs, double w = -1.0) :
|
||||
lid_(lid), logVars_(lvs), weight_(w), negated_(false) { }
|
||||
lid_(lid), logVars_(lvs), weight_(w), negated_(false) { }
|
||||
|
||||
Literal (const Literal& lit, bool negated) :
|
||||
lid_(lit.lid_), logVars_(lit.logVars_), weight_(lit.weight_), negated_(negated) { }
|
||||
@ -23,7 +26,8 @@ class Literal
|
||||
|
||||
LogVars logVars (void) const { return logVars_; }
|
||||
|
||||
double weight (void) const { return weight_; }
|
||||
// FIXME not log aware
|
||||
double weight (void) const { return weight_ < 0.0 ? 1.0 : weight_; }
|
||||
|
||||
void negate (void) { negated_ = !negated_; }
|
||||
|
||||
@ -106,9 +110,11 @@ class LiftedWCNF
|
||||
|
||||
Clause createClauseForLiteral (LiteralId lid) const;
|
||||
|
||||
void printClauses (void) const;
|
||||
|
||||
void printFormulaIndicators (void) const;
|
||||
|
||||
void printWeights (void) const;
|
||||
|
||||
void printClauses (void) const;
|
||||
|
||||
private:
|
||||
void addIndicatorClauses (const ParfactorList& pfList);
|
||||
|
Reference in New Issue
Block a user