276 lines
7.4 KiB
C++
276 lines
7.4 KiB
C++
#include <vector>
|
|
#include <queue>
|
|
#include <iostream>
|
|
|
|
#include "LiftedOperations.h"
|
|
|
|
|
|
void
|
|
LiftedOperations::shatterAgainstQuery (
|
|
ParfactorList& pfList,
|
|
const Grounds& query)
|
|
{
|
|
for (size_t i = 0; i < query.size(); i++) {
|
|
if (query[i].isAtom()) {
|
|
continue;
|
|
}
|
|
bool found = false;
|
|
Parfactors newPfs;
|
|
ParfactorList::iterator it = pfList.begin();
|
|
while (it != pfList.end()) {
|
|
if ((*it)->containsGround (query[i])) {
|
|
found = true;
|
|
std::pair<ConstraintTree*, ConstraintTree*> split;
|
|
LogVars queryLvs (
|
|
(*it)->constr()->logVars().begin(),
|
|
(*it)->constr()->logVars().begin() + query[i].arity());
|
|
split = (*it)->constr()->split (query[i].args());
|
|
ConstraintTree* commCt = split.first;
|
|
ConstraintTree* exclCt = split.second;
|
|
newPfs.push_back (new Parfactor (*it, commCt));
|
|
if (exclCt->empty() == false) {
|
|
newPfs.push_back (new Parfactor (*it, exclCt));
|
|
} else {
|
|
delete exclCt;
|
|
}
|
|
it = pfList.removeAndDelete (it);
|
|
} else {
|
|
++ it;
|
|
}
|
|
}
|
|
if (found == false) {
|
|
std::cerr << "Error: could not find a parfactor with ground " ;
|
|
std::cerr << "`" << query[i] << "'." << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
pfList.add (newPfs);
|
|
}
|
|
if (Globals::verbosity > 2) {
|
|
Util::printAsteriskLine();
|
|
std::cout << "SHATTERED AGAINST THE QUERY" << std::endl;
|
|
for (size_t i = 0; i < query.size(); i++) {
|
|
std::cout << " -> " << query[i] << std::endl;
|
|
}
|
|
Util::printAsteriskLine();
|
|
pfList.print();
|
|
}
|
|
}
|
|
|
|
|
|
|
|
void
|
|
LiftedOperations::runWeakBayesBall (
|
|
ParfactorList& pfList,
|
|
const Grounds& query)
|
|
{
|
|
std::queue<PrvGroup> todo; // groups to process
|
|
std::set<PrvGroup> done; // processed or in queue
|
|
for (size_t i = 0; i < query.size(); i++) {
|
|
ParfactorList::iterator it = pfList.begin();
|
|
while (it != pfList.end()) {
|
|
PrvGroup group = (*it)->findGroup (query[i]);
|
|
if (group != std::numeric_limits<PrvGroup>::max()) {
|
|
todo.push (group);
|
|
done.insert (group);
|
|
break;
|
|
}
|
|
++ it;
|
|
}
|
|
}
|
|
|
|
std::set<Parfactor*> requiredPfs;
|
|
while (todo.empty() == false) {
|
|
PrvGroup group = todo.front();
|
|
ParfactorList::iterator it = pfList.begin();
|
|
while (it != pfList.end()) {
|
|
if (Util::contains (requiredPfs, *it) == false &&
|
|
(*it)->containsGroup (group)) {
|
|
std::vector<PrvGroup> groups = (*it)->getAllGroups();
|
|
for (size_t i = 0; i < groups.size(); i++) {
|
|
if (Util::contains (done, groups[i]) == false) {
|
|
todo.push (groups[i]);
|
|
done.insert (groups[i]);
|
|
}
|
|
}
|
|
requiredPfs.insert (*it);
|
|
}
|
|
++ it;
|
|
}
|
|
todo.pop();
|
|
}
|
|
|
|
ParfactorList::iterator it = pfList.begin();
|
|
bool foundNotRequired = false;
|
|
while (it != pfList.end()) {
|
|
if (Util::contains (requiredPfs, *it) == false) {
|
|
if (Globals::verbosity > 2) {
|
|
if (foundNotRequired == false) {
|
|
Util::printHeader ("PARFACTORS TO DISCARD");
|
|
foundNotRequired = true;
|
|
}
|
|
(*it)->print();
|
|
}
|
|
it = pfList.removeAndDelete (it);
|
|
} else {
|
|
++ it;
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
void
|
|
LiftedOperations::absorveEvidence (
|
|
ParfactorList& pfList,
|
|
ObservedFormulas& obsFormulas)
|
|
{
|
|
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
|
Parfactors newPfs;
|
|
ParfactorList::iterator it = pfList.begin();
|
|
while (it != pfList.end()) {
|
|
Parfactor* pf = *it;
|
|
it = pfList.remove (it);
|
|
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
|
if (absorvedPfs.empty() == false) {
|
|
if (absorvedPfs.size() == 1 && !absorvedPfs[0]) {
|
|
// just remove pf;
|
|
} else {
|
|
Util::addToVector (newPfs, absorvedPfs);
|
|
}
|
|
delete pf;
|
|
} else {
|
|
it = pfList.insertShattered (it, pf);
|
|
++ it;
|
|
}
|
|
}
|
|
pfList.add (newPfs);
|
|
}
|
|
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
|
Util::printAsteriskLine();
|
|
std::cout << "AFTER EVIDENCE ABSORVED" << std::endl;
|
|
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
|
std::cout << " -> " << obsFormulas[i] << std::endl;
|
|
}
|
|
Util::printAsteriskLine();
|
|
pfList.print();
|
|
}
|
|
}
|
|
|
|
|
|
|
|
Parfactors
|
|
LiftedOperations::countNormalize (
|
|
Parfactor* g,
|
|
const LogVarSet& set)
|
|
{
|
|
Parfactors normPfs;
|
|
if (set.empty()) {
|
|
normPfs.push_back (new Parfactor (*g));
|
|
} else {
|
|
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
|
for (size_t i = 0; i < normCts.size(); i++) {
|
|
normPfs.push_back (new Parfactor (g, normCts[i]));
|
|
}
|
|
}
|
|
return normPfs;
|
|
}
|
|
|
|
|
|
|
|
Parfactor
|
|
LiftedOperations::calcGroundMultiplication (Parfactor pf)
|
|
{
|
|
LogVarSet lvs = pf.constr()->logVarSet();
|
|
lvs -= pf.constr()->singletons();
|
|
Parfactors newPfs = {new Parfactor (pf)};
|
|
for (size_t i = 0; i < lvs.size(); i++) {
|
|
Parfactors pfs = newPfs;
|
|
newPfs.clear();
|
|
for (size_t j = 0; j < pfs.size(); j++) {
|
|
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
|
|
if (countedLv) {
|
|
pfs[j]->fullExpand (lvs[i]);
|
|
newPfs.push_back (pfs[j]);
|
|
} else {
|
|
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
|
|
for (size_t k = 0; k < cts.size(); k++) {
|
|
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
|
|
}
|
|
delete pfs[j];
|
|
}
|
|
}
|
|
}
|
|
ParfactorList pfList (newPfs);
|
|
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
|
|
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
|
|
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
|
|
}
|
|
return Parfactor (*groundShatteredPfs[0]);
|
|
}
|
|
|
|
|
|
|
|
Parfactors
|
|
LiftedOperations::absorve (
|
|
ObservedFormula& obsFormula,
|
|
Parfactor* g)
|
|
{
|
|
Parfactors absorvedPfs;
|
|
const ProbFormulas& formulas = g->arguments();
|
|
for (size_t i = 0; i < formulas.size(); i++) {
|
|
if (obsFormula.functor() == formulas[i].functor() &&
|
|
obsFormula.arity() == formulas[i].arity()) {
|
|
|
|
if (obsFormula.isAtom()) {
|
|
if (formulas.size() > 1) {
|
|
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
|
} else {
|
|
// hack to erase parfactor g
|
|
absorvedPfs.push_back (0);
|
|
}
|
|
break;
|
|
}
|
|
|
|
g->constr()->moveToTop (formulas[i].logVars());
|
|
std::pair<ConstraintTree*, ConstraintTree*> res;
|
|
res = g->constr()->split (
|
|
formulas[i].logVars(),
|
|
&(obsFormula.constr()),
|
|
obsFormula.constr().logVars());
|
|
ConstraintTree* commCt = res.first;
|
|
ConstraintTree* exclCt = res.second;
|
|
|
|
if (commCt->empty() == false) {
|
|
if (formulas.size() > 1) {
|
|
LogVarSet excl = g->exclusiveLogVars (i);
|
|
Parfactor tempPf (g, commCt);
|
|
Parfactors countNormPfs = LiftedOperations::countNormalize (
|
|
&tempPf, excl);
|
|
for (size_t j = 0; j < countNormPfs.size(); j++) {
|
|
countNormPfs[j]->absorveEvidence (
|
|
formulas[i], obsFormula.evidence());
|
|
absorvedPfs.push_back (countNormPfs[j]);
|
|
}
|
|
} else {
|
|
delete commCt;
|
|
}
|
|
if (exclCt->empty() == false) {
|
|
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
|
} else {
|
|
delete exclCt;
|
|
}
|
|
if (absorvedPfs.empty()) {
|
|
// hack to erase parfactor g
|
|
absorvedPfs.push_back (0);
|
|
}
|
|
break;
|
|
} else {
|
|
delete commCt;
|
|
delete exclCt;
|
|
}
|
|
}
|
|
}
|
|
return absorvedPfs;
|
|
}
|
|
|