support smoothing on atom counting nodes (beta)

This commit is contained in:
Tiago Gomes 2012-11-04 18:02:40 +00:00
parent b8cef8798a
commit 4518a3db5d
4 changed files with 181 additions and 32 deletions

View File

@ -23,6 +23,10 @@ AndNode::weight (void) const
stack<pair<unsigned, unsigned>> SetOrNode::nrGrsStack;
double double
SetOrNode::weight (void) const SetOrNode::weight (void) const
{ {
@ -583,82 +587,195 @@ LiftedCircuit::shatterCountedLogVarsAux (
TinySet<LiteralId> LogVarTypes
unionTypes (const LogVarTypes& types1, const LogVarTypes& types2)
{
if (types1.empty()) {
return types2;
}
if (types2.empty()) {
return types1;
}
assert (types1.size() == types2.size());
LogVarTypes res;
for (size_t i = 0; i < types1.size(); i++) {
if (types1[i] == LogVarType::POS_LV
&& types2[i] == LogVarType::POS_LV) {
res.push_back (LogVarType::POS_LV);
} else if (types1[i] == LogVarType::NEG_LV
&& types2[i] == LogVarType::NEG_LV) {
res.push_back (LogVarType::NEG_LV);
} else {
res.push_back (LogVarType::FULL_LV);
}
}
return res;
}
vector<LogVarTypes>
getAllPossibleTypes (unsigned nrLogVars)
{
if (nrLogVars == 0) {
return {};
}
if (nrLogVars == 1) {
return {{LogVarType::POS_LV},{LogVarType::NEG_LV}};
}
vector<LogVarTypes> res;
Indexer indexer (vector<unsigned> (nrLogVars, 2));
while (indexer.valid()) {
LogVarTypes types;
for (size_t i = 0; i < nrLogVars; i++) {
if (indexer[i] == 0) {
types.push_back (LogVarType::POS_LV);
} else {
types.push_back (LogVarType::NEG_LV);
}
}
res.push_back (types);
++ indexer;
}
return res;
}
bool
containsTypes (const LogVarTypes& typesA, const LogVarTypes& typesB)
{
for (size_t i = 0; i < typesA.size(); i++) {
if (typesA[i] == LogVarType::FULL_LV) {
} else if (typesA[i] == LogVarType::POS_LV
&& typesB[i] == LogVarType::POS_LV) {
} else if (typesA[i] == LogVarType::NEG_LV
&& typesB[i] == LogVarType::NEG_LV) {
} else {
return false;
}
}
return true;
}
LitLvTypesSet
LiftedCircuit::smoothCircuit (CircuitNode* node) LiftedCircuit::smoothCircuit (CircuitNode* node)
{ {
assert (node != 0); assert (node != 0);
TinySet<LiteralId> propagatingLids; LitLvTypesSet propagLits;
switch (getCircuitNodeType (node)) { switch (getCircuitNodeType (node)) {
case CircuitNodeType::OR_NODE: { case CircuitNodeType::OR_NODE: {
OrNode* casted = dynamic_cast<OrNode*>(node); OrNode* casted = dynamic_cast<OrNode*>(node);
TinySet<LiteralId> lids1 = smoothCircuit (*casted->leftBranch()); LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
TinySet<LiteralId> lids2 = smoothCircuit (*casted->rightBranch()); LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
TinySet<LiteralId> missingLeft = lids2 - lids1; LitLvTypesSet missingLeft = lids2 - lids1;
TinySet<LiteralId> missingRight = lids1 - lids2; LitLvTypesSet missingRight = lids1 - lids2;
createSmoothNode (missingLeft, casted->leftBranch()); createSmoothNode (missingLeft, casted->leftBranch());
createSmoothNode (missingRight, casted->rightBranch()); createSmoothNode (missingRight, casted->rightBranch());
propagatingLids |= lids1; propagLits |= lids1;
propagatingLids |= lids2; propagLits |= lids2;
break; break;
} }
case CircuitNodeType::AND_NODE: { case CircuitNodeType::AND_NODE: {
AndNode* casted = dynamic_cast<AndNode*>(node); AndNode* casted = dynamic_cast<AndNode*>(node);
TinySet<LiteralId> lids1 = smoothCircuit (*casted->leftBranch()); LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
TinySet<LiteralId> lids2 = smoothCircuit (*casted->rightBranch()); LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
propagatingLids |= lids1; propagLits |= lids1;
propagatingLids |= lids2; propagLits |= lids2;
break; break;
} }
case CircuitNodeType::SET_OR_NODE: { case CircuitNodeType::SET_OR_NODE: {
// TODO SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
propagLits = smoothCircuit (*casted->follow());
TinySet<pair<LiteralId,unsigned>> litSet;
for (size_t i = 0; i < propagLits.size(); i++) {
litSet.insert (make_pair (propagLits[i].lid(),
propagLits[i].logVarTypes().size()));
}
LitLvTypesSet missingLids;
for (size_t i = 0; i < litSet.size(); i++) {
vector<LogVarTypes> allTypes = getAllPossibleTypes (litSet[i].second);
for (size_t j = 0; j < allTypes.size(); j++) {
bool typeFound = false;
for (size_t k = 0; k < propagLits.size(); k++) {
if (litSet[i].first == propagLits[k].lid()
&& containsTypes (allTypes[j], propagLits[k].logVarTypes())) {
typeFound = true;
break;
}
}
if (typeFound == false) {
missingLids.insert (LiteralLvTypes (litSet[i].first, allTypes[j]));
}
}
}
createSmoothNode (missingLids, casted->follow());
// TODO change propagLits to full lvs
break; break;
} }
case CircuitNodeType::SET_AND_NODE: { case CircuitNodeType::SET_AND_NODE: {
SetAndNode* casted = dynamic_cast<SetAndNode*>(node); SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
propagatingLids = smoothCircuit (*casted->follow()); propagLits = smoothCircuit (*casted->follow());
break; break;
} }
case CircuitNodeType::INC_EXC_NODE: { case CircuitNodeType::INC_EXC_NODE: {
IncExcNode* casted = dynamic_cast<IncExcNode*>(node); IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
TinySet<LiteralId> lids1 = smoothCircuit (*casted->plus1Branch()); LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch());
TinySet<LiteralId> lids2 = smoothCircuit (*casted->plus2Branch()); LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch());
TinySet<LiteralId> missingPlus1 = lids2 - lids1; LitLvTypesSet missingPlus1 = lids2 - lids1;
TinySet<LiteralId> missingPlus2 = lids1 - lids2; LitLvTypesSet missingPlus2 = lids1 - lids2;
createSmoothNode (missingPlus1, casted->plus1Branch()); createSmoothNode (missingPlus1, casted->plus1Branch());
createSmoothNode (missingPlus2, casted->plus2Branch()); createSmoothNode (missingPlus2, casted->plus2Branch());
propagatingLids |= lids1; propagLits |= lids1;
propagatingLids |= lids2; propagLits |= lids2;
break; break;
} }
case CircuitNodeType::LEAF_NODE: { case CircuitNodeType::LEAF_NODE: {
propagatingLids.insert (node->clauses()[0].literals()[0].lid()); propagLits.insert (LiteralLvTypes (
node->clauses()[0].literals()[0].lid(),
node->clauses()[0].logVarTypes(0)));
} }
default: default:
break; break;
} }
return propagatingLids; return propagLits;
} }
void void
LiftedCircuit::createSmoothNode ( LiftedCircuit::createSmoothNode (
const TinySet<LiteralId>& missingLids, const LitLvTypesSet& missingLits,
CircuitNode** prev) CircuitNode** prev)
{ {
if (missingLids.empty() == false) { if (missingLits.empty() == false) {
Clauses clauses; Clauses clauses;
for (size_t i = 0; i < missingLids.size(); i++) { for (size_t i = 0; i < missingLits.size(); i++) {
Clause c = lwcnf_->createClauseForLiteral (missingLids[i]); LiteralId lid = missingLits[i].lid();
const LogVarTypes& types = missingLits[i].logVarTypes();
Clause c = lwcnf_->createClauseForLiteral (lid);
for (size_t j = 0; j < types.size(); j++) {
LogVar X = c.literals().front().logVars()[j];
if (types[j] == LogVarType::POS_LV) {
c.addPositiveCountedLogVar (X);
} else if (types[j] == LogVarType::NEG_LV) {
c.addNegativeCountedLogVar (X);
}
}
c.addAndNegateLiteral (c.literals()[0]); c.addAndNegateLiteral (c.literals()[0]);
clauses.push_back (c); clauses.push_back (c);
} }

View File

@ -234,9 +234,9 @@ class LiftedCircuit
bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2); bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2);
TinySet<LiteralId> smoothCircuit (CircuitNode* node); LitLvTypesSet smoothCircuit (CircuitNode* node);
void createSmoothNode (const TinySet<LiteralId>& lids, void createSmoothNode (const LitLvTypesSet& lids,
CircuitNode** prev); CircuitNode** prev);
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const; CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;

View File

@ -303,13 +303,13 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
vector<vector<string>> names = {{"p1","p1"},{"p2","p2"}}; vector<vector<string>> names = {{"p1","p1"},{"p2","p2"}};
Clause c1 (names); Clause c1 (names);
c1.addLiteral (Literal (0, LogVars()={0})); c1.addLiteral (Literal (0, LogVars()={0}, 3.0));
c1.addAndNegateLiteral (Literal (1, {0,1})); c1.addAndNegateLiteral (Literal (1, {0,1}, 1.0));
clauses_.push_back(c1); clauses_.push_back(c1);
Clause c2 (names); Clause c2 (names);
c2.addLiteral (Literal (0, LogVars()={0})); c2.addLiteral (Literal (0, LogVars()={0}, 2.0));
c2.addAndNegateLiteral (Literal (1, {1,0})); c2.addAndNegateLiteral (Literal (1, {1,0}, 5.0));
clauses_.push_back(c2); clauses_.push_back(c2);
cout << "FORMULA INDICATORS:" << endl; cout << "FORMULA INDICATORS:" << endl;

View File

@ -146,6 +146,38 @@ typedef vector<Clause> Clauses;
class LiteralLvTypes
{
public:
struct CompareLiteralLvTypes
{
bool operator() (
const LiteralLvTypes& types1,
const LiteralLvTypes& types2) const
{
if (types1.lid_ < types2.lid_) {
return true;
}
return types1.lvTypes_ < types2.lvTypes_;
}
};
LiteralLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
lid_(lid), lvTypes_(lvTypes) { }
LiteralId lid (void) const { return lid_; }
const LogVarTypes& logVarTypes (void) const { return lvTypes_; }
private:
LiteralId lid_;
LogVarTypes lvTypes_;
};
typedef TinySet<LiteralLvTypes,LiteralLvTypes::CompareLiteralLvTypes> LitLvTypesSet;
class LiftedWCNF class LiftedWCNF
{ {
public: public: