factor out some lifted operations in a new class
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include <set>
|
||||
|
||||
#include "LiftedVe.h"
|
||||
#include "LiftedOperations.h"
|
||||
#include "Histogram.h"
|
||||
#include "Util.h"
|
||||
|
||||
@@ -221,7 +222,7 @@ SumOutOperator::apply (void)
|
||||
product->sumOutIndex (fIdx);
|
||||
pfList_.addShattered (product);
|
||||
} else {
|
||||
Parfactors pfs = LiftedVe::countNormalize (product, excl);
|
||||
Parfactors pfs = LiftedOperations::countNormalize (product, excl);
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
pfs[i]->sumOutIndex (fIdx);
|
||||
pfList_.add (pfs[i]);
|
||||
@@ -375,7 +376,7 @@ CountingOperator::apply (void)
|
||||
} else {
|
||||
Parfactor* pf = *pfIter_;
|
||||
pfList_.remove (pfIter_);
|
||||
Parfactors pfs = LiftedVe::countNormalize (pf, X_);
|
||||
Parfactors pfs = LiftedOperations::countNormalize (pf, X_);
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
||||
bool cartProduct = pfs[i]->constr()->isCartesianProduct (
|
||||
@@ -419,7 +420,7 @@ CountingOperator::toString (void)
|
||||
ss << "count convert " << X_ << " in " ;
|
||||
ss << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
Parfactors pfs = LiftedVe::countNormalize (*pfIter_, X_);
|
||||
Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_);
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
ss << " º " << pfs[i]->getLabel() << endl;
|
||||
@@ -655,102 +656,11 @@ LiftedVe::printSolverFlags (void) const
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedVe::absorveEvidence (
|
||||
ParfactorList& pfList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
Parfactors newPfs;
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
Parfactor* pf = *it;
|
||||
it = pfList.remove (it);
|
||||
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
||||
if (absorvedPfs.empty() == false) {
|
||||
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
|
||||
// just remove pf;
|
||||
} else {
|
||||
Util::addToVector (newPfs, absorvedPfs);
|
||||
}
|
||||
delete pf;
|
||||
} else {
|
||||
it = pfList.insertShattered (it, pf);
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
pfList.add (newPfs);
|
||||
}
|
||||
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
cout << " -> " << obsFormulas[i] << endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactors
|
||||
LiftedVe::countNormalize (
|
||||
Parfactor* g,
|
||||
const LogVarSet& set)
|
||||
{
|
||||
Parfactors normPfs;
|
||||
if (set.empty()) {
|
||||
normPfs.push_back (new Parfactor (*g));
|
||||
} else {
|
||||
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
||||
for (size_t i = 0; i < normCts.size(); i++) {
|
||||
normPfs.push_back (new Parfactor (g, normCts[i]));
|
||||
}
|
||||
}
|
||||
return normPfs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor
|
||||
LiftedVe::calcGroundMultiplication (Parfactor pf)
|
||||
{
|
||||
LogVarSet lvs = pf.constr()->logVarSet();
|
||||
lvs -= pf.constr()->singletons();
|
||||
Parfactors newPfs = {new Parfactor (pf)};
|
||||
for (size_t i = 0; i < lvs.size(); i++) {
|
||||
Parfactors pfs = newPfs;
|
||||
newPfs.clear();
|
||||
for (size_t j = 0; j < pfs.size(); j++) {
|
||||
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
|
||||
if (countedLv) {
|
||||
pfs[j]->fullExpand (lvs[i]);
|
||||
newPfs.push_back (pfs[j]);
|
||||
} else {
|
||||
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
|
||||
for (size_t k = 0; k < cts.size(); k++) {
|
||||
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
|
||||
}
|
||||
delete pfs[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
ParfactorList pfList (newPfs);
|
||||
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
|
||||
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
|
||||
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
|
||||
}
|
||||
return Parfactor (*groundShatteredPfs[0]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedVe::runSolver (const Grounds& query)
|
||||
{
|
||||
largestCost_ = std::log (0);
|
||||
shatterAgainstQuery (query);
|
||||
LiftedOperations::shatterAgainstQuery (pfList_, query);
|
||||
runWeakBayesBall (query);
|
||||
while (true) {
|
||||
if (Globals::verbosity > 2) {
|
||||
@@ -876,118 +786,3 @@ LiftedVe::runWeakBayesBall (const Grounds& query)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedVe::shatterAgainstQuery (const Grounds& query)
|
||||
{
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
if (query[i].isAtom()) {
|
||||
continue;
|
||||
}
|
||||
bool found = false;
|
||||
Parfactors newPfs;
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if ((*it)->containsGround (query[i])) {
|
||||
found = true;
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split;
|
||||
LogVars queryLvs (
|
||||
(*it)->constr()->logVars().begin(),
|
||||
(*it)->constr()->logVars().begin() + query[i].arity());
|
||||
split = (*it)->constr()->split (query[i].args());
|
||||
ConstraintTree* commCt = split.first;
|
||||
ConstraintTree* exclCt = split.second;
|
||||
newPfs.push_back (new Parfactor (*it, commCt));
|
||||
if (exclCt->empty() == false) {
|
||||
newPfs.push_back (new Parfactor (*it, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
it = pfList_.removeAndDelete (it);
|
||||
} else {
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
cerr << "error: could not find a parfactor with ground " ;
|
||||
cerr << "`" << query[i] << "'" << endl;
|
||||
exit (0);
|
||||
}
|
||||
pfList_.add (newPfs);
|
||||
}
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
cout << " -> " << query[i] << endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList_.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactors
|
||||
LiftedVe::absorve (
|
||||
ObservedFormula& obsFormula,
|
||||
Parfactor* g)
|
||||
{
|
||||
Parfactors absorvedPfs;
|
||||
const ProbFormulas& formulas = g->arguments();
|
||||
for (size_t i = 0; i < formulas.size(); i++) {
|
||||
if (obsFormula.functor() == formulas[i].functor() &&
|
||||
obsFormula.arity() == formulas[i].arity()) {
|
||||
|
||||
if (obsFormula.isAtom()) {
|
||||
if (formulas.size() > 1) {
|
||||
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
||||
} else {
|
||||
// hack to erase parfactor g
|
||||
absorvedPfs.push_back (0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
g->constr()->moveToTop (formulas[i].logVars());
|
||||
std::pair<ConstraintTree*, ConstraintTree*> res;
|
||||
res = g->constr()->split (
|
||||
formulas[i].logVars(),
|
||||
&(obsFormula.constr()),
|
||||
obsFormula.constr().logVars());
|
||||
ConstraintTree* commCt = res.first;
|
||||
ConstraintTree* exclCt = res.second;
|
||||
|
||||
if (commCt->empty() == false) {
|
||||
if (formulas.size() > 1) {
|
||||
LogVarSet excl = g->exclusiveLogVars (i);
|
||||
Parfactor tempPf (g, commCt);
|
||||
Parfactors countNormPfs = countNormalize (&tempPf, excl);
|
||||
for (size_t j = 0; j < countNormPfs.size(); j++) {
|
||||
countNormPfs[j]->absorveEvidence (
|
||||
formulas[i], obsFormula.evidence());
|
||||
absorvedPfs.push_back (countNormPfs[j]);
|
||||
}
|
||||
} else {
|
||||
delete commCt;
|
||||
}
|
||||
if (exclCt->empty() == false) {
|
||||
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
if (absorvedPfs.empty()) {
|
||||
// hack to erase parfactor g
|
||||
absorvedPfs.push_back (0);
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
}
|
||||
}
|
||||
}
|
||||
return absorvedPfs;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user