use logarithms to calculate the cost of each fove operator
This commit is contained in:
parent
406276b62b
commit
61ee95d92a
@ -44,8 +44,8 @@ LiftedOperator::printValidOps (
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
SumOutOperator::getCost (void)
|
||||
double
|
||||
SumOutOperator::getLogCost (void)
|
||||
{
|
||||
TinySet<unsigned> groupSet;
|
||||
ParfactorList::const_iterator pfIter = pfList_.begin();
|
||||
@ -56,7 +56,7 @@ SumOutOperator::getCost (void)
|
||||
}
|
||||
++ pfIter;
|
||||
}
|
||||
unsigned cost = 1;
|
||||
double cost = 1.0;
|
||||
for (unsigned i = 0; i < groupSet.size(); i++) {
|
||||
pfIter = pfList_.begin();
|
||||
while (pfIter != pfList_.end()) {
|
||||
@ -68,7 +68,7 @@ SumOutOperator::getCost (void)
|
||||
++ pfIter;
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
return std::log (cost);
|
||||
}
|
||||
|
||||
|
||||
@ -143,7 +143,7 @@ SumOutOperator::toString (void)
|
||||
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
||||
ss << "sum out " << f.functor() << "/" << f.arity();
|
||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||
ss << " [cost=" << getCost() << "]" << endl;
|
||||
ss << " [cost=" << getLogCost() << "]" << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -228,10 +228,10 @@ SumOutOperator::isToEliminate (
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
CountingOperator::getCost (void)
|
||||
double
|
||||
CountingOperator::getLogCost (void)
|
||||
{
|
||||
unsigned cost = 0;
|
||||
double cost = 0.0;
|
||||
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||
unsigned range = (*pfIter_)->range (fIdx);
|
||||
unsigned size = (*pfIter_)->size() / range;
|
||||
@ -240,7 +240,7 @@ CountingOperator::getCost (void)
|
||||
for (unsigned i = 0; i < counts.size(); i++) {
|
||||
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
||||
}
|
||||
return cost;
|
||||
return std::log (cost);
|
||||
}
|
||||
|
||||
|
||||
@ -295,7 +295,7 @@ CountingOperator::toString (void)
|
||||
stringstream ss;
|
||||
ss << "count convert " << X_ << " in " ;
|
||||
ss << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << getCost() << "]" << endl;
|
||||
ss << " [cost=" << getLogCost() << "]" << endl;
|
||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
@ -334,20 +334,22 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
GroundOperator::getCost (void)
|
||||
double
|
||||
GroundOperator::getLogCost (void)
|
||||
{
|
||||
unsigned cost = 0;
|
||||
double cost = 0.0;
|
||||
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||
if (isCountingLv) {
|
||||
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||
unsigned currSize = (*pfIter_)->size();
|
||||
unsigned nrHists = (*pfIter_)->range (fIdx);
|
||||
unsigned range = (*pfIter_)->argument (fIdx).range();
|
||||
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
|
||||
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
|
||||
unsigned range = (*pfIter_)->argument (fIdx).range();
|
||||
double power = std::log (range) * nrSymbols;
|
||||
cost = std::log (currSize / nrHists) + power;
|
||||
} else {
|
||||
cost = (*pfIter_)->constr()->nrSymbols (X_) * (*pfIter_)->size();
|
||||
unsigned currSize = (*pfIter_)->size();
|
||||
cost = std::log ((*pfIter_)->constr()->nrSymbols (X_) * currSize);
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
@ -401,7 +403,7 @@ GroundOperator::toString (void)
|
||||
? ss << "full expanding "
|
||||
: ss << "grounding " ;
|
||||
ss << X_ << " in " << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << getCost() << "]" << endl;
|
||||
ss << " [cost=" << getLogCost() << "]" << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -525,12 +527,12 @@ FoveSolver::runSolver (const Grounds& query)
|
||||
LiftedOperator*
|
||||
FoveSolver::getBestOperation (const Grounds& query)
|
||||
{
|
||||
unsigned bestCost = Util::maxUnsigned();
|
||||
double bestCost = 0.0;
|
||||
LiftedOperator* bestOp = 0;
|
||||
vector<LiftedOperator*> validOps;
|
||||
validOps = LiftedOperator::getValidOps (pfList_, query);
|
||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||
unsigned cost = validOps[i]->getCost();
|
||||
double cost = validOps[i]->getLogCost();
|
||||
if ((bestOp == 0) || (cost < bestCost)) {
|
||||
bestOp = validOps[i];
|
||||
bestCost = cost;
|
||||
|
@ -8,7 +8,7 @@
|
||||
class LiftedOperator
|
||||
{
|
||||
public:
|
||||
virtual unsigned getCost (void) = 0;
|
||||
virtual double getLogCost (void) = 0;
|
||||
|
||||
virtual void apply (void) = 0;
|
||||
|
||||
@ -28,7 +28,7 @@ class SumOutOperator : public LiftedOperator
|
||||
SumOutOperator (unsigned group, ParfactorList& pfList)
|
||||
: group_(group), pfList_(pfList) { }
|
||||
|
||||
unsigned getCost (void);
|
||||
double getLogCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
@ -60,7 +60,7 @@ class CountingOperator : public LiftedOperator
|
||||
ParfactorList& pfList)
|
||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||
|
||||
unsigned getCost (void);
|
||||
double getLogCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
@ -87,7 +87,7 @@ class GroundOperator : public LiftedOperator
|
||||
ParfactorList& pfList)
|
||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||
|
||||
unsigned getCost (void);
|
||||
double getLogCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
|
Reference in New Issue
Block a user