log domain calculations fixes for lifted knowledge compilation

This commit is contained in:
Tiago Gomes 2012-11-14 18:40:03 +00:00
parent 9202e286f8
commit d9e48e6290
2 changed files with 19 additions and 29 deletions

View File

@ -18,6 +18,11 @@ AndNode::weight (void) const
{
double lw = leftBranch_->weight();
double rw = rightBranch_->weight();
if (Globals::logDomain) {
// cout << "andw1 = " << std::exp(lw + rw) << endl;
} else {
// cout << "andw2 = " << lw * rw << endl;
}
return Globals::logDomain ? lw + rw : lw * rw;
}
@ -34,18 +39,11 @@ SetOrNode::weight (void) const
for (unsigned i = 0; i < nrGroundings_ + 1; i++) {
nrGrsStack.push (make_pair (nrGroundings_ - i, i));
if (Globals::logDomain) {
double w = std::log (Util::nrCombinations (nrGroundings_, i));
weightSum = Util::logSum (weightSum, w + follow_->weight());
} else {
// cout << endl;
// cout << "nr groundings = " << nrGroundings_ << endl;
// cout << "nr positives = " << nrPositives() << endl;
// cout << "nr negatives = " << nrNegatives() << endl;
// cout << "i = " << i << endl;
// cout << "nr combos = " ;
// cout << Util::nrCombinations (nrGroundings_, i) << endl;
double nrCombs = Util::nrCombinations (nrGroundings_, i);
double w = follow_->weight();
weightSum = Util::logSum (weightSum, std::log (nrCombs) + w);
} else {
double w = follow_->weight();
// cout << "weight = " << w << endl;
weightSum += Util::nrCombinations (nrGroundings_, i) * w;
}
}
@ -100,20 +98,14 @@ LeafNode::weight (void) const
ct.project (lvs);
nrGroundings = ct.size();
}
// cout << "calc weight for " << clauses().front() << endl;
if (c.posCountedLogVars().empty() == false) {
// cout << " -> nr pos = " << SetOrNode::nrPositives() << endl;
nrGroundings *= std::pow (SetOrNode::nrPositives(),
c.nrPosCountedLogVars());
}
if (c.negCountedLogVars().empty() == false) {
//cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl;
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
c.nrNegCountedLogVars());
}
// cout << " -> nr groundings = " << nrGroundings << endl;
// cout << " -> lit weight = " << weight << endl;
// cout << " -> ret weight = " << std::pow (weight, nrGroundings) << endl;
return Globals::logDomain
? weight * nrGroundings
: std::pow (weight, nrGroundings);
@ -139,26 +131,19 @@ SmoothNode::weight (void) const
ct.project (lvs);
nrGroundings = ct.size();
}
// cout << "calc smooth weight for " << cs[i] << endl;
if (cs[i].posCountedLogVars().empty() == false) {
// cout << " -> nr pos = " << SetOrNode::nrPositives() << endl;
nrGroundings *= std::pow (SetOrNode::nrPositives(),
cs[i].nrPosCountedLogVars());
}
if (cs[i].negCountedLogVars().empty() == false) {
// cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl;
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
cs[i].nrNegCountedLogVars());
}
// cout << " -> pos+neg = " << posWeight + negWeight << endl;
// cout << " -> nrgroun = " << nrGroundings << endl;
if (Globals::logDomain) {
totalWeight += (Util::logSum (posWeight, negWeight)
* std::log (nrGroundings));
totalWeight += Util::logSum (posWeight, negWeight) * nrGroundings;
} else {
totalWeight *= std::pow (posWeight + negWeight, nrGroundings);
}
// cout << " -> smooth weight = " << totalWeight << endl;
}
return totalWeight;
}
@ -195,7 +180,8 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
exportToGraphViz("circuit.smooth.dot");
cout << "--------------------------------------------------" << endl;
cout << "--------------------------------------------------" << endl;
cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl;
double wmc = LogAware::exp (getWeightedModelCount());
cout << "WEIGHTED MODEL COUNT = " << wmc << endl;
}

View File

@ -51,9 +51,11 @@ LiftedKc::solveQuery (const Grounds& query)
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
for (size_t j = 0; j < litIds.size(); j++) {
if (indexer[i] == j) {
lwcnf_->addWeight (litIds[j], 1.0, 1.0); // TODO not log aware
lwcnf_->addWeight (litIds[j], LogAware::one(),
LogAware::one());
} else {
lwcnf_->addWeight (litIds[j], 0.0, 1.0); // TODO not log aware
lwcnf_->addWeight (litIds[j], LogAware::zero(),
LogAware::one());
}
}
}
@ -63,8 +65,10 @@ LiftedKc::solveQuery (const Grounds& query)
params.push_back (circuit_->getWeightedModelCount());
++ indexer;
}
cout << "params: " << params << endl;
LogAware::normalize (params);
if (Globals::logDomain) {
Util::exp (params);
}
return params;
}