log domain calculations fixes for lifted knowledge compilation

This commit is contained in:
Tiago Gomes 2012-11-14 18:40:03 +00:00
parent 9202e286f8
commit d9e48e6290
2 changed files with 19 additions and 29 deletions

View File

@ -18,6 +18,11 @@ AndNode::weight (void) const
{ {
double lw = leftBranch_->weight(); double lw = leftBranch_->weight();
double rw = rightBranch_->weight(); double rw = rightBranch_->weight();
if (Globals::logDomain) {
// cout << "andw1 = " << std::exp(lw + rw) << endl;
} else {
// cout << "andw2 = " << lw * rw << endl;
}
return Globals::logDomain ? lw + rw : lw * rw; return Globals::logDomain ? lw + rw : lw * rw;
} }
@ -34,18 +39,11 @@ SetOrNode::weight (void) const
for (unsigned i = 0; i < nrGroundings_ + 1; i++) { for (unsigned i = 0; i < nrGroundings_ + 1; i++) {
nrGrsStack.push (make_pair (nrGroundings_ - i, i)); nrGrsStack.push (make_pair (nrGroundings_ - i, i));
if (Globals::logDomain) { if (Globals::logDomain) {
double w = std::log (Util::nrCombinations (nrGroundings_, i)); double nrCombs = Util::nrCombinations (nrGroundings_, i);
weightSum = Util::logSum (weightSum, w + follow_->weight()); double w = follow_->weight();
} else { weightSum = Util::logSum (weightSum, std::log (nrCombs) + w);
// cout << endl; } else {
// cout << "nr groundings = " << nrGroundings_ << endl;
// cout << "nr positives = " << nrPositives() << endl;
// cout << "nr negatives = " << nrNegatives() << endl;
// cout << "i = " << i << endl;
// cout << "nr combos = " ;
// cout << Util::nrCombinations (nrGroundings_, i) << endl;
double w = follow_->weight(); double w = follow_->weight();
// cout << "weight = " << w << endl;
weightSum += Util::nrCombinations (nrGroundings_, i) * w; weightSum += Util::nrCombinations (nrGroundings_, i) * w;
} }
} }
@ -100,20 +98,14 @@ LeafNode::weight (void) const
ct.project (lvs); ct.project (lvs);
nrGroundings = ct.size(); nrGroundings = ct.size();
} }
// cout << "calc weight for " << clauses().front() << endl;
if (c.posCountedLogVars().empty() == false) { if (c.posCountedLogVars().empty() == false) {
// cout << " -> nr pos = " << SetOrNode::nrPositives() << endl;
nrGroundings *= std::pow (SetOrNode::nrPositives(), nrGroundings *= std::pow (SetOrNode::nrPositives(),
c.nrPosCountedLogVars()); c.nrPosCountedLogVars());
} }
if (c.negCountedLogVars().empty() == false) { if (c.negCountedLogVars().empty() == false) {
//cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl;
nrGroundings *= std::pow (SetOrNode::nrNegatives(), nrGroundings *= std::pow (SetOrNode::nrNegatives(),
c.nrNegCountedLogVars()); c.nrNegCountedLogVars());
} }
// cout << " -> nr groundings = " << nrGroundings << endl;
// cout << " -> lit weight = " << weight << endl;
// cout << " -> ret weight = " << std::pow (weight, nrGroundings) << endl;
return Globals::logDomain return Globals::logDomain
? weight * nrGroundings ? weight * nrGroundings
: std::pow (weight, nrGroundings); : std::pow (weight, nrGroundings);
@ -139,26 +131,19 @@ SmoothNode::weight (void) const
ct.project (lvs); ct.project (lvs);
nrGroundings = ct.size(); nrGroundings = ct.size();
} }
// cout << "calc smooth weight for " << cs[i] << endl;
if (cs[i].posCountedLogVars().empty() == false) { if (cs[i].posCountedLogVars().empty() == false) {
// cout << " -> nr pos = " << SetOrNode::nrPositives() << endl;
nrGroundings *= std::pow (SetOrNode::nrPositives(), nrGroundings *= std::pow (SetOrNode::nrPositives(),
cs[i].nrPosCountedLogVars()); cs[i].nrPosCountedLogVars());
} }
if (cs[i].negCountedLogVars().empty() == false) { if (cs[i].negCountedLogVars().empty() == false) {
// cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl;
nrGroundings *= std::pow (SetOrNode::nrNegatives(), nrGroundings *= std::pow (SetOrNode::nrNegatives(),
cs[i].nrNegCountedLogVars()); cs[i].nrNegCountedLogVars());
} }
// cout << " -> pos+neg = " << posWeight + negWeight << endl;
// cout << " -> nrgroun = " << nrGroundings << endl;
if (Globals::logDomain) { if (Globals::logDomain) {
totalWeight += (Util::logSum (posWeight, negWeight) totalWeight += Util::logSum (posWeight, negWeight) * nrGroundings;
* std::log (nrGroundings));
} else { } else {
totalWeight *= std::pow (posWeight + negWeight, nrGroundings); totalWeight *= std::pow (posWeight + negWeight, nrGroundings);
} }
// cout << " -> smooth weight = " << totalWeight << endl;
} }
return totalWeight; return totalWeight;
} }
@ -195,7 +180,8 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
exportToGraphViz("circuit.smooth.dot"); exportToGraphViz("circuit.smooth.dot");
cout << "--------------------------------------------------" << endl; cout << "--------------------------------------------------" << endl;
cout << "--------------------------------------------------" << endl; cout << "--------------------------------------------------" << endl;
cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl; double wmc = LogAware::exp (getWeightedModelCount());
cout << "WEIGHTED MODEL COUNT = " << wmc << endl;
} }

View File

@ -51,9 +51,11 @@ LiftedKc::solveQuery (const Grounds& query)
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]); vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
for (size_t j = 0; j < litIds.size(); j++) { for (size_t j = 0; j < litIds.size(); j++) {
if (indexer[i] == j) { if (indexer[i] == j) {
lwcnf_->addWeight (litIds[j], 1.0, 1.0); // TODO not log aware lwcnf_->addWeight (litIds[j], LogAware::one(),
LogAware::one());
} else { } else {
lwcnf_->addWeight (litIds[j], 0.0, 1.0); // TODO not log aware lwcnf_->addWeight (litIds[j], LogAware::zero(),
LogAware::one());
} }
} }
} }
@ -63,8 +65,10 @@ LiftedKc::solveQuery (const Grounds& query)
params.push_back (circuit_->getWeightedModelCount()); params.push_back (circuit_->getWeightedModelCount());
++ indexer; ++ indexer;
} }
cout << "params: " << params << endl;
LogAware::normalize (params); LogAware::normalize (params);
if (Globals::logDomain) {
Util::exp (params);
}
return params; return params;
} }