be more precise when calculating the cost of grounding a log var in a formula
This commit is contained in:
parent
995a11be83
commit
af063dcda8
@ -352,31 +352,52 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
||||
double
|
||||
GroundOperator::getLogCost (void)
|
||||
{
|
||||
double cost = std::log (0.0);
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = getParfactorsWithGroup (pfList_, group_);
|
||||
for (unsigned i = 0; i < pfIters.size(); i++) {
|
||||
Parfactor* pf = *pfIters[i];
|
||||
int idx = pf->indexOfGroup (group_);
|
||||
ProbFormula f = pf->argument (idx);
|
||||
LogVar X = f.logVars()[lvIndex_];
|
||||
double pfCost = 0.0;
|
||||
bool isCountingLv = pf->countedLogVars().contains (X);
|
||||
if (isCountingLv) {
|
||||
int fIdx = pf->indexOfLogVar (X);
|
||||
unsigned currSize = pf->size();
|
||||
unsigned nrHists = pf->range (fIdx);
|
||||
unsigned nrSymbols = pf->constr()->getConditionalCount (X);
|
||||
unsigned range = pf->argument (fIdx).range();
|
||||
double power = std::log (range) * nrSymbols;
|
||||
pfCost = std::log (currSize / nrHists) + power;
|
||||
} else {
|
||||
unsigned currSize = pf->size();
|
||||
pfCost = std::log (pf->constr()->nrSymbols (X) * currSize);
|
||||
vector<pair<unsigned, unsigned>> affectedFormulas;
|
||||
affectedFormulas = getAffectedFormulas();
|
||||
// cout << "affected formulas: " ;
|
||||
// for (unsigned i = 0; i < affectedFormulas.size(); i++) {
|
||||
// cout << affectedFormulas[i].first << ":" ;
|
||||
// cout << affectedFormulas[i].second << " " ;
|
||||
// }
|
||||
// cout << "cost =" ;
|
||||
double totalCost = std::log (0.0);
|
||||
ParfactorList::iterator pflIt = pfList_.begin();
|
||||
while (pflIt != pfList_.end()) {
|
||||
Parfactor* pf = *pflIt;
|
||||
double reps = 0.0;
|
||||
double pfSize = std::log (pf->size());
|
||||
bool willBeAffected = false;
|
||||
LogVarSet lvsToGround;
|
||||
for (unsigned i = 0; i < affectedFormulas.size(); i++) {
|
||||
int fIdx = pf->indexOfGroup (affectedFormulas[i].first);
|
||||
if (fIdx != -1) {
|
||||
ProbFormula f = pf->argument (fIdx);
|
||||
LogVar X = f.logVars()[affectedFormulas[i].second];
|
||||
bool isCountingLv = pf->countedLogVars().contains (X);
|
||||
if (isCountingLv) {
|
||||
unsigned nrHists = pf->range (fIdx);
|
||||
unsigned nrSymbols = pf->constr()->getConditionalCount (X);
|
||||
unsigned range = pf->argument (fIdx).range();
|
||||
double power = std::log (range) * nrSymbols;
|
||||
pfSize = (pfSize - std::log (nrHists)) + power;
|
||||
} else {
|
||||
if (lvsToGround.contains (X) == false) {
|
||||
reps += std::log (pf->constr()->nrSymbols (X));
|
||||
lvsToGround.insert (X);
|
||||
}
|
||||
}
|
||||
willBeAffected = true;
|
||||
}
|
||||
}
|
||||
cost = Util::logSum (cost, pfCost);
|
||||
if (willBeAffected) {
|
||||
// cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
|
||||
double pfCost = reps + pfSize;
|
||||
totalCost = Util::logSum (totalCost, pfCost);
|
||||
}
|
||||
++ pflIt;
|
||||
}
|
||||
return cost;
|
||||
// cout << endl;
|
||||
return totalCost;
|
||||
}
|
||||
|
||||
|
||||
@ -463,6 +484,42 @@ GroundOperator::toString (void)
|
||||
|
||||
|
||||
|
||||
vector<pair<unsigned, unsigned>>
|
||||
GroundOperator::getAffectedFormulas (void)
|
||||
{
|
||||
vector<pair<unsigned, unsigned>> affectedFormulas;
|
||||
affectedFormulas.push_back (make_pair (group_, lvIndex_));
|
||||
queue<pair<unsigned, unsigned>> q;
|
||||
q.push (make_pair (group_, lvIndex_));
|
||||
while (q.empty() == false) {
|
||||
pair<unsigned, unsigned> front = q.front();
|
||||
ParfactorList::iterator pflIt = pfList_.begin();
|
||||
while (pflIt != pfList_.end()) {
|
||||
int idx = (*pflIt)->indexOfGroup (front.first);
|
||||
if (idx != -1) {
|
||||
ProbFormula f = (*pflIt)->argument (idx);
|
||||
LogVar X = f.logVars()[front.second];
|
||||
const ProbFormulas& fs = (*pflIt)->arguments();
|
||||
for (unsigned i = 0; i < fs.size(); i++) {
|
||||
if ((int)i != idx && fs[i].contains (X)) {
|
||||
pair<unsigned, unsigned> pair = make_pair (
|
||||
fs[i].group(), fs[i].indexOf (X));
|
||||
if (Util::contains (affectedFormulas, pair) == false) {
|
||||
q.push (pair);
|
||||
affectedFormulas.push_back (pair);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
++ pflIt;
|
||||
}
|
||||
q.pop();
|
||||
}
|
||||
return affectedFormulas;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
FoveSolver::getPosterioriOf (const Ground& query)
|
||||
{
|
||||
|
@ -98,6 +98,8 @@ class GroundOperator : public LiftedOperator
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
vector<pair<unsigned, unsigned>> getAffectedFormulas (void);
|
||||
|
||||
unsigned group_;
|
||||
unsigned lvIndex_;
|
||||
ParfactorList& pfList_;
|
||||
|
@ -141,6 +141,11 @@ class Substitution
|
||||
return subs_.find (X)->second;
|
||||
}
|
||||
|
||||
bool containsReplacementFor (LogVar X) const
|
||||
{
|
||||
return Util::contains (subs_, X);
|
||||
}
|
||||
|
||||
LogVars getDiscardedLogVars (void) const;
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
||||
|
@ -29,6 +29,20 @@ ProbFormula::contains (LogVarSet s) const
|
||||
|
||||
|
||||
|
||||
int
|
||||
ProbFormula::indexOf (LogVar X) const
|
||||
{
|
||||
int pos = std::distance (
|
||||
logVars_.begin(),
|
||||
std::find (logVars_.begin(), logVars_.end(), X));
|
||||
if (pos == (int)logVars_.size()) {
|
||||
pos = -1;
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ProbFormula::isAtom (void) const
|
||||
{
|
||||
|
@ -41,6 +41,8 @@ class ProbFormula
|
||||
|
||||
bool contains (LogVarSet) const;
|
||||
|
||||
int indexOf (LogVar) const;
|
||||
|
||||
bool isAtom (void) const;
|
||||
|
||||
bool isCounting (void) const;
|
||||
|
Reference in New Issue
Block a user