be more precise when calculating the cost of grounding a log var in a formula

This commit is contained in:
Tiago Gomes 2012-04-27 01:18:54 +01:00
parent 995a11be83
commit af063dcda8
5 changed files with 103 additions and 23 deletions

View File

@ -352,31 +352,52 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
double double
GroundOperator::getLogCost (void) GroundOperator::getLogCost (void)
{ {
double cost = std::log (0.0); vector<pair<unsigned, unsigned>> affectedFormulas;
vector<ParfactorList::iterator> pfIters; affectedFormulas = getAffectedFormulas();
pfIters = getParfactorsWithGroup (pfList_, group_); // cout << "affected formulas: " ;
for (unsigned i = 0; i < pfIters.size(); i++) { // for (unsigned i = 0; i < affectedFormulas.size(); i++) {
Parfactor* pf = *pfIters[i]; // cout << affectedFormulas[i].first << ":" ;
int idx = pf->indexOfGroup (group_); // cout << affectedFormulas[i].second << " " ;
ProbFormula f = pf->argument (idx); // }
LogVar X = f.logVars()[lvIndex_]; // cout << "cost =" ;
double pfCost = 0.0; double totalCost = std::log (0.0);
bool isCountingLv = pf->countedLogVars().contains (X); ParfactorList::iterator pflIt = pfList_.begin();
if (isCountingLv) { while (pflIt != pfList_.end()) {
int fIdx = pf->indexOfLogVar (X); Parfactor* pf = *pflIt;
unsigned currSize = pf->size(); double reps = 0.0;
unsigned nrHists = pf->range (fIdx); double pfSize = std::log (pf->size());
unsigned nrSymbols = pf->constr()->getConditionalCount (X); bool willBeAffected = false;
unsigned range = pf->argument (fIdx).range(); LogVarSet lvsToGround;
double power = std::log (range) * nrSymbols; for (unsigned i = 0; i < affectedFormulas.size(); i++) {
pfCost = std::log (currSize / nrHists) + power; int fIdx = pf->indexOfGroup (affectedFormulas[i].first);
} else { if (fIdx != -1) {
unsigned currSize = pf->size(); ProbFormula f = pf->argument (fIdx);
pfCost = std::log (pf->constr()->nrSymbols (X) * currSize); LogVar X = f.logVars()[affectedFormulas[i].second];
bool isCountingLv = pf->countedLogVars().contains (X);
if (isCountingLv) {
unsigned nrHists = pf->range (fIdx);
unsigned nrSymbols = pf->constr()->getConditionalCount (X);
unsigned range = pf->argument (fIdx).range();
double power = std::log (range) * nrSymbols;
pfSize = (pfSize - std::log (nrHists)) + power;
} else {
if (lvsToGround.contains (X) == false) {
reps += std::log (pf->constr()->nrSymbols (X));
lvsToGround.insert (X);
}
}
willBeAffected = true;
}
} }
cost = Util::logSum (cost, pfCost); if (willBeAffected) {
// cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
double pfCost = reps + pfSize;
totalCost = Util::logSum (totalCost, pfCost);
}
++ pflIt;
} }
return cost; // cout << endl;
return totalCost;
} }
@ -463,6 +484,42 @@ GroundOperator::toString (void)
vector<pair<unsigned, unsigned>>
GroundOperator::getAffectedFormulas (void)
{
vector<pair<unsigned, unsigned>> affectedFormulas;
affectedFormulas.push_back (make_pair (group_, lvIndex_));
queue<pair<unsigned, unsigned>> q;
q.push (make_pair (group_, lvIndex_));
while (q.empty() == false) {
pair<unsigned, unsigned> front = q.front();
ParfactorList::iterator pflIt = pfList_.begin();
while (pflIt != pfList_.end()) {
int idx = (*pflIt)->indexOfGroup (front.first);
if (idx != -1) {
ProbFormula f = (*pflIt)->argument (idx);
LogVar X = f.logVars()[front.second];
const ProbFormulas& fs = (*pflIt)->arguments();
for (unsigned i = 0; i < fs.size(); i++) {
if ((int)i != idx && fs[i].contains (X)) {
pair<unsigned, unsigned> pair = make_pair (
fs[i].group(), fs[i].indexOf (X));
if (Util::contains (affectedFormulas, pair) == false) {
q.push (pair);
affectedFormulas.push_back (pair);
}
}
}
}
++ pflIt;
}
q.pop();
}
return affectedFormulas;
}
Params Params
FoveSolver::getPosterioriOf (const Ground& query) FoveSolver::getPosterioriOf (const Ground& query)
{ {

View File

@ -98,6 +98,8 @@ class GroundOperator : public LiftedOperator
string toString (void); string toString (void);
private: private:
vector<pair<unsigned, unsigned>> getAffectedFormulas (void);
unsigned group_; unsigned group_;
unsigned lvIndex_; unsigned lvIndex_;
ParfactorList& pfList_; ParfactorList& pfList_;

View File

@ -141,6 +141,11 @@ class Substitution
return subs_.find (X)->second; return subs_.find (X)->second;
} }
bool containsReplacementFor (LogVar X) const
{
return Util::contains (subs_, X);
}
LogVars getDiscardedLogVars (void) const; LogVars getDiscardedLogVars (void) const;
friend ostream& operator<< (ostream &os, const Substitution& theta); friend ostream& operator<< (ostream &os, const Substitution& theta);

View File

@ -29,6 +29,20 @@ ProbFormula::contains (LogVarSet s) const
int
ProbFormula::indexOf (LogVar X) const
{
int pos = std::distance (
logVars_.begin(),
std::find (logVars_.begin(), logVars_.end(), X));
if (pos == (int)logVars_.size()) {
pos = -1;
}
return pos;
}
bool bool
ProbFormula::isAtom (void) const ProbFormula::isAtom (void) const
{ {

View File

@ -41,6 +41,8 @@ class ProbFormula
bool contains (LogVarSet) const; bool contains (LogVarSet) const;
int indexOf (LogVar) const;
bool isAtom (void) const; bool isAtom (void) const;
bool isCounting (void) const; bool isCounting (void) const;