support smoothing on atom counting nodes (beta)

This commit is contained in:
Tiago Gomes 2012-11-04 18:02:40 +00:00
parent b8cef8798a
commit 4518a3db5d
4 changed files with 181 additions and 32 deletions

View File

@ -23,6 +23,10 @@ AndNode::weight (void) const
stack<pair<unsigned, unsigned>> SetOrNode::nrGrsStack;
double
SetOrNode::weight (void) const
{
@ -583,82 +587,195 @@ LiftedCircuit::shatterCountedLogVarsAux (
TinySet<LiteralId>
LogVarTypes
unionTypes (const LogVarTypes& types1, const LogVarTypes& types2)
{
if (types1.empty()) {
return types2;
}
if (types2.empty()) {
return types1;
}
assert (types1.size() == types2.size());
LogVarTypes res;
for (size_t i = 0; i < types1.size(); i++) {
if (types1[i] == LogVarType::POS_LV
&& types2[i] == LogVarType::POS_LV) {
res.push_back (LogVarType::POS_LV);
} else if (types1[i] == LogVarType::NEG_LV
&& types2[i] == LogVarType::NEG_LV) {
res.push_back (LogVarType::NEG_LV);
} else {
res.push_back (LogVarType::FULL_LV);
}
}
return res;
}
vector<LogVarTypes>
getAllPossibleTypes (unsigned nrLogVars)
{
if (nrLogVars == 0) {
return {};
}
if (nrLogVars == 1) {
return {{LogVarType::POS_LV},{LogVarType::NEG_LV}};
}
vector<LogVarTypes> res;
Indexer indexer (vector<unsigned> (nrLogVars, 2));
while (indexer.valid()) {
LogVarTypes types;
for (size_t i = 0; i < nrLogVars; i++) {
if (indexer[i] == 0) {
types.push_back (LogVarType::POS_LV);
} else {
types.push_back (LogVarType::NEG_LV);
}
}
res.push_back (types);
++ indexer;
}
return res;
}
bool
containsTypes (const LogVarTypes& typesA, const LogVarTypes& typesB)
{
for (size_t i = 0; i < typesA.size(); i++) {
if (typesA[i] == LogVarType::FULL_LV) {
} else if (typesA[i] == LogVarType::POS_LV
&& typesB[i] == LogVarType::POS_LV) {
} else if (typesA[i] == LogVarType::NEG_LV
&& typesB[i] == LogVarType::NEG_LV) {
} else {
return false;
}
}
return true;
}
LitLvTypesSet
LiftedCircuit::smoothCircuit (CircuitNode* node)
{
assert (node != 0);
TinySet<LiteralId> propagatingLids;
LitLvTypesSet propagLits;
switch (getCircuitNodeType (node)) {
case CircuitNodeType::OR_NODE: {
OrNode* casted = dynamic_cast<OrNode*>(node);
TinySet<LiteralId> lids1 = smoothCircuit (*casted->leftBranch());
TinySet<LiteralId> lids2 = smoothCircuit (*casted->rightBranch());
TinySet<LiteralId> missingLeft = lids2 - lids1;
TinySet<LiteralId> missingRight = lids1 - lids2;
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
LitLvTypesSet missingLeft = lids2 - lids1;
LitLvTypesSet missingRight = lids1 - lids2;
createSmoothNode (missingLeft, casted->leftBranch());
createSmoothNode (missingRight, casted->rightBranch());
propagatingLids |= lids1;
propagatingLids |= lids2;
propagLits |= lids1;
propagLits |= lids2;
break;
}
case CircuitNodeType::AND_NODE: {
AndNode* casted = dynamic_cast<AndNode*>(node);
TinySet<LiteralId> lids1 = smoothCircuit (*casted->leftBranch());
TinySet<LiteralId> lids2 = smoothCircuit (*casted->rightBranch());
propagatingLids |= lids1;
propagatingLids |= lids2;
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
propagLits |= lids1;
propagLits |= lids2;
break;
}
case CircuitNodeType::SET_OR_NODE: {
// TODO
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
propagLits = smoothCircuit (*casted->follow());
TinySet<pair<LiteralId,unsigned>> litSet;
for (size_t i = 0; i < propagLits.size(); i++) {
litSet.insert (make_pair (propagLits[i].lid(),
propagLits[i].logVarTypes().size()));
}
LitLvTypesSet missingLids;
for (size_t i = 0; i < litSet.size(); i++) {
vector<LogVarTypes> allTypes = getAllPossibleTypes (litSet[i].second);
for (size_t j = 0; j < allTypes.size(); j++) {
bool typeFound = false;
for (size_t k = 0; k < propagLits.size(); k++) {
if (litSet[i].first == propagLits[k].lid()
&& containsTypes (allTypes[j], propagLits[k].logVarTypes())) {
typeFound = true;
break;
}
}
if (typeFound == false) {
missingLids.insert (LiteralLvTypes (litSet[i].first, allTypes[j]));
}
}
}
createSmoothNode (missingLids, casted->follow());
// TODO change propagLits to full lvs
break;
}
case CircuitNodeType::SET_AND_NODE: {
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
propagatingLids = smoothCircuit (*casted->follow());
propagLits = smoothCircuit (*casted->follow());
break;
}
case CircuitNodeType::INC_EXC_NODE: {
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
TinySet<LiteralId> lids1 = smoothCircuit (*casted->plus1Branch());
TinySet<LiteralId> lids2 = smoothCircuit (*casted->plus2Branch());
TinySet<LiteralId> missingPlus1 = lids2 - lids1;
TinySet<LiteralId> missingPlus2 = lids1 - lids2;
LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch());
LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch());
LitLvTypesSet missingPlus1 = lids2 - lids1;
LitLvTypesSet missingPlus2 = lids1 - lids2;
createSmoothNode (missingPlus1, casted->plus1Branch());
createSmoothNode (missingPlus2, casted->plus2Branch());
propagatingLids |= lids1;
propagatingLids |= lids2;
propagLits |= lids1;
propagLits |= lids2;
break;
}
case CircuitNodeType::LEAF_NODE: {
propagatingLids.insert (node->clauses()[0].literals()[0].lid());
propagLits.insert (LiteralLvTypes (
node->clauses()[0].literals()[0].lid(),
node->clauses()[0].logVarTypes(0)));
}
default:
break;
}
return propagatingLids;
return propagLits;
}
void
LiftedCircuit::createSmoothNode (
const TinySet<LiteralId>& missingLids,
const LitLvTypesSet& missingLits,
CircuitNode** prev)
{
if (missingLids.empty() == false) {
if (missingLits.empty() == false) {
Clauses clauses;
for (size_t i = 0; i < missingLids.size(); i++) {
Clause c = lwcnf_->createClauseForLiteral (missingLids[i]);
for (size_t i = 0; i < missingLits.size(); i++) {
LiteralId lid = missingLits[i].lid();
const LogVarTypes& types = missingLits[i].logVarTypes();
Clause c = lwcnf_->createClauseForLiteral (lid);
for (size_t j = 0; j < types.size(); j++) {
LogVar X = c.literals().front().logVars()[j];
if (types[j] == LogVarType::POS_LV) {
c.addPositiveCountedLogVar (X);
} else if (types[j] == LogVarType::NEG_LV) {
c.addNegativeCountedLogVar (X);
}
}
c.addAndNegateLiteral (c.literals()[0]);
clauses.push_back (c);
}

View File

@ -234,9 +234,9 @@ class LiftedCircuit
bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2);
TinySet<LiteralId> smoothCircuit (CircuitNode* node);
LitLvTypesSet smoothCircuit (CircuitNode* node);
void createSmoothNode (const TinySet<LiteralId>& lids,
void createSmoothNode (const LitLvTypesSet& lids,
CircuitNode** prev);
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;

View File

@ -303,13 +303,13 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
vector<vector<string>> names = {{"p1","p1"},{"p2","p2"}};
Clause c1 (names);
c1.addLiteral (Literal (0, LogVars()={0}));
c1.addAndNegateLiteral (Literal (1, {0,1}));
c1.addLiteral (Literal (0, LogVars()={0}, 3.0));
c1.addAndNegateLiteral (Literal (1, {0,1}, 1.0));
clauses_.push_back(c1);
Clause c2 (names);
c2.addLiteral (Literal (0, LogVars()={0}));
c2.addAndNegateLiteral (Literal (1, {1,0}));
c2.addLiteral (Literal (0, LogVars()={0}, 2.0));
c2.addAndNegateLiteral (Literal (1, {1,0}, 5.0));
clauses_.push_back(c2);
cout << "FORMULA INDICATORS:" << endl;

View File

@ -146,6 +146,38 @@ typedef vector<Clause> Clauses;
class LiteralLvTypes
{
public:
struct CompareLiteralLvTypes
{
bool operator() (
const LiteralLvTypes& types1,
const LiteralLvTypes& types2) const
{
if (types1.lid_ < types2.lid_) {
return true;
}
return types1.lvTypes_ < types2.lvTypes_;
}
};
LiteralLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
lid_(lid), lvTypes_(lvTypes) { }
LiteralId lid (void) const { return lid_; }
const LogVarTypes& logVarTypes (void) const { return lvTypes_; }
private:
LiteralId lid_;
LogVarTypes lvTypes_;
};
typedef TinySet<LiteralLvTypes,LiteralLvTypes::CompareLiteralLvTypes> LitLvTypesSet;
class LiftedWCNF
{
public: