#include #include #include #include #include #include #include #include #include "LiftedVe.h" #include "LiftedOperations.h" #include "Histogram.h" #include "Util.h" namespace Horus { class LiftedOperator { public: virtual ~LiftedOperator() { } virtual double getLogCost() = 0; virtual void apply() = 0; virtual std::string toString() = 0; static std::vector getValidOps ( ParfactorList&, const Grounds&); static void printValidOps (ParfactorList&, const Grounds&); static std::vector getParfactorsWithGroup ( ParfactorList&, PrvGroup group); private: DISALLOW_ASSIGN (LiftedOperator); }; class ProductOperator : public LiftedOperator { public: ProductOperator ( ParfactorList::iterator g1, ParfactorList::iterator g2, ParfactorList& pfList) : g1_(g1), g2_(g2), pfList_(pfList) { } double getLogCost(); void apply(); static std::vector getValidOps (ParfactorList&); std::string toString(); private: static bool validOp (Parfactor*, Parfactor*); ParfactorList::iterator g1_; ParfactorList::iterator g2_; ParfactorList& pfList_; DISALLOW_COPY_AND_ASSIGN (ProductOperator); }; class SumOutOperator : public LiftedOperator { public: SumOutOperator (PrvGroup group, ParfactorList& pfList) : group_(group), pfList_(pfList) { } double getLogCost(); void apply(); static std::vector getValidOps ( ParfactorList&, const Grounds&); std::string toString(); private: static bool validOp (PrvGroup, ParfactorList&, const Grounds&); static bool isToEliminate (Parfactor*, PrvGroup, const Grounds&); PrvGroup group_; ParfactorList& pfList_; DISALLOW_COPY_AND_ASSIGN (SumOutOperator); }; class CountingOperator : public LiftedOperator { public: CountingOperator ( ParfactorList::iterator pfIter, LogVar X, ParfactorList& pfList) : pfIter_(pfIter), X_(X), pfList_(pfList) { } double getLogCost(); void apply(); static std::vector getValidOps (ParfactorList&); std::string toString(); private: static bool validOp (Parfactor*, LogVar); ParfactorList::iterator pfIter_; LogVar X_; ParfactorList& pfList_; DISALLOW_COPY_AND_ASSIGN (CountingOperator); }; class GroundOperator : public LiftedOperator { public: GroundOperator ( PrvGroup group, unsigned lvIndex, ParfactorList& pfList) : group_(group), lvIndex_(lvIndex), pfList_(pfList) { } double getLogCost(); void apply(); static std::vector getValidOps (ParfactorList&); std::string toString(); private: std::vector> getAffectedFormulas(); PrvGroup group_; unsigned lvIndex_; ParfactorList& pfList_; DISALLOW_COPY_AND_ASSIGN (GroundOperator); }; std::vector LiftedOperator::getValidOps ( ParfactorList& pfList, const Grounds& query) { std::vector validOps; std::vector multOps; multOps = ProductOperator::getValidOps (pfList); validOps.insert (validOps.end(), multOps.begin(), multOps.end()); if (Globals::verbosity > 1 || multOps.empty()) { std::vector sumOutOps; std::vector countOps; std::vector groundOps; sumOutOps = SumOutOperator::getValidOps (pfList, query); countOps = CountingOperator::getValidOps (pfList); groundOps = GroundOperator::getValidOps (pfList); validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end()); validOps.insert (validOps.end(), countOps.begin(), countOps.end()); validOps.insert (validOps.end(), groundOps.begin(), groundOps.end()); } return validOps; } void LiftedOperator::printValidOps ( ParfactorList& pfList, const Grounds& query) { std::vector validOps; validOps = LiftedOperator::getValidOps (pfList, query); for (size_t i = 0; i < validOps.size(); i++) { std::cout << "-> " << validOps[i]->toString(); delete validOps[i]; } } std::vector LiftedOperator::getParfactorsWithGroup ( ParfactorList& pfList, PrvGroup group) { std::vector iters; ParfactorList::iterator pflIt = pfList.begin(); while (pflIt != pfList.end()) { if ((*pflIt)->containsGroup (group)) { iters.push_back (pflIt); } ++ pflIt; } return iters; } double ProductOperator::getLogCost() { return std::log (0.0); } void ProductOperator::apply() { Parfactor* g1 = *g1_; Parfactor* g2 = *g2_; g1->multiply (*g2); pfList_.remove (g1_); pfList_.removeAndDelete (g2_); pfList_.addShattered (g1); } std::vector ProductOperator::getValidOps (ParfactorList& pfList) { std::vector validOps; ParfactorList::iterator it1 = pfList.begin(); ParfactorList::iterator penultimate = -- pfList.end(); std::set pfs; while (it1 != penultimate) { if (Util::contains (pfs, *it1)) { ++ it1; continue; } ParfactorList::iterator it2 = it1; ++ it2; while (it2 != pfList.end()) { if (Util::contains (pfs, *it2)) { ++ it2; continue; } else { if (validOp (*it1, *it2)) { pfs.insert (*it1); pfs.insert (*it2); validOps.push_back (new ProductOperator ( it1, it2, pfList)); if (Globals::verbosity < 2) { return validOps; } break; } } ++ it2; } ++ it1; } return validOps; } std::string ProductOperator::toString() { std::stringstream ss; ss << "just multiplicate " ; ss << (*g1_)->getAllGroups(); ss << " x " ; ss << (*g2_)->getAllGroups(); ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl; return ss.str(); } bool ProductOperator::validOp (Parfactor* g1, Parfactor* g2) { TinySet g1_gs (g1->getAllGroups()); TinySet g2_gs (g2->getAllGroups()); if (g1_gs.contains (g2_gs) || g2_gs.contains (g1_gs)) { TinySet intersect = g1_gs & g2_gs; for (size_t i = 0; i < intersect.size(); i++) { if (g1->nrFormulasWithGroup (intersect[i]) != 1 || g2->nrFormulasWithGroup (intersect[i]) != 1) { return false; } size_t idx1 = g1->indexOfGroup (intersect[i]); size_t idx2 = g2->indexOfGroup (intersect[i]); if (g1->range (idx1) != g2->range (idx2)) { return false; } } return Parfactor::canMultiply (g1, g2); } return false; } double SumOutOperator::getLogCost() { TinySet groupSet; ParfactorList::const_iterator pfIter = pfList_.begin(); unsigned nrProdFactors = 0; while (pfIter != pfList_.end()) { if ((*pfIter)->containsGroup (group_)) { std::vector groups = (*pfIter)->getAllGroups(); groupSet |= TinySet (groups); ++ nrProdFactors; } ++ pfIter; } if (nrProdFactors == 1) { // best possible case return std::log (0.0); } double cost = 1.0; for (size_t i = 0; i < groupSet.size(); i++) { pfIter = pfList_.begin(); while (pfIter != pfList_.end()) { if ((*pfIter)->containsGroup (groupSet[i])) { size_t idx = (*pfIter)->indexOfGroup (groupSet[i]); cost *= (*pfIter)->range (idx); break; } ++ pfIter; } } return std::log (cost); } void SumOutOperator::apply() { std::vector iters; iters = getParfactorsWithGroup (pfList_, group_); Parfactor* product = *(iters[0]); pfList_.remove (iters[0]); for (size_t i = 1; i < iters.size(); i++) { product->multiply (**(iters[i])); pfList_.removeAndDelete (iters[i]); } if (product->nrArguments() == 1) { delete product; return; } size_t fIdx = product->indexOfGroup (group_); LogVarSet excl = product->exclusiveLogVars (fIdx); if (product->constr()->isCountNormalized (excl)) { product->sumOutIndex (fIdx); pfList_.addShattered (product); } else { Parfactors pfs = LiftedOperations::countNormalize (product, excl); for (size_t i = 0; i < pfs.size(); i++) { pfs[i]->sumOutIndex (fIdx); pfList_.add (pfs[i]); } delete product; } } std::vector SumOutOperator::getValidOps ( ParfactorList& pfList, const Grounds& query) { std::vector validOps; std::set allGroups; ParfactorList::const_iterator it = pfList.begin(); while (it != pfList.end()) { const ProbFormulas& formulas = (*it)->arguments(); for (size_t i = 0; i < formulas.size(); i++) { allGroups.insert (formulas[i].group()); } ++ it; } std::set::const_iterator groupIt = allGroups.begin(); while (groupIt != allGroups.end()) { if (validOp (*groupIt, pfList, query)) { validOps.push_back (new SumOutOperator (*groupIt, pfList)); } ++ groupIt; } return validOps; } std::string SumOutOperator::toString() { std::stringstream ss; std::vector pfIters; pfIters = getParfactorsWithGroup (pfList_, group_); size_t idx = (*pfIters[0])->indexOfGroup (group_); ProbFormula f = (*pfIters[0])->argument (idx); TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars()); ss << "sum out " << f.functor() << "/" << f.arity(); ss << "|" << tupleSet << " (group " << group_ << ")"; ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl; return ss.str(); } bool SumOutOperator::validOp ( PrvGroup group, ParfactorList& pfList, const Grounds& query) { std::vector pfIters; pfIters = getParfactorsWithGroup (pfList, group); if (isToEliminate (*pfIters[0], group, query) == false) { return false; } int range = -1; for (size_t i = 0; i < pfIters.size(); i++) { if ((*pfIters[i])->nrFormulasWithGroup (group) > 1) { return false; } size_t fIdx = (*pfIters[i])->indexOfGroup (group); if ((*pfIters[i])->argument (fIdx).contains ( (*pfIters[i])->elimLogVars()) == false) { return false; } if (range == -1) { range = (*pfIters[i])->range (fIdx); } else if ((int)(*pfIters[i])->range (fIdx) != range) { return false; } } return true; } bool SumOutOperator::isToEliminate ( Parfactor* g, PrvGroup group, const Grounds& query) { size_t fIdx = g->indexOfGroup (group); const ProbFormula& formula = g->argument (fIdx); bool toElim = true; for (size_t i = 0; i < query.size(); i++) { if (formula.functor() == query[i].functor() && formula.arity() == query[i].arity()) { g->constr()->moveToTop (formula.logVars()); if (g->constr()->containsTuple (query[i].args())) { toElim = false; break; } } } return toElim; } double CountingOperator::getLogCost() { double cost = 0.0; size_t fIdx = (*pfIter_)->indexOfLogVar (X_); unsigned range = (*pfIter_)->range (fIdx); unsigned size = (*pfIter_)->size() / range; TinySet counts; counts = (*pfIter_)->constr()->getConditionalCounts (X_); for (size_t i = 0; i < counts.size(); i++) { cost += size * HistogramSet::nrHistograms (counts[i], range); } PrvGroup group = (*pfIter_)->argument (fIdx).group(); size_t lvIndex = Util::indexOf ( (*pfIter_)->argument (fIdx).logVars(), X_); assert (lvIndex != (*pfIter_)->argument (fIdx).logVars().size()); ParfactorList::iterator pfIter = pfList_.begin(); while (pfIter != pfList_.end()) { if (pfIter != pfIter_) { size_t fIdx2 = (*pfIter)->indexOfGroup (group); if (fIdx2 != (*pfIter)->nrArguments()) { LogVar Y = ((*pfIter)->argument (fIdx2).logVars()[lvIndex]); if ((*pfIter)->canCountConvert (Y) == false) { // the real cost should be the cost of grounding Y cost *= 10.0; } } } ++ pfIter; } return std::log (cost); } void CountingOperator::apply() { if ((*pfIter_)->constr()->isCountNormalized (X_)) { (*pfIter_)->countConvert (X_); } else { Parfactor* pf = *pfIter_; pfList_.remove (pfIter_); Parfactors pfs = LiftedOperations::countNormalize (pf, X_); for (size_t i = 0; i < pfs.size(); i++) { unsigned condCount = pfs[i]->constr()->getConditionalCount (X_); bool cartProduct = pfs[i]->constr()->isCartesianProduct ( pfs[i]->countedLogVars() | X_); if (condCount > 1 && cartProduct) { pfs[i]->countConvert (X_); } pfList_.add (pfs[i]); } delete pf; } } std::vector CountingOperator::getValidOps (ParfactorList& pfList) { std::vector validOps; ParfactorList::iterator it = pfList.begin(); while (it != pfList.end()) { LogVarSet candidates = (*it)->uncountedLogVars(); for (size_t i = 0; i < candidates.size(); i++) { if (validOp (*it, candidates[i])) { validOps.push_back (new CountingOperator ( it, candidates[i], pfList)); } else { } } ++ it; } return validOps; } std::string CountingOperator::toString() { std::stringstream ss; ss << "count convert " << X_ << " in " ; ss << (*pfIter_)->getLabel(); ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl; Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_); if ((*pfIter_)->constr()->isCountNormalized (X_) == false) { for (size_t i = 0; i < pfs.size(); i++) { ss << " ยบ " << pfs[i]->getLabel() << std::endl; } } for (size_t i = 0; i < pfs.size(); i++) { delete pfs[i]; } return ss.str(); } bool CountingOperator::validOp (Parfactor* g, LogVar X) { if (g->nrFormulas (X) != 1) { return false; } size_t fIdx = g->indexOfLogVar (X); if (g->argument (fIdx).isCounting()) { return false; } bool countNormalized = g->constr()->isCountNormalized (X); if (countNormalized) { return g->canCountConvert (X); } return true; } double GroundOperator::getLogCost() { std::vector> affectedFormulas; affectedFormulas = getAffectedFormulas(); // std::cout << "affected formulas: " ; // for (size_t i = 0; i < affectedFormulas.size(); i++) { // std::cout << affectedFormulas[i].first << ":" ; // std::cout << affectedFormulas[i].second << " " ; // } // std::cout << "cost =" ; double totalCost = std::log (0.0); ParfactorList::iterator pflIt = pfList_.begin(); while (pflIt != pfList_.end()) { Parfactor* pf = *pflIt; double reps = 0.0; double pfSize = std::log (pf->size()); bool willBeAffected = false; LogVarSet lvsToGround; for (size_t i = 0; i < affectedFormulas.size(); i++) { size_t fIdx = pf->indexOfGroup (affectedFormulas[i].first); if (fIdx != pf->nrArguments()) { ProbFormula f = pf->argument (fIdx); LogVar X = f.logVars()[affectedFormulas[i].second]; bool isCountingLv = pf->countedLogVars().contains (X); if (isCountingLv) { unsigned nrHists = pf->range (fIdx); unsigned nrSymbols = pf->constr()->getConditionalCount (X); unsigned range = pf->argument (fIdx).range(); double power = std::log (range) * nrSymbols; pfSize = (pfSize - std::log (nrHists)) + power; } else { if (lvsToGround.contains (X) == false) { reps += std::log (pf->constr()->nrSymbols (X)); lvsToGround.insert (X); } } willBeAffected = true; } } if (willBeAffected) { // std::cout << " + " << std::exp (reps) << "x" << std::exp (pfSize); double pfCost = reps + pfSize; totalCost = Util::logSum (totalCost, pfCost); } ++ pflIt; } // std::cout << std::endl; return totalCost + 3; } void GroundOperator::apply() { ParfactorList::iterator pfIter; pfIter = getParfactorsWithGroup (pfList_, group_).front(); Parfactor* pf = *pfIter; size_t idx = pf->indexOfGroup (group_); ProbFormula f = pf->argument (idx); LogVar X = f.logVars()[lvIndex_]; bool countedLv = pf->countedLogVars().contains (X); pfList_.remove (pfIter); if (countedLv) { pf->fullExpand (X); pfList_.add (pf); } else { ConstraintTrees cts = pf->constr()->ground (X); for (size_t i = 0; i < cts.size(); i++) { pfList_.add (new Parfactor (pf, cts[i])); } delete pf; } ParfactorList::iterator pflIt = pfList_.begin(); while (pflIt != pfList_.end()) { (*pflIt)->simplifyGrounds(); ++ pflIt; } } std::vector GroundOperator::getValidOps (ParfactorList& pfList) { std::vector validOps; std::set allGroups; ParfactorList::const_iterator it = pfList.begin(); while (it != pfList.end()) { const ProbFormulas& formulas = (*it)->arguments(); for (size_t i = 0; i < formulas.size(); i++) { if (Util::contains (allGroups, formulas[i].group()) == false) { const LogVars& lvs = formulas[i].logVars(); for (size_t j = 0; j < lvs.size(); j++) { if ((*it)->constr()->isSingleton (lvs[j]) == false) { validOps.push_back (new GroundOperator ( formulas[i].group(), j, pfList)); } } allGroups.insert (formulas[i].group()); } } ++ it; } return validOps; } std::string GroundOperator::toString() { std::stringstream ss; std::vector pfIters; pfIters = getParfactorsWithGroup (pfList_, group_); Parfactor* pf = *(getParfactorsWithGroup (pfList_, group_).front()); size_t idx = pf->indexOfGroup (group_); ProbFormula f = pf->argument (idx); LogVar lv = f.logVars()[lvIndex_]; TupleSet tupleSet = pf->constr()->tupleSet ({lv}); std::string pos = "th"; if (lvIndex_ == 0) { pos = "st" ; } else if (lvIndex_ == 1) { pos = "nd" ; } else if (lvIndex_ == 2) { pos = "rd" ; } ss << "grounding " << lvIndex_ + 1 << pos << " log var in " ; ss << f.functor() << "/" << f.arity(); ss << "|" << tupleSet << " (group " << group_ << ")"; ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl; return ss.str(); } std::vector> GroundOperator::getAffectedFormulas() { std::vector> affectedFormulas; affectedFormulas.push_back (std::make_pair (group_, lvIndex_)); std::queue> q; q.push (std::make_pair (group_, lvIndex_)); while (q.empty() == false) { std::pair front = q.front(); ParfactorList::iterator pflIt = pfList_.begin(); while (pflIt != pfList_.end()) { size_t idx = (*pflIt)->indexOfGroup (front.first); if (idx != (*pflIt)->nrArguments()) { ProbFormula f = (*pflIt)->argument (idx); LogVar X = f.logVars()[front.second]; const ProbFormulas& fs = (*pflIt)->arguments(); for (size_t i = 0; i < fs.size(); i++) { if (i != idx && fs[i].contains (X)) { std::pair pair = std::make_pair ( fs[i].group(), fs[i].indexOf (X)); if (Util::contains (affectedFormulas, pair) == false) { q.push (pair); affectedFormulas.push_back (pair); } } } } ++ pflIt; } q.pop(); } return affectedFormulas; } Params LiftedVe::solveQuery (const Grounds& query) { assert (query.empty() == false); pfList_ = parfactorList; runSolver (query); (*pfList_.begin())->normalize(); Params params = (*pfList_.begin())->params(); if (Globals::logDomain) { Util::exp (params); } return params; } void LiftedVe::printSolverFlags() const { std::stringstream ss; ss << "lve [" ; ss << "log_domain=" << Util::toString (Globals::logDomain); ss << "]" ; std::cout << ss.str() << std::endl; } void LiftedVe::runSolver (const Grounds& query) { largestCost_ = std::log (0); LiftedOperations::shatterAgainstQuery (pfList_, query); LiftedOperations::runWeakBayesBall (pfList_, query); while (true) { if (Globals::verbosity > 2) { Util::printDashedLine(); pfList_.print(); if (Globals::verbosity > 3) { LiftedOperator::printValidOps (pfList_, query); } } LiftedOperator* op = getBestOperation (query); if (op == 0) { break; } if (Globals::verbosity > 1) { std::cout << "best operation: " << op->toString(); if (Globals::verbosity > 2) { std::cout << std::endl; } } op->apply(); delete op; } assert (pfList_.size() > 0); if (pfList_.size() > 1) { ParfactorList::iterator pfIter = pfList_.begin(); ++ pfIter; while (pfIter != pfList_.end()) { (*pfList_.begin())->multiply (**pfIter); ++ pfIter; } } if (Globals::verbosity > 0) { std::cout << "largest cost = " << std::exp (largestCost_); std::cout << std::endl; std::cout << std::endl; } (*pfList_.begin())->simplifyGrounds(); (*pfList_.begin())->reorderAccordingGrounds (query); } LiftedOperator* LiftedVe::getBestOperation (const Grounds& query) { double bestCost = 0.0; LiftedOperator* bestOp = 0; std::vector validOps; validOps = LiftedOperator::getValidOps (pfList_, query); for (size_t i = 0; i < validOps.size(); i++) { double cost = validOps[i]->getLogCost(); if (!bestOp || cost < bestCost) { bestOp = validOps[i]; bestCost = cost; } } if (bestCost > largestCost_) { largestCost_ = bestCost; } for (size_t i = 0; i < validOps.size(); i++) { if (validOps[i] != bestOp) { delete validOps[i]; } } return bestOp; } } // namespace Horus