From af063dcda86ee2f7461bb272bb2c31a2c9b58721 Mon Sep 17 00:00:00 2001 From: Tiago Gomes Date: Fri, 27 Apr 2012 01:18:54 +0100 Subject: [PATCH] be more precise when calculating the cost of grounding a log var in a formula --- packages/CLPBN/clpbn/bp/FoveSolver.cpp | 103 ++++++++++++++++++------ packages/CLPBN/clpbn/bp/FoveSolver.h | 2 + packages/CLPBN/clpbn/bp/LiftedUtils.h | 5 ++ packages/CLPBN/clpbn/bp/ProbFormula.cpp | 14 ++++ packages/CLPBN/clpbn/bp/ProbFormula.h | 2 + 5 files changed, 103 insertions(+), 23 deletions(-) diff --git a/packages/CLPBN/clpbn/bp/FoveSolver.cpp b/packages/CLPBN/clpbn/bp/FoveSolver.cpp index 08c1c43a2..4205e62c5 100644 --- a/packages/CLPBN/clpbn/bp/FoveSolver.cpp +++ b/packages/CLPBN/clpbn/bp/FoveSolver.cpp @@ -352,31 +352,52 @@ CountingOperator::validOp (Parfactor* g, LogVar X) double GroundOperator::getLogCost (void) { - double cost = std::log (0.0); - vector pfIters; - pfIters = getParfactorsWithGroup (pfList_, group_); - for (unsigned i = 0; i < pfIters.size(); i++) { - Parfactor* pf = *pfIters[i]; - int idx = pf->indexOfGroup (group_); - ProbFormula f = pf->argument (idx); - LogVar X = f.logVars()[lvIndex_]; - double pfCost = 0.0; - bool isCountingLv = pf->countedLogVars().contains (X); - if (isCountingLv) { - int fIdx = pf->indexOfLogVar (X); - unsigned currSize = pf->size(); - unsigned nrHists = pf->range (fIdx); - unsigned nrSymbols = pf->constr()->getConditionalCount (X); - unsigned range = pf->argument (fIdx).range(); - double power = std::log (range) * nrSymbols; - pfCost = std::log (currSize / nrHists) + power; - } else { - unsigned currSize = pf->size(); - pfCost = std::log (pf->constr()->nrSymbols (X) * currSize); + vector> affectedFormulas; + affectedFormulas = getAffectedFormulas(); + // cout << "affected formulas: " ; + // for (unsigned i = 0; i < affectedFormulas.size(); i++) { + // cout << affectedFormulas[i].first << ":" ; + // cout << affectedFormulas[i].second << " " ; + // } + // cout << "cost =" ; + double totalCost = std::log (0.0); + ParfactorList::iterator pflIt = pfList_.begin(); + while (pflIt != pfList_.end()) { + Parfactor* pf = *pflIt; + double reps = 0.0; + double pfSize = std::log (pf->size()); + bool willBeAffected = false; + LogVarSet lvsToGround; + for (unsigned i = 0; i < affectedFormulas.size(); i++) { + int fIdx = pf->indexOfGroup (affectedFormulas[i].first); + if (fIdx != -1) { + ProbFormula f = pf->argument (fIdx); + LogVar X = f.logVars()[affectedFormulas[i].second]; + bool isCountingLv = pf->countedLogVars().contains (X); + if (isCountingLv) { + unsigned nrHists = pf->range (fIdx); + unsigned nrSymbols = pf->constr()->getConditionalCount (X); + unsigned range = pf->argument (fIdx).range(); + double power = std::log (range) * nrSymbols; + pfSize = (pfSize - std::log (nrHists)) + power; + } else { + if (lvsToGround.contains (X) == false) { + reps += std::log (pf->constr()->nrSymbols (X)); + lvsToGround.insert (X); + } + } + willBeAffected = true; + } } - cost = Util::logSum (cost, pfCost); + if (willBeAffected) { + // cout << " + " << std::exp (reps) << "x" << std::exp (pfSize); + double pfCost = reps + pfSize; + totalCost = Util::logSum (totalCost, pfCost); + } + ++ pflIt; } - return cost; + // cout << endl; + return totalCost; } @@ -463,6 +484,42 @@ GroundOperator::toString (void) +vector> +GroundOperator::getAffectedFormulas (void) +{ + vector> affectedFormulas; + affectedFormulas.push_back (make_pair (group_, lvIndex_)); + queue> q; + q.push (make_pair (group_, lvIndex_)); + while (q.empty() == false) { + pair front = q.front(); + ParfactorList::iterator pflIt = pfList_.begin(); + while (pflIt != pfList_.end()) { + int idx = (*pflIt)->indexOfGroup (front.first); + if (idx != -1) { + ProbFormula f = (*pflIt)->argument (idx); + LogVar X = f.logVars()[front.second]; + const ProbFormulas& fs = (*pflIt)->arguments(); + for (unsigned i = 0; i < fs.size(); i++) { + if ((int)i != idx && fs[i].contains (X)) { + pair pair = make_pair ( + fs[i].group(), fs[i].indexOf (X)); + if (Util::contains (affectedFormulas, pair) == false) { + q.push (pair); + affectedFormulas.push_back (pair); + } + } + } + } + ++ pflIt; + } + q.pop(); + } + return affectedFormulas; +} + + + Params FoveSolver::getPosterioriOf (const Ground& query) { diff --git a/packages/CLPBN/clpbn/bp/FoveSolver.h b/packages/CLPBN/clpbn/bp/FoveSolver.h index b26d3d119..5b39aac8b 100644 --- a/packages/CLPBN/clpbn/bp/FoveSolver.h +++ b/packages/CLPBN/clpbn/bp/FoveSolver.h @@ -98,6 +98,8 @@ class GroundOperator : public LiftedOperator string toString (void); private: + vector> getAffectedFormulas (void); + unsigned group_; unsigned lvIndex_; ParfactorList& pfList_; diff --git a/packages/CLPBN/clpbn/bp/LiftedUtils.h b/packages/CLPBN/clpbn/bp/LiftedUtils.h index 7c925fbf4..d4ec22c55 100644 --- a/packages/CLPBN/clpbn/bp/LiftedUtils.h +++ b/packages/CLPBN/clpbn/bp/LiftedUtils.h @@ -141,6 +141,11 @@ class Substitution return subs_.find (X)->second; } + bool containsReplacementFor (LogVar X) const + { + return Util::contains (subs_, X); + } + LogVars getDiscardedLogVars (void) const; friend ostream& operator<< (ostream &os, const Substitution& theta); diff --git a/packages/CLPBN/clpbn/bp/ProbFormula.cpp b/packages/CLPBN/clpbn/bp/ProbFormula.cpp index 68c857285..f27221acf 100644 --- a/packages/CLPBN/clpbn/bp/ProbFormula.cpp +++ b/packages/CLPBN/clpbn/bp/ProbFormula.cpp @@ -29,6 +29,20 @@ ProbFormula::contains (LogVarSet s) const +int +ProbFormula::indexOf (LogVar X) const +{ + int pos = std::distance ( + logVars_.begin(), + std::find (logVars_.begin(), logVars_.end(), X)); + if (pos == (int)logVars_.size()) { + pos = -1; + } + return pos; +} + + + bool ProbFormula::isAtom (void) const { diff --git a/packages/CLPBN/clpbn/bp/ProbFormula.h b/packages/CLPBN/clpbn/bp/ProbFormula.h index 793183ba7..aa6c70bf9 100644 --- a/packages/CLPBN/clpbn/bp/ProbFormula.h +++ b/packages/CLPBN/clpbn/bp/ProbFormula.h @@ -41,6 +41,8 @@ class ProbFormula bool contains (LogVarSet) const; + int indexOf (LogVar) const; + bool isAtom (void) const; bool isCounting (void) const;