be more precise when calculating the cost of grounding a log var in a formula
This commit is contained in:
parent
995a11be83
commit
af063dcda8
@ -352,31 +352,52 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
|||||||
double
|
double
|
||||||
GroundOperator::getLogCost (void)
|
GroundOperator::getLogCost (void)
|
||||||
{
|
{
|
||||||
double cost = std::log (0.0);
|
vector<pair<unsigned, unsigned>> affectedFormulas;
|
||||||
vector<ParfactorList::iterator> pfIters;
|
affectedFormulas = getAffectedFormulas();
|
||||||
pfIters = getParfactorsWithGroup (pfList_, group_);
|
// cout << "affected formulas: " ;
|
||||||
for (unsigned i = 0; i < pfIters.size(); i++) {
|
// for (unsigned i = 0; i < affectedFormulas.size(); i++) {
|
||||||
Parfactor* pf = *pfIters[i];
|
// cout << affectedFormulas[i].first << ":" ;
|
||||||
int idx = pf->indexOfGroup (group_);
|
// cout << affectedFormulas[i].second << " " ;
|
||||||
ProbFormula f = pf->argument (idx);
|
// }
|
||||||
LogVar X = f.logVars()[lvIndex_];
|
// cout << "cost =" ;
|
||||||
double pfCost = 0.0;
|
double totalCost = std::log (0.0);
|
||||||
bool isCountingLv = pf->countedLogVars().contains (X);
|
ParfactorList::iterator pflIt = pfList_.begin();
|
||||||
if (isCountingLv) {
|
while (pflIt != pfList_.end()) {
|
||||||
int fIdx = pf->indexOfLogVar (X);
|
Parfactor* pf = *pflIt;
|
||||||
unsigned currSize = pf->size();
|
double reps = 0.0;
|
||||||
unsigned nrHists = pf->range (fIdx);
|
double pfSize = std::log (pf->size());
|
||||||
unsigned nrSymbols = pf->constr()->getConditionalCount (X);
|
bool willBeAffected = false;
|
||||||
unsigned range = pf->argument (fIdx).range();
|
LogVarSet lvsToGround;
|
||||||
double power = std::log (range) * nrSymbols;
|
for (unsigned i = 0; i < affectedFormulas.size(); i++) {
|
||||||
pfCost = std::log (currSize / nrHists) + power;
|
int fIdx = pf->indexOfGroup (affectedFormulas[i].first);
|
||||||
} else {
|
if (fIdx != -1) {
|
||||||
unsigned currSize = pf->size();
|
ProbFormula f = pf->argument (fIdx);
|
||||||
pfCost = std::log (pf->constr()->nrSymbols (X) * currSize);
|
LogVar X = f.logVars()[affectedFormulas[i].second];
|
||||||
|
bool isCountingLv = pf->countedLogVars().contains (X);
|
||||||
|
if (isCountingLv) {
|
||||||
|
unsigned nrHists = pf->range (fIdx);
|
||||||
|
unsigned nrSymbols = pf->constr()->getConditionalCount (X);
|
||||||
|
unsigned range = pf->argument (fIdx).range();
|
||||||
|
double power = std::log (range) * nrSymbols;
|
||||||
|
pfSize = (pfSize - std::log (nrHists)) + power;
|
||||||
|
} else {
|
||||||
|
if (lvsToGround.contains (X) == false) {
|
||||||
|
reps += std::log (pf->constr()->nrSymbols (X));
|
||||||
|
lvsToGround.insert (X);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
willBeAffected = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cost = Util::logSum (cost, pfCost);
|
if (willBeAffected) {
|
||||||
|
// cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
|
||||||
|
double pfCost = reps + pfSize;
|
||||||
|
totalCost = Util::logSum (totalCost, pfCost);
|
||||||
|
}
|
||||||
|
++ pflIt;
|
||||||
}
|
}
|
||||||
return cost;
|
// cout << endl;
|
||||||
|
return totalCost;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -463,6 +484,42 @@ GroundOperator::toString (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<pair<unsigned, unsigned>>
|
||||||
|
GroundOperator::getAffectedFormulas (void)
|
||||||
|
{
|
||||||
|
vector<pair<unsigned, unsigned>> affectedFormulas;
|
||||||
|
affectedFormulas.push_back (make_pair (group_, lvIndex_));
|
||||||
|
queue<pair<unsigned, unsigned>> q;
|
||||||
|
q.push (make_pair (group_, lvIndex_));
|
||||||
|
while (q.empty() == false) {
|
||||||
|
pair<unsigned, unsigned> front = q.front();
|
||||||
|
ParfactorList::iterator pflIt = pfList_.begin();
|
||||||
|
while (pflIt != pfList_.end()) {
|
||||||
|
int idx = (*pflIt)->indexOfGroup (front.first);
|
||||||
|
if (idx != -1) {
|
||||||
|
ProbFormula f = (*pflIt)->argument (idx);
|
||||||
|
LogVar X = f.logVars()[front.second];
|
||||||
|
const ProbFormulas& fs = (*pflIt)->arguments();
|
||||||
|
for (unsigned i = 0; i < fs.size(); i++) {
|
||||||
|
if ((int)i != idx && fs[i].contains (X)) {
|
||||||
|
pair<unsigned, unsigned> pair = make_pair (
|
||||||
|
fs[i].group(), fs[i].indexOf (X));
|
||||||
|
if (Util::contains (affectedFormulas, pair) == false) {
|
||||||
|
q.push (pair);
|
||||||
|
affectedFormulas.push_back (pair);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++ pflIt;
|
||||||
|
}
|
||||||
|
q.pop();
|
||||||
|
}
|
||||||
|
return affectedFormulas;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FoveSolver::getPosterioriOf (const Ground& query)
|
FoveSolver::getPosterioriOf (const Ground& query)
|
||||||
{
|
{
|
||||||
|
@ -98,6 +98,8 @@ class GroundOperator : public LiftedOperator
|
|||||||
string toString (void);
|
string toString (void);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
vector<pair<unsigned, unsigned>> getAffectedFormulas (void);
|
||||||
|
|
||||||
unsigned group_;
|
unsigned group_;
|
||||||
unsigned lvIndex_;
|
unsigned lvIndex_;
|
||||||
ParfactorList& pfList_;
|
ParfactorList& pfList_;
|
||||||
|
@ -141,6 +141,11 @@ class Substitution
|
|||||||
return subs_.find (X)->second;
|
return subs_.find (X)->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool containsReplacementFor (LogVar X) const
|
||||||
|
{
|
||||||
|
return Util::contains (subs_, X);
|
||||||
|
}
|
||||||
|
|
||||||
LogVars getDiscardedLogVars (void) const;
|
LogVars getDiscardedLogVars (void) const;
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
||||||
|
@ -29,6 +29,20 @@ ProbFormula::contains (LogVarSet s) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
ProbFormula::indexOf (LogVar X) const
|
||||||
|
{
|
||||||
|
int pos = std::distance (
|
||||||
|
logVars_.begin(),
|
||||||
|
std::find (logVars_.begin(), logVars_.end(), X));
|
||||||
|
if (pos == (int)logVars_.size()) {
|
||||||
|
pos = -1;
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
ProbFormula::isAtom (void) const
|
ProbFormula::isAtom (void) const
|
||||||
{
|
{
|
||||||
|
@ -41,6 +41,8 @@ class ProbFormula
|
|||||||
|
|
||||||
bool contains (LogVarSet) const;
|
bool contains (LogVarSet) const;
|
||||||
|
|
||||||
|
int indexOf (LogVar) const;
|
||||||
|
|
||||||
bool isAtom (void) const;
|
bool isAtom (void) const;
|
||||||
|
|
||||||
bool isCounting (void) const;
|
bool isCounting (void) const;
|
||||||
|
Reference in New Issue
Block a user