support smoothing on atom counting nodes (beta)
This commit is contained in:
parent
b8cef8798a
commit
4518a3db5d
@ -23,6 +23,10 @@ AndNode::weight (void) const
|
||||
|
||||
|
||||
|
||||
stack<pair<unsigned, unsigned>> SetOrNode::nrGrsStack;
|
||||
|
||||
|
||||
|
||||
double
|
||||
SetOrNode::weight (void) const
|
||||
{
|
||||
@ -583,82 +587,195 @@ LiftedCircuit::shatterCountedLogVarsAux (
|
||||
|
||||
|
||||
|
||||
TinySet<LiteralId>
|
||||
LogVarTypes
|
||||
unionTypes (const LogVarTypes& types1, const LogVarTypes& types2)
|
||||
{
|
||||
if (types1.empty()) {
|
||||
return types2;
|
||||
}
|
||||
if (types2.empty()) {
|
||||
return types1;
|
||||
}
|
||||
assert (types1.size() == types2.size());
|
||||
LogVarTypes res;
|
||||
for (size_t i = 0; i < types1.size(); i++) {
|
||||
if (types1[i] == LogVarType::POS_LV
|
||||
&& types2[i] == LogVarType::POS_LV) {
|
||||
res.push_back (LogVarType::POS_LV);
|
||||
} else if (types1[i] == LogVarType::NEG_LV
|
||||
&& types2[i] == LogVarType::NEG_LV) {
|
||||
res.push_back (LogVarType::NEG_LV);
|
||||
} else {
|
||||
res.push_back (LogVarType::FULL_LV);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<LogVarTypes>
|
||||
getAllPossibleTypes (unsigned nrLogVars)
|
||||
{
|
||||
if (nrLogVars == 0) {
|
||||
return {};
|
||||
}
|
||||
if (nrLogVars == 1) {
|
||||
return {{LogVarType::POS_LV},{LogVarType::NEG_LV}};
|
||||
}
|
||||
vector<LogVarTypes> res;
|
||||
Indexer indexer (vector<unsigned> (nrLogVars, 2));
|
||||
while (indexer.valid()) {
|
||||
LogVarTypes types;
|
||||
for (size_t i = 0; i < nrLogVars; i++) {
|
||||
if (indexer[i] == 0) {
|
||||
types.push_back (LogVarType::POS_LV);
|
||||
} else {
|
||||
types.push_back (LogVarType::NEG_LV);
|
||||
}
|
||||
}
|
||||
res.push_back (types);
|
||||
++ indexer;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
containsTypes (const LogVarTypes& typesA, const LogVarTypes& typesB)
|
||||
{
|
||||
for (size_t i = 0; i < typesA.size(); i++) {
|
||||
if (typesA[i] == LogVarType::FULL_LV) {
|
||||
|
||||
} else if (typesA[i] == LogVarType::POS_LV
|
||||
&& typesB[i] == LogVarType::POS_LV) {
|
||||
|
||||
} else if (typesA[i] == LogVarType::NEG_LV
|
||||
&& typesB[i] == LogVarType::NEG_LV) {
|
||||
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LitLvTypesSet
|
||||
LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
{
|
||||
assert (node != 0);
|
||||
TinySet<LiteralId> propagatingLids;
|
||||
LitLvTypesSet propagLits;
|
||||
|
||||
switch (getCircuitNodeType (node)) {
|
||||
|
||||
case CircuitNodeType::OR_NODE: {
|
||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||
TinySet<LiteralId> lids1 = smoothCircuit (*casted->leftBranch());
|
||||
TinySet<LiteralId> lids2 = smoothCircuit (*casted->rightBranch());
|
||||
TinySet<LiteralId> missingLeft = lids2 - lids1;
|
||||
TinySet<LiteralId> missingRight = lids1 - lids2;
|
||||
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
|
||||
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
|
||||
LitLvTypesSet missingLeft = lids2 - lids1;
|
||||
LitLvTypesSet missingRight = lids1 - lids2;
|
||||
createSmoothNode (missingLeft, casted->leftBranch());
|
||||
createSmoothNode (missingRight, casted->rightBranch());
|
||||
propagatingLids |= lids1;
|
||||
propagatingLids |= lids2;
|
||||
propagLits |= lids1;
|
||||
propagLits |= lids2;
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::AND_NODE: {
|
||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||
TinySet<LiteralId> lids1 = smoothCircuit (*casted->leftBranch());
|
||||
TinySet<LiteralId> lids2 = smoothCircuit (*casted->rightBranch());
|
||||
propagatingLids |= lids1;
|
||||
propagatingLids |= lids2;
|
||||
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
|
||||
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
|
||||
propagLits |= lids1;
|
||||
propagLits |= lids2;
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::SET_OR_NODE: {
|
||||
// TODO
|
||||
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
|
||||
propagLits = smoothCircuit (*casted->follow());
|
||||
TinySet<pair<LiteralId,unsigned>> litSet;
|
||||
for (size_t i = 0; i < propagLits.size(); i++) {
|
||||
litSet.insert (make_pair (propagLits[i].lid(),
|
||||
propagLits[i].logVarTypes().size()));
|
||||
}
|
||||
LitLvTypesSet missingLids;
|
||||
for (size_t i = 0; i < litSet.size(); i++) {
|
||||
vector<LogVarTypes> allTypes = getAllPossibleTypes (litSet[i].second);
|
||||
for (size_t j = 0; j < allTypes.size(); j++) {
|
||||
bool typeFound = false;
|
||||
for (size_t k = 0; k < propagLits.size(); k++) {
|
||||
if (litSet[i].first == propagLits[k].lid()
|
||||
&& containsTypes (allTypes[j], propagLits[k].logVarTypes())) {
|
||||
typeFound = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (typeFound == false) {
|
||||
missingLids.insert (LiteralLvTypes (litSet[i].first, allTypes[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
createSmoothNode (missingLids, casted->follow());
|
||||
// TODO change propagLits to full lvs
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::SET_AND_NODE: {
|
||||
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
|
||||
propagatingLids = smoothCircuit (*casted->follow());
|
||||
propagLits = smoothCircuit (*casted->follow());
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::INC_EXC_NODE: {
|
||||
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
|
||||
TinySet<LiteralId> lids1 = smoothCircuit (*casted->plus1Branch());
|
||||
TinySet<LiteralId> lids2 = smoothCircuit (*casted->plus2Branch());
|
||||
TinySet<LiteralId> missingPlus1 = lids2 - lids1;
|
||||
TinySet<LiteralId> missingPlus2 = lids1 - lids2;
|
||||
LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch());
|
||||
LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch());
|
||||
LitLvTypesSet missingPlus1 = lids2 - lids1;
|
||||
LitLvTypesSet missingPlus2 = lids1 - lids2;
|
||||
createSmoothNode (missingPlus1, casted->plus1Branch());
|
||||
createSmoothNode (missingPlus2, casted->plus2Branch());
|
||||
propagatingLids |= lids1;
|
||||
propagatingLids |= lids2;
|
||||
propagLits |= lids1;
|
||||
propagLits |= lids2;
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::LEAF_NODE: {
|
||||
propagatingLids.insert (node->clauses()[0].literals()[0].lid());
|
||||
propagLits.insert (LiteralLvTypes (
|
||||
node->clauses()[0].literals()[0].lid(),
|
||||
node->clauses()[0].logVarTypes(0)));
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return propagatingLids;
|
||||
return propagLits;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedCircuit::createSmoothNode (
|
||||
const TinySet<LiteralId>& missingLids,
|
||||
const LitLvTypesSet& missingLits,
|
||||
CircuitNode** prev)
|
||||
{
|
||||
if (missingLids.empty() == false) {
|
||||
if (missingLits.empty() == false) {
|
||||
Clauses clauses;
|
||||
for (size_t i = 0; i < missingLids.size(); i++) {
|
||||
Clause c = lwcnf_->createClauseForLiteral (missingLids[i]);
|
||||
for (size_t i = 0; i < missingLits.size(); i++) {
|
||||
LiteralId lid = missingLits[i].lid();
|
||||
const LogVarTypes& types = missingLits[i].logVarTypes();
|
||||
Clause c = lwcnf_->createClauseForLiteral (lid);
|
||||
for (size_t j = 0; j < types.size(); j++) {
|
||||
LogVar X = c.literals().front().logVars()[j];
|
||||
if (types[j] == LogVarType::POS_LV) {
|
||||
c.addPositiveCountedLogVar (X);
|
||||
} else if (types[j] == LogVarType::NEG_LV) {
|
||||
c.addNegativeCountedLogVar (X);
|
||||
}
|
||||
}
|
||||
c.addAndNegateLiteral (c.literals()[0]);
|
||||
clauses.push_back (c);
|
||||
}
|
||||
|
@ -234,9 +234,9 @@ class LiftedCircuit
|
||||
|
||||
bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2);
|
||||
|
||||
TinySet<LiteralId> smoothCircuit (CircuitNode* node);
|
||||
LitLvTypesSet smoothCircuit (CircuitNode* node);
|
||||
|
||||
void createSmoothNode (const TinySet<LiteralId>& lids,
|
||||
void createSmoothNode (const LitLvTypesSet& lids,
|
||||
CircuitNode** prev);
|
||||
|
||||
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
|
||||
|
@ -303,13 +303,13 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
vector<vector<string>> names = {{"p1","p1"},{"p2","p2"}};
|
||||
|
||||
Clause c1 (names);
|
||||
c1.addLiteral (Literal (0, LogVars()={0}));
|
||||
c1.addAndNegateLiteral (Literal (1, {0,1}));
|
||||
c1.addLiteral (Literal (0, LogVars()={0}, 3.0));
|
||||
c1.addAndNegateLiteral (Literal (1, {0,1}, 1.0));
|
||||
clauses_.push_back(c1);
|
||||
|
||||
Clause c2 (names);
|
||||
c2.addLiteral (Literal (0, LogVars()={0}));
|
||||
c2.addAndNegateLiteral (Literal (1, {1,0}));
|
||||
c2.addLiteral (Literal (0, LogVars()={0}, 2.0));
|
||||
c2.addAndNegateLiteral (Literal (1, {1,0}, 5.0));
|
||||
clauses_.push_back(c2);
|
||||
|
||||
cout << "FORMULA INDICATORS:" << endl;
|
||||
|
@ -146,6 +146,38 @@ typedef vector<Clause> Clauses;
|
||||
|
||||
|
||||
|
||||
class LiteralLvTypes
|
||||
{
|
||||
public:
|
||||
struct CompareLiteralLvTypes
|
||||
{
|
||||
bool operator() (
|
||||
const LiteralLvTypes& types1,
|
||||
const LiteralLvTypes& types2) const
|
||||
{
|
||||
if (types1.lid_ < types2.lid_) {
|
||||
return true;
|
||||
}
|
||||
return types1.lvTypes_ < types2.lvTypes_;
|
||||
}
|
||||
};
|
||||
|
||||
LiteralLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
|
||||
lid_(lid), lvTypes_(lvTypes) { }
|
||||
|
||||
LiteralId lid (void) const { return lid_; }
|
||||
|
||||
const LogVarTypes& logVarTypes (void) const { return lvTypes_; }
|
||||
|
||||
private:
|
||||
LiteralId lid_;
|
||||
LogVarTypes lvTypes_;
|
||||
};
|
||||
|
||||
typedef TinySet<LiteralLvTypes,LiteralLvTypes::CompareLiteralLvTypes> LitLvTypesSet;
|
||||
|
||||
|
||||
|
||||
class LiftedWCNF
|
||||
{
|
||||
public:
|
||||
|
Reference in New Issue
Block a user