2012-03-22 11:33:24 +00:00
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
#include <set>
|
|
|
|
|
|
|
|
#include "FoveSolver.h"
|
|
|
|
#include "Histogram.h"
|
|
|
|
#include "Util.h"
|
|
|
|
|
|
|
|
|
|
|
|
vector<LiftedOperator*>
|
2012-03-31 23:27:37 +01:00
|
|
|
LiftedOperator::getValidOps (
|
|
|
|
ParfactorList& pfList,
|
|
|
|
const Grounds& query)
|
2012-03-22 11:33:24 +00:00
|
|
|
{
|
|
|
|
vector<LiftedOperator*> validOps;
|
|
|
|
vector<SumOutOperator*> sumOutOps;
|
|
|
|
vector<CountingOperator*> countOps;
|
|
|
|
vector<GroundOperator*> groundOps;
|
|
|
|
|
|
|
|
sumOutOps = SumOutOperator::getValidOps (pfList, query);
|
|
|
|
countOps = CountingOperator::getValidOps (pfList);
|
|
|
|
groundOps = GroundOperator::getValidOps (pfList);
|
|
|
|
|
|
|
|
validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end());
|
|
|
|
validOps.insert (validOps.end(), countOps.begin(), countOps.end());
|
|
|
|
validOps.insert (validOps.end(), groundOps.begin(), groundOps.end());
|
|
|
|
return validOps;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
2012-03-31 23:27:37 +01:00
|
|
|
LiftedOperator::printValidOps (
|
|
|
|
ParfactorList& pfList,
|
|
|
|
const Grounds& query)
|
2012-03-22 11:33:24 +00:00
|
|
|
{
|
|
|
|
vector<LiftedOperator*> validOps;
|
|
|
|
validOps = LiftedOperator::getValidOps (pfList, query);
|
|
|
|
for (unsigned i = 0; i < validOps.size(); i++) {
|
2012-04-19 17:59:45 +01:00
|
|
|
cout << "-> " << validOps[i]->toString();
|
2012-03-31 23:27:37 +01:00
|
|
|
delete validOps[i];
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-04-19 17:59:45 +01:00
|
|
|
vector<unsigned>
|
|
|
|
LiftedOperator::getAllGroupss (ParfactorList& )
|
|
|
|
{
|
|
|
|
return { };
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vector<ParfactorList::iterator>
|
|
|
|
LiftedOperator::getParfactorsWithGroup (
|
|
|
|
ParfactorList& pfList, unsigned group)
|
|
|
|
{
|
|
|
|
vector<ParfactorList::iterator> iters;
|
|
|
|
ParfactorList::iterator pflIt = pfList.begin();
|
|
|
|
while (pflIt != pfList.end()) {
|
|
|
|
if ((*pflIt)->containsGroup (group)) {
|
|
|
|
iters.push_back (pflIt);
|
|
|
|
}
|
|
|
|
++ pflIt;
|
|
|
|
}
|
|
|
|
return iters;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-04-18 16:40:12 +01:00
|
|
|
double
|
|
|
|
SumOutOperator::getLogCost (void)
|
2012-03-22 11:33:24 +00:00
|
|
|
{
|
|
|
|
TinySet<unsigned> groupSet;
|
|
|
|
ParfactorList::const_iterator pfIter = pfList_.begin();
|
2012-04-19 17:59:45 +01:00
|
|
|
unsigned nrProdFactors = 0;
|
2012-03-22 11:33:24 +00:00
|
|
|
while (pfIter != pfList_.end()) {
|
|
|
|
if ((*pfIter)->containsGroup (group_)) {
|
|
|
|
vector<unsigned> groups = (*pfIter)->getAllGroups();
|
|
|
|
groupSet |= TinySet<unsigned> (groups);
|
2012-04-19 17:59:45 +01:00
|
|
|
++ nrProdFactors;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
++ pfIter;
|
|
|
|
}
|
2012-04-19 17:59:45 +01:00
|
|
|
if (nrProdFactors == 1) {
|
2012-05-03 00:56:19 +01:00
|
|
|
return std::log (0.0); // best possible case
|
2012-04-19 17:59:45 +01:00
|
|
|
}
|
2012-04-18 16:40:12 +01:00
|
|
|
double cost = 1.0;
|
2012-03-22 11:33:24 +00:00
|
|
|
for (unsigned i = 0; i < groupSet.size(); i++) {
|
|
|
|
pfIter = pfList_.begin();
|
|
|
|
while (pfIter != pfList_.end()) {
|
|
|
|
if ((*pfIter)->containsGroup (groupSet[i])) {
|
2012-03-31 23:27:37 +01:00
|
|
|
int idx = (*pfIter)->indexOfGroup (groupSet[i]);
|
2012-03-22 11:33:24 +00:00
|
|
|
cost *= (*pfIter)->range (idx);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
++ pfIter;
|
|
|
|
}
|
|
|
|
}
|
2012-04-18 16:40:12 +01:00
|
|
|
return std::log (cost);
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
SumOutOperator::apply (void)
|
|
|
|
{
|
2012-04-19 17:59:45 +01:00
|
|
|
vector<ParfactorList::iterator> iters;
|
|
|
|
iters = getParfactorsWithGroup (pfList_, group_);
|
2012-03-22 11:33:24 +00:00
|
|
|
Parfactor* product = *(iters[0]);
|
|
|
|
pfList_.remove (iters[0]);
|
|
|
|
for (unsigned i = 1; i < iters.size(); i++) {
|
|
|
|
product->multiply (**(iters[i]));
|
2012-03-31 23:27:37 +01:00
|
|
|
pfList_.removeAndDelete (iters[i]);
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
if (product->nrArguments() == 1) {
|
2012-03-22 11:33:24 +00:00
|
|
|
delete product;
|
|
|
|
return;
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
int fIdx = product->indexOfGroup (group_);
|
2012-03-22 11:33:24 +00:00
|
|
|
LogVarSet excl = product->exclusiveLogVars (fIdx);
|
|
|
|
if (product->constr()->isCountNormalized (excl)) {
|
|
|
|
product->sumOut (fIdx);
|
|
|
|
pfList_.addShattered (product);
|
|
|
|
} else {
|
|
|
|
Parfactors pfs = FoveSolver::countNormalize (product, excl);
|
|
|
|
for (unsigned i = 0; i < pfs.size(); i++) {
|
|
|
|
pfs[i]->sumOut (fIdx);
|
|
|
|
pfList_.add (pfs[i]);
|
|
|
|
}
|
|
|
|
delete product;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vector<SumOutOperator*>
|
2012-03-31 23:27:37 +01:00
|
|
|
SumOutOperator::getValidOps (
|
|
|
|
ParfactorList& pfList,
|
|
|
|
const Grounds& query)
|
2012-03-22 11:33:24 +00:00
|
|
|
{
|
|
|
|
vector<SumOutOperator*> validOps;
|
|
|
|
set<unsigned> allGroups;
|
|
|
|
ParfactorList::const_iterator it = pfList.begin();
|
|
|
|
while (it != pfList.end()) {
|
2012-03-31 23:27:37 +01:00
|
|
|
const ProbFormulas& formulas = (*it)->arguments();
|
2012-03-22 11:33:24 +00:00
|
|
|
for (unsigned i = 0; i < formulas.size(); i++) {
|
|
|
|
allGroups.insert (formulas[i].group());
|
|
|
|
}
|
|
|
|
++ it;
|
|
|
|
}
|
|
|
|
set<unsigned>::const_iterator groupIt = allGroups.begin();
|
|
|
|
while (groupIt != allGroups.end()) {
|
|
|
|
if (validOp (*groupIt, pfList, query)) {
|
|
|
|
validOps.push_back (new SumOutOperator (*groupIt, pfList));
|
|
|
|
}
|
|
|
|
++ groupIt;
|
|
|
|
}
|
|
|
|
return validOps;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
string
|
|
|
|
SumOutOperator::toString (void)
|
|
|
|
{
|
|
|
|
stringstream ss;
|
|
|
|
vector<ParfactorList::iterator> pfIters;
|
2012-04-19 17:59:45 +01:00
|
|
|
pfIters = getParfactorsWithGroup (pfList_, group_);
|
2012-03-31 23:27:37 +01:00
|
|
|
int idx = (*pfIters[0])->indexOfGroup (group_);
|
|
|
|
ProbFormula f = (*pfIters[0])->argument (idx);
|
2012-03-22 11:33:24 +00:00
|
|
|
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
|
|
|
ss << "sum out " << f.functor() << "/" << f.arity();
|
|
|
|
ss << "|" << tupleSet << " (group " << group_ << ")";
|
2012-04-19 17:59:45 +01:00
|
|
|
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
2012-03-22 11:33:24 +00:00
|
|
|
return ss.str();
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool
|
|
|
|
SumOutOperator::validOp (
|
|
|
|
unsigned group,
|
|
|
|
ParfactorList& pfList,
|
|
|
|
const Grounds& query)
|
|
|
|
{
|
|
|
|
vector<ParfactorList::iterator> pfIters;
|
2012-04-19 17:59:45 +01:00
|
|
|
pfIters = getParfactorsWithGroup (pfList, group);
|
2012-03-22 11:33:24 +00:00
|
|
|
if (isToEliminate (*pfIters[0], group, query) == false) {
|
|
|
|
return false;
|
|
|
|
}
|
2012-05-05 23:11:32 +01:00
|
|
|
|
2012-03-22 11:33:24 +00:00
|
|
|
unordered_map<unsigned, unsigned> groupToRange;
|
|
|
|
for (unsigned i = 0; i < pfIters.size(); i++) {
|
2012-05-05 23:11:32 +01:00
|
|
|
const ProbFormulas& formulas = (*pfIters[i])->arguments();
|
|
|
|
int fIdx = -1;
|
|
|
|
for (unsigned j = 0; j < formulas.size(); j++) {
|
|
|
|
if (formulas[j].group() == group) {
|
|
|
|
if (fIdx != -1) {
|
|
|
|
// only summout a group of rand vars if they don't
|
|
|
|
// appear in another position on the factor
|
|
|
|
return false;
|
|
|
|
} else {
|
|
|
|
fIdx = j;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
if ((*pfIters[i])->argument (fIdx).contains (
|
|
|
|
(*pfIters[i])->elimLogVars()) == false) {
|
2012-03-22 11:33:24 +00:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
vector<unsigned> ranges = (*pfIters[i])->ranges();
|
|
|
|
vector<unsigned> groups = (*pfIters[i])->getAllGroups();
|
|
|
|
for (unsigned i = 0; i < groups.size(); i++) {
|
|
|
|
unordered_map<unsigned, unsigned>::iterator it;
|
|
|
|
it = groupToRange.find (groups[i]);
|
|
|
|
if (it == groupToRange.end()) {
|
|
|
|
groupToRange.insert (make_pair (groups[i], ranges[i]));
|
|
|
|
} else {
|
|
|
|
if (it->second != ranges[i]) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool
|
|
|
|
SumOutOperator::isToEliminate (
|
|
|
|
Parfactor* g,
|
|
|
|
unsigned group,
|
|
|
|
const Grounds& query)
|
|
|
|
{
|
2012-03-31 23:27:37 +01:00
|
|
|
int fIdx = g->indexOfGroup (group);
|
|
|
|
const ProbFormula& formula = g->argument (fIdx);
|
2012-03-22 11:33:24 +00:00
|
|
|
bool toElim = true;
|
|
|
|
for (unsigned i = 0; i < query.size(); i++) {
|
|
|
|
if (formula.functor() == query[i].functor() &&
|
|
|
|
formula.arity() == query[i].arity()) {
|
|
|
|
g->constr()->moveToTop (formula.logVars());
|
|
|
|
if (g->constr()->containsTuple (query[i].args())) {
|
|
|
|
toElim = false;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return toElim;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-04-18 16:40:12 +01:00
|
|
|
double
|
|
|
|
CountingOperator::getLogCost (void)
|
2012-03-22 11:33:24 +00:00
|
|
|
{
|
2012-04-18 16:40:12 +01:00
|
|
|
double cost = 0.0;
|
2012-03-31 23:27:37 +01:00
|
|
|
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
2012-03-22 11:33:24 +00:00
|
|
|
unsigned range = (*pfIter_)->range (fIdx);
|
|
|
|
unsigned size = (*pfIter_)->size() / range;
|
|
|
|
TinySet<unsigned> counts;
|
|
|
|
counts = (*pfIter_)->constr()->getConditionalCounts (X_);
|
|
|
|
for (unsigned i = 0; i < counts.size(); i++) {
|
|
|
|
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
|
|
|
}
|
2012-04-19 18:37:15 +01:00
|
|
|
if ((*pfIter_)->nrArguments() == 1) {
|
|
|
|
cost *= 3; // avoid counting conversion in the beginning
|
|
|
|
}
|
2012-04-18 16:40:12 +01:00
|
|
|
return std::log (cost);
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
CountingOperator::apply (void)
|
|
|
|
{
|
|
|
|
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
|
|
|
|
(*pfIter_)->countConvert (X_);
|
|
|
|
} else {
|
2012-03-31 23:27:37 +01:00
|
|
|
Parfactor* pf = *pfIter_;
|
|
|
|
pfList_.remove (pfIter_);
|
|
|
|
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
|
2012-03-22 11:33:24 +00:00
|
|
|
for (unsigned i = 0; i < pfs.size(); i++) {
|
|
|
|
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
|
|
|
bool cartProduct = pfs[i]->constr()->isCarteesianProduct (
|
2012-03-31 23:27:37 +01:00
|
|
|
pfs[i]->countedLogVars() | X_);
|
2012-03-22 11:33:24 +00:00
|
|
|
if (condCount > 1 && cartProduct) {
|
|
|
|
pfs[i]->countConvert (X_);
|
|
|
|
}
|
|
|
|
pfList_.add (pfs[i]);
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
delete pf;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vector<CountingOperator*>
|
|
|
|
CountingOperator::getValidOps (ParfactorList& pfList)
|
|
|
|
{
|
|
|
|
vector<CountingOperator*> validOps;
|
|
|
|
ParfactorList::iterator it = pfList.begin();
|
|
|
|
while (it != pfList.end()) {
|
|
|
|
LogVarSet candidates = (*it)->uncountedLogVars();
|
|
|
|
for (unsigned i = 0; i < candidates.size(); i++) {
|
|
|
|
if (validOp (*it, candidates[i])) {
|
|
|
|
validOps.push_back (new CountingOperator (
|
|
|
|
it, candidates[i], pfList));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
++ it;
|
|
|
|
}
|
|
|
|
return validOps;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
string
|
|
|
|
CountingOperator::toString (void)
|
|
|
|
{
|
|
|
|
stringstream ss;
|
|
|
|
ss << "count convert " << X_ << " in " ;
|
2012-03-31 23:27:37 +01:00
|
|
|
ss << (*pfIter_)->getLabel();
|
2012-04-19 17:59:45 +01:00
|
|
|
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
2012-03-22 11:33:24 +00:00
|
|
|
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
|
|
|
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
|
|
|
for (unsigned i = 0; i < pfs.size(); i++) {
|
2012-03-31 23:27:37 +01:00
|
|
|
ss << " º " << pfs[i]->getLabel() << endl;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
for (unsigned i = 0; i < pfs.size(); i++) {
|
|
|
|
delete pfs[i];
|
|
|
|
}
|
2012-03-22 11:33:24 +00:00
|
|
|
return ss.str();
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool
|
|
|
|
CountingOperator::validOp (Parfactor* g, LogVar X)
|
|
|
|
{
|
|
|
|
if (g->nrFormulas (X) != 1) {
|
|
|
|
return false;
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
int fIdx = g->indexOfLogVar (X);
|
|
|
|
if (g->argument (fIdx).isCounting()) {
|
2012-03-22 11:33:24 +00:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
bool countNormalized = g->constr()->isCountNormalized (X);
|
|
|
|
if (countNormalized) {
|
|
|
|
unsigned condCount = g->constr()->getConditionalCount (X);
|
|
|
|
bool cartProduct = g->constr()->isCarteesianProduct (
|
|
|
|
g->countedLogVars() | X);
|
|
|
|
if (condCount == 1 || cartProduct == false) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-04-18 16:40:12 +01:00
|
|
|
double
|
|
|
|
GroundOperator::getLogCost (void)
|
2012-03-22 11:33:24 +00:00
|
|
|
{
|
2012-04-27 01:18:54 +01:00
|
|
|
vector<pair<unsigned, unsigned>> affectedFormulas;
|
|
|
|
affectedFormulas = getAffectedFormulas();
|
|
|
|
// cout << "affected formulas: " ;
|
|
|
|
// for (unsigned i = 0; i < affectedFormulas.size(); i++) {
|
|
|
|
// cout << affectedFormulas[i].first << ":" ;
|
|
|
|
// cout << affectedFormulas[i].second << " " ;
|
|
|
|
// }
|
|
|
|
// cout << "cost =" ;
|
|
|
|
double totalCost = std::log (0.0);
|
|
|
|
ParfactorList::iterator pflIt = pfList_.begin();
|
|
|
|
while (pflIt != pfList_.end()) {
|
|
|
|
Parfactor* pf = *pflIt;
|
|
|
|
double reps = 0.0;
|
|
|
|
double pfSize = std::log (pf->size());
|
|
|
|
bool willBeAffected = false;
|
|
|
|
LogVarSet lvsToGround;
|
|
|
|
for (unsigned i = 0; i < affectedFormulas.size(); i++) {
|
|
|
|
int fIdx = pf->indexOfGroup (affectedFormulas[i].first);
|
|
|
|
if (fIdx != -1) {
|
|
|
|
ProbFormula f = pf->argument (fIdx);
|
|
|
|
LogVar X = f.logVars()[affectedFormulas[i].second];
|
|
|
|
bool isCountingLv = pf->countedLogVars().contains (X);
|
|
|
|
if (isCountingLv) {
|
|
|
|
unsigned nrHists = pf->range (fIdx);
|
|
|
|
unsigned nrSymbols = pf->constr()->getConditionalCount (X);
|
|
|
|
unsigned range = pf->argument (fIdx).range();
|
|
|
|
double power = std::log (range) * nrSymbols;
|
|
|
|
pfSize = (pfSize - std::log (nrHists)) + power;
|
|
|
|
} else {
|
|
|
|
if (lvsToGround.contains (X) == false) {
|
|
|
|
reps += std::log (pf->constr()->nrSymbols (X));
|
|
|
|
lvsToGround.insert (X);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
willBeAffected = true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (willBeAffected) {
|
|
|
|
// cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
|
|
|
|
double pfCost = reps + pfSize;
|
|
|
|
totalCost = Util::logSum (totalCost, pfCost);
|
2012-04-19 17:59:45 +01:00
|
|
|
}
|
2012-04-27 01:18:54 +01:00
|
|
|
++ pflIt;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
2012-04-27 01:18:54 +01:00
|
|
|
// cout << endl;
|
|
|
|
return totalCost;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
GroundOperator::apply (void)
|
|
|
|
{
|
2012-04-19 17:59:45 +01:00
|
|
|
// TODO if we update the correct groups
|
|
|
|
// we can skip shattering
|
|
|
|
ParfactorList::iterator pfIter;
|
|
|
|
pfIter = getParfactorsWithGroup (pfList_, group_).front();
|
|
|
|
Parfactor* pf = *pfIter;
|
|
|
|
int idx = pf->indexOfGroup (group_);
|
|
|
|
ProbFormula f = pf->argument (idx);
|
|
|
|
LogVar X = f.logVars()[lvIndex_];
|
|
|
|
bool countedLv = pf->countedLogVars().contains (X);
|
|
|
|
pfList_.remove (pfIter);
|
2012-03-22 11:33:24 +00:00
|
|
|
if (countedLv) {
|
2012-04-19 17:59:45 +01:00
|
|
|
pf->fullExpand (X);
|
2012-03-31 23:27:37 +01:00
|
|
|
pfList_.add (pf);
|
2012-03-22 11:33:24 +00:00
|
|
|
} else {
|
2012-04-19 17:59:45 +01:00
|
|
|
ConstraintTrees cts = pf->constr()->ground (X);
|
2012-03-22 11:33:24 +00:00
|
|
|
for (unsigned i = 0; i < cts.size(); i++) {
|
2012-03-31 23:27:37 +01:00
|
|
|
pfList_.add (new Parfactor (pf, cts[i]));
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
delete pf;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vector<GroundOperator*>
|
|
|
|
GroundOperator::getValidOps (ParfactorList& pfList)
|
|
|
|
{
|
|
|
|
vector<GroundOperator*> validOps;
|
2012-04-19 17:59:45 +01:00
|
|
|
set<unsigned> allGroups;
|
|
|
|
ParfactorList::const_iterator it = pfList.begin();
|
|
|
|
while (it != pfList.end()) {
|
|
|
|
const ProbFormulas& formulas = (*it)->arguments();
|
|
|
|
for (unsigned i = 0; i < formulas.size(); i++) {
|
|
|
|
if (Util::contains (allGroups, formulas[i].group()) == false) {
|
|
|
|
const LogVars& lvs = formulas[i].logVars();
|
|
|
|
for (unsigned j = 0; j < lvs.size(); j++) {
|
|
|
|
if ((*it)->constr()->isSingleton (lvs[j]) == false) {
|
|
|
|
validOps.push_back (new GroundOperator (
|
|
|
|
formulas[i].group(), j, pfList));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
allGroups.insert (formulas[i].group());
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
}
|
2012-04-19 17:59:45 +01:00
|
|
|
++ it;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
return validOps;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
string
|
|
|
|
GroundOperator::toString (void)
|
|
|
|
{
|
|
|
|
stringstream ss;
|
2012-04-19 17:59:45 +01:00
|
|
|
vector<ParfactorList::iterator> pfIters;
|
|
|
|
pfIters = getParfactorsWithGroup (pfList_, group_);
|
|
|
|
Parfactor* pf = *(getParfactorsWithGroup (pfList_, group_).front());
|
|
|
|
int idx = pf->indexOfGroup (group_);
|
|
|
|
ProbFormula f = pf->argument (idx);
|
|
|
|
LogVar lv = f.logVars()[lvIndex_];
|
|
|
|
TupleSet tupleSet = pf->constr()->tupleSet ({lv});
|
|
|
|
string pos = "th";
|
|
|
|
if (lvIndex_ == 0) {
|
|
|
|
pos = "st" ;
|
|
|
|
} else if (lvIndex_ == 1) {
|
|
|
|
pos = "nd" ;
|
|
|
|
} else if (lvIndex_ == 2) {
|
|
|
|
pos = "rd" ;
|
|
|
|
}
|
|
|
|
ss << "grounding " << lvIndex_ + 1 << pos << " log var in " ;
|
|
|
|
ss << f.functor() << "/" << f.arity();
|
|
|
|
ss << "|" << tupleSet << " (group " << group_ << ")";
|
|
|
|
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
2012-03-22 11:33:24 +00:00
|
|
|
return ss.str();
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-04-27 01:18:54 +01:00
|
|
|
vector<pair<unsigned, unsigned>>
|
|
|
|
GroundOperator::getAffectedFormulas (void)
|
|
|
|
{
|
|
|
|
vector<pair<unsigned, unsigned>> affectedFormulas;
|
|
|
|
affectedFormulas.push_back (make_pair (group_, lvIndex_));
|
|
|
|
queue<pair<unsigned, unsigned>> q;
|
|
|
|
q.push (make_pair (group_, lvIndex_));
|
|
|
|
while (q.empty() == false) {
|
|
|
|
pair<unsigned, unsigned> front = q.front();
|
|
|
|
ParfactorList::iterator pflIt = pfList_.begin();
|
|
|
|
while (pflIt != pfList_.end()) {
|
|
|
|
int idx = (*pflIt)->indexOfGroup (front.first);
|
|
|
|
if (idx != -1) {
|
|
|
|
ProbFormula f = (*pflIt)->argument (idx);
|
|
|
|
LogVar X = f.logVars()[front.second];
|
|
|
|
const ProbFormulas& fs = (*pflIt)->arguments();
|
|
|
|
for (unsigned i = 0; i < fs.size(); i++) {
|
|
|
|
if ((int)i != idx && fs[i].contains (X)) {
|
|
|
|
pair<unsigned, unsigned> pair = make_pair (
|
|
|
|
fs[i].group(), fs[i].indexOf (X));
|
|
|
|
if (Util::contains (affectedFormulas, pair) == false) {
|
|
|
|
q.push (pair);
|
|
|
|
affectedFormulas.push_back (pair);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
++ pflIt;
|
|
|
|
}
|
|
|
|
q.pop();
|
|
|
|
}
|
|
|
|
return affectedFormulas;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-03-22 11:33:24 +00:00
|
|
|
Params
|
|
|
|
FoveSolver::getPosterioriOf (const Ground& query)
|
|
|
|
{
|
|
|
|
return getJointDistributionOf ({query});
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Params
|
|
|
|
FoveSolver::getJointDistributionOf (const Grounds& query)
|
|
|
|
{
|
|
|
|
runSolver (query);
|
|
|
|
(*pfList_.begin())->normalize();
|
|
|
|
Params params = (*pfList_.begin())->params();
|
|
|
|
if (Globals::logDomain) {
|
|
|
|
Util::fromLog (params);
|
|
|
|
}
|
|
|
|
return params;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-04-29 20:07:09 +01:00
|
|
|
void
|
|
|
|
FoveSolver::printSolverFlags (void) const
|
|
|
|
{
|
|
|
|
stringstream ss;
|
|
|
|
ss << "fove [" ;
|
|
|
|
ss << "log_domain=" << Util::toString (Globals::logDomain);
|
|
|
|
ss << "]" ;
|
|
|
|
cout << ss.str() << endl;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-03-22 11:33:24 +00:00
|
|
|
void
|
|
|
|
FoveSolver::absorveEvidence (
|
|
|
|
ParfactorList& pfList,
|
2012-03-31 23:27:37 +01:00
|
|
|
ObservedFormulas& obsFormulas)
|
2012-03-22 11:33:24 +00:00
|
|
|
{
|
2012-03-31 23:27:37 +01:00
|
|
|
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
|
|
|
Parfactors newPfs;
|
|
|
|
ParfactorList::iterator it = pfList.begin();
|
|
|
|
while (it != pfList.end()) {
|
|
|
|
Parfactor* pf = *it;
|
|
|
|
it = pfList.remove (it);
|
|
|
|
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
|
|
|
if (absorvedPfs.empty() == false) {
|
|
|
|
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
|
|
|
|
// just remove pf;
|
|
|
|
} else {
|
|
|
|
Util::addToVector (newPfs, absorvedPfs);
|
|
|
|
}
|
|
|
|
delete pf;
|
|
|
|
} else {
|
|
|
|
it = pfList.insertShattered (it, pf);
|
|
|
|
++ it;
|
|
|
|
}
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
pfList.add (newPfs);
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
2012-04-29 20:07:09 +01:00
|
|
|
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
2012-03-31 23:27:37 +01:00
|
|
|
Util::printAsteriskLine();
|
2012-03-22 11:33:24 +00:00
|
|
|
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
|
|
|
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
2012-03-31 23:27:37 +01:00
|
|
|
cout << " -> " << obsFormulas[i] << endl;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
Util::printAsteriskLine();
|
|
|
|
pfList.print();
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Parfactors
|
|
|
|
FoveSolver::countNormalize (
|
|
|
|
Parfactor* g,
|
|
|
|
const LogVarSet& set)
|
|
|
|
{
|
|
|
|
Parfactors normPfs;
|
2012-03-31 23:27:37 +01:00
|
|
|
if (set.empty()) {
|
|
|
|
normPfs.push_back (new Parfactor (*g));
|
|
|
|
} else {
|
|
|
|
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
|
|
|
for (unsigned i = 0; i < normCts.size(); i++) {
|
|
|
|
normPfs.push_back (new Parfactor (g, normCts[i]));
|
|
|
|
}
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
return normPfs;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
FoveSolver::runSolver (const Grounds& query)
|
|
|
|
{
|
2012-04-29 20:07:09 +01:00
|
|
|
largestCost_ = std::log (0);
|
2012-03-31 23:27:37 +01:00
|
|
|
shatterAgainstQuery (query);
|
|
|
|
runWeakBayesBall (query);
|
2012-03-22 11:33:24 +00:00
|
|
|
while (true) {
|
2012-04-29 20:07:09 +01:00
|
|
|
if (Globals::verbosity > 2) {
|
2012-03-31 23:27:37 +01:00
|
|
|
Util::printDashedLine();
|
|
|
|
pfList_.print();
|
2012-04-29 20:07:09 +01:00
|
|
|
if (Globals::verbosity > 3) {
|
|
|
|
LiftedOperator::printValidOps (pfList_, query);
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
}
|
2012-03-22 11:33:24 +00:00
|
|
|
LiftedOperator* op = getBestOperation (query);
|
|
|
|
if (op == 0) {
|
|
|
|
break;
|
|
|
|
}
|
2012-04-29 20:07:09 +01:00
|
|
|
if (Globals::verbosity > 1) {
|
|
|
|
cout << "best operation: " << op->toString();
|
|
|
|
if (Globals::verbosity > 2) {
|
|
|
|
cout << endl;
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
}
|
2012-03-22 11:33:24 +00:00
|
|
|
op->apply();
|
2012-03-31 23:27:37 +01:00
|
|
|
delete op;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
assert (pfList_.size() > 0);
|
2012-03-22 11:33:24 +00:00
|
|
|
if (pfList_.size() > 1) {
|
|
|
|
ParfactorList::iterator pfIter = pfList_.begin();
|
|
|
|
pfIter ++;
|
|
|
|
while (pfIter != pfList_.end()) {
|
|
|
|
(*pfList_.begin())->multiply (**pfIter);
|
|
|
|
++ pfIter;
|
|
|
|
}
|
|
|
|
}
|
2012-04-29 20:07:09 +01:00
|
|
|
if (Globals::verbosity > 0) {
|
|
|
|
cout << "largest cost = " << std::exp (largestCost_) << endl;
|
|
|
|
cout << endl;
|
|
|
|
}
|
2012-03-22 11:33:24 +00:00
|
|
|
(*pfList_.begin())->reorderAccordingGrounds (query);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LiftedOperator*
|
|
|
|
FoveSolver::getBestOperation (const Grounds& query)
|
|
|
|
{
|
2012-04-18 16:40:12 +01:00
|
|
|
double bestCost = 0.0;
|
2012-03-22 11:33:24 +00:00
|
|
|
LiftedOperator* bestOp = 0;
|
|
|
|
vector<LiftedOperator*> validOps;
|
|
|
|
validOps = LiftedOperator::getValidOps (pfList_, query);
|
|
|
|
for (unsigned i = 0; i < validOps.size(); i++) {
|
2012-04-18 16:40:12 +01:00
|
|
|
double cost = validOps[i]->getLogCost();
|
2012-03-22 11:33:24 +00:00
|
|
|
if ((bestOp == 0) || (cost < bestCost)) {
|
|
|
|
bestOp = validOps[i];
|
|
|
|
bestCost = cost;
|
|
|
|
}
|
|
|
|
}
|
2012-04-29 20:07:09 +01:00
|
|
|
if (bestCost > largestCost_) {
|
|
|
|
largestCost_ = bestCost;
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
for (unsigned i = 0; i < validOps.size(); i++) {
|
|
|
|
if (validOps[i] != bestOp) {
|
|
|
|
delete validOps[i];
|
|
|
|
}
|
|
|
|
}
|
2012-03-22 11:33:24 +00:00
|
|
|
return bestOp;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-03-31 23:27:37 +01:00
|
|
|
void
|
|
|
|
FoveSolver::runWeakBayesBall (const Grounds& query)
|
|
|
|
{
|
|
|
|
queue<unsigned> todo; // groups to process
|
|
|
|
set<unsigned> done; // processed or in queue
|
|
|
|
for (unsigned i = 0; i < query.size(); i++) {
|
|
|
|
ParfactorList::iterator it = pfList_.begin();
|
|
|
|
while (it != pfList_.end()) {
|
2012-04-03 11:58:21 +01:00
|
|
|
int group = (*it)->findGroup (query[i]);
|
2012-03-31 23:27:37 +01:00
|
|
|
if (group != -1) {
|
|
|
|
todo.push (group);
|
|
|
|
done.insert (group);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
++ it;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
set<Parfactor*> requiredPfs;
|
|
|
|
while (todo.empty() == false) {
|
|
|
|
unsigned group = todo.front();
|
|
|
|
ParfactorList::iterator it = pfList_.begin();
|
|
|
|
while (it != pfList_.end()) {
|
|
|
|
if (Util::contains (requiredPfs, *it) == false &&
|
|
|
|
(*it)->containsGroup (group)) {
|
|
|
|
vector<unsigned> groups = (*it)->getAllGroups();
|
|
|
|
for (unsigned i = 0; i < groups.size(); i++) {
|
|
|
|
if (Util::contains (done, groups[i]) == false) {
|
|
|
|
todo.push (groups[i]);
|
|
|
|
done.insert (groups[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
requiredPfs.insert (*it);
|
|
|
|
}
|
|
|
|
++ it;
|
|
|
|
}
|
|
|
|
todo.pop();
|
|
|
|
}
|
|
|
|
|
|
|
|
ParfactorList::iterator it = pfList_.begin();
|
2012-04-29 20:07:09 +01:00
|
|
|
bool foundNotRequired = false;
|
2012-03-31 23:27:37 +01:00
|
|
|
while (it != pfList_.end()) {
|
|
|
|
if (Util::contains (requiredPfs, *it) == false) {
|
2012-04-29 20:07:09 +01:00
|
|
|
if (Globals::verbosity > 2) {
|
|
|
|
if (foundNotRequired == false) {
|
|
|
|
Util::printHeader ("PARFACTORS TO DISCARD");
|
|
|
|
foundNotRequired = true;
|
|
|
|
}
|
|
|
|
(*it)->print();
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
it = pfList_.removeAndDelete (it);
|
|
|
|
} else {
|
|
|
|
++ it;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-03-22 11:33:24 +00:00
|
|
|
void
|
|
|
|
FoveSolver::shatterAgainstQuery (const Grounds& query)
|
|
|
|
{
|
|
|
|
for (unsigned i = 0; i < query.size(); i++) {
|
|
|
|
if (query[i].isAtom()) {
|
|
|
|
continue;
|
|
|
|
}
|
2012-04-11 15:36:50 +01:00
|
|
|
bool found = false;
|
2012-03-31 23:27:37 +01:00
|
|
|
Parfactors newPfs;
|
|
|
|
ParfactorList::iterator it = pfList_.begin();
|
|
|
|
while (it != pfList_.end()) {
|
|
|
|
if ((*it)->containsGround (query[i])) {
|
2012-04-11 15:36:50 +01:00
|
|
|
found = true;
|
2012-03-22 11:33:24 +00:00
|
|
|
std::pair<ConstraintTree*, ConstraintTree*> split =
|
2012-03-31 23:27:37 +01:00
|
|
|
(*it)->constr()->split (query[i].args(), query[i].arity());
|
2012-03-22 11:33:24 +00:00
|
|
|
ConstraintTree* commCt = split.first;
|
|
|
|
ConstraintTree* exclCt = split.second;
|
2012-03-31 23:27:37 +01:00
|
|
|
newPfs.push_back (new Parfactor (*it, commCt));
|
2012-03-22 11:33:24 +00:00
|
|
|
if (exclCt->empty() == false) {
|
2012-03-31 23:27:37 +01:00
|
|
|
newPfs.push_back (new Parfactor (*it, exclCt));
|
2012-03-22 11:33:24 +00:00
|
|
|
} else {
|
|
|
|
delete exclCt;
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
it = pfList_.removeAndDelete (it);
|
2012-03-22 11:33:24 +00:00
|
|
|
} else {
|
2012-03-31 23:27:37 +01:00
|
|
|
++ it;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
}
|
2012-04-11 15:36:50 +01:00
|
|
|
if (found == false) {
|
|
|
|
cerr << "error: could not find a parfactor with ground " ;
|
|
|
|
cerr << "`" << query[i] << "'" << endl;
|
|
|
|
exit (0);
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
pfList_.add (newPfs);
|
|
|
|
}
|
2012-04-29 20:07:09 +01:00
|
|
|
if (Globals::verbosity > 2) {
|
2012-03-31 23:27:37 +01:00
|
|
|
Util::printAsteriskLine();
|
|
|
|
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
|
|
|
for (unsigned i = 0; i < query.size(); i++) {
|
|
|
|
cout << " -> " << query[i] << endl;
|
|
|
|
}
|
|
|
|
Util::printAsteriskLine();
|
|
|
|
pfList_.print();
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-03-31 23:27:37 +01:00
|
|
|
Parfactors
|
|
|
|
FoveSolver::absorve (
|
|
|
|
ObservedFormula& obsFormula,
|
|
|
|
Parfactor* g)
|
2012-03-22 11:33:24 +00:00
|
|
|
{
|
|
|
|
Parfactors absorvedPfs;
|
2012-03-31 23:27:37 +01:00
|
|
|
const ProbFormulas& formulas = g->arguments();
|
2012-03-22 11:33:24 +00:00
|
|
|
for (unsigned i = 0; i < formulas.size(); i++) {
|
2012-03-31 23:27:37 +01:00
|
|
|
if (obsFormula.functor() == formulas[i].functor() &&
|
|
|
|
obsFormula.arity() == formulas[i].arity()) {
|
2012-03-22 11:33:24 +00:00
|
|
|
|
2012-03-31 23:27:37 +01:00
|
|
|
if (obsFormula.isAtom()) {
|
2012-03-22 11:33:24 +00:00
|
|
|
if (formulas.size() > 1) {
|
2012-03-31 23:27:37 +01:00
|
|
|
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
2012-03-22 11:33:24 +00:00
|
|
|
} else {
|
2012-03-31 23:27:37 +01:00
|
|
|
// hack to erase parfactor g
|
|
|
|
absorvedPfs.push_back (0);
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
break;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
g->constr()->moveToTop (formulas[i].logVars());
|
|
|
|
std::pair<ConstraintTree*, ConstraintTree*> res
|
2012-03-31 23:27:37 +01:00
|
|
|
= g->constr()->split (&(obsFormula.constr()), formulas[i].arity());
|
2012-03-22 11:33:24 +00:00
|
|
|
ConstraintTree* commCt = res.first;
|
|
|
|
ConstraintTree* exclCt = res.second;
|
|
|
|
|
2012-03-31 23:27:37 +01:00
|
|
|
if (commCt->empty() == false) {
|
|
|
|
if (formulas.size() > 1) {
|
|
|
|
LogVarSet excl = g->exclusiveLogVars (i);
|
|
|
|
Parfactors countNormPfs = countNormalize (g, excl);
|
|
|
|
for (unsigned j = 0; j < countNormPfs.size(); j++) {
|
|
|
|
countNormPfs[j]->absorveEvidence (
|
|
|
|
formulas[i], obsFormula.evidence());
|
|
|
|
absorvedPfs.push_back (countNormPfs[j]);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
delete commCt;
|
|
|
|
}
|
|
|
|
if (exclCt->empty() == false) {
|
|
|
|
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
|
|
|
} else {
|
|
|
|
delete exclCt;
|
|
|
|
}
|
|
|
|
if (absorvedPfs.empty()) {
|
|
|
|
// hack to erase parfactor g
|
|
|
|
absorvedPfs.push_back (0);
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
break;
|
2012-03-22 11:33:24 +00:00
|
|
|
} else {
|
|
|
|
delete commCt;
|
2012-03-31 23:27:37 +01:00
|
|
|
delete exclCt;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2012-03-31 23:27:37 +01:00
|
|
|
return absorvedPfs;
|
2012-03-22 11:33:24 +00:00
|
|
|
}
|
|
|
|
|