use logarithms to calculate the cost of each fove operator

This commit is contained in:
Tiago Gomes 2012-04-18 16:40:12 +01:00
parent 406276b62b
commit 61ee95d92a
2 changed files with 25 additions and 23 deletions

View File

@ -44,8 +44,8 @@ LiftedOperator::printValidOps (
unsigned
SumOutOperator::getCost (void)
double
SumOutOperator::getLogCost (void)
{
TinySet<unsigned> groupSet;
ParfactorList::const_iterator pfIter = pfList_.begin();
@ -56,7 +56,7 @@ SumOutOperator::getCost (void)
}
++ pfIter;
}
unsigned cost = 1;
double cost = 1.0;
for (unsigned i = 0; i < groupSet.size(); i++) {
pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
@ -68,7 +68,7 @@ SumOutOperator::getCost (void)
++ pfIter;
}
}
return cost;
return std::log (cost);
}
@ -143,7 +143,7 @@ SumOutOperator::toString (void)
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
ss << "sum out " << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")";
ss << " [cost=" << getCost() << "]" << endl;
ss << " [cost=" << getLogCost() << "]" << endl;
return ss.str();
}
@ -228,10 +228,10 @@ SumOutOperator::isToEliminate (
unsigned
CountingOperator::getCost (void)
double
CountingOperator::getLogCost (void)
{
unsigned cost = 0;
double cost = 0.0;
int fIdx = (*pfIter_)->indexOfLogVar (X_);
unsigned range = (*pfIter_)->range (fIdx);
unsigned size = (*pfIter_)->size() / range;
@ -240,7 +240,7 @@ CountingOperator::getCost (void)
for (unsigned i = 0; i < counts.size(); i++) {
cost += size * HistogramSet::nrHistograms (counts[i], range);
}
return cost;
return std::log (cost);
}
@ -295,7 +295,7 @@ CountingOperator::toString (void)
stringstream ss;
ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel();
ss << " [cost=" << getCost() << "]" << endl;
ss << " [cost=" << getLogCost() << "]" << endl;
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (unsigned i = 0; i < pfs.size(); i++) {
@ -334,20 +334,22 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
unsigned
GroundOperator::getCost (void)
double
GroundOperator::getLogCost (void)
{
unsigned cost = 0;
double cost = 0.0;
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
if (isCountingLv) {
int fIdx = (*pfIter_)->indexOfLogVar (X_);
unsigned currSize = (*pfIter_)->size();
unsigned nrHists = (*pfIter_)->range (fIdx);
unsigned range = (*pfIter_)->argument (fIdx).range();
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
unsigned range = (*pfIter_)->argument (fIdx).range();
double power = std::log (range) * nrSymbols;
cost = std::log (currSize / nrHists) + power;
} else {
cost = (*pfIter_)->constr()->nrSymbols (X_) * (*pfIter_)->size();
unsigned currSize = (*pfIter_)->size();
cost = std::log ((*pfIter_)->constr()->nrSymbols (X_) * currSize);
}
return cost;
}
@ -401,7 +403,7 @@ GroundOperator::toString (void)
? ss << "full expanding "
: ss << "grounding " ;
ss << X_ << " in " << (*pfIter_)->getLabel();
ss << " [cost=" << getCost() << "]" << endl;
ss << " [cost=" << getLogCost() << "]" << endl;
return ss.str();
}
@ -525,12 +527,12 @@ FoveSolver::runSolver (const Grounds& query)
LiftedOperator*
FoveSolver::getBestOperation (const Grounds& query)
{
unsigned bestCost = Util::maxUnsigned();
double bestCost = 0.0;
LiftedOperator* bestOp = 0;
vector<LiftedOperator*> validOps;
validOps = LiftedOperator::getValidOps (pfList_, query);
for (unsigned i = 0; i < validOps.size(); i++) {
unsigned cost = validOps[i]->getCost();
double cost = validOps[i]->getLogCost();
if ((bestOp == 0) || (cost < bestCost)) {
bestOp = validOps[i];
bestCost = cost;

View File

@ -8,7 +8,7 @@
class LiftedOperator
{
public:
virtual unsigned getCost (void) = 0;
virtual double getLogCost (void) = 0;
virtual void apply (void) = 0;
@ -28,7 +28,7 @@ class SumOutOperator : public LiftedOperator
SumOutOperator (unsigned group, ParfactorList& pfList)
: group_(group), pfList_(pfList) { }
unsigned getCost (void);
double getLogCost (void);
void apply (void);
@ -60,7 +60,7 @@ class CountingOperator : public LiftedOperator
ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
unsigned getCost (void);
double getLogCost (void);
void apply (void);
@ -87,7 +87,7 @@ class GroundOperator : public LiftedOperator
ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
unsigned getCost (void);
double getLogCost (void);
void apply (void);