diff --git a/packages/CLPBN/horus/LiftedCircuit.cpp b/packages/CLPBN/horus/LiftedCircuit.cpp index f5d43ef8f..b3f5a653a 100644 --- a/packages/CLPBN/horus/LiftedCircuit.cpp +++ b/packages/CLPBN/horus/LiftedCircuit.cpp @@ -18,6 +18,11 @@ AndNode::weight (void) const { double lw = leftBranch_->weight(); double rw = rightBranch_->weight(); + if (Globals::logDomain) { +// cout << "andw1 = " << std::exp(lw + rw) << endl; + } else { +// cout << "andw2 = " << lw * rw << endl; + } return Globals::logDomain ? lw + rw : lw * rw; } @@ -34,18 +39,11 @@ SetOrNode::weight (void) const for (unsigned i = 0; i < nrGroundings_ + 1; i++) { nrGrsStack.push (make_pair (nrGroundings_ - i, i)); if (Globals::logDomain) { - double w = std::log (Util::nrCombinations (nrGroundings_, i)); - weightSum = Util::logSum (weightSum, w + follow_->weight()); - } else { - // cout << endl; - // cout << "nr groundings = " << nrGroundings_ << endl; - // cout << "nr positives = " << nrPositives() << endl; - // cout << "nr negatives = " << nrNegatives() << endl; - // cout << "i = " << i << endl; - // cout << "nr combos = " ; - // cout << Util::nrCombinations (nrGroundings_, i) << endl; + double nrCombs = Util::nrCombinations (nrGroundings_, i); + double w = follow_->weight(); + weightSum = Util::logSum (weightSum, std::log (nrCombs) + w); + } else { double w = follow_->weight(); - // cout << "weight = " << w << endl; weightSum += Util::nrCombinations (nrGroundings_, i) * w; } } @@ -100,20 +98,14 @@ LeafNode::weight (void) const ct.project (lvs); nrGroundings = ct.size(); } - // cout << "calc weight for " << clauses().front() << endl; if (c.posCountedLogVars().empty() == false) { - // cout << " -> nr pos = " << SetOrNode::nrPositives() << endl; nrGroundings *= std::pow (SetOrNode::nrPositives(), c.nrPosCountedLogVars()); } if (c.negCountedLogVars().empty() == false) { - //cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl; nrGroundings *= std::pow (SetOrNode::nrNegatives(), c.nrNegCountedLogVars()); } - // cout << " -> nr groundings = " << nrGroundings << endl; - // cout << " -> lit weight = " << weight << endl; - // cout << " -> ret weight = " << std::pow (weight, nrGroundings) << endl; return Globals::logDomain ? weight * nrGroundings : std::pow (weight, nrGroundings); @@ -139,26 +131,19 @@ SmoothNode::weight (void) const ct.project (lvs); nrGroundings = ct.size(); } - // cout << "calc smooth weight for " << cs[i] << endl; if (cs[i].posCountedLogVars().empty() == false) { - // cout << " -> nr pos = " << SetOrNode::nrPositives() << endl; nrGroundings *= std::pow (SetOrNode::nrPositives(), cs[i].nrPosCountedLogVars()); } if (cs[i].negCountedLogVars().empty() == false) { - // cout << " -> nr neg = " << SetOrNode::nrNegatives() << endl; nrGroundings *= std::pow (SetOrNode::nrNegatives(), cs[i].nrNegCountedLogVars()); } - // cout << " -> pos+neg = " << posWeight + negWeight << endl; - // cout << " -> nrgroun = " << nrGroundings << endl; if (Globals::logDomain) { - totalWeight += (Util::logSum (posWeight, negWeight) - * std::log (nrGroundings)); + totalWeight += Util::logSum (posWeight, negWeight) * nrGroundings; } else { totalWeight *= std::pow (posWeight + negWeight, nrGroundings); } - // cout << " -> smooth weight = " << totalWeight << endl; } return totalWeight; } @@ -195,7 +180,8 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf) exportToGraphViz("circuit.smooth.dot"); cout << "--------------------------------------------------" << endl; cout << "--------------------------------------------------" << endl; - cout << "WEIGHTED MODEL COUNT = " << getWeightedModelCount() << endl; + double wmc = LogAware::exp (getWeightedModelCount()); + cout << "WEIGHTED MODEL COUNT = " << wmc << endl; } diff --git a/packages/CLPBN/horus/LiftedKc.cpp b/packages/CLPBN/horus/LiftedKc.cpp index ee7f56bfe..63d6c1d62 100644 --- a/packages/CLPBN/horus/LiftedKc.cpp +++ b/packages/CLPBN/horus/LiftedKc.cpp @@ -51,9 +51,11 @@ LiftedKc::solveQuery (const Grounds& query) vector litIds = lwcnf_->prvGroupLiterals (groups[i]); for (size_t j = 0; j < litIds.size(); j++) { if (indexer[i] == j) { - lwcnf_->addWeight (litIds[j], 1.0, 1.0); // TODO not log aware + lwcnf_->addWeight (litIds[j], LogAware::one(), + LogAware::one()); } else { - lwcnf_->addWeight (litIds[j], 0.0, 1.0); // TODO not log aware + lwcnf_->addWeight (litIds[j], LogAware::zero(), + LogAware::one()); } } } @@ -63,8 +65,10 @@ LiftedKc::solveQuery (const Grounds& query) params.push_back (circuit_->getWeightedModelCount()); ++ indexer; } - cout << "params: " << params << endl; LogAware::normalize (params); + if (Globals::logDomain) { + Util::exp (params); + } return params; }