This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/LiftedVe.cpp

735 lines
19 KiB
C++
Raw Normal View History

2013-02-07 20:09:10 +00:00
#include <cassert>
2012-12-27 12:54:58 +00:00
2012-05-23 14:56:01 +01:00
#include <set>
2013-02-07 20:09:10 +00:00
#include <queue>
#include <algorithm>
#include <iostream>
#include <sstream>
2012-05-23 14:56:01 +01:00
#include "LiftedVe.h"
#include "LiftedOperations.h"
2012-05-23 14:56:01 +01:00
#include "Histogram.h"
#include "Util.h"
2013-02-07 13:37:15 +00:00
std::vector<LiftedOperator*>
2012-05-23 14:56:01 +01:00
LiftedOperator::getValidOps (
ParfactorList& pfList,
const Grounds& query)
{
2013-02-07 13:37:15 +00:00
std::vector<LiftedOperator*> validOps;
std::vector<ProductOperator*> multOps;
2012-05-23 14:56:01 +01:00
multOps = ProductOperator::getValidOps (pfList);
validOps.insert (validOps.end(), multOps.begin(), multOps.end());
if (Globals::verbosity > 1 || multOps.empty()) {
2013-02-07 13:37:15 +00:00
std::vector<SumOutOperator*> sumOutOps;
std::vector<CountingOperator*> countOps;
std::vector<GroundOperator*> groundOps;
2012-05-23 14:56:01 +01:00
sumOutOps = SumOutOperator::getValidOps (pfList, query);
countOps = CountingOperator::getValidOps (pfList);
groundOps = GroundOperator::getValidOps (pfList);
validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end());
validOps.insert (validOps.end(), countOps.begin(), countOps.end());
validOps.insert (validOps.end(), groundOps.begin(), groundOps.end());
}
return validOps;
}
void
LiftedOperator::printValidOps (
ParfactorList& pfList,
const Grounds& query)
{
2013-02-07 13:37:15 +00:00
std::vector<LiftedOperator*> validOps;
2012-05-23 14:56:01 +01:00
validOps = LiftedOperator::getValidOps (pfList, query);
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < validOps.size(); i++) {
2013-02-07 13:37:15 +00:00
std::cout << "-> " << validOps[i]->toString();
2012-05-23 14:56:01 +01:00
delete validOps[i];
}
}
2013-02-07 13:37:15 +00:00
std::vector<ParfactorList::iterator>
2012-05-23 14:56:01 +01:00
LiftedOperator::getParfactorsWithGroup (
2012-05-24 23:38:44 +01:00
ParfactorList& pfList, PrvGroup group)
2012-05-23 14:56:01 +01:00
{
2013-02-07 13:37:15 +00:00
std::vector<ParfactorList::iterator> iters;
2012-05-23 14:56:01 +01:00
ParfactorList::iterator pflIt = pfList.begin();
while (pflIt != pfList.end()) {
if ((*pflIt)->containsGroup (group)) {
iters.push_back (pflIt);
}
++ pflIt;
}
return iters;
}
double
ProductOperator::getLogCost (void)
{
return std::log (0.0);
}
void
ProductOperator::apply (void)
{
Parfactor* g1 = *g1_;
Parfactor* g2 = *g2_;
g1->multiply (*g2);
pfList_.remove (g1_);
pfList_.removeAndDelete (g2_);
pfList_.addShattered (g1);
}
2013-02-07 13:37:15 +00:00
std::vector<ProductOperator*>
2012-05-23 14:56:01 +01:00
ProductOperator::getValidOps (ParfactorList& pfList)
{
2013-02-07 13:37:15 +00:00
std::vector<ProductOperator*> validOps;
2012-05-23 14:56:01 +01:00
ParfactorList::iterator it1 = pfList.begin();
ParfactorList::iterator penultimate = -- pfList.end();
2013-02-07 13:37:15 +00:00
std::set<Parfactor*> pfs;
2012-05-23 14:56:01 +01:00
while (it1 != penultimate) {
if (Util::contains (pfs, *it1)) {
++ it1;
continue;
}
ParfactorList::iterator it2 = it1;
++ it2;
while (it2 != pfList.end()) {
if (Util::contains (pfs, *it2)) {
++ it2;
continue;
} else {
if (validOp (*it1, *it2)) {
pfs.insert (*it1);
pfs.insert (*it2);
validOps.push_back (new ProductOperator (
it1, it2, pfList));
if (Globals::verbosity < 2) {
return validOps;
}
break;
}
}
++ it2;
}
++ it1;
}
return validOps;
}
2013-02-07 13:37:15 +00:00
std::string
2012-05-23 14:56:01 +01:00
ProductOperator::toString (void)
{
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2012-05-23 14:56:01 +01:00
ss << "just multiplicate " ;
ss << (*g1_)->getAllGroups();
2012-12-20 23:19:10 +00:00
ss << " x " ;
2012-05-23 14:56:01 +01:00
ss << (*g2_)->getAllGroups();
2013-02-07 13:37:15 +00:00
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
2012-05-23 14:56:01 +01:00
return ss.str();
}
bool
ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
{
2012-05-24 23:38:44 +01:00
TinySet<PrvGroup> g1_gs (g1->getAllGroups());
TinySet<PrvGroup> g2_gs (g2->getAllGroups());
2012-05-23 14:56:01 +01:00
if (g1_gs.contains (g2_gs) || g2_gs.contains (g1_gs)) {
2012-05-24 23:38:44 +01:00
TinySet<PrvGroup> intersect = g1_gs & g2_gs;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < intersect.size(); i++) {
2012-05-23 14:56:01 +01:00
if (g1->nrFormulasWithGroup (intersect[i]) != 1 ||
g2->nrFormulasWithGroup (intersect[i]) != 1) {
return false;
}
2012-05-24 22:55:20 +01:00
size_t idx1 = g1->indexOfGroup (intersect[i]);
size_t idx2 = g2->indexOfGroup (intersect[i]);
2012-12-20 23:19:10 +00:00
if (g1->range (idx1) != g2->range (idx2)) {
2012-05-23 14:56:01 +01:00
return false;
}
}
return Parfactor::canMultiply (g1, g2);
}
return false;
}
double
SumOutOperator::getLogCost (void)
{
2012-05-24 23:38:44 +01:00
TinySet<PrvGroup> groupSet;
2012-05-23 14:56:01 +01:00
ParfactorList::const_iterator pfIter = pfList_.begin();
unsigned nrProdFactors = 0;
while (pfIter != pfList_.end()) {
if ((*pfIter)->containsGroup (group_)) {
2013-02-07 13:37:15 +00:00
std::vector<PrvGroup> groups = (*pfIter)->getAllGroups();
2012-05-24 23:38:44 +01:00
groupSet |= TinySet<PrvGroup> (groups);
2012-05-23 14:56:01 +01:00
++ nrProdFactors;
}
++ pfIter;
}
if (nrProdFactors == 1) {
// best possible case
return std::log (0.0);
}
double cost = 1.0;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < groupSet.size(); i++) {
2012-05-23 14:56:01 +01:00
pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
if ((*pfIter)->containsGroup (groupSet[i])) {
2012-05-24 22:55:20 +01:00
size_t idx = (*pfIter)->indexOfGroup (groupSet[i]);
2012-05-23 14:56:01 +01:00
cost *= (*pfIter)->range (idx);
break;
}
++ pfIter;
}
}
return std::log (cost);
}
void
SumOutOperator::apply (void)
{
2013-02-07 13:37:15 +00:00
std::vector<ParfactorList::iterator> iters;
2012-05-23 14:56:01 +01:00
iters = getParfactorsWithGroup (pfList_, group_);
Parfactor* product = *(iters[0]);
pfList_.remove (iters[0]);
2012-05-24 22:55:20 +01:00
for (size_t i = 1; i < iters.size(); i++) {
2012-05-23 14:56:01 +01:00
product->multiply (**(iters[i]));
pfList_.removeAndDelete (iters[i]);
}
if (product->nrArguments() == 1) {
delete product;
return;
}
2012-05-24 22:55:20 +01:00
size_t fIdx = product->indexOfGroup (group_);
2012-05-23 14:56:01 +01:00
LogVarSet excl = product->exclusiveLogVars (fIdx);
if (product->constr()->isCountNormalized (excl)) {
2012-05-25 20:15:05 +01:00
product->sumOutIndex (fIdx);
2012-05-23 14:56:01 +01:00
pfList_.addShattered (product);
} else {
Parfactors pfs = LiftedOperations::countNormalize (product, excl);
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < pfs.size(); i++) {
2012-05-25 20:15:05 +01:00
pfs[i]->sumOutIndex (fIdx);
2012-05-23 14:56:01 +01:00
pfList_.add (pfs[i]);
}
delete product;
}
}
2013-02-07 13:37:15 +00:00
std::vector<SumOutOperator*>
2012-05-23 14:56:01 +01:00
SumOutOperator::getValidOps (
ParfactorList& pfList,
const Grounds& query)
{
2013-02-07 13:37:15 +00:00
std::vector<SumOutOperator*> validOps;
std::set<PrvGroup> allGroups;
2012-05-23 14:56:01 +01:00
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < formulas.size(); i++) {
2012-05-23 14:56:01 +01:00
allGroups.insert (formulas[i].group());
}
++ it;
}
2013-02-07 13:37:15 +00:00
std::set<PrvGroup>::const_iterator groupIt = allGroups.begin();
2012-05-23 14:56:01 +01:00
while (groupIt != allGroups.end()) {
if (validOp (*groupIt, pfList, query)) {
validOps.push_back (new SumOutOperator (*groupIt, pfList));
}
++ groupIt;
}
return validOps;
}
2013-02-07 13:37:15 +00:00
std::string
2012-05-23 14:56:01 +01:00
SumOutOperator::toString (void)
{
2013-02-07 13:37:15 +00:00
std::stringstream ss;
std::vector<ParfactorList::iterator> pfIters;
2012-05-23 14:56:01 +01:00
pfIters = getParfactorsWithGroup (pfList_, group_);
2012-05-24 22:55:20 +01:00
size_t idx = (*pfIters[0])->indexOfGroup (group_);
2012-05-23 14:56:01 +01:00
ProbFormula f = (*pfIters[0])->argument (idx);
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
ss << "sum out " << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")";
2013-02-07 13:37:15 +00:00
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
2012-05-23 14:56:01 +01:00
return ss.str();
}
bool
SumOutOperator::validOp (
2012-05-24 23:38:44 +01:00
PrvGroup group,
2012-05-23 14:56:01 +01:00
ParfactorList& pfList,
const Grounds& query)
{
2013-02-07 13:37:15 +00:00
std::vector<ParfactorList::iterator> pfIters;
2012-05-23 14:56:01 +01:00
pfIters = getParfactorsWithGroup (pfList, group);
if (isToEliminate (*pfIters[0], group, query) == false) {
return false;
}
int range = -1;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < pfIters.size(); i++) {
2012-05-23 14:56:01 +01:00
if ((*pfIters[i])->nrFormulasWithGroup (group) > 1) {
return false;
}
2012-05-24 22:55:20 +01:00
size_t fIdx = (*pfIters[i])->indexOfGroup (group);
2012-05-23 14:56:01 +01:00
if ((*pfIters[i])->argument (fIdx).contains (
(*pfIters[i])->elimLogVars()) == false) {
return false;
}
if (range == -1) {
range = (*pfIters[i])->range (fIdx);
} else if ((int)(*pfIters[i])->range (fIdx) != range) {
return false;
}
}
return true;
}
bool
SumOutOperator::isToEliminate (
Parfactor* g,
2012-05-24 23:38:44 +01:00
PrvGroup group,
2012-05-23 14:56:01 +01:00
const Grounds& query)
{
2012-05-24 22:55:20 +01:00
size_t fIdx = g->indexOfGroup (group);
2012-05-23 14:56:01 +01:00
const ProbFormula& formula = g->argument (fIdx);
bool toElim = true;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < query.size(); i++) {
2012-05-23 14:56:01 +01:00
if (formula.functor() == query[i].functor() &&
formula.arity() == query[i].arity()) {
g->constr()->moveToTop (formula.logVars());
if (g->constr()->containsTuple (query[i].args())) {
toElim = false;
break;
}
}
}
return toElim;
}
double
CountingOperator::getLogCost (void)
{
double cost = 0.0;
2012-05-24 22:55:20 +01:00
size_t fIdx = (*pfIter_)->indexOfLogVar (X_);
2012-05-23 14:56:01 +01:00
unsigned range = (*pfIter_)->range (fIdx);
unsigned size = (*pfIter_)->size() / range;
TinySet<unsigned> counts;
counts = (*pfIter_)->constr()->getConditionalCounts (X_);
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < counts.size(); i++) {
2012-05-23 14:56:01 +01:00
cost += size * HistogramSet::nrHistograms (counts[i], range);
}
2012-05-24 23:38:44 +01:00
PrvGroup group = (*pfIter_)->argument (fIdx).group();
2012-05-24 22:55:20 +01:00
size_t lvIndex = Util::indexOf (
2012-05-23 14:56:01 +01:00
(*pfIter_)->argument (fIdx).logVars(), X_);
2012-05-24 22:55:20 +01:00
assert (lvIndex != (*pfIter_)->argument (fIdx).logVars().size());
2012-05-23 14:56:01 +01:00
ParfactorList::iterator pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
if (pfIter != pfIter_) {
2012-05-24 22:55:20 +01:00
size_t fIdx2 = (*pfIter)->indexOfGroup (group);
if (fIdx2 != (*pfIter)->nrArguments()) {
2012-05-23 14:56:01 +01:00
LogVar Y = ((*pfIter)->argument (fIdx2).logVars()[lvIndex]);
if ((*pfIter)->canCountConvert (Y) == false) {
// the real cost should be the cost of grounding Y
cost *= 10.0;
}
}
}
++ pfIter;
}
return std::log (cost);
}
void
CountingOperator::apply (void)
{
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
(*pfIter_)->countConvert (X_);
} else {
Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_);
Parfactors pfs = LiftedOperations::countNormalize (pf, X_);
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < pfs.size(); i++) {
2012-05-23 14:56:01 +01:00
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
bool cartProduct = pfs[i]->constr()->isCartesianProduct (
pfs[i]->countedLogVars() | X_);
if (condCount > 1 && cartProduct) {
pfs[i]->countConvert (X_);
}
pfList_.add (pfs[i]);
}
delete pf;
}
}
2013-02-07 13:37:15 +00:00
std::vector<CountingOperator*>
2012-05-23 14:56:01 +01:00
CountingOperator::getValidOps (ParfactorList& pfList)
{
2013-02-07 13:37:15 +00:00
std::vector<CountingOperator*> validOps;
2012-05-23 14:56:01 +01:00
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
LogVarSet candidates = (*it)->uncountedLogVars();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < candidates.size(); i++) {
2012-05-23 14:56:01 +01:00
if (validOp (*it, candidates[i])) {
validOps.push_back (new CountingOperator (
it, candidates[i], pfList));
} else {
}
}
++ it;
}
return validOps;
}
2013-02-07 13:37:15 +00:00
std::string
2012-05-23 14:56:01 +01:00
CountingOperator::toString (void)
{
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2012-05-23 14:56:01 +01:00
ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel();
2013-02-07 13:37:15 +00:00
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_);
2012-05-23 14:56:01 +01:00
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < pfs.size(); i++) {
2013-02-07 13:37:15 +00:00
ss << " º " << pfs[i]->getLabel() << std::endl;
2012-05-23 14:56:01 +01:00
}
}
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < pfs.size(); i++) {
2012-05-23 14:56:01 +01:00
delete pfs[i];
}
return ss.str();
}
bool
CountingOperator::validOp (Parfactor* g, LogVar X)
{
if (g->nrFormulas (X) != 1) {
return false;
}
2012-05-24 22:55:20 +01:00
size_t fIdx = g->indexOfLogVar (X);
2012-05-23 14:56:01 +01:00
if (g->argument (fIdx).isCounting()) {
return false;
}
bool countNormalized = g->constr()->isCountNormalized (X);
if (countNormalized) {
return g->canCountConvert (X);
}
return true;
}
double
GroundOperator::getLogCost (void)
{
2013-02-07 13:37:15 +00:00
std::vector<std::pair<PrvGroup, unsigned>> affectedFormulas;
2012-05-23 14:56:01 +01:00
affectedFormulas = getAffectedFormulas();
2013-02-07 13:37:15 +00:00
// std::cout << "affected formulas: " ;
2012-05-24 22:55:20 +01:00
// for (size_t i = 0; i < affectedFormulas.size(); i++) {
2013-02-07 13:37:15 +00:00
// std::cout << affectedFormulas[i].first << ":" ;
// std::cout << affectedFormulas[i].second << " " ;
2012-05-23 14:56:01 +01:00
// }
2013-02-07 13:37:15 +00:00
// std::cout << "cost =" ;
2012-05-23 14:56:01 +01:00
double totalCost = std::log (0.0);
ParfactorList::iterator pflIt = pfList_.begin();
while (pflIt != pfList_.end()) {
Parfactor* pf = *pflIt;
double reps = 0.0;
double pfSize = std::log (pf->size());
bool willBeAffected = false;
LogVarSet lvsToGround;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < affectedFormulas.size(); i++) {
size_t fIdx = pf->indexOfGroup (affectedFormulas[i].first);
if (fIdx != pf->nrArguments()) {
2012-05-23 14:56:01 +01:00
ProbFormula f = pf->argument (fIdx);
LogVar X = f.logVars()[affectedFormulas[i].second];
bool isCountingLv = pf->countedLogVars().contains (X);
if (isCountingLv) {
unsigned nrHists = pf->range (fIdx);
unsigned nrSymbols = pf->constr()->getConditionalCount (X);
unsigned range = pf->argument (fIdx).range();
double power = std::log (range) * nrSymbols;
pfSize = (pfSize - std::log (nrHists)) + power;
} else {
if (lvsToGround.contains (X) == false) {
reps += std::log (pf->constr()->nrSymbols (X));
lvsToGround.insert (X);
}
}
willBeAffected = true;
}
}
if (willBeAffected) {
2013-02-07 13:37:15 +00:00
// std::cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
2012-05-23 14:56:01 +01:00
double pfCost = reps + pfSize;
totalCost = Util::logSum (totalCost, pfCost);
}
++ pflIt;
}
2013-02-07 13:37:15 +00:00
// std::cout << std::endl;
return totalCost + 3;
2012-05-23 14:56:01 +01:00
}
void
GroundOperator::apply (void)
{
ParfactorList::iterator pfIter;
pfIter = getParfactorsWithGroup (pfList_, group_).front();
Parfactor* pf = *pfIter;
2012-05-24 22:55:20 +01:00
size_t idx = pf->indexOfGroup (group_);
2012-05-23 14:56:01 +01:00
ProbFormula f = pf->argument (idx);
LogVar X = f.logVars()[lvIndex_];
bool countedLv = pf->countedLogVars().contains (X);
pfList_.remove (pfIter);
if (countedLv) {
pf->fullExpand (X);
pfList_.add (pf);
} else {
ConstraintTrees cts = pf->constr()->ground (X);
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < cts.size(); i++) {
2012-05-23 14:56:01 +01:00
pfList_.add (new Parfactor (pf, cts[i]));
}
delete pf;
}
ParfactorList::iterator pflIt = pfList_.begin();
while (pflIt != pfList_.end()) {
(*pflIt)->simplifyGrounds();
++ pflIt;
}
}
2013-02-07 13:37:15 +00:00
std::vector<GroundOperator*>
2012-05-23 14:56:01 +01:00
GroundOperator::getValidOps (ParfactorList& pfList)
{
2013-02-07 13:37:15 +00:00
std::vector<GroundOperator*> validOps;
std::set<PrvGroup> allGroups;
2012-05-23 14:56:01 +01:00
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < formulas.size(); i++) {
2012-05-23 14:56:01 +01:00
if (Util::contains (allGroups, formulas[i].group()) == false) {
const LogVars& lvs = formulas[i].logVars();
2012-05-24 22:55:20 +01:00
for (size_t j = 0; j < lvs.size(); j++) {
2012-05-23 14:56:01 +01:00
if ((*it)->constr()->isSingleton (lvs[j]) == false) {
validOps.push_back (new GroundOperator (
formulas[i].group(), j, pfList));
}
}
allGroups.insert (formulas[i].group());
}
}
++ it;
}
return validOps;
}
2013-02-07 13:37:15 +00:00
std::string
2012-05-23 14:56:01 +01:00
GroundOperator::toString (void)
{
2013-02-07 13:37:15 +00:00
std::stringstream ss;
std::vector<ParfactorList::iterator> pfIters;
2012-05-23 14:56:01 +01:00
pfIters = getParfactorsWithGroup (pfList_, group_);
Parfactor* pf = *(getParfactorsWithGroup (pfList_, group_).front());
2012-05-24 22:55:20 +01:00
size_t idx = pf->indexOfGroup (group_);
2012-05-23 14:56:01 +01:00
ProbFormula f = pf->argument (idx);
LogVar lv = f.logVars()[lvIndex_];
TupleSet tupleSet = pf->constr()->tupleSet ({lv});
2013-02-07 13:37:15 +00:00
std::string pos = "th";
2012-05-23 14:56:01 +01:00
if (lvIndex_ == 0) {
pos = "st" ;
} else if (lvIndex_ == 1) {
pos = "nd" ;
} else if (lvIndex_ == 2) {
pos = "rd" ;
}
ss << "grounding " << lvIndex_ + 1 << pos << " log var in " ;
ss << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")";
2013-02-07 13:37:15 +00:00
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
2012-05-23 14:56:01 +01:00
return ss.str();
}
2013-02-07 13:37:15 +00:00
std::vector<std::pair<PrvGroup, unsigned>>
2012-05-23 14:56:01 +01:00
GroundOperator::getAffectedFormulas (void)
{
2013-02-07 13:37:15 +00:00
std::vector<std::pair<PrvGroup, unsigned>> affectedFormulas;
affectedFormulas.push_back (std::make_pair (group_, lvIndex_));
std::queue<std::pair<PrvGroup, unsigned>> q;
q.push (std::make_pair (group_, lvIndex_));
2012-05-23 14:56:01 +01:00
while (q.empty() == false) {
2013-02-07 13:37:15 +00:00
std::pair<PrvGroup, unsigned> front = q.front();
2012-05-23 14:56:01 +01:00
ParfactorList::iterator pflIt = pfList_.begin();
while (pflIt != pfList_.end()) {
2012-05-24 22:55:20 +01:00
size_t idx = (*pflIt)->indexOfGroup (front.first);
if (idx != (*pflIt)->nrArguments()) {
2012-05-23 14:56:01 +01:00
ProbFormula f = (*pflIt)->argument (idx);
LogVar X = f.logVars()[front.second];
const ProbFormulas& fs = (*pflIt)->arguments();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < fs.size(); i++) {
2012-06-06 15:04:28 +01:00
if (i != idx && fs[i].contains (X)) {
2013-02-07 13:37:15 +00:00
std::pair<PrvGroup, unsigned> pair = std::make_pair (
2012-05-23 14:56:01 +01:00
fs[i].group(), fs[i].indexOf (X));
if (Util::contains (affectedFormulas, pair) == false) {
q.push (pair);
affectedFormulas.push_back (pair);
}
}
}
}
++ pflIt;
}
q.pop();
}
return affectedFormulas;
}
Params
LiftedVe::solveQuery (const Grounds& query)
2012-05-23 14:56:01 +01:00
{
2012-05-31 23:06:53 +01:00
assert (query.empty() == false);
pfList_ = parfactorList;
2012-05-23 14:56:01 +01:00
runSolver (query);
(*pfList_.begin())->normalize();
Params params = (*pfList_.begin())->params();
if (Globals::logDomain) {
2012-05-24 16:14:13 +01:00
Util::exp (params);
2012-05-23 14:56:01 +01:00
}
return params;
}
void
LiftedVe::printSolverFlags (void) const
2012-05-23 14:56:01 +01:00
{
2013-02-07 13:37:15 +00:00
std::stringstream ss;
ss << "lve [" ;
2012-05-23 14:56:01 +01:00
ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
2013-02-07 13:37:15 +00:00
std::cout << ss.str() << std::endl;
2012-05-23 14:56:01 +01:00
}
void
LiftedVe::runSolver (const Grounds& query)
2012-05-23 14:56:01 +01:00
{
largestCost_ = std::log (0);
LiftedOperations::shatterAgainstQuery (pfList_, query);
2012-11-10 00:18:20 +00:00
LiftedOperations::runWeakBayesBall (pfList_, query);
2012-05-23 14:56:01 +01:00
while (true) {
if (Globals::verbosity > 2) {
Util::printDashedLine();
pfList_.print();
if (Globals::verbosity > 3) {
LiftedOperator::printValidOps (pfList_, query);
}
}
LiftedOperator* op = getBestOperation (query);
if (op == 0) {
break;
}
if (Globals::verbosity > 1) {
2013-02-07 13:37:15 +00:00
std::cout << "best operation: " << op->toString();
2012-05-23 14:56:01 +01:00
if (Globals::verbosity > 2) {
2013-02-07 13:37:15 +00:00
std::cout << std::endl;
2012-05-23 14:56:01 +01:00
}
}
op->apply();
delete op;
}
assert (pfList_.size() > 0);
if (pfList_.size() > 1) {
ParfactorList::iterator pfIter = pfList_.begin();
2012-05-28 14:12:18 +01:00
++ pfIter;
2012-05-23 14:56:01 +01:00
while (pfIter != pfList_.end()) {
(*pfList_.begin())->multiply (**pfIter);
++ pfIter;
}
}
if (Globals::verbosity > 0) {
2013-02-07 13:37:15 +00:00
std::cout << "largest cost = " << std::exp (largestCost_);
std::cout << std::endl;
std::cout << std::endl;
2012-05-23 14:56:01 +01:00
}
(*pfList_.begin())->simplifyGrounds();
(*pfList_.begin())->reorderAccordingGrounds (query);
}
LiftedOperator*
LiftedVe::getBestOperation (const Grounds& query)
2012-05-23 14:56:01 +01:00
{
double bestCost = 0.0;
LiftedOperator* bestOp = 0;
2013-02-07 13:37:15 +00:00
std::vector<LiftedOperator*> validOps;
2012-05-23 14:56:01 +01:00
validOps = LiftedOperator::getValidOps (pfList_, query);
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < validOps.size(); i++) {
2012-05-23 14:56:01 +01:00
double cost = validOps[i]->getLogCost();
2012-12-27 12:54:58 +00:00
if (!bestOp || cost < bestCost) {
2012-05-23 14:56:01 +01:00
bestOp = validOps[i];
bestCost = cost;
2012-12-20 23:19:10 +00:00
}
2012-05-23 14:56:01 +01:00
}
if (bestCost > largestCost_) {
largestCost_ = bestCost;
}
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < validOps.size(); i++) {
2012-05-23 14:56:01 +01:00
if (validOps[i] != bestOp) {
delete validOps[i];
}
}
return bestOp;
}