refactor the way we calculate the grounding cost

This commit is contained in:
Tiago Gomes 2012-04-19 17:59:45 +01:00
parent 661ce08961
commit 2b7da4bc23
2 changed files with 116 additions and 63 deletions

View File

@ -37,25 +37,55 @@ LiftedOperator::printValidOps (
vector<LiftedOperator*> validOps; vector<LiftedOperator*> validOps;
validOps = LiftedOperator::getValidOps (pfList, query); validOps = LiftedOperator::getValidOps (pfList, query);
for (unsigned i = 0; i < validOps.size(); i++) { for (unsigned i = 0; i < validOps.size(); i++) {
cout << "-> " << validOps[i]->toString() << endl; cout << "-> " << validOps[i]->toString();
delete validOps[i]; delete validOps[i];
} }
} }
vector<unsigned>
LiftedOperator::getAllGroupss (ParfactorList& )
{
return { };
}
vector<ParfactorList::iterator>
LiftedOperator::getParfactorsWithGroup (
ParfactorList& pfList, unsigned group)
{
vector<ParfactorList::iterator> iters;
ParfactorList::iterator pflIt = pfList.begin();
while (pflIt != pfList.end()) {
if ((*pflIt)->containsGroup (group)) {
iters.push_back (pflIt);
}
++ pflIt;
}
return iters;
}
double double
SumOutOperator::getLogCost (void) SumOutOperator::getLogCost (void)
{ {
TinySet<unsigned> groupSet; TinySet<unsigned> groupSet;
ParfactorList::const_iterator pfIter = pfList_.begin(); ParfactorList::const_iterator pfIter = pfList_.begin();
unsigned nrProdFactors = 0;
while (pfIter != pfList_.end()) { while (pfIter != pfList_.end()) {
if ((*pfIter)->containsGroup (group_)) { if ((*pfIter)->containsGroup (group_)) {
vector<unsigned> groups = (*pfIter)->getAllGroups(); vector<unsigned> groups = (*pfIter)->getAllGroups();
groupSet |= TinySet<unsigned> (groups); groupSet |= TinySet<unsigned> (groups);
++ nrProdFactors;
} }
++ pfIter; ++ pfIter;
} }
if (nrProdFactors == 1) {
return 1.0; // best possible case
}
double cost = 1.0; double cost = 1.0;
for (unsigned i = 0; i < groupSet.size(); i++) { for (unsigned i = 0; i < groupSet.size(); i++) {
pfIter = pfList_.begin(); pfIter = pfList_.begin();
@ -76,8 +106,8 @@ SumOutOperator::getLogCost (void)
void void
SumOutOperator::apply (void) SumOutOperator::apply (void)
{ {
vector<ParfactorList::iterator> iters vector<ParfactorList::iterator> iters;
= parfactorsWithGroup (pfList_, group_); iters = getParfactorsWithGroup (pfList_, group_);
Parfactor* product = *(iters[0]); Parfactor* product = *(iters[0]);
pfList_.remove (iters[0]); pfList_.remove (iters[0]);
for (unsigned i = 1; i < iters.size(); i++) { for (unsigned i = 1; i < iters.size(); i++) {
@ -137,13 +167,13 @@ SumOutOperator::toString (void)
{ {
stringstream ss; stringstream ss;
vector<ParfactorList::iterator> pfIters; vector<ParfactorList::iterator> pfIters;
pfIters = parfactorsWithGroup (pfList_, group_); pfIters = getParfactorsWithGroup (pfList_, group_);
int idx = (*pfIters[0])->indexOfGroup (group_); int idx = (*pfIters[0])->indexOfGroup (group_);
ProbFormula f = (*pfIters[0])->argument (idx); ProbFormula f = (*pfIters[0])->argument (idx);
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars()); TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
ss << "sum out " << f.functor() << "/" << f.arity(); ss << "sum out " << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")"; ss << "|" << tupleSet << " (group " << group_ << ")";
ss << " [cost=" << getLogCost() << "]" << endl; ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
return ss.str(); return ss.str();
} }
@ -156,7 +186,7 @@ SumOutOperator::validOp (
const Grounds& query) const Grounds& query)
{ {
vector<ParfactorList::iterator> pfIters; vector<ParfactorList::iterator> pfIters;
pfIters = parfactorsWithGroup (pfList, group); pfIters = getParfactorsWithGroup (pfList, group);
if (isToEliminate (*pfIters[0], group, query) == false) { if (isToEliminate (*pfIters[0], group, query) == false) {
return false; return false;
} }
@ -186,24 +216,6 @@ SumOutOperator::validOp (
vector<ParfactorList::iterator>
SumOutOperator::parfactorsWithGroup (
ParfactorList& pfList,
unsigned group)
{
vector<ParfactorList::iterator> iters;
ParfactorList::iterator pflIt = pfList.begin();
while (pflIt != pfList.end()) {
if ((*pflIt)->containsGroup (group)) {
iters.push_back (pflIt);
}
++ pflIt;
}
return iters;
}
bool bool
SumOutOperator::isToEliminate ( SumOutOperator::isToEliminate (
Parfactor* g, Parfactor* g,
@ -295,7 +307,7 @@ CountingOperator::toString (void)
stringstream ss; stringstream ss;
ss << "count convert " << X_ << " in " ; ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel(); ss << (*pfIter_)->getLabel();
ss << " [cost=" << getLogCost() << "]" << endl; ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_); Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) { if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (unsigned i = 0; i < pfs.size(); i++) { for (unsigned i = 0; i < pfs.size(); i++) {
@ -337,19 +349,29 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
double double
GroundOperator::getLogCost (void) GroundOperator::getLogCost (void)
{ {
double cost = 0.0; double cost = std::log (0.0);
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_); vector<ParfactorList::iterator> pfIters;
if (isCountingLv) { pfIters = getParfactorsWithGroup (pfList_, group_);
int fIdx = (*pfIter_)->indexOfLogVar (X_); for (unsigned i = 0; i < pfIters.size(); i++) {
unsigned currSize = (*pfIter_)->size(); Parfactor* pf = *pfIters[i];
unsigned nrHists = (*pfIter_)->range (fIdx); int idx = pf->indexOfGroup (group_);
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_); ProbFormula f = pf->argument (idx);
unsigned range = (*pfIter_)->argument (fIdx).range(); LogVar X = f.logVars()[lvIndex_];
double power = std::log (range) * nrSymbols; double pfCost = 0.0;
cost = std::log (currSize / nrHists) + power; bool isCountingLv = pf->countedLogVars().contains (X);
} else { if (isCountingLv) {
unsigned currSize = (*pfIter_)->size(); int fIdx = pf->indexOfLogVar (X);
cost = std::log ((*pfIter_)->constr()->nrSymbols (X_) * currSize); unsigned currSize = pf->size();
unsigned nrHists = pf->range (fIdx);
unsigned nrSymbols = pf->constr()->getConditionalCount (X);
unsigned range = pf->argument (fIdx).range();
double power = std::log (range) * nrSymbols;
pfCost = std::log (currSize / nrHists) + power;
} else {
unsigned currSize = pf->size();
pfCost = std::log (pf->constr()->nrSymbols (X) * currSize);
}
cost = Util::logSum (cost, pfCost);
} }
return cost; return cost;
} }
@ -359,14 +381,21 @@ GroundOperator::getLogCost (void)
void void
GroundOperator::apply (void) GroundOperator::apply (void)
{ {
bool countedLv = (*pfIter_)->countedLogVars().contains (X_); // TODO if we update the correct groups
Parfactor* pf = *pfIter_; // we can skip shattering
pfList_.remove (pfIter_); ParfactorList::iterator pfIter;
pfIter = getParfactorsWithGroup (pfList_, group_).front();
Parfactor* pf = *pfIter;
int idx = pf->indexOfGroup (group_);
ProbFormula f = pf->argument (idx);
LogVar X = f.logVars()[lvIndex_];
bool countedLv = pf->countedLogVars().contains (X);
pfList_.remove (pfIter);
if (countedLv) { if (countedLv) {
pf->fullExpand (X_); pf->fullExpand (X);
pfList_.add (pf); pfList_.add (pf);
} else { } else {
ConstraintTrees cts = pf->constr()->ground (X_); ConstraintTrees cts = pf->constr()->ground (X);
for (unsigned i = 0; i < cts.size(); i++) { for (unsigned i = 0; i < cts.size(); i++) {
pfList_.add (new Parfactor (pf, cts[i])); pfList_.add (new Parfactor (pf, cts[i]));
} }
@ -380,15 +409,23 @@ vector<GroundOperator*>
GroundOperator::getValidOps (ParfactorList& pfList) GroundOperator::getValidOps (ParfactorList& pfList)
{ {
vector<GroundOperator*> validOps; vector<GroundOperator*> validOps;
ParfactorList::iterator pfIter = pfList.begin(); set<unsigned> allGroups;
while (pfIter != pfList.end()) { ParfactorList::const_iterator it = pfList.begin();
LogVarSet set = (*pfIter)->logVarSet(); while (it != pfList.end()) {
for (unsigned i = 0; i < set.size(); i++) { const ProbFormulas& formulas = (*it)->arguments();
if ((*pfIter)->constr()->isSingleton (set[i]) == false) { for (unsigned i = 0; i < formulas.size(); i++) {
validOps.push_back (new GroundOperator (pfIter, set[i], pfList)); if (Util::contains (allGroups, formulas[i].group()) == false) {
const LogVars& lvs = formulas[i].logVars();
for (unsigned j = 0; j < lvs.size(); j++) {
if ((*it)->constr()->isSingleton (lvs[j]) == false) {
validOps.push_back (new GroundOperator (
formulas[i].group(), j, pfList));
}
}
allGroups.insert (formulas[i].group());
} }
} }
++ pfIter; ++ it;
} }
return validOps; return validOps;
} }
@ -399,11 +436,25 @@ string
GroundOperator::toString (void) GroundOperator::toString (void)
{ {
stringstream ss; stringstream ss;
((*pfIter_)->countedLogVars().contains (X_)) vector<ParfactorList::iterator> pfIters;
? ss << "full expanding " pfIters = getParfactorsWithGroup (pfList_, group_);
: ss << "grounding " ; Parfactor* pf = *(getParfactorsWithGroup (pfList_, group_).front());
ss << X_ << " in " << (*pfIter_)->getLabel(); int idx = pf->indexOfGroup (group_);
ss << " [cost=" << getLogCost() << "]" << endl; ProbFormula f = pf->argument (idx);
LogVar lv = f.logVars()[lvIndex_];
TupleSet tupleSet = pf->constr()->tupleSet ({lv});
string pos = "th";
if (lvIndex_ == 0) {
pos = "st" ;
} else if (lvIndex_ == 1) {
pos = "nd" ;
} else if (lvIndex_ == 2) {
pos = "rd" ;
}
ss << "grounding " << lvIndex_ + 1 << pos << " log var in " ;
ss << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")";
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
return ss.str(); return ss.str();
} }

View File

@ -18,6 +18,11 @@ class LiftedOperator
ParfactorList&, const Grounds&); ParfactorList&, const Grounds&);
static void printValidOps (ParfactorList&, const Grounds&); static void printValidOps (ParfactorList&, const Grounds&);
static vector<unsigned> getAllGroupss (ParfactorList&);
static vector<ParfactorList::iterator> getParfactorsWithGroup (
ParfactorList&, unsigned group);
}; };
@ -40,9 +45,6 @@ class SumOutOperator : public LiftedOperator
private: private:
static bool validOp (unsigned, ParfactorList&, const Grounds&); static bool validOp (unsigned, ParfactorList&, const Grounds&);
static vector<ParfactorList::iterator> parfactorsWithGroup (
ParfactorList& pfList, unsigned group);
static bool isToEliminate (Parfactor*, unsigned, const Grounds&); static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
unsigned group_; unsigned group_;
@ -82,10 +84,10 @@ class GroundOperator : public LiftedOperator
{ {
public: public:
GroundOperator ( GroundOperator (
ParfactorList::iterator pfIter, unsigned group,
LogVar X, unsigned lvIndex,
ParfactorList& pfList) ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { } : group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
double getLogCost (void); double getLogCost (void);
@ -96,8 +98,8 @@ class GroundOperator : public LiftedOperator
string toString (void); string toString (void);
private: private:
ParfactorList::iterator pfIter_; unsigned group_;
LogVar X_; unsigned lvIndex_;
ParfactorList& pfList_; ParfactorList& pfList_;
}; };