log domain calculations fixes for lifted knowledge compilation
This commit is contained in:
parent
9202e286f8
commit
d9e48e6290
@ -18,6 +18,11 @@ AndNode::weight (void) const
|
|||||||
{
|
{
|
||||||
double lw = leftBranch_->weight();
|
double lw = leftBranch_->weight();
|
||||||
double rw = rightBranch_->weight();
|
double rw = rightBranch_->weight();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
// cout << "andw1 = " << std::exp(lw + rw) << endl;
|
||||||
|
} else {
|
||||||
|
// cout << "andw2 = " << lw * rw << endl;
|
||||||
|
}
|
||||||
return Globals::logDomain ? lw + rw : lw * rw;
|
return Globals::logDomain ? lw + rw : lw * rw;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -34,18 +39,11 @@ SetOrNode::weight (void) const
|
|||||||
for (unsigned i = 0; i < nrGroundings_ + 1; i++) {
|
for (unsigned i = 0; i < nrGroundings_ + 1; i++) {
|
||||||
nrGrsStack.push (make_pair (nrGroundings_ - i, i));
|
nrGrsStack.push (make_pair (nrGroundings_ - i, i));
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
double w = std::log (Util::nrCombinations (nrGroundings_, i));
|
double nrCombs = Util::nrCombinations (nrGroundings_, i);
|
||||||
weightSum = Util::logSum (weightSum, w + follow_->weight());
|
double w = follow_->weight();
|
||||||
} else {
|
weightSum = Util::logSum (weightSum, std::log (nrCombs) + w);
|
||||||
// cout << endl;
|
} else {
|
||||||
// cout << "nr groundings = " << nrGroundings_ << endl;
|
|
||||||
// cout << "nr positives = " << nrPositives() << endl;
|
|
||||||
// cout << "nr negatives = " << nrNegatives() << endl;
|
|
||||||
// cout << "i = " << i << endl;
|
|
||||||
// cout << "nr combos = " ;
|
|
||||||
// cout << Util::nrCombinations (nrGroundings_, i) << endl;
|
|
||||||
double w = follow_->weight();
|
double w = follow_->weight();
|
||||||
// cout << "weight = " << w << endl;
|
|
||||||
weightSum += Util::nrCombinations (nrGroundings_, i) * w;
|
weightSum += Util::nrCombinations (nrGroundings_, i) * w;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -100,20 +98,14 @@ LeafNode::weight (void) const
|
|||||||
ct.project (lvs);
|
ct.project (lvs);
|
||||||
nrGroundings = ct.size();
|
nrGroundings = ct.size();
|
||||||
}
|
}
|
||||||
// cout << "calc weight for " << clauses().front() << endl;
|
|
||||||
if (c.posCountedLogVars().empty() == false) {
|
if (c.posCountedLogVars().empty() == false) {
|
||||||
// cout << " -> nr pos = " << SetOrNode::nrPositives() << endl;
|
|
||||||
nrGroundings *= std::pow (SetOrNode::nrPositives(),
|
nrGroundings *= std::pow (SetOrNode::nrPositives(),
|
||||||
c.nrPosCountedLogVars());
|
c.nrPosCountedLogVars());
|
||||||
}
|
}
|
||||||
if (c.negCountedLogVars().empty() == false) {
|
if (c.negCountedLogVars().empty() == false) {
|
||||||
//cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl;
|
|
||||||
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
|
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
|
||||||
c.nrNegCountedLogVars());
|
c.nrNegCountedLogVars());
|
||||||
}
|
}
|
||||||
// cout << " -> nr groundings = " << nrGroundings << endl;
|
|
||||||
// cout << " -> lit weight = " << weight << endl;
|
|
||||||
// cout << " -> ret weight = " << std::pow (weight, nrGroundings) << endl;
|
|
||||||
return Globals::logDomain
|
return Globals::logDomain
|
||||||
? weight * nrGroundings
|
? weight * nrGroundings
|
||||||
: std::pow (weight, nrGroundings);
|
: std::pow (weight, nrGroundings);
|
||||||
@ -139,26 +131,19 @@ SmoothNode::weight (void) const
|
|||||||
ct.project (lvs);
|
ct.project (lvs);
|
||||||
nrGroundings = ct.size();
|
nrGroundings = ct.size();
|
||||||
}
|
}
|
||||||
// cout << "calc smooth weight for " << cs[i] << endl;
|
|
||||||
if (cs[i].posCountedLogVars().empty() == false) {
|
if (cs[i].posCountedLogVars().empty() == false) {
|
||||||
// cout << " -> nr pos = " << SetOrNode::nrPositives() << endl;
|
|
||||||
nrGroundings *= std::pow (SetOrNode::nrPositives(),
|
nrGroundings *= std::pow (SetOrNode::nrPositives(),
|
||||||
cs[i].nrPosCountedLogVars());
|
cs[i].nrPosCountedLogVars());
|
||||||
}
|
}
|
||||||
if (cs[i].negCountedLogVars().empty() == false) {
|
if (cs[i].negCountedLogVars().empty() == false) {
|
||||||
// cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl;
|
|
||||||
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
|
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
|
||||||
cs[i].nrNegCountedLogVars());
|
cs[i].nrNegCountedLogVars());
|
||||||
}
|
}
|
||||||
// cout << " -> pos+neg = " << posWeight + negWeight << endl;
|
|
||||||
// cout << " -> nrgroun = " << nrGroundings << endl;
|
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
totalWeight += (Util::logSum (posWeight, negWeight)
|
totalWeight += Util::logSum (posWeight, negWeight) * nrGroundings;
|
||||||
* std::log (nrGroundings));
|
|
||||||
} else {
|
} else {
|
||||||
totalWeight *= std::pow (posWeight + negWeight, nrGroundings);
|
totalWeight *= std::pow (posWeight + negWeight, nrGroundings);
|
||||||
}
|
}
|
||||||
// cout << " -> smooth weight = " << totalWeight << endl;
|
|
||||||
}
|
}
|
||||||
return totalWeight;
|
return totalWeight;
|
||||||
}
|
}
|
||||||
@ -195,7 +180,8 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
|||||||
exportToGraphViz("circuit.smooth.dot");
|
exportToGraphViz("circuit.smooth.dot");
|
||||||
cout << "--------------------------------------------------" << endl;
|
cout << "--------------------------------------------------" << endl;
|
||||||
cout << "--------------------------------------------------" << endl;
|
cout << "--------------------------------------------------" << endl;
|
||||||
cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl;
|
double wmc = LogAware::exp (getWeightedModelCount());
|
||||||
|
cout << "WEIGHTED MODEL COUNT = " << wmc << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,9 +51,11 @@ LiftedKc::solveQuery (const Grounds& query)
|
|||||||
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
|
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
|
||||||
for (size_t j = 0; j < litIds.size(); j++) {
|
for (size_t j = 0; j < litIds.size(); j++) {
|
||||||
if (indexer[i] == j) {
|
if (indexer[i] == j) {
|
||||||
lwcnf_->addWeight (litIds[j], 1.0, 1.0); // TODO not log aware
|
lwcnf_->addWeight (litIds[j], LogAware::one(),
|
||||||
|
LogAware::one());
|
||||||
} else {
|
} else {
|
||||||
lwcnf_->addWeight (litIds[j], 0.0, 1.0); // TODO not log aware
|
lwcnf_->addWeight (litIds[j], LogAware::zero(),
|
||||||
|
LogAware::one());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -63,8 +65,10 @@ LiftedKc::solveQuery (const Grounds& query)
|
|||||||
params.push_back (circuit_->getWeightedModelCount());
|
params.push_back (circuit_->getWeightedModelCount());
|
||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
cout << "params: " << params << endl;
|
|
||||||
LogAware::normalize (params);
|
LogAware::normalize (params);
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::exp (params);
|
||||||
|
}
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user