refactor the way we calculate the grounding cost
This commit is contained in:
parent
661ce08961
commit
2b7da4bc23
@ -37,25 +37,55 @@ LiftedOperator::printValidOps (
|
||||
vector<LiftedOperator*> validOps;
|
||||
validOps = LiftedOperator::getValidOps (pfList, query);
|
||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||
cout << "-> " << validOps[i]->toString() << endl;
|
||||
cout << "-> " << validOps[i]->toString();
|
||||
delete validOps[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<unsigned>
|
||||
LiftedOperator::getAllGroupss (ParfactorList& )
|
||||
{
|
||||
return { };
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<ParfactorList::iterator>
|
||||
LiftedOperator::getParfactorsWithGroup (
|
||||
ParfactorList& pfList, unsigned group)
|
||||
{
|
||||
vector<ParfactorList::iterator> iters;
|
||||
ParfactorList::iterator pflIt = pfList.begin();
|
||||
while (pflIt != pfList.end()) {
|
||||
if ((*pflIt)->containsGroup (group)) {
|
||||
iters.push_back (pflIt);
|
||||
}
|
||||
++ pflIt;
|
||||
}
|
||||
return iters;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
SumOutOperator::getLogCost (void)
|
||||
{
|
||||
TinySet<unsigned> groupSet;
|
||||
ParfactorList::const_iterator pfIter = pfList_.begin();
|
||||
unsigned nrProdFactors = 0;
|
||||
while (pfIter != pfList_.end()) {
|
||||
if ((*pfIter)->containsGroup (group_)) {
|
||||
vector<unsigned> groups = (*pfIter)->getAllGroups();
|
||||
groupSet |= TinySet<unsigned> (groups);
|
||||
++ nrProdFactors;
|
||||
}
|
||||
++ pfIter;
|
||||
}
|
||||
if (nrProdFactors == 1) {
|
||||
return 1.0; // best possible case
|
||||
}
|
||||
double cost = 1.0;
|
||||
for (unsigned i = 0; i < groupSet.size(); i++) {
|
||||
pfIter = pfList_.begin();
|
||||
@ -76,8 +106,8 @@ SumOutOperator::getLogCost (void)
|
||||
void
|
||||
SumOutOperator::apply (void)
|
||||
{
|
||||
vector<ParfactorList::iterator> iters
|
||||
= parfactorsWithGroup (pfList_, group_);
|
||||
vector<ParfactorList::iterator> iters;
|
||||
iters = getParfactorsWithGroup (pfList_, group_);
|
||||
Parfactor* product = *(iters[0]);
|
||||
pfList_.remove (iters[0]);
|
||||
for (unsigned i = 1; i < iters.size(); i++) {
|
||||
@ -137,13 +167,13 @@ SumOutOperator::toString (void)
|
||||
{
|
||||
stringstream ss;
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = parfactorsWithGroup (pfList_, group_);
|
||||
pfIters = getParfactorsWithGroup (pfList_, group_);
|
||||
int idx = (*pfIters[0])->indexOfGroup (group_);
|
||||
ProbFormula f = (*pfIters[0])->argument (idx);
|
||||
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
||||
ss << "sum out " << f.functor() << "/" << f.arity();
|
||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||
ss << " [cost=" << getLogCost() << "]" << endl;
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -156,7 +186,7 @@ SumOutOperator::validOp (
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = parfactorsWithGroup (pfList, group);
|
||||
pfIters = getParfactorsWithGroup (pfList, group);
|
||||
if (isToEliminate (*pfIters[0], group, query) == false) {
|
||||
return false;
|
||||
}
|
||||
@ -186,24 +216,6 @@ SumOutOperator::validOp (
|
||||
|
||||
|
||||
|
||||
vector<ParfactorList::iterator>
|
||||
SumOutOperator::parfactorsWithGroup (
|
||||
ParfactorList& pfList,
|
||||
unsigned group)
|
||||
{
|
||||
vector<ParfactorList::iterator> iters;
|
||||
ParfactorList::iterator pflIt = pfList.begin();
|
||||
while (pflIt != pfList.end()) {
|
||||
if ((*pflIt)->containsGroup (group)) {
|
||||
iters.push_back (pflIt);
|
||||
}
|
||||
++ pflIt;
|
||||
}
|
||||
return iters;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
SumOutOperator::isToEliminate (
|
||||
Parfactor* g,
|
||||
@ -295,7 +307,7 @@ CountingOperator::toString (void)
|
||||
stringstream ss;
|
||||
ss << "count convert " << X_ << " in " ;
|
||||
ss << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << getLogCost() << "]" << endl;
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
@ -337,19 +349,29 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
||||
double
|
||||
GroundOperator::getLogCost (void)
|
||||
{
|
||||
double cost = 0.0;
|
||||
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||
if (isCountingLv) {
|
||||
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||
unsigned currSize = (*pfIter_)->size();
|
||||
unsigned nrHists = (*pfIter_)->range (fIdx);
|
||||
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
|
||||
unsigned range = (*pfIter_)->argument (fIdx).range();
|
||||
double power = std::log (range) * nrSymbols;
|
||||
cost = std::log (currSize / nrHists) + power;
|
||||
} else {
|
||||
unsigned currSize = (*pfIter_)->size();
|
||||
cost = std::log ((*pfIter_)->constr()->nrSymbols (X_) * currSize);
|
||||
double cost = std::log (0.0);
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = getParfactorsWithGroup (pfList_, group_);
|
||||
for (unsigned i = 0; i < pfIters.size(); i++) {
|
||||
Parfactor* pf = *pfIters[i];
|
||||
int idx = pf->indexOfGroup (group_);
|
||||
ProbFormula f = pf->argument (idx);
|
||||
LogVar X = f.logVars()[lvIndex_];
|
||||
double pfCost = 0.0;
|
||||
bool isCountingLv = pf->countedLogVars().contains (X);
|
||||
if (isCountingLv) {
|
||||
int fIdx = pf->indexOfLogVar (X);
|
||||
unsigned currSize = pf->size();
|
||||
unsigned nrHists = pf->range (fIdx);
|
||||
unsigned nrSymbols = pf->constr()->getConditionalCount (X);
|
||||
unsigned range = pf->argument (fIdx).range();
|
||||
double power = std::log (range) * nrSymbols;
|
||||
pfCost = std::log (currSize / nrHists) + power;
|
||||
} else {
|
||||
unsigned currSize = pf->size();
|
||||
pfCost = std::log (pf->constr()->nrSymbols (X) * currSize);
|
||||
}
|
||||
cost = Util::logSum (cost, pfCost);
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
@ -359,14 +381,21 @@ GroundOperator::getLogCost (void)
|
||||
void
|
||||
GroundOperator::apply (void)
|
||||
{
|
||||
bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||
Parfactor* pf = *pfIter_;
|
||||
pfList_.remove (pfIter_);
|
||||
// TODO if we update the correct groups
|
||||
// we can skip shattering
|
||||
ParfactorList::iterator pfIter;
|
||||
pfIter = getParfactorsWithGroup (pfList_, group_).front();
|
||||
Parfactor* pf = *pfIter;
|
||||
int idx = pf->indexOfGroup (group_);
|
||||
ProbFormula f = pf->argument (idx);
|
||||
LogVar X = f.logVars()[lvIndex_];
|
||||
bool countedLv = pf->countedLogVars().contains (X);
|
||||
pfList_.remove (pfIter);
|
||||
if (countedLv) {
|
||||
pf->fullExpand (X_);
|
||||
pf->fullExpand (X);
|
||||
pfList_.add (pf);
|
||||
} else {
|
||||
ConstraintTrees cts = pf->constr()->ground (X_);
|
||||
ConstraintTrees cts = pf->constr()->ground (X);
|
||||
for (unsigned i = 0; i < cts.size(); i++) {
|
||||
pfList_.add (new Parfactor (pf, cts[i]));
|
||||
}
|
||||
@ -380,15 +409,23 @@ vector<GroundOperator*>
|
||||
GroundOperator::getValidOps (ParfactorList& pfList)
|
||||
{
|
||||
vector<GroundOperator*> validOps;
|
||||
ParfactorList::iterator pfIter = pfList.begin();
|
||||
while (pfIter != pfList.end()) {
|
||||
LogVarSet set = (*pfIter)->logVarSet();
|
||||
for (unsigned i = 0; i < set.size(); i++) {
|
||||
if ((*pfIter)->constr()->isSingleton (set[i]) == false) {
|
||||
validOps.push_back (new GroundOperator (pfIter, set[i], pfList));
|
||||
set<unsigned> allGroups;
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||
if (Util::contains (allGroups, formulas[i].group()) == false) {
|
||||
const LogVars& lvs = formulas[i].logVars();
|
||||
for (unsigned j = 0; j < lvs.size(); j++) {
|
||||
if ((*it)->constr()->isSingleton (lvs[j]) == false) {
|
||||
validOps.push_back (new GroundOperator (
|
||||
formulas[i].group(), j, pfList));
|
||||
}
|
||||
}
|
||||
allGroups.insert (formulas[i].group());
|
||||
}
|
||||
}
|
||||
++ pfIter;
|
||||
++ it;
|
||||
}
|
||||
return validOps;
|
||||
}
|
||||
@ -399,11 +436,25 @@ string
|
||||
GroundOperator::toString (void)
|
||||
{
|
||||
stringstream ss;
|
||||
((*pfIter_)->countedLogVars().contains (X_))
|
||||
? ss << "full expanding "
|
||||
: ss << "grounding " ;
|
||||
ss << X_ << " in " << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << getLogCost() << "]" << endl;
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = getParfactorsWithGroup (pfList_, group_);
|
||||
Parfactor* pf = *(getParfactorsWithGroup (pfList_, group_).front());
|
||||
int idx = pf->indexOfGroup (group_);
|
||||
ProbFormula f = pf->argument (idx);
|
||||
LogVar lv = f.logVars()[lvIndex_];
|
||||
TupleSet tupleSet = pf->constr()->tupleSet ({lv});
|
||||
string pos = "th";
|
||||
if (lvIndex_ == 0) {
|
||||
pos = "st" ;
|
||||
} else if (lvIndex_ == 1) {
|
||||
pos = "nd" ;
|
||||
} else if (lvIndex_ == 2) {
|
||||
pos = "rd" ;
|
||||
}
|
||||
ss << "grounding " << lvIndex_ + 1 << pos << " log var in " ;
|
||||
ss << f.functor() << "/" << f.arity();
|
||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
@ -18,6 +18,11 @@ class LiftedOperator
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
static void printValidOps (ParfactorList&, const Grounds&);
|
||||
|
||||
static vector<unsigned> getAllGroupss (ParfactorList&);
|
||||
|
||||
static vector<ParfactorList::iterator> getParfactorsWithGroup (
|
||||
ParfactorList&, unsigned group);
|
||||
};
|
||||
|
||||
|
||||
@ -40,9 +45,6 @@ class SumOutOperator : public LiftedOperator
|
||||
private:
|
||||
static bool validOp (unsigned, ParfactorList&, const Grounds&);
|
||||
|
||||
static vector<ParfactorList::iterator> parfactorsWithGroup (
|
||||
ParfactorList& pfList, unsigned group);
|
||||
|
||||
static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
|
||||
|
||||
unsigned group_;
|
||||
@ -82,10 +84,10 @@ class GroundOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
GroundOperator (
|
||||
ParfactorList::iterator pfIter,
|
||||
LogVar X,
|
||||
unsigned group,
|
||||
unsigned lvIndex,
|
||||
ParfactorList& pfList)
|
||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
|
||||
|
||||
double getLogCost (void);
|
||||
|
||||
@ -96,8 +98,8 @@ class GroundOperator : public LiftedOperator
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
unsigned group_;
|
||||
unsigned lvIndex_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user