log domain calculations fixes for lifted knowledge compilation
This commit is contained in:
parent
9202e286f8
commit
d9e48e6290
@ -18,6 +18,11 @@ AndNode::weight (void) const
|
||||
{
|
||||
double lw = leftBranch_->weight();
|
||||
double rw = rightBranch_->weight();
|
||||
if (Globals::logDomain) {
|
||||
// cout << "andw1 = " << std::exp(lw + rw) << endl;
|
||||
} else {
|
||||
// cout << "andw2 = " << lw * rw << endl;
|
||||
}
|
||||
return Globals::logDomain ? lw + rw : lw * rw;
|
||||
}
|
||||
|
||||
@ -34,18 +39,11 @@ SetOrNode::weight (void) const
|
||||
for (unsigned i = 0; i < nrGroundings_ + 1; i++) {
|
||||
nrGrsStack.push (make_pair (nrGroundings_ - i, i));
|
||||
if (Globals::logDomain) {
|
||||
double w = std::log (Util::nrCombinations (nrGroundings_, i));
|
||||
weightSum = Util::logSum (weightSum, w + follow_->weight());
|
||||
} else {
|
||||
// cout << endl;
|
||||
// cout << "nr groundings = " << nrGroundings_ << endl;
|
||||
// cout << "nr positives = " << nrPositives() << endl;
|
||||
// cout << "nr negatives = " << nrNegatives() << endl;
|
||||
// cout << "i = " << i << endl;
|
||||
// cout << "nr combos = " ;
|
||||
// cout << Util::nrCombinations (nrGroundings_, i) << endl;
|
||||
double nrCombs = Util::nrCombinations (nrGroundings_, i);
|
||||
double w = follow_->weight();
|
||||
weightSum = Util::logSum (weightSum, std::log (nrCombs) + w);
|
||||
} else {
|
||||
double w = follow_->weight();
|
||||
// cout << "weight = " << w << endl;
|
||||
weightSum += Util::nrCombinations (nrGroundings_, i) * w;
|
||||
}
|
||||
}
|
||||
@ -100,20 +98,14 @@ LeafNode::weight (void) const
|
||||
ct.project (lvs);
|
||||
nrGroundings = ct.size();
|
||||
}
|
||||
// cout << "calc weight for " << clauses().front() << endl;
|
||||
if (c.posCountedLogVars().empty() == false) {
|
||||
// cout << " -> nr pos = " << SetOrNode::nrPositives() << endl;
|
||||
nrGroundings *= std::pow (SetOrNode::nrPositives(),
|
||||
c.nrPosCountedLogVars());
|
||||
}
|
||||
if (c.negCountedLogVars().empty() == false) {
|
||||
//cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl;
|
||||
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
|
||||
c.nrNegCountedLogVars());
|
||||
}
|
||||
// cout << " -> nr groundings = " << nrGroundings << endl;
|
||||
// cout << " -> lit weight = " << weight << endl;
|
||||
// cout << " -> ret weight = " << std::pow (weight, nrGroundings) << endl;
|
||||
return Globals::logDomain
|
||||
? weight * nrGroundings
|
||||
: std::pow (weight, nrGroundings);
|
||||
@ -139,26 +131,19 @@ SmoothNode::weight (void) const
|
||||
ct.project (lvs);
|
||||
nrGroundings = ct.size();
|
||||
}
|
||||
// cout << "calc smooth weight for " << cs[i] << endl;
|
||||
if (cs[i].posCountedLogVars().empty() == false) {
|
||||
// cout << " -> nr pos = " << SetOrNode::nrPositives() << endl;
|
||||
nrGroundings *= std::pow (SetOrNode::nrPositives(),
|
||||
cs[i].nrPosCountedLogVars());
|
||||
}
|
||||
if (cs[i].negCountedLogVars().empty() == false) {
|
||||
// cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl;
|
||||
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
|
||||
cs[i].nrNegCountedLogVars());
|
||||
}
|
||||
// cout << " -> pos+neg = " << posWeight + negWeight << endl;
|
||||
// cout << " -> nrgroun = " << nrGroundings << endl;
|
||||
if (Globals::logDomain) {
|
||||
totalWeight += (Util::logSum (posWeight, negWeight)
|
||||
* std::log (nrGroundings));
|
||||
totalWeight += Util::logSum (posWeight, negWeight) * nrGroundings;
|
||||
} else {
|
||||
totalWeight *= std::pow (posWeight + negWeight, nrGroundings);
|
||||
}
|
||||
// cout << " -> smooth weight = " << totalWeight << endl;
|
||||
}
|
||||
return totalWeight;
|
||||
}
|
||||
@ -195,7 +180,8 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
||||
exportToGraphViz("circuit.smooth.dot");
|
||||
cout << "--------------------------------------------------" << endl;
|
||||
cout << "--------------------------------------------------" << endl;
|
||||
cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl;
|
||||
double wmc = LogAware::exp (getWeightedModelCount());
|
||||
cout << "WEIGHTED MODEL COUNT = " << wmc << endl;
|
||||
}
|
||||
|
||||
|
||||
|
@ -51,9 +51,11 @@ LiftedKc::solveQuery (const Grounds& query)
|
||||
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
|
||||
for (size_t j = 0; j < litIds.size(); j++) {
|
||||
if (indexer[i] == j) {
|
||||
lwcnf_->addWeight (litIds[j], 1.0, 1.0); // TODO not log aware
|
||||
lwcnf_->addWeight (litIds[j], LogAware::one(),
|
||||
LogAware::one());
|
||||
} else {
|
||||
lwcnf_->addWeight (litIds[j], 0.0, 1.0); // TODO not log aware
|
||||
lwcnf_->addWeight (litIds[j], LogAware::zero(),
|
||||
LogAware::one());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -63,8 +65,10 @@ LiftedKc::solveQuery (const Grounds& query)
|
||||
params.push_back (circuit_->getWeightedModelCount());
|
||||
++ indexer;
|
||||
}
|
||||
cout << "params: " << params << endl;
|
||||
LogAware::normalize (params);
|
||||
if (Globals::logDomain) {
|
||||
Util::exp (params);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user