cleanups, refactorings & renamings
This commit is contained in:
parent
83c1e58674
commit
07c6509a79
@ -91,8 +91,8 @@ LeafNode::weight (void) const
|
||||
: lwcnf_.negWeight (c.literals().front().lid());
|
||||
LogVarSet lvs = c.constr().logVarSet();
|
||||
lvs -= c.ipgLogVars();
|
||||
lvs -= c.positiveCountedLogVars();
|
||||
lvs -= c.negativeCountedLogVars();
|
||||
lvs -= c.posCountedLogVars();
|
||||
lvs -= c.negCountedLogVars();
|
||||
unsigned nrGroundings = 1;
|
||||
if (lvs.empty() == false) {
|
||||
ConstraintTree ct = c.constr();
|
||||
@ -100,15 +100,15 @@ LeafNode::weight (void) const
|
||||
nrGroundings = ct.size();
|
||||
}
|
||||
// cout << "calc weight for " << clauses().front() << endl;
|
||||
if (c.positiveCountedLogVars().empty() == false) {
|
||||
if (c.posCountedLogVars().empty() == false) {
|
||||
// cout << " -> nr pos = " << SetOrNode::nrPositives() << endl;
|
||||
nrGroundings *= std::pow (SetOrNode::nrPositives(),
|
||||
c.nrPositiveCountedLogVars());
|
||||
c.nrPosCountedLogVars());
|
||||
}
|
||||
if (c.negativeCountedLogVars().empty() == false) {
|
||||
if (c.negCountedLogVars().empty() == false) {
|
||||
//cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl;
|
||||
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
|
||||
c.nrNegativeCountedLogVars());
|
||||
c.nrNegCountedLogVars());
|
||||
}
|
||||
// cout << " -> nr groundings = " << nrGroundings << endl;
|
||||
// cout << " -> lit weight = " << weight << endl;
|
||||
@ -130,8 +130,8 @@ SmoothNode::weight (void) const
|
||||
double negWeight = lwcnf_.negWeight (cs[i].literals()[0].lid());
|
||||
LogVarSet lvs = cs[i].constr().logVarSet();
|
||||
lvs -= cs[i].ipgLogVars();
|
||||
lvs -= cs[i].positiveCountedLogVars();
|
||||
lvs -= cs[i].negativeCountedLogVars();
|
||||
lvs -= cs[i].posCountedLogVars();
|
||||
lvs -= cs[i].negCountedLogVars();
|
||||
unsigned nrGroundings = 1;
|
||||
if (lvs.empty() == false) {
|
||||
ConstraintTree ct = cs[i].constr();
|
||||
@ -139,15 +139,15 @@ SmoothNode::weight (void) const
|
||||
nrGroundings = ct.size();
|
||||
}
|
||||
// cout << "calc smooth weight for " << cs[i] << endl;
|
||||
if (cs[i].positiveCountedLogVars().empty() == false) {
|
||||
if (cs[i].posCountedLogVars().empty() == false) {
|
||||
// cout << " -> nr pos = " << SetOrNode::nrPositives() << endl;
|
||||
nrGroundings *= std::pow (SetOrNode::nrPositives(),
|
||||
cs[i].nrPositiveCountedLogVars());
|
||||
cs[i].nrPosCountedLogVars());
|
||||
}
|
||||
if (cs[i].negativeCountedLogVars().empty() == false) {
|
||||
if (cs[i].negCountedLogVars().empty() == false) {
|
||||
// cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl;
|
||||
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
|
||||
cs[i].nrNegativeCountedLogVars());
|
||||
cs[i].nrNegCountedLogVars());
|
||||
}
|
||||
// cout << " -> pos+neg = " << posWeight + negWeight << endl;
|
||||
// cout << " -> nrgroun = " << nrGroundings << endl;
|
||||
@ -527,8 +527,8 @@ LiftedCircuit::tryAtomCounting (
|
||||
Clauses& clauses)
|
||||
{
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
if (clauses[i].nrPositiveCountedLogVars() > 0
|
||||
|| clauses[i].nrNegativeCountedLogVars() > 0) {
|
||||
if (clauses[i].nrPosCountedLogVars() > 0
|
||||
|| clauses[i].nrNegCountedLogVars() > 0) {
|
||||
// only allow one atom counting node per branch
|
||||
return false;
|
||||
}
|
||||
@ -545,9 +545,9 @@ LiftedCircuit::tryAtomCounting (
|
||||
Clause c1 (clauses[i].constr().projectedCopy (literals[j].logVars()));
|
||||
Clause c2 (clauses[i].constr().projectedCopy (literals[j].logVars()));
|
||||
c1.addLiteral (literals[j]);
|
||||
c2.addLiteralNegated (literals[j]);
|
||||
c1.addPositiveCountedLogVar (literals[j].logVars().front());
|
||||
c2.addNegativeCountedLogVar (literals[j].logVars().front());
|
||||
c2.addLiteralComplemented (literals[j]);
|
||||
c1.addPosCountedLogVar (literals[j].logVars().front());
|
||||
c2.addNegCountedLogVar (literals[j].logVars().front());
|
||||
clauses.push_back (c1);
|
||||
clauses.push_back (c2);
|
||||
shatterCountedLogVars (clauses);
|
||||
@ -633,15 +633,15 @@ LiftedCircuit::shatterCountedLogVarsAux (
|
||||
if (clauses[idx1].isCountedLogVar (lvs1[k])
|
||||
&& clauses[idx2].isCountedLogVar (lvs2[k]) == false) {
|
||||
clauses.push_back (clauses[idx2]);
|
||||
clauses[idx2].addPositiveCountedLogVar (lvs2[k]);
|
||||
clauses.back().addNegativeCountedLogVar (lvs2[k]);
|
||||
clauses[idx2].addPosCountedLogVar (lvs2[k]);
|
||||
clauses.back().addNegCountedLogVar (lvs2[k]);
|
||||
return true;
|
||||
}
|
||||
if (clauses[idx2].isCountedLogVar (lvs2[k])
|
||||
&& clauses[idx1].isCountedLogVar (lvs1[k]) == false) {
|
||||
clauses.push_back (clauses[idx1]);
|
||||
clauses[idx1].addPositiveCountedLogVar (lvs1[k]);
|
||||
clauses.back().addNegativeCountedLogVar (lvs1[k]);
|
||||
clauses[idx1].addPosCountedLogVar (lvs1[k]);
|
||||
clauses.back().addNegCountedLogVar (lvs1[k]);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -704,7 +704,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
}
|
||||
}
|
||||
if (typeFound == false) {
|
||||
missingLids.insert (LiteralLvTypes (litSet[i].first, allTypes[j]));
|
||||
missingLids.insert (LitLvTypes (litSet[i].first, allTypes[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -735,7 +735,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
}
|
||||
|
||||
case CircuitNodeType::LEAF_NODE: {
|
||||
propagLits.insert (LiteralLvTypes (
|
||||
propagLits.insert (LitLvTypes (
|
||||
node->clauses()[0].literals()[0].lid(),
|
||||
node->clauses()[0].logVarTypes(0)));
|
||||
}
|
||||
@ -759,16 +759,16 @@ LiftedCircuit::createSmoothNode (
|
||||
for (size_t i = 0; i < missingLits.size(); i++) {
|
||||
LiteralId lid = missingLits[i].lid();
|
||||
const LogVarTypes& types = missingLits[i].logVarTypes();
|
||||
Clause c = lwcnf_->createClauseForLiteral (lid);
|
||||
Clause c = lwcnf_->createClause (lid);
|
||||
for (size_t j = 0; j < types.size(); j++) {
|
||||
LogVar X = c.literals().front().logVars()[j];
|
||||
if (types[j] == LogVarType::POS_LV) {
|
||||
c.addPositiveCountedLogVar (X);
|
||||
c.addPosCountedLogVar (X);
|
||||
} else if (types[j] == LogVarType::NEG_LV) {
|
||||
c.addNegativeCountedLogVar (X);
|
||||
c.addNegCountedLogVar (X);
|
||||
}
|
||||
}
|
||||
c.addLiteralNegated (c.literals()[0]);
|
||||
c.addLiteralComplemented (c.literals()[0]);
|
||||
clauses.push_back (c);
|
||||
}
|
||||
SmoothNode* smoothNode = new SmoothNode (clauses, *lwcnf_);
|
||||
|
@ -73,6 +73,16 @@ std::ostream& operator<< (ostream &os, const Literal& lit)
|
||||
|
||||
|
||||
|
||||
void
|
||||
Clause::addLiteralComplemented (const Literal& lit)
|
||||
{
|
||||
assert (constr_.logVarSet().contains (lit.logVars()));
|
||||
literals_.push_back (lit);
|
||||
literals_.back().complement();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Clause::containsLiteral (LiteralId lid) const
|
||||
{
|
||||
@ -320,7 +330,7 @@ Clause::getLogVarSetExcluding (size_t idx) const
|
||||
|
||||
|
||||
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
: pfList_(pfList), freeLiteralId_(0)
|
||||
: freeLiteralId_(0), pfList_(pfList)
|
||||
{
|
||||
//addIndicatorClauses (pfList);
|
||||
//addParameterClauses (pfList);
|
||||
@ -350,12 +360,12 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
|
||||
Clause c1 (names);
|
||||
c1.addLiteral (Literal (0, LogVars() = {0}));
|
||||
c1.addLiteralNegated (Literal (1, {0,1}));
|
||||
c1.addLiteralComplemented (Literal (1, {0,1}));
|
||||
clauses_.push_back(c1);
|
||||
|
||||
Clause c2 (names);
|
||||
c2.addLiteral (Literal (0, LogVars()={0}));
|
||||
c2.addLiteralNegated (Literal (1, {1,0}));
|
||||
c2.addLiteralComplemented (Literal (1, {1,0}));
|
||||
clauses_.push_back(c2);
|
||||
|
||||
addWeight (0, 3.0, 4.0);
|
||||
@ -384,8 +394,28 @@ LiftedWCNF::~LiftedWCNF (void)
|
||||
|
||||
|
||||
|
||||
double
|
||||
LiftedWCNF::posWeight (LiteralId lid) const
|
||||
{
|
||||
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||
it = weights_.find (lid);
|
||||
return it != weights_.end() ? it->second.first : 1.0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
LiftedWCNF::negWeight (LiteralId lid) const
|
||||
{
|
||||
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||
it = weights_.find (lid);
|
||||
return it != weights_.end() ? it->second.second : 1.0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Clause
|
||||
LiftedWCNF::createClauseForLiteral (LiteralId lid) const
|
||||
LiftedWCNF::createClause (LiteralId lid) const
|
||||
{
|
||||
for (size_t i = 0; i < clauses_.size(); i++) {
|
||||
const Literals& literals = clauses_[i].literals();
|
||||
@ -399,12 +429,25 @@ LiftedWCNF::createClauseForLiteral (LiteralId lid) const
|
||||
}
|
||||
}
|
||||
}
|
||||
// FIXME
|
||||
Clause c (ConstraintTree({}));
|
||||
c.addLiteral (Literal (lid,LogVars() = {}));
|
||||
return c;
|
||||
//assert (false);
|
||||
//return Clause (0);
|
||||
abort(); // we should not reach this point
|
||||
return Clause (ConstraintTree({}));
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiteralId
|
||||
LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range)
|
||||
{
|
||||
assert (Util::contains (map_, prvGroup));
|
||||
return map_[prvGroup][range];
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
|
||||
{
|
||||
weights_[lid] = make_pair (posW, negW);
|
||||
}
|
||||
|
||||
|
||||
@ -434,8 +477,8 @@ LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
|
||||
ConstraintTree tempConstr2 = *(*it)->constr();
|
||||
tempConstr2.project (formulas[i].logVars());
|
||||
Clause clause2 (tempConstr2);
|
||||
clause2.addLiteralNegated (Literal (clause.literals()[j]));
|
||||
clause2.addLiteralNegated (Literal (clause.literals()[k]));
|
||||
clause2.addLiteralComplemented (Literal (clause.literals()[j]));
|
||||
clause2.addLiteralComplemented (Literal (clause.literals()[k]));
|
||||
clauses_.push_back (clause2);
|
||||
}
|
||||
}
|
||||
@ -470,12 +513,12 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
LiteralId lid = getLiteralId (groups[i], indexer[i]);
|
||||
|
||||
clause1.addLiteralNegated (
|
||||
clause1.addLiteralComplemented (
|
||||
Literal (lid, (*it)->argument(i).logVars()));
|
||||
|
||||
ConstraintTree ct = *(*it)->constr();
|
||||
Clause tempClause (ct);
|
||||
tempClause.addLiteralNegated (Literal (
|
||||
tempClause.addLiteralComplemented (Literal (
|
||||
paramVarLid, (*it)->constr()->logVars()));
|
||||
tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars()));
|
||||
clauses_.push_back (tempClause);
|
||||
|
@ -17,9 +17,10 @@ enum LogVarType
|
||||
NEG_LV
|
||||
};
|
||||
|
||||
|
||||
typedef vector<LogVarType> LogVarTypes;
|
||||
|
||||
|
||||
|
||||
class Literal
|
||||
{
|
||||
public:
|
||||
@ -32,12 +33,12 @@ class Literal
|
||||
LiteralId lid (void) const { return lid_; }
|
||||
|
||||
LogVars logVars (void) const { return logVars_; }
|
||||
|
||||
size_t nrLogVars (void) const { return logVars_.size(); }
|
||||
|
||||
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||
|
||||
size_t nrLogVars (void) const { return logVars_.size(); }
|
||||
|
||||
void negate (void) { negated_ = !negated_; }
|
||||
|
||||
void complement (void) { negated_ = !negated_; }
|
||||
|
||||
bool isPositive (void) const { return negated_ == false; }
|
||||
|
||||
@ -51,12 +52,12 @@ class Literal
|
||||
LogVarSet posCountedLvs = LogVarSet(),
|
||||
LogVarSet negCountedLvs = LogVarSet()) const;
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const Literal& lit);
|
||||
friend std::ostream& operator<< (std::ostream &os, const Literal& lit);
|
||||
|
||||
private:
|
||||
LiteralId lid_;
|
||||
LogVars logVars_;
|
||||
bool negated_;
|
||||
LiteralId lid_;
|
||||
LogVars logVars_;
|
||||
bool negated_;
|
||||
};
|
||||
|
||||
typedef vector<Literal> Literals;
|
||||
@ -72,13 +73,7 @@ class Clause
|
||||
|
||||
void addLiteral (const Literal& l) { literals_.push_back (l); }
|
||||
|
||||
void addLiteralNegated (const Literal& l)
|
||||
{
|
||||
literals_.push_back (l);
|
||||
literals_.back().negate();
|
||||
}
|
||||
|
||||
const vector<Literal>& literals (void) const { return literals_; }
|
||||
const Literals& literals (void) const { return literals_; }
|
||||
|
||||
const ConstraintTree& constr (void) const { return constr_; }
|
||||
|
||||
@ -90,17 +85,19 @@ class Clause
|
||||
|
||||
void addIpgLogVar (LogVar X) { ipgLogVars_.insert (X); }
|
||||
|
||||
void addPositiveCountedLogVar (LogVar X) { posCountedLvs_.insert (X); }
|
||||
void addPosCountedLogVar (LogVar X) { posCountedLvs_.insert (X); }
|
||||
|
||||
void addNegativeCountedLogVar (LogVar X) { negCountedLvs_.insert (X); }
|
||||
void addNegCountedLogVar (LogVar X) { negCountedLvs_.insert (X); }
|
||||
|
||||
LogVarSet positiveCountedLogVars (void) const { return posCountedLvs_; }
|
||||
LogVarSet posCountedLogVars (void) const { return posCountedLvs_; }
|
||||
|
||||
LogVarSet negativeCountedLogVars (void) const { return negCountedLvs_; }
|
||||
LogVarSet negCountedLogVars (void) const { return negCountedLvs_; }
|
||||
|
||||
unsigned nrPositiveCountedLogVars (void) const { return posCountedLvs_.size(); }
|
||||
unsigned nrPosCountedLogVars (void) const { return posCountedLvs_.size(); }
|
||||
|
||||
unsigned nrNegativeCountedLogVars (void) const { return negCountedLvs_.size(); }
|
||||
unsigned nrNegCountedLogVars (void) const { return negCountedLvs_.size(); }
|
||||
|
||||
void addLiteralComplemented (const Literal& lit);
|
||||
|
||||
bool containsLiteral (LiteralId lid) const;
|
||||
|
||||
@ -137,7 +134,7 @@ class Clause
|
||||
private:
|
||||
LogVarSet getLogVarSetExcluding (size_t idx) const;
|
||||
|
||||
vector<Literal> literals_;
|
||||
Literals literals_;
|
||||
LogVarSet ipgLogVars_;
|
||||
LogVarSet posCountedLvs_;
|
||||
LogVarSet negCountedLvs_;
|
||||
@ -148,14 +145,14 @@ typedef vector<Clause> Clauses;
|
||||
|
||||
|
||||
|
||||
class LiteralLvTypes
|
||||
class LitLvTypes
|
||||
{
|
||||
public:
|
||||
struct CompareLiteralLvTypes
|
||||
struct CompareLitLvTypes
|
||||
{
|
||||
bool operator() (
|
||||
const LiteralLvTypes& types1,
|
||||
const LiteralLvTypes& types2) const
|
||||
const LitLvTypes& types1,
|
||||
const LitLvTypes& types2) const
|
||||
{
|
||||
if (types1.lid_ < types2.lid_) {
|
||||
return true;
|
||||
@ -164,7 +161,7 @@ class LiteralLvTypes
|
||||
}
|
||||
};
|
||||
|
||||
LiteralLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
|
||||
LitLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
|
||||
lid_(lid), lvTypes_(lvTypes) { }
|
||||
|
||||
LiteralId lid (void) const { return lid_; }
|
||||
@ -172,14 +169,14 @@ class LiteralLvTypes
|
||||
const LogVarTypes& logVarTypes (void) const { return lvTypes_; }
|
||||
|
||||
void setAllFullLogVars (void) {
|
||||
lvTypes_ = LogVarTypes (lvTypes_.size(), LogVarType::FULL_LV); }
|
||||
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV); }
|
||||
|
||||
private:
|
||||
LiteralId lid_;
|
||||
LogVarTypes lvTypes_;
|
||||
};
|
||||
|
||||
typedef TinySet<LiteralLvTypes,LiteralLvTypes::CompareLiteralLvTypes> LitLvTypesSet;
|
||||
typedef TinySet<LitLvTypes,LitLvTypes::CompareLitLvTypes> LitLvTypesSet;
|
||||
|
||||
|
||||
|
||||
@ -192,21 +189,11 @@ class LiftedWCNF
|
||||
|
||||
const Clauses& clauses (void) const { return clauses_; }
|
||||
|
||||
double posWeight (LiteralId lid) const
|
||||
{
|
||||
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||
it = weights_.find (lid);
|
||||
return it != weights_.end() ? it->second.first : 1.0;
|
||||
}
|
||||
double posWeight (LiteralId lid) const;
|
||||
|
||||
double negWeight (LiteralId lid) const
|
||||
{
|
||||
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||
it = weights_.find (lid);
|
||||
return it != weights_.end() ? it->second.second : 1.0;
|
||||
}
|
||||
double negWeight (LiteralId lid) const;
|
||||
|
||||
Clause createClauseForLiteral (LiteralId lid) const;
|
||||
Clause createClause (LiteralId lid) const;
|
||||
|
||||
void printFormulaIndicators (void) const;
|
||||
|
||||
@ -216,30 +203,19 @@ class LiftedWCNF
|
||||
|
||||
private:
|
||||
|
||||
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range)
|
||||
{
|
||||
assert (Util::contains (map_, prvGroup));
|
||||
return map_[prvGroup][range];
|
||||
}
|
||||
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
|
||||
|
||||
void addWeight (LiteralId lid, double posW, double negW)
|
||||
{
|
||||
weights_[lid] = make_pair (posW, negW);
|
||||
}
|
||||
void addWeight (LiteralId lid, double posW, double negW);
|
||||
|
||||
void addIndicatorClauses (const ParfactorList& pfList);
|
||||
|
||||
void addParameterClauses (const ParfactorList& pfList);
|
||||
|
||||
Clauses clauses_;
|
||||
|
||||
unordered_map<PrvGroup, vector<LiteralId>> map_;
|
||||
|
||||
unordered_map<LiteralId, std::pair<double,double>> weights_;
|
||||
|
||||
const ParfactorList& pfList_;
|
||||
|
||||
LiteralId freeLiteralId_;
|
||||
Clauses clauses_;
|
||||
LiteralId freeLiteralId_;
|
||||
const ParfactorList& pfList_;
|
||||
unordered_map<PrvGroup, vector<LiteralId>> map_;
|
||||
unordered_map<LiteralId, std::pair<double,double>> weights_;
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDWCNF_H
|
||||
|
Reference in New Issue
Block a user