This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/LiftedOperations.cpp
Tiago Gomes ef4ebb4d7f Use camel case for constants and enumerators.
All capitals case should be reserved for macros and besides there is no big need to emphasize constness in general.
2013-02-13 18:54:15 +00:00

280 lines
7.5 KiB
C++

#include <vector>
#include <queue>
#include <iostream>
#include "LiftedOperations.h"
namespace Horus {
void
LiftedOperations::shatterAgainstQuery (
ParfactorList& pfList,
const Grounds& query)
{
for (size_t i = 0; i < query.size(); i++) {
if (query[i].isAtom()) {
continue;
}
bool found = false;
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
if ((*it)->containsGround (query[i])) {
found = true;
std::pair<ConstraintTree*, ConstraintTree*> split;
LogVars queryLvs (
(*it)->constr()->logVars().begin(),
(*it)->constr()->logVars().begin() + query[i].arity());
split = (*it)->constr()->split (query[i].args());
ConstraintTree* commCt = split.first;
ConstraintTree* exclCt = split.second;
newPfs.push_back (new Parfactor (*it, commCt));
if (exclCt->empty() == false) {
newPfs.push_back (new Parfactor (*it, exclCt));
} else {
delete exclCt;
}
it = pfList.removeAndDelete (it);
} else {
++ it;
}
}
if (found == false) {
std::cerr << "Error: could not find a parfactor with ground " ;
std::cerr << "`" << query[i] << "'." << std::endl;
exit (EXIT_FAILURE);
}
pfList.add (newPfs);
}
if (Globals::verbosity > 2) {
Util::printAsteriskLine();
std::cout << "SHATTERED AGAINST THE QUERY" << std::endl;
for (size_t i = 0; i < query.size(); i++) {
std::cout << " -> " << query[i] << std::endl;
}
Util::printAsteriskLine();
pfList.print();
}
}
void
LiftedOperations::runWeakBayesBall (
ParfactorList& pfList,
const Grounds& query)
{
std::queue<PrvGroup> todo; // groups to process
std::set<PrvGroup> done; // processed or in queue
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
PrvGroup group = (*it)->findGroup (query[i]);
if (group != std::numeric_limits<PrvGroup>::max()) {
todo.push (group);
done.insert (group);
break;
}
++ it;
}
}
std::set<Parfactor*> requiredPfs;
while (todo.empty() == false) {
PrvGroup group = todo.front();
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) {
std::vector<PrvGroup> groups = (*it)->getAllGroups();
for (size_t i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]);
done.insert (groups[i]);
}
}
requiredPfs.insert (*it);
}
++ it;
}
todo.pop();
}
ParfactorList::iterator it = pfList.begin();
bool foundNotRequired = false;
while (it != pfList.end()) {
if (Util::contains (requiredPfs, *it) == false) {
if (Globals::verbosity > 2) {
if (foundNotRequired == false) {
Util::printHeader ("PARFACTORS TO DISCARD");
foundNotRequired = true;
}
(*it)->print();
}
it = pfList.removeAndDelete (it);
} else {
++ it;
}
}
}
void
LiftedOperations::absorveEvidence (
ParfactorList& pfList,
ObservedFormulas& obsFormulas)
{
for (size_t i = 0; i < obsFormulas.size(); i++) {
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
Parfactor* pf = *it;
it = pfList.remove (it);
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
if (absorvedPfs.empty() == false) {
if (absorvedPfs.size() == 1 && !absorvedPfs[0]) {
// just remove pf;
} else {
Util::addToVector (newPfs, absorvedPfs);
}
delete pf;
} else {
it = pfList.insertShattered (it, pf);
++ it;
}
}
pfList.add (newPfs);
}
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine();
std::cout << "AFTER EVIDENCE ABSORveSolverD" << std::endl;
for (size_t i = 0; i < obsFormulas.size(); i++) {
std::cout << " -> " << obsFormulas[i] << std::endl;
}
Util::printAsteriskLine();
pfList.print();
}
}
Parfactors
LiftedOperations::countNormalize (
Parfactor* g,
const LogVarSet& set)
{
Parfactors normPfs;
if (set.empty()) {
normPfs.push_back (new Parfactor (*g));
} else {
ConstraintTrees normCts = g->constr()->countNormalize (set);
for (size_t i = 0; i < normCts.size(); i++) {
normPfs.push_back (new Parfactor (g, normCts[i]));
}
}
return normPfs;
}
Parfactor
LiftedOperations::calcGroundMultiplication (Parfactor pf)
{
LogVarSet lvs = pf.constr()->logVarSet();
lvs -= pf.constr()->singletons();
Parfactors newPfs = {new Parfactor (pf)};
for (size_t i = 0; i < lvs.size(); i++) {
Parfactors pfs = newPfs;
newPfs.clear();
for (size_t j = 0; j < pfs.size(); j++) {
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
if (countedLv) {
pfs[j]->fullExpand (lvs[i]);
newPfs.push_back (pfs[j]);
} else {
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
for (size_t k = 0; k < cts.size(); k++) {
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
}
delete pfs[j];
}
}
}
ParfactorList pfList (newPfs);
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
}
return Parfactor (*groundShatteredPfs[0]);
}
Parfactors
LiftedOperations::absorve (
ObservedFormula& obsFormula,
Parfactor* g)
{
Parfactors absorvedPfs;
const ProbFormulas& formulas = g->arguments();
for (size_t i = 0; i < formulas.size(); i++) {
if (obsFormula.functor() == formulas[i].functor() &&
obsFormula.arity() == formulas[i].arity()) {
if (obsFormula.isAtom()) {
if (formulas.size() > 1) {
g->absorveEvidence (formulas[i], obsFormula.evidence());
} else {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
}
g->constr()->moveToTop (formulas[i].logVars());
std::pair<ConstraintTree*, ConstraintTree*> res;
res = g->constr()->split (
formulas[i].logVars(),
&(obsFormula.constr()),
obsFormula.constr().logVars());
ConstraintTree* commCt = res.first;
ConstraintTree* exclCt = res.second;
if (commCt->empty() == false) {
if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i);
Parfactor tempPf (g, commCt);
Parfactors countNormPfs = LiftedOperations::countNormalize (
&tempPf, excl);
for (size_t j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (
formulas[i], obsFormula.evidence());
absorvedPfs.push_back (countNormPfs[j]);
}
} else {
delete commCt;
}
if (exclCt->empty() == false) {
absorvedPfs.push_back (new Parfactor (g, exclCt));
} else {
delete exclCt;
}
if (absorvedPfs.empty()) {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
} else {
delete commCt;
delete exclCt;
}
}
}
return absorvedPfs;
}
} // namespace Horus