refactor the way we calculate the grounding cost

This commit is contained in:
Tiago Gomes 2012-04-19 17:59:45 +01:00
parent 661ce08961
commit 2b7da4bc23
2 changed files with 116 additions and 63 deletions

View File

@ -37,25 +37,55 @@ LiftedOperator::printValidOps (
vector<LiftedOperator*> validOps;
validOps = LiftedOperator::getValidOps (pfList, query);
for (unsigned i = 0; i < validOps.size(); i++) {
cout << "-> " << validOps[i]->toString() << endl;
cout << "-> " << validOps[i]->toString();
delete validOps[i];
}
}
vector<unsigned>
LiftedOperator::getAllGroupss (ParfactorList& )
{
return { };
}
vector<ParfactorList::iterator>
LiftedOperator::getParfactorsWithGroup (
ParfactorList& pfList, unsigned group)
{
vector<ParfactorList::iterator> iters;
ParfactorList::iterator pflIt = pfList.begin();
while (pflIt != pfList.end()) {
if ((*pflIt)->containsGroup (group)) {
iters.push_back (pflIt);
}
++ pflIt;
}
return iters;
}
double
SumOutOperator::getLogCost (void)
{
TinySet<unsigned> groupSet;
ParfactorList::const_iterator pfIter = pfList_.begin();
unsigned nrProdFactors = 0;
while (pfIter != pfList_.end()) {
if ((*pfIter)->containsGroup (group_)) {
vector<unsigned> groups = (*pfIter)->getAllGroups();
groupSet |= TinySet<unsigned> (groups);
++ nrProdFactors;
}
++ pfIter;
}
if (nrProdFactors == 1) {
return 1.0; // best possible case
}
double cost = 1.0;
for (unsigned i = 0; i < groupSet.size(); i++) {
pfIter = pfList_.begin();
@ -76,8 +106,8 @@ SumOutOperator::getLogCost (void)
void
SumOutOperator::apply (void)
{
vector<ParfactorList::iterator> iters
= parfactorsWithGroup (pfList_, group_);
vector<ParfactorList::iterator> iters;
iters = getParfactorsWithGroup (pfList_, group_);
Parfactor* product = *(iters[0]);
pfList_.remove (iters[0]);
for (unsigned i = 1; i < iters.size(); i++) {
@ -137,13 +167,13 @@ SumOutOperator::toString (void)
{
stringstream ss;
vector<ParfactorList::iterator> pfIters;
pfIters = parfactorsWithGroup (pfList_, group_);
pfIters = getParfactorsWithGroup (pfList_, group_);
int idx = (*pfIters[0])->indexOfGroup (group_);
ProbFormula f = (*pfIters[0])->argument (idx);
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
ss << "sum out " << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")";
ss << " [cost=" << getLogCost() << "]" << endl;
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
return ss.str();
}
@ -156,7 +186,7 @@ SumOutOperator::validOp (
const Grounds& query)
{
vector<ParfactorList::iterator> pfIters;
pfIters = parfactorsWithGroup (pfList, group);
pfIters = getParfactorsWithGroup (pfList, group);
if (isToEliminate (*pfIters[0], group, query) == false) {
return false;
}
@ -186,24 +216,6 @@ SumOutOperator::validOp (
vector<ParfactorList::iterator>
SumOutOperator::parfactorsWithGroup (
ParfactorList& pfList,
unsigned group)
{
vector<ParfactorList::iterator> iters;
ParfactorList::iterator pflIt = pfList.begin();
while (pflIt != pfList.end()) {
if ((*pflIt)->containsGroup (group)) {
iters.push_back (pflIt);
}
++ pflIt;
}
return iters;
}
bool
SumOutOperator::isToEliminate (
Parfactor* g,
@ -295,7 +307,7 @@ CountingOperator::toString (void)
stringstream ss;
ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel();
ss << " [cost=" << getLogCost() << "]" << endl;
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (unsigned i = 0; i < pfs.size(); i++) {
@ -337,19 +349,29 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
double
GroundOperator::getLogCost (void)
{
double cost = 0.0;
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
if (isCountingLv) {
int fIdx = (*pfIter_)->indexOfLogVar (X_);
unsigned currSize = (*pfIter_)->size();
unsigned nrHists = (*pfIter_)->range (fIdx);
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
unsigned range = (*pfIter_)->argument (fIdx).range();
double power = std::log (range) * nrSymbols;
cost = std::log (currSize / nrHists) + power;
} else {
unsigned currSize = (*pfIter_)->size();
cost = std::log ((*pfIter_)->constr()->nrSymbols (X_) * currSize);
double cost = std::log (0.0);
vector<ParfactorList::iterator> pfIters;
pfIters = getParfactorsWithGroup (pfList_, group_);
for (unsigned i = 0; i < pfIters.size(); i++) {
Parfactor* pf = *pfIters[i];
int idx = pf->indexOfGroup (group_);
ProbFormula f = pf->argument (idx);
LogVar X = f.logVars()[lvIndex_];
double pfCost = 0.0;
bool isCountingLv = pf->countedLogVars().contains (X);
if (isCountingLv) {
int fIdx = pf->indexOfLogVar (X);
unsigned currSize = pf->size();
unsigned nrHists = pf->range (fIdx);
unsigned nrSymbols = pf->constr()->getConditionalCount (X);
unsigned range = pf->argument (fIdx).range();
double power = std::log (range) * nrSymbols;
pfCost = std::log (currSize / nrHists) + power;
} else {
unsigned currSize = pf->size();
pfCost = std::log (pf->constr()->nrSymbols (X) * currSize);
}
cost = Util::logSum (cost, pfCost);
}
return cost;
}
@ -359,14 +381,21 @@ GroundOperator::getLogCost (void)
void
GroundOperator::apply (void)
{
bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_);
// TODO if we update the correct groups
// we can skip shattering
ParfactorList::iterator pfIter;
pfIter = getParfactorsWithGroup (pfList_, group_).front();
Parfactor* pf = *pfIter;
int idx = pf->indexOfGroup (group_);
ProbFormula f = pf->argument (idx);
LogVar X = f.logVars()[lvIndex_];
bool countedLv = pf->countedLogVars().contains (X);
pfList_.remove (pfIter);
if (countedLv) {
pf->fullExpand (X_);
pf->fullExpand (X);
pfList_.add (pf);
} else {
ConstraintTrees cts = pf->constr()->ground (X_);
ConstraintTrees cts = pf->constr()->ground (X);
for (unsigned i = 0; i < cts.size(); i++) {
pfList_.add (new Parfactor (pf, cts[i]));
}
@ -380,15 +409,23 @@ vector<GroundOperator*>
GroundOperator::getValidOps (ParfactorList& pfList)
{
vector<GroundOperator*> validOps;
ParfactorList::iterator pfIter = pfList.begin();
while (pfIter != pfList.end()) {
LogVarSet set = (*pfIter)->logVarSet();
for (unsigned i = 0; i < set.size(); i++) {
if ((*pfIter)->constr()->isSingleton (set[i]) == false) {
validOps.push_back (new GroundOperator (pfIter, set[i], pfList));
set<unsigned> allGroups;
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments();
for (unsigned i = 0; i < formulas.size(); i++) {
if (Util::contains (allGroups, formulas[i].group()) == false) {
const LogVars& lvs = formulas[i].logVars();
for (unsigned j = 0; j < lvs.size(); j++) {
if ((*it)->constr()->isSingleton (lvs[j]) == false) {
validOps.push_back (new GroundOperator (
formulas[i].group(), j, pfList));
}
}
allGroups.insert (formulas[i].group());
}
}
++ pfIter;
++ it;
}
return validOps;
}
@ -399,11 +436,25 @@ string
GroundOperator::toString (void)
{
stringstream ss;
((*pfIter_)->countedLogVars().contains (X_))
? ss << "full expanding "
: ss << "grounding " ;
ss << X_ << " in " << (*pfIter_)->getLabel();
ss << " [cost=" << getLogCost() << "]" << endl;
vector<ParfactorList::iterator> pfIters;
pfIters = getParfactorsWithGroup (pfList_, group_);
Parfactor* pf = *(getParfactorsWithGroup (pfList_, group_).front());
int idx = pf->indexOfGroup (group_);
ProbFormula f = pf->argument (idx);
LogVar lv = f.logVars()[lvIndex_];
TupleSet tupleSet = pf->constr()->tupleSet ({lv});
string pos = "th";
if (lvIndex_ == 0) {
pos = "st" ;
} else if (lvIndex_ == 1) {
pos = "nd" ;
} else if (lvIndex_ == 2) {
pos = "rd" ;
}
ss << "grounding " << lvIndex_ + 1 << pos << " log var in " ;
ss << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")";
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
return ss.str();
}

View File

@ -18,6 +18,11 @@ class LiftedOperator
ParfactorList&, const Grounds&);
static void printValidOps (ParfactorList&, const Grounds&);
static vector<unsigned> getAllGroupss (ParfactorList&);
static vector<ParfactorList::iterator> getParfactorsWithGroup (
ParfactorList&, unsigned group);
};
@ -40,9 +45,6 @@ class SumOutOperator : public LiftedOperator
private:
static bool validOp (unsigned, ParfactorList&, const Grounds&);
static vector<ParfactorList::iterator> parfactorsWithGroup (
ParfactorList& pfList, unsigned group);
static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
unsigned group_;
@ -82,10 +84,10 @@ class GroundOperator : public LiftedOperator
{
public:
GroundOperator (
ParfactorList::iterator pfIter,
LogVar X,
unsigned group,
unsigned lvIndex,
ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
double getLogCost (void);
@ -96,8 +98,8 @@ class GroundOperator : public LiftedOperator
string toString (void);
private:
ParfactorList::iterator pfIter_;
LogVar X_;
unsigned group_;
unsigned lvIndex_;
ParfactorList& pfList_;
};