This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/LiftedKc.cpp

1584 lines
40 KiB
C++
Raw Normal View History

2013-02-07 20:09:10 +00:00
#include <cassert>
2013-02-19 23:59:05 +00:00
#include <vector>
#include <unordered_map>
#include <string>
#include <fstream>
2013-02-07 20:09:10 +00:00
#include <iostream>
#include "LiftedKc.h"
#include "LiftedWCNF.h"
#include "LiftedOperations.h"
#include "Indexer.h"
namespace Horus {
2013-02-07 23:53:13 +00:00
enum class CircuitNodeType {
orCnt,
andCnt,
setOrCnt,
setAndCnt,
incExcCnt,
leafCnt,
smoothCnt,
trueCnt,
compilationFailedCnt
};
class CircuitNode {
public:
CircuitNode() { }
virtual ~CircuitNode() { }
virtual double weight() const = 0;
};
class OrNode : public CircuitNode {
public:
OrNode() : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
~OrNode();
CircuitNode** leftBranch () { return &leftBranch_; }
CircuitNode** rightBranch() { return &rightBranch_; }
double weight() const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class AndNode : public CircuitNode {
public:
AndNode() : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
AndNode (CircuitNode* leftBranch, CircuitNode* rightBranch)
2013-03-09 17:14:00 +00:00
: CircuitNode(), leftBranch_(leftBranch),
rightBranch_(rightBranch) { }
~AndNode();
CircuitNode** leftBranch () { return &leftBranch_; }
CircuitNode** rightBranch() { return &rightBranch_; }
double weight() const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class SetOrNode : public CircuitNode {
public:
SetOrNode (unsigned nrGroundings)
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
~SetOrNode();
CircuitNode** follow() { return &follow_; }
static unsigned nrPositives() { return nrPos_; }
static unsigned nrNegatives() { return nrNeg_; }
static bool isSet() { return nrPos_ >= 0; }
double weight() const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
static int nrPos_;
static int nrNeg_;
};
class SetAndNode : public CircuitNode {
public:
SetAndNode (unsigned nrGroundings)
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
~SetAndNode();
CircuitNode** follow() { return &follow_; }
double weight() const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
};
class IncExcNode : public CircuitNode {
public:
IncExcNode()
: CircuitNode(), plus1Branch_(0), plus2Branch_(0), minusBranch_(0) { }
~IncExcNode();
CircuitNode** plus1Branch() { return &plus1Branch_; }
CircuitNode** plus2Branch() { return &plus2Branch_; }
CircuitNode** minusBranch() { return &minusBranch_; }
double weight() const;
private:
CircuitNode* plus1Branch_;
CircuitNode* plus2Branch_;
CircuitNode* minusBranch_;
};
class LeafNode : public CircuitNode {
public:
LeafNode (Clause* clause, const LiftedWCNF& lwcnf)
: CircuitNode(), clause_(clause), lwcnf_(lwcnf) { }
~LeafNode();
const Clause* clause() const { return clause_; }
Clause* clause() { return clause_; }
double weight() const;
private:
Clause* clause_;
const LiftedWCNF& lwcnf_;
};
class SmoothNode : public CircuitNode {
public:
SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf)
: CircuitNode(), clauses_(clauses), lwcnf_(lwcnf) { }
~SmoothNode();
const Clauses& clauses() const { return clauses_; }
Clauses clauses() { return clauses_; }
double weight() const;
private:
Clauses clauses_;
const LiftedWCNF& lwcnf_;
};
class TrueNode : public CircuitNode {
public:
TrueNode() : CircuitNode() { }
double weight() const;
};
class CompilationFailedNode : public CircuitNode {
public:
CompilationFailedNode() : CircuitNode() { }
double weight() const;
};
class LiftedCircuit {
public:
LiftedCircuit (const LiftedWCNF* lwcnf);
~LiftedCircuit();
bool isCompilationSucceeded() const;
double getWeightedModelCount() const;
void exportToGraphViz (const char*);
private:
void compile (CircuitNode** follow, Clauses& clauses);
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
LogVars& rootLogVars);
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
void shatterCountedLogVars (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses);
2013-03-09 17:14:00 +00:00
bool shatterCountedLogVarsAux (Clauses& clauses,
size_t idx1, size_t idx2);
bool independentClause (Clause& clause, Clauses& otherClauses) const;
bool independentLiteral (const Literal& lit,
const Literals& otherLits) const;
LitLvTypesSet smoothCircuit (CircuitNode* node);
void createSmoothNode (const LitLvTypesSet& lids,
CircuitNode** prev);
std::vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
bool containsTypes (const LogVarTypes& typesA,
const LogVarTypes& typesB) const;
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
void exportToGraphViz (CircuitNode* node, std::ofstream&);
void printClauses (CircuitNode* node, std::ofstream&,
std::string extraOptions = "");
std::string escapeNode (const CircuitNode* node) const;
std::string getExplanationString (CircuitNode* node);
CircuitNode* root_;
const LiftedWCNF* lwcnf_;
bool compilationSucceeded_;
Clauses backupClauses_;
std::unordered_map<CircuitNode*, Clauses> originClausesMap_;
std::unordered_map<CircuitNode*, std::string> explanationMap_;
DISALLOW_COPY_AND_ASSIGN (LiftedCircuit);
};
OrNode::~OrNode()
{
delete leftBranch_;
delete rightBranch_;
}
double
OrNode::weight() const
{
double lw = leftBranch_->weight();
double rw = rightBranch_->weight();
return Globals::logDomain ? Util::logSum (lw, rw) : lw + rw;
}
AndNode::~AndNode()
{
delete leftBranch_;
delete rightBranch_;
}
double
AndNode::weight() const
{
double lw = leftBranch_->weight();
double rw = rightBranch_->weight();
return Globals::logDomain ? lw + rw : lw * rw;
}
int SetOrNode::nrPos_ = -1;
int SetOrNode::nrNeg_ = -1;
SetOrNode::~SetOrNode()
{
delete follow_;
}
double
SetOrNode::weight() const
{
double weightSum = LogAware::addIdenty();
for (unsigned i = 0; i < nrGroundings_ + 1; i++) {
nrPos_ = nrGroundings_ - i;
nrNeg_ = i;
if (Globals::logDomain) {
double nrCombs = Util::nrCombinations (nrGroundings_, i);
double w = follow_->weight();
weightSum = Util::logSum (weightSum, std::log (nrCombs) + w);
} else {
double w = follow_->weight();
weightSum += Util::nrCombinations (nrGroundings_, i) * w;
}
}
nrPos_ = -1;
nrNeg_ = -1;
return weightSum;
}
SetAndNode::~SetAndNode()
{
delete follow_;
}
double
SetAndNode::weight() const
{
return LogAware::pow (follow_->weight(), nrGroundings_);
}
IncExcNode::~IncExcNode()
{
delete plus1Branch_;
delete plus2Branch_;
delete minusBranch_;
}
double
IncExcNode::weight() const
{
double w = 0.0;
if (Globals::logDomain) {
w = Util::logSum (plus1Branch_->weight(), plus2Branch_->weight());
w = std::log (std::exp (w) - std::exp (minusBranch_->weight()));
} else {
w = plus1Branch_->weight() + plus2Branch_->weight();
w -= minusBranch_->weight();
}
return w;
}
LeafNode::~LeafNode()
{
delete clause_;
}
double
LeafNode::weight() const
{
assert (clause_->isUnit());
2012-12-20 23:19:10 +00:00
if (clause_->posCountedLogVars().empty() == false
|| clause_->negCountedLogVars().empty() == false) {
if (SetOrNode::isSet() == false) {
// return a NaN if we have a SetOrNode
// ancester that is not set. This can only
// happen when calculating the weights
// for the edge labels in graphviz
return 0.0 / 0.0;
}
}
double weight = clause_->literals()[0].isPositive()
? lwcnf_.posWeight (clause_->literals().front().lid())
: lwcnf_.negWeight (clause_->literals().front().lid());
LogVarSet lvs = clause_->constr().logVarSet();
lvs -= clause_->ipgLogVars();
lvs -= clause_->posCountedLogVars();
lvs -= clause_->negCountedLogVars();
unsigned nrGroundings = 1;
if (lvs.empty() == false) {
nrGroundings = clause_->constr().projectedCopy (lvs).size();
}
if (clause_->posCountedLogVars().empty() == false) {
nrGroundings *= std::pow (SetOrNode::nrPositives(),
clause_->nrPosCountedLogVars());
}
if (clause_->negCountedLogVars().empty() == false) {
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
clause_->nrNegCountedLogVars());
}
return LogAware::pow (weight, nrGroundings);
}
SmoothNode::~SmoothNode()
{
Clause::deleteClauses (clauses_);
}
double
SmoothNode::weight() const
{
Clauses cs = clauses();
double totalWeight = LogAware::multIdenty();
for (size_t i = 0; i < cs.size(); i++) {
double posWeight = lwcnf_.posWeight (cs[i]->literals()[0].lid());
double negWeight = lwcnf_.negWeight (cs[i]->literals()[0].lid());
LogVarSet lvs = cs[i]->constr().logVarSet();
lvs -= cs[i]->ipgLogVars();
lvs -= cs[i]->posCountedLogVars();
lvs -= cs[i]->negCountedLogVars();
unsigned nrGroundings = 1;
if (lvs.empty() == false) {
nrGroundings = cs[i]->constr().projectedCopy (lvs).size();
}
if (cs[i]->posCountedLogVars().empty() == false) {
nrGroundings *= std::pow (SetOrNode::nrPositives(),
cs[i]->nrPosCountedLogVars());
}
if (cs[i]->negCountedLogVars().empty() == false) {
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
cs[i]->nrNegCountedLogVars());
}
if (Globals::logDomain) {
totalWeight += Util::logSum (posWeight, negWeight) * nrGroundings;
} else {
totalWeight *= std::pow (posWeight + negWeight, nrGroundings);
}
}
return totalWeight;
}
double
TrueNode::weight() const
{
return LogAware::multIdenty();
}
double
CompilationFailedNode::weight() const
{
// weighted model counting in compilation
// failed nodes should give NaN
return 0.0 / 0.0;
}
LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
: lwcnf_(lwcnf)
{
root_ = 0;
compilationSucceeded_ = true;
Clauses clauses = Clause::copyClauses (lwcnf->clauses());
compile (&root_, clauses);
if (compilationSucceeded_) {
smoothCircuit (root_);
}
if (Globals::verbosity > 1) {
if (compilationSucceeded_) {
double wmc = LogAware::exp (getWeightedModelCount());
2013-02-07 13:37:15 +00:00
std::cout << "Weighted model count = " << wmc;
std::cout << std::endl << std::endl;
}
2013-02-07 13:37:15 +00:00
std::cout << "Exporting circuit to graphviz (circuit.dot)..." ;
std::cout << std::endl << std::endl;
exportToGraphViz ("circuit.dot");
}
}
LiftedCircuit::~LiftedCircuit()
{
delete root_;
2013-02-07 13:37:15 +00:00
std::unordered_map<CircuitNode*, Clauses>::iterator it
= originClausesMap_.begin();
while (it != originClausesMap_.end()) {
Clause::deleteClauses (it->second);
++ it;
}
}
bool
LiftedCircuit::isCompilationSucceeded() const
{
return compilationSucceeded_;
}
double
LiftedCircuit::getWeightedModelCount() const
{
assert (compilationSucceeded_);
return root_->weight();
}
void
LiftedCircuit::exportToGraphViz (const char* fileName)
{
2013-02-07 13:37:15 +00:00
std::ofstream out (fileName);
if (!out.is_open()) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return;
}
2013-02-07 13:37:15 +00:00
out << "digraph {" << std::endl;
out << "ranksep=1" << std::endl;
exportToGraphViz (root_, out);
2013-02-07 13:37:15 +00:00
out << "}" << std::endl;
out.close();
}
void
LiftedCircuit::compile (
CircuitNode** follow,
Clauses& clauses)
{
if (compilationSucceeded_ == false
&& Globals::verbosity <= 1) {
return;
}
if (clauses.empty()) {
*follow = new TrueNode();
return;
}
if (clauses.size() == 1 && clauses[0]->isUnit()) {
*follow = new LeafNode (clauses[0], *lwcnf_);
return;
}
if (tryUnitPropagation (follow, clauses)) {
return;
}
if (tryIndependence (follow, clauses)) {
return;
}
if (tryShannonDecomp (follow, clauses)) {
return;
}
if (tryInclusionExclusion (follow, clauses)) {
return;
}
if (tryIndepPartialGrounding (follow, clauses)) {
return;
}
if (tryAtomCounting (follow, clauses)) {
return;
}
*follow = new CompilationFailedNode();
if (Globals::verbosity > 1) {
originClausesMap_[*follow] = clauses;
explanationMap_[*follow] = "" ;
}
compilationSucceeded_ = false;
}
bool
LiftedCircuit::tryUnitPropagation (
CircuitNode** follow,
Clauses& clauses)
{
if (Globals::verbosity > 1) {
backupClauses_ = Clause::copyClauses (clauses);
}
for (size_t i = 0; i < clauses.size(); i++) {
if (clauses[i]->isUnit()) {
Clauses propagClauses;
for (size_t j = 0; j < clauses.size(); j++) {
if (i != j) {
LiteralId lid = clauses[i]->literals()[0].lid();
LogVarTypes types = clauses[i]->logVarTypes (0);
if (clauses[i]->literals()[0].isPositive()) {
if (clauses[j]->containsPositiveLiteral (lid, types) == false) {
clauses[j]->removeNegativeLiterals (lid, types);
if (clauses[j]->nrLiterals() > 0) {
propagClauses.push_back (clauses[j]);
} else {
delete clauses[j];
}
} else {
delete clauses[j];
}
} else if (clauses[i]->literals()[0].isNegative()) {
if (clauses[j]->containsNegativeLiteral (lid, types) == false) {
clauses[j]->removePositiveLiterals (lid, types);
if (clauses[j]->nrLiterals() > 0) {
propagClauses.push_back (clauses[j]);
} else {
delete clauses[j];
}
} else {
delete clauses[j];
}
}
}
}
AndNode* andNode = new AndNode();
if (Globals::verbosity > 1) {
originClausesMap_[andNode] = backupClauses_;
2013-02-07 13:37:15 +00:00
std::stringstream explanation;
explanation << " UP on " << clauses[i]->literals()[0];
explanationMap_[andNode] = explanation.str();
}
Clauses unitClause = { clauses[i] };
compile (andNode->leftBranch(), unitClause);
compile (andNode->rightBranch(), propagClauses);
(*follow) = andNode;
return true;
}
}
if (Globals::verbosity > 1) {
Clause::deleteClauses (backupClauses_);
}
return false;
}
bool
LiftedCircuit::tryIndependence (
CircuitNode** follow,
Clauses& clauses)
{
if (clauses.size() == 1) {
return false;
}
if (Globals::verbosity > 1) {
backupClauses_ = Clause::copyClauses (clauses);
}
Clauses depClauses = { clauses[0] };
Clauses indepClauses (clauses.begin() + 1, clauses.end());
bool finish = false;
while (finish == false) {
finish = true;
for (size_t i = 0; i < indepClauses.size(); i++) {
if (independentClause (*indepClauses[i], depClauses) == false) {
depClauses.push_back (indepClauses[i]);
indepClauses.erase (indepClauses.begin() + i);
finish = false;
break;
}
}
}
if (indepClauses.empty() == false) {
AndNode* andNode = new AndNode ();
if (Globals::verbosity > 1) {
originClausesMap_[andNode] = backupClauses_;
explanationMap_[andNode] = " Independence" ;
}
compile (andNode->leftBranch(), depClauses);
compile (andNode->rightBranch(), indepClauses);
(*follow) = andNode;
return true;
}
if (Globals::verbosity > 1) {
Clause::deleteClauses (backupClauses_);
}
return false;
}
bool
LiftedCircuit::tryShannonDecomp (
CircuitNode** follow,
Clauses& clauses)
{
if (Globals::verbosity > 1) {
backupClauses_ = Clause::copyClauses (clauses);
}
for (size_t i = 0; i < clauses.size(); i++) {
const Literals& literals = clauses[i]->literals();
for (size_t j = 0; j < literals.size(); j++) {
if (literals[j].isGround (
clauses[i]->constr(), clauses[i]->ipgLogVars())) {
Clause* c1 = lwcnf_->createClause (literals[j].lid());
Clause* c2 = new Clause (*c1);
c2->literals().front().complement();
Clauses otherClauses = Clause::copyClauses (clauses);
clauses.push_back (c1);
otherClauses.push_back (c2);
OrNode* orNode = new OrNode();
if (Globals::verbosity > 1) {
originClausesMap_[orNode] = backupClauses_;
2013-02-07 13:37:15 +00:00
std::stringstream explanation;
explanation << " SD on " << literals[j];
explanationMap_[orNode] = explanation.str();
}
compile (orNode->leftBranch(), clauses);
compile (orNode->rightBranch(), otherClauses);
(*follow) = orNode;
return true;
}
}
}
if (Globals::verbosity > 1) {
Clause::deleteClauses (backupClauses_);
}
return false;
}
bool
LiftedCircuit::tryInclusionExclusion (
CircuitNode** follow,
Clauses& clauses)
{
if (Globals::verbosity > 1) {
backupClauses_ = Clause::copyClauses (clauses);
}
for (size_t i = 0; i < clauses.size(); i++) {
Literals depLits = { clauses[i]->literals().front() };
Literals indepLits (clauses[i]->literals().begin() + 1,
clauses[i]->literals().end());
bool finish = false;
while (finish == false) {
finish = true;
for (size_t j = 0; j < indepLits.size(); j++) {
if (independentLiteral (indepLits[j], depLits) == false) {
depLits.push_back (indepLits[j]);
indepLits.erase (indepLits.begin() + j);
finish = false;
break;
}
}
}
if (indepLits.empty() == false) {
LogVarSet lvs1;
for (size_t j = 0; j < depLits.size(); j++) {
lvs1 |= depLits[j].logVarSet();
}
if (clauses[i]->constr().isCountNormalized (lvs1) == false) {
break;
}
LogVarSet lvs2;
for (size_t j = 0; j < indepLits.size(); j++) {
lvs2 |= indepLits[j].logVarSet();
}
if (clauses[i]->constr().isCountNormalized (lvs2) == false) {
break;
}
Clause* c1 = new Clause (clauses[i]->constr().projectedCopy (lvs1));
for (size_t j = 0; j < depLits.size(); j++) {
c1->addLiteral (depLits[j]);
}
Clause* c2 = new Clause (clauses[i]->constr().projectedCopy (lvs2));
for (size_t j = 0; j < indepLits.size(); j++) {
c2->addLiteral (indepLits[j]);
}
clauses.erase (clauses.begin() + i);
Clauses plus1Clauses = Clause::copyClauses (clauses);
Clauses plus2Clauses = Clause::copyClauses (clauses);
plus1Clauses.push_back (c1);
plus2Clauses.push_back (c2);
clauses.push_back (c1);
clauses.push_back (c2);
IncExcNode* ieNode = new IncExcNode();
if (Globals::verbosity > 1) {
originClausesMap_[ieNode] = backupClauses_;
2013-02-07 13:37:15 +00:00
std::stringstream explanation;
explanation << " IncExc on clause nº " << i + 1;
explanationMap_[ieNode] = explanation.str();
}
compile (ieNode->plus1Branch(), plus1Clauses);
compile (ieNode->plus2Branch(), plus2Clauses);
compile (ieNode->minusBranch(), clauses);
*follow = ieNode;
return true;
}
}
if (Globals::verbosity > 1) {
Clause::deleteClauses (backupClauses_);
}
return false;
}
bool
LiftedCircuit::tryIndepPartialGrounding (
CircuitNode** follow,
Clauses& clauses)
{
// assumes that all literals have logical variables
// else, shannon decomp was possible
if (Globals::verbosity > 1) {
backupClauses_ = Clause::copyClauses (clauses);
}
LogVars rootLogVars;
LogVarSet lvs = clauses[0]->ipgCandidates();
for (size_t i = 0; i < lvs.size(); i++) {
rootLogVars.clear();
rootLogVars.push_back (lvs[i]);
ConstraintTree ct = clauses[0]->constr().projectedCopy ({lvs[i]});
if (tryIndepPartialGroundingAux (clauses, ct, rootLogVars)) {
for (size_t j = 0; j < clauses.size(); j++) {
clauses[j]->addIpgLogVar (rootLogVars[j]);
}
SetAndNode* setAndNode = new SetAndNode (ct.size());
if (Globals::verbosity > 1) {
originClausesMap_[setAndNode] = backupClauses_;
explanationMap_[setAndNode] = " IPG" ;
}
*follow = setAndNode;
compile (setAndNode->follow(), clauses);
return true;
}
}
if (Globals::verbosity > 1) {
Clause::deleteClauses (backupClauses_);
}
return false;
}
bool
LiftedCircuit::tryIndepPartialGroundingAux (
Clauses& clauses,
ConstraintTree& ct,
LogVars& rootLogVars)
{
for (size_t i = 1; i < clauses.size(); i++) {
LogVarSet lvs = clauses[i]->ipgCandidates();
for (size_t j = 0; j < lvs.size(); j++) {
ConstraintTree ct2 = clauses[i]->constr().projectedCopy ({lvs[j]});
if (ct.tupleSet() == ct2.tupleSet()) {
rootLogVars.push_back (lvs[j]);
break;
}
}
if (rootLogVars.size() != i + 1) {
return false;
}
}
// verifies if the IPG logical vars appear in the same positions
2013-02-07 13:37:15 +00:00
std::unordered_map<LiteralId, size_t> positions;
for (size_t i = 0; i < clauses.size(); i++) {
const Literals& literals = clauses[i]->literals();
for (size_t j = 0; j < literals.size(); j++) {
size_t idx = literals[j].indexOfLogVar (rootLogVars[i]);
assert (idx != literals[j].nrLogVars());
2013-02-07 13:37:15 +00:00
std::unordered_map<LiteralId, size_t>::iterator it;
it = positions.find (literals[j].lid());
if (it != positions.end()) {
if (it->second != idx) {
return false;
}
} else {
positions[literals[j].lid()] = idx;
}
}
}
return true;
}
bool
LiftedCircuit::tryAtomCounting (
CircuitNode** follow,
Clauses& clauses)
{
for (size_t i = 0; i < clauses.size(); i++) {
if (clauses[i]->nrPosCountedLogVars() > 0
|| clauses[i]->nrNegCountedLogVars() > 0) {
// only allow one atom counting node per branch
return false;
}
}
if (Globals::verbosity > 1) {
backupClauses_ = Clause::copyClauses (clauses);
}
for (size_t i = 0; i < clauses.size(); i++) {
Literals literals = clauses[i]->literals();
for (size_t j = 0; j < literals.size(); j++) {
if (literals[j].nrLogVars() == 1
&& ! clauses[i]->isIpgLogVar (literals[j].logVars().front())
&& ! clauses[i]->isCountedLogVar (literals[j].logVars().front())) {
unsigned nrGroundings = clauses[i]->constr().projectedCopy (
literals[j].logVars()).size();
SetOrNode* setOrNode = new SetOrNode (nrGroundings);
if (Globals::verbosity > 1) {
originClausesMap_[setOrNode] = backupClauses_;
explanationMap_[setOrNode] = " AC" ;
}
Clause* c1 = new Clause (
clauses[i]->constr().projectedCopy (literals[j].logVars()));
Clause* c2 = new Clause (
clauses[i]->constr().projectedCopy (literals[j].logVars()));
c1->addLiteral (literals[j]);
c2->addLiteralComplemented (literals[j]);
c1->addPosCountedLogVar (literals[j].logVars().front());
c2->addNegCountedLogVar (literals[j].logVars().front());
clauses.push_back (c1);
clauses.push_back (c2);
shatterCountedLogVars (clauses);
compile (setOrNode->follow(), clauses);
*follow = setOrNode;
return true;
}
}
}
if (Globals::verbosity > 1) {
Clause::deleteClauses (backupClauses_);
}
return false;
}
void
LiftedCircuit::shatterCountedLogVars (Clauses& clauses)
{
while (shatterCountedLogVarsAux (clauses)) ;
}
bool
LiftedCircuit::shatterCountedLogVarsAux (Clauses& clauses)
{
for (size_t i = 0; i < clauses.size() - 1; i++) {
for (size_t j = i + 1; j < clauses.size(); j++) {
bool splitedSome = shatterCountedLogVarsAux (clauses, i, j);
if (splitedSome) {
return true;
}
}
}
return false;
}
bool
LiftedCircuit::shatterCountedLogVarsAux (
Clauses& clauses,
size_t idx1,
size_t idx2)
{
Literals lits1 = clauses[idx1]->literals();
Literals lits2 = clauses[idx2]->literals();
for (size_t i = 0; i < lits1.size(); i++) {
for (size_t j = 0; j < lits2.size(); j++) {
if (lits1[i].lid() == lits2[j].lid()) {
LogVars lvs1 = lits1[i].logVars();
LogVars lvs2 = lits2[j].logVars();
for (size_t k = 0; k < lvs1.size(); k++) {
if (clauses[idx1]->isCountedLogVar (lvs1[k])
&& clauses[idx2]->isCountedLogVar (lvs2[k]) == false) {
clauses.push_back (new Clause (*clauses[idx2]));
clauses[idx2]->addPosCountedLogVar (lvs2[k]);
clauses.back()->addNegCountedLogVar (lvs2[k]);
return true;
}
if (clauses[idx2]->isCountedLogVar (lvs2[k])
&& clauses[idx1]->isCountedLogVar (lvs1[k]) == false) {
clauses.push_back (new Clause (*clauses[idx1]));
clauses[idx1]->addPosCountedLogVar (lvs1[k]);
clauses.back()->addNegCountedLogVar (lvs1[k]);
return true;
}
}
}
}
}
return false;
}
bool
LiftedCircuit::independentClause (
Clause& clause,
Clauses& otherClauses) const
{
for (size_t i = 0; i < otherClauses.size(); i++) {
if (Clause::independentClauses (clause, *otherClauses[i]) == false) {
return false;
}
}
return true;
}
bool
LiftedCircuit::independentLiteral (
const Literal& lit,
const Literals& otherLits) const
{
for (size_t i = 0; i < otherLits.size(); i++) {
if (lit.lid() == otherLits[i].lid()
|| (lit.logVarSet() & otherLits[i].logVarSet()).empty() == false) {
return false;
}
}
return true;
}
LitLvTypesSet
LiftedCircuit::smoothCircuit (CircuitNode* node)
{
2012-12-27 12:54:58 +00:00
assert (node);
LitLvTypesSet propagLits;
switch (getCircuitNodeType (node)) {
case CircuitNodeType::orCnt: {
OrNode* casted = dynamic_cast<OrNode*>(node);
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
LitLvTypesSet missingLeft = lids2 - lids1;
LitLvTypesSet missingRight = lids1 - lids2;
createSmoothNode (missingLeft, casted->leftBranch());
createSmoothNode (missingRight, casted->rightBranch());
propagLits |= lids1;
propagLits |= lids2;
break;
}
case CircuitNodeType::andCnt: {
AndNode* casted = dynamic_cast<AndNode*>(node);
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
propagLits |= lids1;
propagLits |= lids2;
break;
}
case CircuitNodeType::setOrCnt: {
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
propagLits = smoothCircuit (*casted->follow());
2013-02-07 13:37:15 +00:00
TinySet<std::pair<LiteralId,unsigned>> litSet;
for (size_t i = 0; i < propagLits.size(); i++) {
2013-02-07 13:37:15 +00:00
litSet.insert (std::make_pair (propagLits[i].lid(),
propagLits[i].logVarTypes().size()));
}
LitLvTypesSet missingLids;
for (size_t i = 0; i < litSet.size(); i++) {
2013-02-07 13:37:15 +00:00
std::vector<LogVarTypes> allTypes
= getAllPossibleTypes (litSet[i].second);
for (size_t j = 0; j < allTypes.size(); j++) {
bool typeFound = false;
for (size_t k = 0; k < propagLits.size(); k++) {
if (litSet[i].first == propagLits[k].lid()
&& containsTypes (propagLits[k].logVarTypes(), allTypes[j])) {
typeFound = true;
break;
}
}
if (typeFound == false) {
missingLids.insert (LitLvTypes (litSet[i].first, allTypes[j]));
}
}
}
createSmoothNode (missingLids, casted->follow());
// setAllFullLogVars() can cause repeated elements in
// the set. Fix this by reconstructing the set again
LitLvTypesSet copy = propagLits;
propagLits.clear();
for (size_t i = 0; i < copy.size(); i++) {
copy[i].setAllFullLogVars();
propagLits.insert (copy[i]);
}
break;
}
case CircuitNodeType::setAndCnt: {
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
propagLits = smoothCircuit (*casted->follow());
break;
}
case CircuitNodeType::incExcCnt: {
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch());
LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch());
LitLvTypesSet missingPlus1 = lids2 - lids1;
LitLvTypesSet missingPlus2 = lids1 - lids2;
createSmoothNode (missingPlus1, casted->plus1Branch());
createSmoothNode (missingPlus2, casted->plus2Branch());
propagLits |= lids1;
propagLits |= lids2;
break;
}
case CircuitNodeType::leafCnt: {
LeafNode* casted = dynamic_cast<LeafNode*>(node);
propagLits.insert (LitLvTypes (
casted->clause()->literals()[0].lid(),
casted->clause()->logVarTypes(0)));
}
default:
break;
}
return propagLits;
}
void
LiftedCircuit::createSmoothNode (
const LitLvTypesSet& missingLits,
CircuitNode** prev)
{
if (missingLits.empty() == false) {
if (Globals::verbosity > 1) {
2013-02-07 13:37:15 +00:00
std::unordered_map<CircuitNode*, Clauses>::iterator it
= originClausesMap_.find (*prev);
if (it != originClausesMap_.end()) {
backupClauses_ = it->second;
} else {
backupClauses_ = Clause::copyClauses (
{((dynamic_cast<LeafNode*>(*prev))->clause())});
}
}
Clauses clauses;
for (size_t i = 0; i < missingLits.size(); i++) {
LiteralId lid = missingLits[i].lid();
const LogVarTypes& types = missingLits[i].logVarTypes();
Clause* c = lwcnf_->createClause (lid);
for (size_t j = 0; j < types.size(); j++) {
LogVar X = c->literals().front().logVars()[j];
if (types[j] == LogVarType::posLvt) {
c->addPosCountedLogVar (X);
} else if (types[j] == LogVarType::negLvt) {
c->addNegCountedLogVar (X);
}
}
c->addLiteralComplemented (c->literals()[0]);
clauses.push_back (c);
}
SmoothNode* smoothNode = new SmoothNode (clauses, *lwcnf_);
*prev = new AndNode (smoothNode, *prev);
if (Globals::verbosity > 1) {
originClausesMap_[*prev] = backupClauses_;
explanationMap_[*prev] = " Smoothing" ;
}
}
}
2013-02-07 13:37:15 +00:00
std::vector<LogVarTypes>
LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
{
2013-02-07 13:37:15 +00:00
std::vector<LogVarTypes> res;
if (nrLogVars == 0) {
// do nothing
2013-01-25 13:58:30 +00:00
} else if (nrLogVars == 1) {
res.push_back ({ LogVarType::posLvt });
res.push_back ({ LogVarType::negLvt });
} else {
Ranges ranges (nrLogVars, 2);
Indexer indexer (ranges);
while (indexer.valid()) {
LogVarTypes types;
for (size_t i = 0; i < nrLogVars; i++) {
if (indexer[i] == 0) {
types.push_back (LogVarType::posLvt);
} else {
types.push_back (LogVarType::negLvt);
}
}
res.push_back (types);
++ indexer;
}
}
return res;
}
bool
LiftedCircuit::containsTypes (
const LogVarTypes& typesA,
const LogVarTypes& typesB) const
{
for (size_t i = 0; i < typesA.size(); i++) {
if (typesA[i] == LogVarType::fullLvt) {
} else if (typesA[i] == LogVarType::posLvt
&& typesB[i] == LogVarType::posLvt) {
} else if (typesA[i] == LogVarType::negLvt
&& typesB[i] == LogVarType::negLvt) {
} else {
return false;
}
}
return true;
}
CircuitNodeType
LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const
{
CircuitNodeType type = CircuitNodeType::orCnt;
2012-12-27 12:54:58 +00:00
if (dynamic_cast<const OrNode*>(node)) {
type = CircuitNodeType::orCnt;
2012-12-27 12:54:58 +00:00
} else if (dynamic_cast<const AndNode*>(node)) {
type = CircuitNodeType::andCnt;
2012-12-27 12:54:58 +00:00
} else if (dynamic_cast<const SetOrNode*>(node)) {
type = CircuitNodeType::setOrCnt;
2012-12-27 12:54:58 +00:00
} else if (dynamic_cast<const SetAndNode*>(node)) {
type = CircuitNodeType::setAndCnt;
2012-12-27 12:54:58 +00:00
} else if (dynamic_cast<const IncExcNode*>(node)) {
type = CircuitNodeType::incExcCnt;
2012-12-27 12:54:58 +00:00
} else if (dynamic_cast<const LeafNode*>(node)) {
type = CircuitNodeType::leafCnt;
2012-12-27 12:54:58 +00:00
} else if (dynamic_cast<const SmoothNode*>(node)) {
type = CircuitNodeType::smoothCnt;
2012-12-27 12:54:58 +00:00
} else if (dynamic_cast<const TrueNode*>(node)) {
type = CircuitNodeType::trueCnt;
2012-12-27 12:54:58 +00:00
} else if (dynamic_cast<const CompilationFailedNode*>(node)) {
type = CircuitNodeType::compilationFailedCnt;
} else {
assert (false);
}
return type;
}
void
2013-02-07 13:37:15 +00:00
LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
{
2012-12-27 12:54:58 +00:00
assert (node);
static unsigned nrAuxNodes = 0;
2013-02-07 13:37:15 +00:00
std::stringstream ss;
ss << "n" << nrAuxNodes;
2013-02-07 13:37:15 +00:00
std::string auxNode = ss.str();
nrAuxNodes ++;
2013-02-07 13:37:15 +00:00
std::string opStyle = "shape=circle,width=0.7,margin=\"0.0,0.0\"," ;
switch (getCircuitNodeType (node)) {
case CircuitNodeType::orCnt: {
OrNode* casted = dynamic_cast<OrNode*>(node);
printClauses (casted, os);
2013-02-07 13:37:15 +00:00
os << auxNode << " [" << opStyle << "label=\"\"]" ;
os << std::endl;
os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->leftBranch());
os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->rightBranch());
os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
exportToGraphViz (*casted->leftBranch(), os);
exportToGraphViz (*casted->rightBranch(), os);
break;
}
case CircuitNodeType::andCnt: {
AndNode* casted = dynamic_cast<AndNode*>(node);
printClauses (casted, os);
2013-02-07 13:37:15 +00:00
os << auxNode << " [" << opStyle << "label=\"\"]" ;
os << std::endl;
os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->leftBranch());
os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
os << auxNode << " -> " ;
2013-02-07 13:37:15 +00:00
os << escapeNode (*casted->rightBranch());
os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
exportToGraphViz (*casted->leftBranch(), os);
exportToGraphViz (*casted->rightBranch(), os);
break;
}
case CircuitNodeType::setOrCnt: {
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
printClauses (casted, os);
2013-02-07 13:37:15 +00:00
os << auxNode << " [" << opStyle << "label=\"(X)\"]" ;
os << std::endl;
os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->follow());
os << " [label=\" " << (*casted->follow())->weight() << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
exportToGraphViz (*casted->follow(), os);
break;
}
case CircuitNodeType::setAndCnt: {
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
printClauses (casted, os);
2013-02-07 13:37:15 +00:00
os << auxNode << " [" << opStyle << "label=\"∧(X)\"]" ;
os << std::endl;
os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->follow());
os << " [label=\" " << (*casted->follow())->weight() << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
exportToGraphViz (*casted->follow(), os);
break;
}
case CircuitNodeType::incExcCnt: {
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
printClauses (casted, os);
os << auxNode << " [" << opStyle << "label=\"+ - +\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->plus1Branch());
os << " [label=\" " << (*casted->plus1Branch())->weight() << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
os << auxNode << " -> " ;
2013-02-07 13:37:15 +00:00
os << escapeNode (*casted->minusBranch()) << std::endl;
os << " [label=\" " << (*casted->minusBranch())->weight() << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->plus2Branch());
os << " [label=\" " << (*casted->plus2Branch())->weight() << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
exportToGraphViz (*casted->plus1Branch(), os);
exportToGraphViz (*casted->plus2Branch(), os);
exportToGraphViz (*casted->minusBranch(), os);
break;
}
case CircuitNodeType::leafCnt: {
printClauses (node, os, "style=filled,fillcolor=palegreen,");
break;
}
case CircuitNodeType::smoothCnt: {
printClauses (node, os, "style=filled,fillcolor=lightblue,");
break;
}
case CircuitNodeType::trueCnt: {
os << escapeNode (node);
os << " [shape=box,label=\"\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
break;
}
case CircuitNodeType::compilationFailedCnt: {
printClauses (node, os, "style=filled,fillcolor=salmon,");
break;
}
default:
assert (false);
}
}
2013-02-07 13:37:15 +00:00
std::string
LiftedCircuit::escapeNode (const CircuitNode* node) const
{
2013-02-07 13:37:15 +00:00
std::stringstream ss;
ss << "\"" << node << "\"" ;
return ss.str();
}
2013-02-07 13:37:15 +00:00
std::string
LiftedCircuit::getExplanationString (CircuitNode* node)
{
return Util::contains (explanationMap_, node)
? explanationMap_[node]
: "" ;
}
void
LiftedCircuit::printClauses (
CircuitNode* node,
2013-02-07 13:37:15 +00:00
std::ofstream& os,
std::string extraOptions)
{
Clauses clauses;
if (Util::contains (originClausesMap_, node)) {
clauses = originClausesMap_[node];
} else if (getCircuitNodeType (node) == CircuitNodeType::leafCnt) {
clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ;
} else if (getCircuitNodeType (node) == CircuitNodeType::smoothCnt) {
clauses = (dynamic_cast<SmoothNode*>(node))->clauses();
}
assert (clauses.empty() == false);
os << escapeNode (node);
os << " [shape=box," << extraOptions << "label=\"" ;
for (size_t i = 0; i < clauses.size(); i++) {
if (i != 0) os << "\\n" ;
os << *clauses[i];
}
os << "\"]" ;
2013-02-07 13:37:15 +00:00
os << std::endl;
}
Params
LiftedKc::solveQuery (const Grounds& query)
{
ParfactorList pfList (parfactorList);
LiftedOperations::shatterAgainstQuery (pfList, query);
LiftedOperations::runWeakBayesBall (pfList, query);
LiftedWCNF lwcnf (pfList);
LiftedCircuit circuit (&lwcnf);
if (circuit.isCompilationSucceeded() == false) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: the circuit compilation has failed." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
2012-11-27 16:54:02 +00:00
}
2013-02-07 13:37:15 +00:00
std::vector<PrvGroup> groups;
Ranges ranges;
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
size_t idx = (*it)->indexOfGround (query[i]);
if (idx != (*it)->nrArguments()) {
groups.push_back ((*it)->argument (idx).group());
ranges.push_back ((*it)->range (idx));
break;
}
++ it;
}
}
assert (groups.size() == query.size());
Params params;
Indexer indexer (ranges);
while (indexer.valid()) {
for (size_t i = 0; i < groups.size(); i++) {
std::vector<LiteralId> litIds = lwcnf.prvGroupLiterals (groups[i]);
for (size_t j = 0; j < litIds.size(); j++) {
if (indexer[i] == j) {
lwcnf.addWeight (litIds[j], LogAware::one(),
LogAware::one());
} else {
lwcnf.addWeight (litIds[j], LogAware::zero(),
LogAware::one());
}
}
}
params.push_back (circuit.getWeightedModelCount());
++ indexer;
}
LogAware::normalize (params);
if (Globals::logDomain) {
Util::exp (params);
}
return params;
}
void
LiftedKc::printSolverFlags() const
{
2013-02-07 13:37:15 +00:00
std::stringstream ss;
ss << "lifted kc [" ;
ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
2013-02-07 13:37:15 +00:00
std::cout << ss.str() << std::endl;
}
} // namespace Horus
2013-02-07 23:53:13 +00:00