use logarithms to calculate the cost of each fove operator

This commit is contained in:
Tiago Gomes 2012-04-18 16:40:12 +01:00
parent 406276b62b
commit 61ee95d92a
2 changed files with 25 additions and 23 deletions

View File

@ -44,8 +44,8 @@ LiftedOperator::printValidOps (
unsigned double
SumOutOperator::getCost (void) SumOutOperator::getLogCost (void)
{ {
TinySet<unsigned> groupSet; TinySet<unsigned> groupSet;
ParfactorList::const_iterator pfIter = pfList_.begin(); ParfactorList::const_iterator pfIter = pfList_.begin();
@ -56,7 +56,7 @@ SumOutOperator::getCost (void)
} }
++ pfIter; ++ pfIter;
} }
unsigned cost = 1; double cost = 1.0;
for (unsigned i = 0; i < groupSet.size(); i++) { for (unsigned i = 0; i < groupSet.size(); i++) {
pfIter = pfList_.begin(); pfIter = pfList_.begin();
while (pfIter != pfList_.end()) { while (pfIter != pfList_.end()) {
@ -68,7 +68,7 @@ SumOutOperator::getCost (void)
++ pfIter; ++ pfIter;
} }
} }
return cost; return std::log (cost);
} }
@ -143,7 +143,7 @@ SumOutOperator::toString (void)
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars()); TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
ss << "sum out " << f.functor() << "/" << f.arity(); ss << "sum out " << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")"; ss << "|" << tupleSet << " (group " << group_ << ")";
ss << " [cost=" << getCost() << "]" << endl; ss << " [cost=" << getLogCost() << "]" << endl;
return ss.str(); return ss.str();
} }
@ -228,10 +228,10 @@ SumOutOperator::isToEliminate (
unsigned double
CountingOperator::getCost (void) CountingOperator::getLogCost (void)
{ {
unsigned cost = 0; double cost = 0.0;
int fIdx = (*pfIter_)->indexOfLogVar (X_); int fIdx = (*pfIter_)->indexOfLogVar (X_);
unsigned range = (*pfIter_)->range (fIdx); unsigned range = (*pfIter_)->range (fIdx);
unsigned size = (*pfIter_)->size() / range; unsigned size = (*pfIter_)->size() / range;
@ -240,7 +240,7 @@ CountingOperator::getCost (void)
for (unsigned i = 0; i < counts.size(); i++) { for (unsigned i = 0; i < counts.size(); i++) {
cost += size * HistogramSet::nrHistograms (counts[i], range); cost += size * HistogramSet::nrHistograms (counts[i], range);
} }
return cost; return std::log (cost);
} }
@ -295,7 +295,7 @@ CountingOperator::toString (void)
stringstream ss; stringstream ss;
ss << "count convert " << X_ << " in " ; ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel(); ss << (*pfIter_)->getLabel();
ss << " [cost=" << getCost() << "]" << endl; ss << " [cost=" << getLogCost() << "]" << endl;
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_); Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) { if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (unsigned i = 0; i < pfs.size(); i++) { for (unsigned i = 0; i < pfs.size(); i++) {
@ -334,20 +334,22 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
unsigned double
GroundOperator::getCost (void) GroundOperator::getLogCost (void)
{ {
unsigned cost = 0; double cost = 0.0;
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_); bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
if (isCountingLv) { if (isCountingLv) {
int fIdx = (*pfIter_)->indexOfLogVar (X_); int fIdx = (*pfIter_)->indexOfLogVar (X_);
unsigned currSize = (*pfIter_)->size(); unsigned currSize = (*pfIter_)->size();
unsigned nrHists = (*pfIter_)->range (fIdx); unsigned nrHists = (*pfIter_)->range (fIdx);
unsigned range = (*pfIter_)->argument (fIdx).range();
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_); unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
cost = (currSize / nrHists) * (std::pow (range, nrSymbols)); unsigned range = (*pfIter_)->argument (fIdx).range();
double power = std::log (range) * nrSymbols;
cost = std::log (currSize / nrHists) + power;
} else { } else {
cost = (*pfIter_)->constr()->nrSymbols (X_) * (*pfIter_)->size(); unsigned currSize = (*pfIter_)->size();
cost = std::log ((*pfIter_)->constr()->nrSymbols (X_) * currSize);
} }
return cost; return cost;
} }
@ -401,7 +403,7 @@ GroundOperator::toString (void)
? ss << "full expanding " ? ss << "full expanding "
: ss << "grounding " ; : ss << "grounding " ;
ss << X_ << " in " << (*pfIter_)->getLabel(); ss << X_ << " in " << (*pfIter_)->getLabel();
ss << " [cost=" << getCost() << "]" << endl; ss << " [cost=" << getLogCost() << "]" << endl;
return ss.str(); return ss.str();
} }
@ -525,12 +527,12 @@ FoveSolver::runSolver (const Grounds& query)
LiftedOperator* LiftedOperator*
FoveSolver::getBestOperation (const Grounds& query) FoveSolver::getBestOperation (const Grounds& query)
{ {
unsigned bestCost = Util::maxUnsigned(); double bestCost = 0.0;
LiftedOperator* bestOp = 0; LiftedOperator* bestOp = 0;
vector<LiftedOperator*> validOps; vector<LiftedOperator*> validOps;
validOps = LiftedOperator::getValidOps (pfList_, query); validOps = LiftedOperator::getValidOps (pfList_, query);
for (unsigned i = 0; i < validOps.size(); i++) { for (unsigned i = 0; i < validOps.size(); i++) {
unsigned cost = validOps[i]->getCost(); double cost = validOps[i]->getLogCost();
if ((bestOp == 0) || (cost < bestCost)) { if ((bestOp == 0) || (cost < bestCost)) {
bestOp = validOps[i]; bestOp = validOps[i];
bestCost = cost; bestCost = cost;

View File

@ -8,7 +8,7 @@
class LiftedOperator class LiftedOperator
{ {
public: public:
virtual unsigned getCost (void) = 0; virtual double getLogCost (void) = 0;
virtual void apply (void) = 0; virtual void apply (void) = 0;
@ -28,7 +28,7 @@ class SumOutOperator : public LiftedOperator
SumOutOperator (unsigned group, ParfactorList& pfList) SumOutOperator (unsigned group, ParfactorList& pfList)
: group_(group), pfList_(pfList) { } : group_(group), pfList_(pfList) { }
unsigned getCost (void); double getLogCost (void);
void apply (void); void apply (void);
@ -60,7 +60,7 @@ class CountingOperator : public LiftedOperator
ParfactorList& pfList) ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { } : pfIter_(pfIter), X_(X), pfList_(pfList) { }
unsigned getCost (void); double getLogCost (void);
void apply (void); void apply (void);
@ -87,7 +87,7 @@ class GroundOperator : public LiftedOperator
ParfactorList& pfList) ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { } : pfIter_(pfIter), X_(X), pfList_(pfList) { }
unsigned getCost (void); double getLogCost (void);
void apply (void); void apply (void);