first stab for atom counting
This commit is contained in:
@@ -115,7 +115,7 @@ CompilationFailedNode::weight (void) const
|
||||
{
|
||||
// we should not perform model counting
|
||||
// in compilation failed nodes
|
||||
abort();
|
||||
// abort();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
@@ -125,19 +125,11 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
||||
: lwcnf_(lwcnf)
|
||||
{
|
||||
root_ = 0;
|
||||
Clauses ccc = lwcnf->clauses();
|
||||
//ccc.erase (ccc.begin() + 5, ccc.end());
|
||||
//Clause c2 = ccc.front();
|
||||
//c2.removeLiteralByIndex (1);
|
||||
//ccc.push_back (c2);
|
||||
|
||||
//compile (&root_, lwcnf->clauses());
|
||||
Clauses cccc = {ccc[6],ccc[4]};
|
||||
cccc.front().removeLiteral (2);
|
||||
compile (&root_, cccc);
|
||||
Clauses clauses = lwcnf->clauses();
|
||||
compile (&root_, clauses);
|
||||
exportToGraphViz("circuit.dot");
|
||||
smoothCircuit();
|
||||
exportToGraphViz("smooth.dot");
|
||||
exportToGraphViz("circuit.smooth.dot");
|
||||
cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl;
|
||||
}
|
||||
|
||||
@@ -188,18 +180,7 @@ LiftedCircuit::compile (
|
||||
}
|
||||
|
||||
if (clauses.size() == 1 && clauses[0].isUnit()) {
|
||||
static int count = 0; count ++;
|
||||
*follow = new LeafNode (clauses[0]);
|
||||
if (count == 1) {
|
||||
// Clause c (new ConstraintTree({}));
|
||||
// c.addLiteral (Literal (100,{}));
|
||||
// *follow = new LeafNode (c);
|
||||
}
|
||||
if (count == 2) {
|
||||
// Clause c (new ConstraintTree({}));
|
||||
// c.addLiteral (Literal (101,{}));
|
||||
// *follow = new LeafNode (c);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -219,7 +200,11 @@ LiftedCircuit::compile (
|
||||
return;
|
||||
}
|
||||
|
||||
if (tryIndepPartialGrounding (follow, clauses)) {
|
||||
//if (tryIndepPartialGrounding (follow, clauses)) {
|
||||
// return;
|
||||
//}
|
||||
|
||||
if (tryAtomCounting (follow, clauses)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -229,7 +214,81 @@ LiftedCircuit::compile (
|
||||
|
||||
// assert (false);
|
||||
*follow = new CompilationFailedNode (clauses);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
LiftedCircuit::propagate (
|
||||
const Clause& c,
|
||||
const Clause& unitClause,
|
||||
Clauses& newClauses)
|
||||
{
|
||||
/*
|
||||
Literals literals = c.literals();
|
||||
for (size_t i = 0; i < literals.size(); i++) {
|
||||
if (literals_[i].lid() == lid && literals[i].isPositive()) {
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
|
||||
bool
|
||||
shatterCountedLogVars (Clauses& clauses, size_t idx1, size_t idx2)
|
||||
{
|
||||
Literals lits1 = clauses[idx1].literals();
|
||||
Literals lits2 = clauses[idx2].literals();
|
||||
for (size_t i = 0; i < lits1.size(); i++) {
|
||||
for (size_t j = 0; j < lits2.size(); j++) {
|
||||
if (lits1[i].lid() == lits2[j].lid()) {
|
||||
LogVars lvs1 = lits1[i].logVars();
|
||||
LogVars lvs2 = lits2[j].logVars();
|
||||
for (size_t k = 0; k < lvs1.size(); k++) {
|
||||
if (clauses[idx1].isCountedLogVar (lvs1[k])
|
||||
&& clauses[idx2].isCountedLogVar (lvs2[k]) == false) {
|
||||
clauses.push_back (clauses[idx2]);
|
||||
clauses[idx2].addPositiveCountedLogVar (lvs2[k]);
|
||||
clauses.back().addNegativeCountedLogVar (lvs2[k]);
|
||||
return true;
|
||||
}
|
||||
if (clauses[idx2].isCountedLogVar (lvs2[k])
|
||||
&& clauses[idx1].isCountedLogVar (lvs1[k]) == false) {
|
||||
clauses.push_back (clauses[idx1]);
|
||||
clauses[idx1].addPositiveCountedLogVar (lvs1[k]);
|
||||
clauses.back().addNegativeCountedLogVar (lvs1[k]);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
shatterCountedLogVarsAux (Clauses& clauses)
|
||||
{
|
||||
for (size_t i = 0; i < clauses.size() - 1; i++) {
|
||||
for (size_t j = i + 1; j < clauses.size(); j++) {
|
||||
bool splitedSome = shatterCountedLogVars (clauses, i, j);
|
||||
if (splitedSome) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
shatterCountedLogVars (Clauses& clauses)
|
||||
{
|
||||
while (shatterCountedLogVarsAux (clauses)) ;
|
||||
}
|
||||
|
||||
|
||||
@@ -239,22 +298,29 @@ LiftedCircuit::tryUnitPropagation (
|
||||
CircuitNode** follow,
|
||||
Clauses& clauses)
|
||||
{
|
||||
cout << "ALL CLAUSES:" << endl;
|
||||
Clause::printClauses (clauses);
|
||||
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (clauses[i].isUnit()) {
|
||||
cout << clauses[i] << " is unit!" << endl;
|
||||
Clauses newClauses;
|
||||
for (size_t j = 0; j < clauses.size(); j++) {
|
||||
if (i != j) {
|
||||
LiteralId lid = clauses[i].literals()[0].lid();
|
||||
LogVarTypes types = clauses[i].logVarTypes (0);
|
||||
if (clauses[i].literals()[0].isPositive()) {
|
||||
if (clauses[j].containsPositiveLiteral (lid) == false) {
|
||||
if (clauses[j].containsPositiveLiteral (lid, types) == false) {
|
||||
Clause newClause = clauses[j];
|
||||
newClause.removeNegativeLiterals (lid);
|
||||
cout << "removing negative literals on " << newClause << endl;
|
||||
newClause.removeNegativeLiterals (lid, types);
|
||||
newClauses.push_back (newClause);
|
||||
}
|
||||
} else if (clauses[i].literals()[0].isNegative()) {
|
||||
if (clauses[j].containsNegativeLiteral (lid) == false) {
|
||||
if (clauses[j].containsNegativeLiteral (lid, types) == false) {
|
||||
Clause newClause = clauses[j];
|
||||
newClause.removePositiveLiterals (lid);
|
||||
cout << "removing negative literals on " << newClause << endl;
|
||||
newClause.removePositiveLiterals (lid, types);
|
||||
newClauses.push_back (newClause);
|
||||
}
|
||||
}
|
||||
@@ -264,6 +330,7 @@ LiftedCircuit::tryUnitPropagation (
|
||||
explanation << " UP on" << clauses[i].literals()[0];
|
||||
AndNode* andNode = new AndNode (clauses, explanation.str());
|
||||
Clauses leftClauses = {clauses[i]};
|
||||
cout << "new clauses: " << newClauses << endl;
|
||||
compile (andNode->leftBranch(), leftClauses);
|
||||
compile (andNode->rightBranch(), newClauses);
|
||||
(*follow) = andNode;
|
||||
@@ -454,6 +521,37 @@ LiftedCircuit::tryIndepPartialGroundingAux (
|
||||
|
||||
|
||||
|
||||
bool
|
||||
LiftedCircuit::tryAtomCounting (
|
||||
CircuitNode** follow,
|
||||
Clauses& clauses)
|
||||
{
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
Literals literals = clauses[i].literals();
|
||||
for (size_t j = 0; j < literals.size(); j++) {
|
||||
if (literals[j].logVars().size() == 1) {
|
||||
// TODO check if not already in ipg and countedlvs
|
||||
SetOrNode* setOrNode = new SetOrNode (clauses);
|
||||
Clause c1 (clauses[i].constr().projectedCopy (literals[j].logVars()));
|
||||
Clause c2 (clauses[i].constr().projectedCopy (literals[j].logVars()));
|
||||
c1.addLiteral (literals[j]);
|
||||
c2.addAndNegateLiteral (literals[j]);
|
||||
c1.addPositiveCountedLogVar (literals[j].logVars().front());
|
||||
c2.addNegativeCountedLogVar (literals[j].logVars().front());
|
||||
clauses.push_back (c1);
|
||||
clauses.push_back (c2);
|
||||
shatterCountedLogVars (clauses);
|
||||
compile (setOrNode->follow(), clauses);
|
||||
*follow = setOrNode;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
LiftedCircuit::tryGrounding (
|
||||
CircuitNode**,
|
||||
@@ -638,11 +736,11 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
case AND_NODE: {
|
||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
|
||||
os << auxNode << " [label=\"∧\"]" << endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << node->explanation() << "\"]" ;
|
||||
@@ -662,34 +760,47 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
case SET_OR_NODE: {
|
||||
// TODO
|
||||
assert (false);
|
||||
}
|
||||
|
||||
case SET_AND_NODE: {
|
||||
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
|
||||
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
os << auxNode << " [label=\"∧(X)\"]" << endl;
|
||||
|
||||
os << auxNode << " [label=\"∨(X)\"]" << endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << node->explanation() << "\"]" ;
|
||||
os << endl;
|
||||
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->follow());
|
||||
os << " [label=\" " << (*casted->follow())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
|
||||
|
||||
exportToGraphViz (*casted->follow(), os);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
case SET_AND_NODE: {
|
||||
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
os << auxNode << " [label=\"∧(X)\"]" << endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << node->explanation() << "\"]" ;
|
||||
os << endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->follow());
|
||||
os << " [label=\" " << (*casted->follow())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
|
||||
exportToGraphViz (*casted->follow(), os);
|
||||
break;
|
||||
}
|
||||
|
||||
case INC_EXC_NODE: {
|
||||
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
|
||||
os << auxNode << " [label=\"IncExc\"]" << endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << node->explanation() << "\"]" ;
|
||||
@@ -699,7 +810,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
||||
os << escapeNode (*casted->plus1Branch());
|
||||
os << " [label=\" " << (*casted->plus1Branch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->plus2Branch());
|
||||
os << " [label=\" " << (*casted->plus2Branch())->weight() << "\"]" ;
|
||||
@@ -709,35 +820,35 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
||||
os << escapeNode (*casted->minusBranch()) << endl;
|
||||
os << " [label=\" " << (*casted->minusBranch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
|
||||
|
||||
exportToGraphViz (*casted->plus1Branch(), os);
|
||||
exportToGraphViz (*casted->plus2Branch(), os);
|
||||
exportToGraphViz (*casted->minusBranch(), os);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
case LEAF_NODE: {
|
||||
printClauses (node, os, "style=filled,fillcolor=palegreen,");
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
case SMOOTH_NODE: {
|
||||
printClauses (node, os, "style=filled,fillcolor=lightblue,");
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
case TRUE_NODE: {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,label=\"⊤\"]" ;
|
||||
os << endl;
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
case COMPILATION_FAILED_NODE: {
|
||||
printClauses (node, os, "style=filled,fillcolor=salmon,");
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
|
Reference in New Issue
Block a user