use logarithms to calculate the cost of each fove operator
This commit is contained in:
parent
406276b62b
commit
61ee95d92a
@ -44,8 +44,8 @@ LiftedOperator::printValidOps (
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
double
|
||||||
SumOutOperator::getCost (void)
|
SumOutOperator::getLogCost (void)
|
||||||
{
|
{
|
||||||
TinySet<unsigned> groupSet;
|
TinySet<unsigned> groupSet;
|
||||||
ParfactorList::const_iterator pfIter = pfList_.begin();
|
ParfactorList::const_iterator pfIter = pfList_.begin();
|
||||||
@ -56,7 +56,7 @@ SumOutOperator::getCost (void)
|
|||||||
}
|
}
|
||||||
++ pfIter;
|
++ pfIter;
|
||||||
}
|
}
|
||||||
unsigned cost = 1;
|
double cost = 1.0;
|
||||||
for (unsigned i = 0; i < groupSet.size(); i++) {
|
for (unsigned i = 0; i < groupSet.size(); i++) {
|
||||||
pfIter = pfList_.begin();
|
pfIter = pfList_.begin();
|
||||||
while (pfIter != pfList_.end()) {
|
while (pfIter != pfList_.end()) {
|
||||||
@ -68,7 +68,7 @@ SumOutOperator::getCost (void)
|
|||||||
++ pfIter;
|
++ pfIter;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return cost;
|
return std::log (cost);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -143,7 +143,7 @@ SumOutOperator::toString (void)
|
|||||||
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
||||||
ss << "sum out " << f.functor() << "/" << f.arity();
|
ss << "sum out " << f.functor() << "/" << f.arity();
|
||||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||||
ss << " [cost=" << getCost() << "]" << endl;
|
ss << " [cost=" << getLogCost() << "]" << endl;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,10 +228,10 @@ SumOutOperator::isToEliminate (
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
double
|
||||||
CountingOperator::getCost (void)
|
CountingOperator::getLogCost (void)
|
||||||
{
|
{
|
||||||
unsigned cost = 0;
|
double cost = 0.0;
|
||||||
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||||
unsigned range = (*pfIter_)->range (fIdx);
|
unsigned range = (*pfIter_)->range (fIdx);
|
||||||
unsigned size = (*pfIter_)->size() / range;
|
unsigned size = (*pfIter_)->size() / range;
|
||||||
@ -240,7 +240,7 @@ CountingOperator::getCost (void)
|
|||||||
for (unsigned i = 0; i < counts.size(); i++) {
|
for (unsigned i = 0; i < counts.size(); i++) {
|
||||||
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
||||||
}
|
}
|
||||||
return cost;
|
return std::log (cost);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -295,7 +295,7 @@ CountingOperator::toString (void)
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "count convert " << X_ << " in " ;
|
ss << "count convert " << X_ << " in " ;
|
||||||
ss << (*pfIter_)->getLabel();
|
ss << (*pfIter_)->getLabel();
|
||||||
ss << " [cost=" << getCost() << "]" << endl;
|
ss << " [cost=" << getLogCost() << "]" << endl;
|
||||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
||||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||||
@ -334,20 +334,22 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
double
|
||||||
GroundOperator::getCost (void)
|
GroundOperator::getLogCost (void)
|
||||||
{
|
{
|
||||||
unsigned cost = 0;
|
double cost = 0.0;
|
||||||
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
|
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||||
if (isCountingLv) {
|
if (isCountingLv) {
|
||||||
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||||
unsigned currSize = (*pfIter_)->size();
|
unsigned currSize = (*pfIter_)->size();
|
||||||
unsigned nrHists = (*pfIter_)->range (fIdx);
|
unsigned nrHists = (*pfIter_)->range (fIdx);
|
||||||
unsigned range = (*pfIter_)->argument (fIdx).range();
|
|
||||||
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
|
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
|
||||||
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
|
unsigned range = (*pfIter_)->argument (fIdx).range();
|
||||||
|
double power = std::log (range) * nrSymbols;
|
||||||
|
cost = std::log (currSize / nrHists) + power;
|
||||||
} else {
|
} else {
|
||||||
cost = (*pfIter_)->constr()->nrSymbols (X_) * (*pfIter_)->size();
|
unsigned currSize = (*pfIter_)->size();
|
||||||
|
cost = std::log ((*pfIter_)->constr()->nrSymbols (X_) * currSize);
|
||||||
}
|
}
|
||||||
return cost;
|
return cost;
|
||||||
}
|
}
|
||||||
@ -401,7 +403,7 @@ GroundOperator::toString (void)
|
|||||||
? ss << "full expanding "
|
? ss << "full expanding "
|
||||||
: ss << "grounding " ;
|
: ss << "grounding " ;
|
||||||
ss << X_ << " in " << (*pfIter_)->getLabel();
|
ss << X_ << " in " << (*pfIter_)->getLabel();
|
||||||
ss << " [cost=" << getCost() << "]" << endl;
|
ss << " [cost=" << getLogCost() << "]" << endl;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -525,12 +527,12 @@ FoveSolver::runSolver (const Grounds& query)
|
|||||||
LiftedOperator*
|
LiftedOperator*
|
||||||
FoveSolver::getBestOperation (const Grounds& query)
|
FoveSolver::getBestOperation (const Grounds& query)
|
||||||
{
|
{
|
||||||
unsigned bestCost = Util::maxUnsigned();
|
double bestCost = 0.0;
|
||||||
LiftedOperator* bestOp = 0;
|
LiftedOperator* bestOp = 0;
|
||||||
vector<LiftedOperator*> validOps;
|
vector<LiftedOperator*> validOps;
|
||||||
validOps = LiftedOperator::getValidOps (pfList_, query);
|
validOps = LiftedOperator::getValidOps (pfList_, query);
|
||||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||||
unsigned cost = validOps[i]->getCost();
|
double cost = validOps[i]->getLogCost();
|
||||||
if ((bestOp == 0) || (cost < bestCost)) {
|
if ((bestOp == 0) || (cost < bestCost)) {
|
||||||
bestOp = validOps[i];
|
bestOp = validOps[i];
|
||||||
bestCost = cost;
|
bestCost = cost;
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
class LiftedOperator
|
class LiftedOperator
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
virtual unsigned getCost (void) = 0;
|
virtual double getLogCost (void) = 0;
|
||||||
|
|
||||||
virtual void apply (void) = 0;
|
virtual void apply (void) = 0;
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ class SumOutOperator : public LiftedOperator
|
|||||||
SumOutOperator (unsigned group, ParfactorList& pfList)
|
SumOutOperator (unsigned group, ParfactorList& pfList)
|
||||||
: group_(group), pfList_(pfList) { }
|
: group_(group), pfList_(pfList) { }
|
||||||
|
|
||||||
unsigned getCost (void);
|
double getLogCost (void);
|
||||||
|
|
||||||
void apply (void);
|
void apply (void);
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ class CountingOperator : public LiftedOperator
|
|||||||
ParfactorList& pfList)
|
ParfactorList& pfList)
|
||||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||||
|
|
||||||
unsigned getCost (void);
|
double getLogCost (void);
|
||||||
|
|
||||||
void apply (void);
|
void apply (void);
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ class GroundOperator : public LiftedOperator
|
|||||||
ParfactorList& pfList)
|
ParfactorList& pfList)
|
||||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||||
|
|
||||||
unsigned getCost (void);
|
double getLogCost (void);
|
||||||
|
|
||||||
void apply (void);
|
void apply (void);
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user