initial support for weighted model countign
This commit is contained in:
@@ -3,22 +3,92 @@
|
||||
#include "LiftedCircuit.h"
|
||||
|
||||
|
||||
double
|
||||
OrNode::weight (void) const
|
||||
{
|
||||
double lw = leftBranch_->weight();
|
||||
double rw = rightBranch_->weight();
|
||||
return Globals::logDomain ? Util::logSum (lw, rw) : lw + rw;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
AndNode::weight (void) const
|
||||
{
|
||||
double lw = leftBranch_->weight();
|
||||
double rw = rightBranch_->weight();
|
||||
return Globals::logDomain ? lw + rw : lw * rw;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
LeafNode::weight (void) const
|
||||
{
|
||||
assert (clauses().size() == 1);
|
||||
assert (clauses()[0].isUnit());
|
||||
Clause c = clauses()[0];
|
||||
double weight = c.literals()[0].weight();
|
||||
unsigned nrGroundings = c.constr()->size();
|
||||
assert (nrGroundings != 0);
|
||||
double www = Globals::logDomain
|
||||
? weight * nrGroundings
|
||||
: std::pow (weight, nrGroundings);
|
||||
|
||||
cout << "leaf w: " << www << endl;
|
||||
|
||||
return Globals::logDomain
|
||||
? weight * nrGroundings
|
||||
: std::pow (weight, nrGroundings);
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
SmoothNode::weight (void) const
|
||||
{
|
||||
Clauses cs = clauses();
|
||||
double totalWeight = LogAware::multIdenty();
|
||||
for (size_t i = 0; i < cs.size(); i++) {
|
||||
double posWeight = cs[i].literals()[0].weight();
|
||||
double negWeight = cs[i].literals()[1].weight();
|
||||
unsigned nrGroundings = cs[i].constr()->size();
|
||||
if (Globals::logDomain) {
|
||||
totalWeight += (Util::logSum (posWeight, negWeight) * nrGroundings);
|
||||
} else {
|
||||
totalWeight *= std::pow (posWeight + negWeight, nrGroundings);
|
||||
}
|
||||
}
|
||||
return totalWeight;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
TrueNode::weight (void) const
|
||||
{
|
||||
return LogAware::multIdenty();
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
||||
: lwcnf_(lwcnf)
|
||||
{
|
||||
root_ = 0;
|
||||
Clauses ccc = lwcnf->clauses();
|
||||
ccc.erase (ccc.begin() + 5, ccc.end());
|
||||
//ccc.erase (ccc.begin() + 5, ccc.end());
|
||||
//Clause c2 = ccc.front();
|
||||
//c2.removeLiteralByIndex (1);
|
||||
//ccc.push_back (c2);
|
||||
|
||||
//compile (&root_, lwcnf->clauses());
|
||||
compile (&root_, ccc);
|
||||
cout << "done compiling..." << endl;
|
||||
exportToGraphViz("circuit.dot");
|
||||
smoothCircuit();
|
||||
exportToGraphViz("smooth.dot");
|
||||
cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl;
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +101,14 @@ LiftedCircuit::smoothCircuit (void)
|
||||
|
||||
|
||||
|
||||
double
|
||||
LiftedCircuit::getWeightedModelCount (void) const
|
||||
{
|
||||
return root_->weight();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedCircuit::exportToGraphViz (const char* fileName)
|
||||
{
|
||||
@@ -63,14 +141,14 @@ LiftedCircuit::compile (
|
||||
static int count = 0; count ++;
|
||||
*follow = new LeafNode (clauses[0]);
|
||||
if (count == 1) {
|
||||
Clause c (new ConstraintTree({}));
|
||||
c.addLiteral (Literal (100,{}));
|
||||
*follow = new LeafNode (c);
|
||||
// Clause c (new ConstraintTree({}));
|
||||
// c.addLiteral (Literal (100,{}));
|
||||
// *follow = new LeafNode (c);
|
||||
}
|
||||
if (count == 2) {
|
||||
Clause c (new ConstraintTree({}));
|
||||
c.addLiteral (Literal (101,{}));
|
||||
*follow = new LeafNode (c);
|
||||
// Clause c (new ConstraintTree({}));
|
||||
// c.addLiteral (Literal (101,{}));
|
||||
// *follow = new LeafNode (c);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -326,83 +404,104 @@ void
|
||||
LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
||||
{
|
||||
assert (node != 0);
|
||||
|
||||
static unsigned nrAndNodes = 0;
|
||||
static unsigned nrOrNodes = 0;
|
||||
|
||||
if (dynamic_cast<OrNode*>(node) != 0) {
|
||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||
const Clauses& clauses = node->clauses();
|
||||
if (clauses.empty() == false) {
|
||||
os << escapeNode (node) << " [shape=box,label=\"" ;
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
switch (getCircuitNodeType (node)) {
|
||||
|
||||
case OR_NODE: {
|
||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||
const Clauses& clauses = node->clauses();
|
||||
if (clauses.empty() == false) {
|
||||
os << escapeNode (node) << " [shape=box,label=\"" ;
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
}
|
||||
os << "or" << nrOrNodes << " [label=\"∨\"]" << endl;
|
||||
os << '"' << node << '"' << " -> " << "or" << nrOrNodes;
|
||||
os << " [label=\"" << node->explanation() << "\"]" << endl;
|
||||
os << "or" << nrOrNodes << " -> " ;
|
||||
os << escapeNode (*casted->leftBranch()) << endl;
|
||||
os << "or" << nrOrNodes << " -> " ;
|
||||
os << escapeNode (*casted->rightBranch()) << endl;
|
||||
nrOrNodes ++;
|
||||
exportToGraphViz (*casted->leftBranch(), os);
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
break;
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
|
||||
case AND_NODE: {
|
||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||
const Clauses& clauses = node->clauses();
|
||||
os << escapeNode (node) << " [shape=box,label=\"" ;
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
os << "and" << nrAndNodes << " [label=\"∧\"]" << endl;
|
||||
os << '"' << node << '"' << " -> " << "and" << nrAndNodes;
|
||||
os << " [label=\"" << node->explanation() << "\"]" << endl;
|
||||
os << "and" << nrAndNodes << " -> " ;
|
||||
os << escapeNode (*casted->leftBranch()) << endl;
|
||||
os << "and" << nrAndNodes << " -> " ;
|
||||
os << escapeNode (*casted->rightBranch()) << endl;
|
||||
nrAndNodes ++;
|
||||
exportToGraphViz (*casted->leftBranch(), os);
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
break;
|
||||
}
|
||||
os << "or" << nrOrNodes << " [label=\"∨\"]" << endl;
|
||||
os << '"' << node << '"' << " -> " << "or" << nrOrNodes;
|
||||
os << " [label=\"" << node->explanation() << "\"]" << endl;
|
||||
os << "or" << nrOrNodes << " -> " ;
|
||||
os << escapeNode (*casted->leftBranch()) << endl;
|
||||
os << "or" << nrOrNodes << " -> " ;
|
||||
os << escapeNode (*casted->rightBranch()) << endl;
|
||||
nrOrNodes ++;
|
||||
exportToGraphViz (*casted->leftBranch(), os);
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
} else if (dynamic_cast<AndNode*>(node) != 0) {
|
||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||
const Clauses& clauses = node->clauses();
|
||||
os << escapeNode (node) << " [shape=box,label=\"" ;
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
|
||||
case LEAF_NODE: {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,label=\"" ;
|
||||
os << node->clauses()[0];
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
break;
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
os << "and" << nrAndNodes << " [label=\"∧\"]" << endl;
|
||||
os << '"' << node << '"' << " -> " << "and" << nrAndNodes;
|
||||
os << " [label=\"" << node->explanation() << "\"]" << endl;
|
||||
os << "and" << nrAndNodes << " -> " ;
|
||||
os << escapeNode (*casted->leftBranch()) << endl;
|
||||
os << "and" << nrAndNodes << " -> " ;
|
||||
os << escapeNode (*casted->rightBranch()) << endl;
|
||||
nrAndNodes ++;
|
||||
exportToGraphViz (*casted->leftBranch(), os);
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
} else if (dynamic_cast<LeafNode*>(node) != 0) {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,label=\"" ;
|
||||
os << node->clauses()[0];
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
} else if (dynamic_cast<SmoothNode*>(node) != 0) {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,style=filled,fillcolor=chartreuse,label=\"" ;
|
||||
const Clauses& clauses = node->clauses();
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
|
||||
case SMOOTH_NODE: {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,style=filled,fillcolor=chartreuse,label=\"" ;
|
||||
const Clauses& clauses = node->clauses();
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
break;
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
} else if (dynamic_cast<TrueNode*>(node) != 0) {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,label=\"⊤\"]" ;
|
||||
os << endl;
|
||||
} else if (dynamic_cast<FailNode*>(node) != 0) {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,style=filled,fillcolor=brown1,label=\"" ;
|
||||
const Clauses& clauses = node->clauses();
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
|
||||
case TRUE_NODE: {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,label=\"⊤\"]" ;
|
||||
os << endl;
|
||||
break;
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
} else {
|
||||
assert (false);
|
||||
}
|
||||
|
||||
case FAIL_NODE: {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,style=filled,fillcolor=brown1,label=\"" ;
|
||||
const Clauses& clauses = node->clauses();
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (i != 0) os << "\\n" ;
|
||||
os << clauses[i];
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user