be more precise when calculating the cost of grounding a log var in a formula

This commit is contained in:
Tiago Gomes 2012-04-27 01:18:54 +01:00
parent 995a11be83
commit af063dcda8
5 changed files with 103 additions and 23 deletions

View File

@ -352,31 +352,52 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
double
GroundOperator::getLogCost (void)
{
double cost = std::log (0.0);
vector<ParfactorList::iterator> pfIters;
pfIters = getParfactorsWithGroup (pfList_, group_);
for (unsigned i = 0; i < pfIters.size(); i++) {
Parfactor* pf = *pfIters[i];
int idx = pf->indexOfGroup (group_);
ProbFormula f = pf->argument (idx);
LogVar X = f.logVars()[lvIndex_];
double pfCost = 0.0;
bool isCountingLv = pf->countedLogVars().contains (X);
if (isCountingLv) {
int fIdx = pf->indexOfLogVar (X);
unsigned currSize = pf->size();
unsigned nrHists = pf->range (fIdx);
unsigned nrSymbols = pf->constr()->getConditionalCount (X);
unsigned range = pf->argument (fIdx).range();
double power = std::log (range) * nrSymbols;
pfCost = std::log (currSize / nrHists) + power;
} else {
unsigned currSize = pf->size();
pfCost = std::log (pf->constr()->nrSymbols (X) * currSize);
vector<pair<unsigned, unsigned>> affectedFormulas;
affectedFormulas = getAffectedFormulas();
// cout << "affected formulas: " ;
// for (unsigned i = 0; i < affectedFormulas.size(); i++) {
// cout << affectedFormulas[i].first << ":" ;
// cout << affectedFormulas[i].second << " " ;
// }
// cout << "cost =" ;
double totalCost = std::log (0.0);
ParfactorList::iterator pflIt = pfList_.begin();
while (pflIt != pfList_.end()) {
Parfactor* pf = *pflIt;
double reps = 0.0;
double pfSize = std::log (pf->size());
bool willBeAffected = false;
LogVarSet lvsToGround;
for (unsigned i = 0; i < affectedFormulas.size(); i++) {
int fIdx = pf->indexOfGroup (affectedFormulas[i].first);
if (fIdx != -1) {
ProbFormula f = pf->argument (fIdx);
LogVar X = f.logVars()[affectedFormulas[i].second];
bool isCountingLv = pf->countedLogVars().contains (X);
if (isCountingLv) {
unsigned nrHists = pf->range (fIdx);
unsigned nrSymbols = pf->constr()->getConditionalCount (X);
unsigned range = pf->argument (fIdx).range();
double power = std::log (range) * nrSymbols;
pfSize = (pfSize - std::log (nrHists)) + power;
} else {
if (lvsToGround.contains (X) == false) {
reps += std::log (pf->constr()->nrSymbols (X));
lvsToGround.insert (X);
}
}
willBeAffected = true;
}
}
cost = Util::logSum (cost, pfCost);
if (willBeAffected) {
// cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
double pfCost = reps + pfSize;
totalCost = Util::logSum (totalCost, pfCost);
}
++ pflIt;
}
return cost;
// cout << endl;
return totalCost;
}
@ -463,6 +484,42 @@ GroundOperator::toString (void)
vector<pair<unsigned, unsigned>>
GroundOperator::getAffectedFormulas (void)
{
vector<pair<unsigned, unsigned>> affectedFormulas;
affectedFormulas.push_back (make_pair (group_, lvIndex_));
queue<pair<unsigned, unsigned>> q;
q.push (make_pair (group_, lvIndex_));
while (q.empty() == false) {
pair<unsigned, unsigned> front = q.front();
ParfactorList::iterator pflIt = pfList_.begin();
while (pflIt != pfList_.end()) {
int idx = (*pflIt)->indexOfGroup (front.first);
if (idx != -1) {
ProbFormula f = (*pflIt)->argument (idx);
LogVar X = f.logVars()[front.second];
const ProbFormulas& fs = (*pflIt)->arguments();
for (unsigned i = 0; i < fs.size(); i++) {
if ((int)i != idx && fs[i].contains (X)) {
pair<unsigned, unsigned> pair = make_pair (
fs[i].group(), fs[i].indexOf (X));
if (Util::contains (affectedFormulas, pair) == false) {
q.push (pair);
affectedFormulas.push_back (pair);
}
}
}
}
++ pflIt;
}
q.pop();
}
return affectedFormulas;
}
Params
FoveSolver::getPosterioriOf (const Ground& query)
{

View File

@ -98,6 +98,8 @@ class GroundOperator : public LiftedOperator
string toString (void);
private:
vector<pair<unsigned, unsigned>> getAffectedFormulas (void);
unsigned group_;
unsigned lvIndex_;
ParfactorList& pfList_;

View File

@ -141,6 +141,11 @@ class Substitution
return subs_.find (X)->second;
}
bool containsReplacementFor (LogVar X) const
{
return Util::contains (subs_, X);
}
LogVars getDiscardedLogVars (void) const;
friend ostream& operator<< (ostream &os, const Substitution& theta);

View File

@ -29,6 +29,20 @@ ProbFormula::contains (LogVarSet s) const
int
ProbFormula::indexOf (LogVar X) const
{
int pos = std::distance (
logVars_.begin(),
std::find (logVars_.begin(), logVars_.end(), X));
if (pos == (int)logVars_.size()) {
pos = -1;
}
return pos;
}
bool
ProbFormula::isAtom (void) const
{

View File

@ -41,6 +41,8 @@ class ProbFormula
bool contains (LogVarSet) const;
int indexOf (LogVar) const;
bool isAtom (void) const;
bool isCounting (void) const;