cleanups, refactorings & renamings

This commit is contained in:
Tiago Gomes 2012-11-07 15:28:33 +00:00
parent 83c1e58674
commit 07c6509a79
3 changed files with 121 additions and 102 deletions

View File

@ -91,8 +91,8 @@ LeafNode::weight (void) const
: lwcnf_.negWeight (c.literals().front().lid());
LogVarSet lvs = c.constr().logVarSet();
lvs -= c.ipgLogVars();
lvs -= c.positiveCountedLogVars();
lvs -= c.negativeCountedLogVars();
lvs -= c.posCountedLogVars();
lvs -= c.negCountedLogVars();
unsigned nrGroundings = 1;
if (lvs.empty() == false) {
ConstraintTree ct = c.constr();
@ -100,15 +100,15 @@ LeafNode::weight (void) const
nrGroundings = ct.size();
}
// cout << "calc weight for " << clauses().front() << endl;
if (c.positiveCountedLogVars().empty() == false) {
if (c.posCountedLogVars().empty() == false) {
// cout << " -> nr pos = " << SetOrNode::nrPositives() << endl;
nrGroundings *= std::pow (SetOrNode::nrPositives(),
c.nrPositiveCountedLogVars());
c.nrPosCountedLogVars());
}
if (c.negativeCountedLogVars().empty() == false) {
if (c.negCountedLogVars().empty() == false) {
//cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl;
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
c.nrNegativeCountedLogVars());
c.nrNegCountedLogVars());
}
// cout << " -> nr groundings = " << nrGroundings << endl;
// cout << " -> lit weight = " << weight << endl;
@ -130,8 +130,8 @@ SmoothNode::weight (void) const
double negWeight = lwcnf_.negWeight (cs[i].literals()[0].lid());
LogVarSet lvs = cs[i].constr().logVarSet();
lvs -= cs[i].ipgLogVars();
lvs -= cs[i].positiveCountedLogVars();
lvs -= cs[i].negativeCountedLogVars();
lvs -= cs[i].posCountedLogVars();
lvs -= cs[i].negCountedLogVars();
unsigned nrGroundings = 1;
if (lvs.empty() == false) {
ConstraintTree ct = cs[i].constr();
@ -139,15 +139,15 @@ SmoothNode::weight (void) const
nrGroundings = ct.size();
}
// cout << "calc smooth weight for " << cs[i] << endl;
if (cs[i].positiveCountedLogVars().empty() == false) {
if (cs[i].posCountedLogVars().empty() == false) {
// cout << " -> nr pos = " << SetOrNode::nrPositives() << endl;
nrGroundings *= std::pow (SetOrNode::nrPositives(),
cs[i].nrPositiveCountedLogVars());
cs[i].nrPosCountedLogVars());
}
if (cs[i].negativeCountedLogVars().empty() == false) {
if (cs[i].negCountedLogVars().empty() == false) {
// cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl;
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
cs[i].nrNegativeCountedLogVars());
cs[i].nrNegCountedLogVars());
}
// cout << " -> pos+neg = " << posWeight + negWeight << endl;
// cout << " -> nrgroun = " << nrGroundings << endl;
@ -527,8 +527,8 @@ LiftedCircuit::tryAtomCounting (
Clauses& clauses)
{
for (size_t i = 0; i < clauses.size(); i++) {
if (clauses[i].nrPositiveCountedLogVars() > 0
|| clauses[i].nrNegativeCountedLogVars() > 0) {
if (clauses[i].nrPosCountedLogVars() > 0
|| clauses[i].nrNegCountedLogVars() > 0) {
// only allow one atom counting node per branch
return false;
}
@ -545,9 +545,9 @@ LiftedCircuit::tryAtomCounting (
Clause c1 (clauses[i].constr().projectedCopy (literals[j].logVars()));
Clause c2 (clauses[i].constr().projectedCopy (literals[j].logVars()));
c1.addLiteral (literals[j]);
c2.addLiteralNegated (literals[j]);
c1.addPositiveCountedLogVar (literals[j].logVars().front());
c2.addNegativeCountedLogVar (literals[j].logVars().front());
c2.addLiteralComplemented (literals[j]);
c1.addPosCountedLogVar (literals[j].logVars().front());
c2.addNegCountedLogVar (literals[j].logVars().front());
clauses.push_back (c1);
clauses.push_back (c2);
shatterCountedLogVars (clauses);
@ -633,15 +633,15 @@ LiftedCircuit::shatterCountedLogVarsAux (
if (clauses[idx1].isCountedLogVar (lvs1[k])
&& clauses[idx2].isCountedLogVar (lvs2[k]) == false) {
clauses.push_back (clauses[idx2]);
clauses[idx2].addPositiveCountedLogVar (lvs2[k]);
clauses.back().addNegativeCountedLogVar (lvs2[k]);
clauses[idx2].addPosCountedLogVar (lvs2[k]);
clauses.back().addNegCountedLogVar (lvs2[k]);
return true;
}
if (clauses[idx2].isCountedLogVar (lvs2[k])
&& clauses[idx1].isCountedLogVar (lvs1[k]) == false) {
clauses.push_back (clauses[idx1]);
clauses[idx1].addPositiveCountedLogVar (lvs1[k]);
clauses.back().addNegativeCountedLogVar (lvs1[k]);
clauses[idx1].addPosCountedLogVar (lvs1[k]);
clauses.back().addNegCountedLogVar (lvs1[k]);
return true;
}
}
@ -704,7 +704,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
}
}
if (typeFound == false) {
missingLids.insert (LiteralLvTypes (litSet[i].first, allTypes[j]));
missingLids.insert (LitLvTypes (litSet[i].first, allTypes[j]));
}
}
}
@ -735,7 +735,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
}
case CircuitNodeType::LEAF_NODE: {
propagLits.insert (LiteralLvTypes (
propagLits.insert (LitLvTypes (
node->clauses()[0].literals()[0].lid(),
node->clauses()[0].logVarTypes(0)));
}
@ -759,16 +759,16 @@ LiftedCircuit::createSmoothNode (
for (size_t i = 0; i < missingLits.size(); i++) {
LiteralId lid = missingLits[i].lid();
const LogVarTypes& types = missingLits[i].logVarTypes();
Clause c = lwcnf_->createClauseForLiteral (lid);
Clause c = lwcnf_->createClause (lid);
for (size_t j = 0; j < types.size(); j++) {
LogVar X = c.literals().front().logVars()[j];
if (types[j] == LogVarType::POS_LV) {
c.addPositiveCountedLogVar (X);
c.addPosCountedLogVar (X);
} else if (types[j] == LogVarType::NEG_LV) {
c.addNegativeCountedLogVar (X);
c.addNegCountedLogVar (X);
}
}
c.addLiteralNegated (c.literals()[0]);
c.addLiteralComplemented (c.literals()[0]);
clauses.push_back (c);
}
SmoothNode* smoothNode = new SmoothNode (clauses, *lwcnf_);

View File

@ -73,6 +73,16 @@ std::ostream& operator<< (ostream &os, const Literal& lit)
void
Clause::addLiteralComplemented (const Literal& lit)
{
assert (constr_.logVarSet().contains (lit.logVars()));
literals_.push_back (lit);
literals_.back().complement();
}
bool
Clause::containsLiteral (LiteralId lid) const
{
@ -320,7 +330,7 @@ Clause::getLogVarSetExcluding (size_t idx) const
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
: pfList_(pfList), freeLiteralId_(0)
: freeLiteralId_(0), pfList_(pfList)
{
//addIndicatorClauses (pfList);
//addParameterClauses (pfList);
@ -350,12 +360,12 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
Clause c1 (names);
c1.addLiteral (Literal (0, LogVars() = {0}));
c1.addLiteralNegated (Literal (1, {0,1}));
c1.addLiteralComplemented (Literal (1, {0,1}));
clauses_.push_back(c1);
Clause c2 (names);
c2.addLiteral (Literal (0, LogVars()={0}));
c2.addLiteralNegated (Literal (1, {1,0}));
c2.addLiteralComplemented (Literal (1, {1,0}));
clauses_.push_back(c2);
addWeight (0, 3.0, 4.0);
@ -384,8 +394,28 @@ LiftedWCNF::~LiftedWCNF (void)
double
LiftedWCNF::posWeight (LiteralId lid) const
{
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
it = weights_.find (lid);
return it != weights_.end() ? it->second.first : 1.0;
}
double
LiftedWCNF::negWeight (LiteralId lid) const
{
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
it = weights_.find (lid);
return it != weights_.end() ? it->second.second : 1.0;
}
Clause
LiftedWCNF::createClauseForLiteral (LiteralId lid) const
LiftedWCNF::createClause (LiteralId lid) const
{
for (size_t i = 0; i < clauses_.size(); i++) {
const Literals& literals = clauses_[i].literals();
@ -399,12 +429,25 @@ LiftedWCNF::createClauseForLiteral (LiteralId lid) const
}
}
}
// FIXME
Clause c (ConstraintTree({}));
c.addLiteral (Literal (lid,LogVars() = {}));
return c;
//assert (false);
//return Clause (0);
abort(); // we should not reach this point
return Clause (ConstraintTree({}));
}
LiteralId
LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range)
{
assert (Util::contains (map_, prvGroup));
return map_[prvGroup][range];
}
void
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
{
weights_[lid] = make_pair (posW, negW);
}
@ -434,8 +477,8 @@ LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
ConstraintTree tempConstr2 = *(*it)->constr();
tempConstr2.project (formulas[i].logVars());
Clause clause2 (tempConstr2);
clause2.addLiteralNegated (Literal (clause.literals()[j]));
clause2.addLiteralNegated (Literal (clause.literals()[k]));
clause2.addLiteralComplemented (Literal (clause.literals()[j]));
clause2.addLiteralComplemented (Literal (clause.literals()[k]));
clauses_.push_back (clause2);
}
}
@ -470,12 +513,12 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
for (unsigned i = 0; i < groups.size(); i++) {
LiteralId lid = getLiteralId (groups[i], indexer[i]);
clause1.addLiteralNegated (
clause1.addLiteralComplemented (
Literal (lid, (*it)->argument(i).logVars()));
ConstraintTree ct = *(*it)->constr();
Clause tempClause (ct);
tempClause.addLiteralNegated (Literal (
tempClause.addLiteralComplemented (Literal (
paramVarLid, (*it)->constr()->logVars()));
tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars()));
clauses_.push_back (tempClause);

View File

@ -17,9 +17,10 @@ enum LogVarType
NEG_LV
};
typedef vector<LogVarType> LogVarTypes;
class Literal
{
public:
@ -33,11 +34,11 @@ class Literal
LogVars logVars (void) const { return logVars_; }
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
size_t nrLogVars (void) const { return logVars_.size(); }
void negate (void) { negated_ = !negated_; }
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
void complement (void) { negated_ = !negated_; }
bool isPositive (void) const { return negated_ == false; }
@ -51,12 +52,12 @@ class Literal
LogVarSet posCountedLvs = LogVarSet(),
LogVarSet negCountedLvs = LogVarSet()) const;
friend std::ostream& operator<< (ostream &os, const Literal& lit);
friend std::ostream& operator<< (std::ostream &os, const Literal& lit);
private:
LiteralId lid_;
LogVars logVars_;
bool negated_;
LiteralId lid_;
LogVars logVars_;
bool negated_;
};
typedef vector<Literal> Literals;
@ -72,13 +73,7 @@ class Clause
void addLiteral (const Literal& l) { literals_.push_back (l); }
void addLiteralNegated (const Literal& l)
{
literals_.push_back (l);
literals_.back().negate();
}
const vector<Literal>& literals (void) const { return literals_; }
const Literals& literals (void) const { return literals_; }
const ConstraintTree& constr (void) const { return constr_; }
@ -90,17 +85,19 @@ class Clause
void addIpgLogVar (LogVar X) { ipgLogVars_.insert (X); }
void addPositiveCountedLogVar (LogVar X) { posCountedLvs_.insert (X); }
void addPosCountedLogVar (LogVar X) { posCountedLvs_.insert (X); }
void addNegativeCountedLogVar (LogVar X) { negCountedLvs_.insert (X); }
void addNegCountedLogVar (LogVar X) { negCountedLvs_.insert (X); }
LogVarSet positiveCountedLogVars (void) const { return posCountedLvs_; }
LogVarSet posCountedLogVars (void) const { return posCountedLvs_; }
LogVarSet negativeCountedLogVars (void) const { return negCountedLvs_; }
LogVarSet negCountedLogVars (void) const { return negCountedLvs_; }
unsigned nrPositiveCountedLogVars (void) const { return posCountedLvs_.size(); }
unsigned nrPosCountedLogVars (void) const { return posCountedLvs_.size(); }
unsigned nrNegativeCountedLogVars (void) const { return negCountedLvs_.size(); }
unsigned nrNegCountedLogVars (void) const { return negCountedLvs_.size(); }
void addLiteralComplemented (const Literal& lit);
bool containsLiteral (LiteralId lid) const;
@ -137,7 +134,7 @@ class Clause
private:
LogVarSet getLogVarSetExcluding (size_t idx) const;
vector<Literal> literals_;
Literals literals_;
LogVarSet ipgLogVars_;
LogVarSet posCountedLvs_;
LogVarSet negCountedLvs_;
@ -148,14 +145,14 @@ typedef vector<Clause> Clauses;
class LiteralLvTypes
class LitLvTypes
{
public:
struct CompareLiteralLvTypes
struct CompareLitLvTypes
{
bool operator() (
const LiteralLvTypes& types1,
const LiteralLvTypes& types2) const
const LitLvTypes& types1,
const LitLvTypes& types2) const
{
if (types1.lid_ < types2.lid_) {
return true;
@ -164,7 +161,7 @@ class LiteralLvTypes
}
};
LiteralLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
LitLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
lid_(lid), lvTypes_(lvTypes) { }
LiteralId lid (void) const { return lid_; }
@ -172,14 +169,14 @@ class LiteralLvTypes
const LogVarTypes& logVarTypes (void) const { return lvTypes_; }
void setAllFullLogVars (void) {
lvTypes_ = LogVarTypes (lvTypes_.size(), LogVarType::FULL_LV); }
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV); }
private:
LiteralId lid_;
LogVarTypes lvTypes_;
};
typedef TinySet<LiteralLvTypes,LiteralLvTypes::CompareLiteralLvTypes> LitLvTypesSet;
typedef TinySet<LitLvTypes,LitLvTypes::CompareLitLvTypes> LitLvTypesSet;
@ -192,21 +189,11 @@ class LiftedWCNF
const Clauses& clauses (void) const { return clauses_; }
double posWeight (LiteralId lid) const
{
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
it = weights_.find (lid);
return it != weights_.end() ? it->second.first : 1.0;
}
double posWeight (LiteralId lid) const;
double negWeight (LiteralId lid) const
{
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
it = weights_.find (lid);
return it != weights_.end() ? it->second.second : 1.0;
}
double negWeight (LiteralId lid) const;
Clause createClauseForLiteral (LiteralId lid) const;
Clause createClause (LiteralId lid) const;
void printFormulaIndicators (void) const;
@ -216,30 +203,19 @@ class LiftedWCNF
private:
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range)
{
assert (Util::contains (map_, prvGroup));
return map_[prvGroup][range];
}
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
void addWeight (LiteralId lid, double posW, double negW)
{
weights_[lid] = make_pair (posW, negW);
}
void addWeight (LiteralId lid, double posW, double negW);
void addIndicatorClauses (const ParfactorList& pfList);
void addParameterClauses (const ParfactorList& pfList);
Clauses clauses_;
unordered_map<PrvGroup, vector<LiteralId>> map_;
unordered_map<LiteralId, std::pair<double,double>> weights_;
const ParfactorList& pfList_;
LiteralId freeLiteralId_;
Clauses clauses_;
LiteralId freeLiteralId_;
const ParfactorList& pfList_;
unordered_map<PrvGroup, vector<LiteralId>> map_;
unordered_map<LiteralId, std::pair<double,double>> weights_;
};
#endif // HORUS_LIFTEDWCNF_H