From 07bcc89a76bb55b3efa1361e63f108397cfaaeeb Mon Sep 17 00:00:00 2001 From: Tiago Gomes Date: Fri, 9 Nov 2012 23:52:35 +0000 Subject: [PATCH] factor out some lifted operations in a new class --- packages/CLPBN/horus/HorusYap.cpp | 3 +- packages/CLPBN/horus/LiftedBp.cpp | 3 +- packages/CLPBN/horus/LiftedBp.h | 1 + packages/CLPBN/horus/LiftedOperations.cpp | 211 +++++++++++++++++++++ packages/CLPBN/horus/LiftedOperations.h | 24 +++ packages/CLPBN/horus/LiftedVe.cpp | 215 +--------------------- packages/CLPBN/horus/LiftedVe.h | 4 - packages/CLPBN/horus/Makefile.in | 3 + 8 files changed, 248 insertions(+), 216 deletions(-) create mode 100644 packages/CLPBN/horus/LiftedOperations.cpp create mode 100644 packages/CLPBN/horus/LiftedOperations.h diff --git a/packages/CLPBN/horus/HorusYap.cpp b/packages/CLPBN/horus/HorusYap.cpp index 7246beb5e..cd31c0612 100644 --- a/packages/CLPBN/horus/HorusYap.cpp +++ b/packages/CLPBN/horus/HorusYap.cpp @@ -9,6 +9,7 @@ #include "ParfactorList.h" #include "FactorGraph.h" +#include "LiftedOperations.h" #include "LiftedVe.h" #include "VarElim.h" #include "LiftedBp.h" @@ -281,7 +282,7 @@ runLiftedSolver (void) YAP_Term taskList = YAP_ARG2; vector results; ParfactorList pfListCopy (*network->first); - LiftedVe::absorveEvidence (pfListCopy, *network->second); + LiftedOperations::absorveEvidence (pfListCopy, *network->second); while (taskList != YAP_TermNil()) { Grounds queryVars; YAP_Term jointList = YAP_HeadOfTerm (taskList); diff --git a/packages/CLPBN/horus/LiftedBp.cpp b/packages/CLPBN/horus/LiftedBp.cpp index 05a5ea6af..1c5a93023 100644 --- a/packages/CLPBN/horus/LiftedBp.cpp +++ b/packages/CLPBN/horus/LiftedBp.cpp @@ -1,6 +1,7 @@ #include "LiftedBp.h" #include "WeightedBp.h" #include "FactorGraph.h" +#include "LiftedOperations.h" #include "LiftedVe.h" @@ -101,7 +102,7 @@ LiftedBp::iterate (void) for (size_t i = 0; i < args.size(); i++) { LogVarSet lvs = (*it)->logVarSet() - args[i].logVars(); if ((*it)->constr()->isCountNormalized (lvs) == false) { - Parfactors pfs = LiftedVe::countNormalize (*it, lvs); + Parfactors pfs = LiftedOperations::countNormalize (*it, lvs); it = pfList_.removeAndDelete (it); pfList_.add (pfs); return false; diff --git a/packages/CLPBN/horus/LiftedBp.h b/packages/CLPBN/horus/LiftedBp.h index c34956320..29edf0ac8 100644 --- a/packages/CLPBN/horus/LiftedBp.h +++ b/packages/CLPBN/horus/LiftedBp.h @@ -39,3 +39,4 @@ class LiftedBp }; #endif // HORUS_LIFTEDBP_H + diff --git a/packages/CLPBN/horus/LiftedOperations.cpp b/packages/CLPBN/horus/LiftedOperations.cpp new file mode 100644 index 000000000..1c9c8b413 --- /dev/null +++ b/packages/CLPBN/horus/LiftedOperations.cpp @@ -0,0 +1,211 @@ +#include "LiftedOperations.h" + + +void +LiftedOperations::shatterAgainstQuery ( + ParfactorList& pfList, + const Grounds& query) +{ + for (size_t i = 0; i < query.size(); i++) { + if (query[i].isAtom()) { + continue; + } + bool found = false; + Parfactors newPfs; + ParfactorList::iterator it = pfList.begin(); + while (it != pfList.end()) { + if ((*it)->containsGround (query[i])) { + found = true; + std::pair split; + LogVars queryLvs ( + (*it)->constr()->logVars().begin(), + (*it)->constr()->logVars().begin() + query[i].arity()); + split = (*it)->constr()->split (query[i].args()); + ConstraintTree* commCt = split.first; + ConstraintTree* exclCt = split.second; + newPfs.push_back (new Parfactor (*it, commCt)); + if (exclCt->empty() == false) { + newPfs.push_back (new Parfactor (*it, exclCt)); + } else { + delete exclCt; + } + it = pfList.removeAndDelete (it); + } else { + ++ it; + } + } + if (found == false) { + cerr << "error: could not find a parfactor with ground " ; + cerr << "`" << query[i] << "'" << endl; + exit (0); + } + pfList.add (newPfs); + } + if (Globals::verbosity > 2) { + Util::printAsteriskLine(); + cout << "SHATTERED AGAINST THE QUERY" << endl; + for (size_t i = 0; i < query.size(); i++) { + cout << " -> " << query[i] << endl; + } + Util::printAsteriskLine(); + pfList.print(); + } +} + + + +void +LiftedOperations::absorveEvidence ( + ParfactorList& pfList, + ObservedFormulas& obsFormulas) +{ + for (size_t i = 0; i < obsFormulas.size(); i++) { + Parfactors newPfs; + ParfactorList::iterator it = pfList.begin(); + while (it != pfList.end()) { + Parfactor* pf = *it; + it = pfList.remove (it); + Parfactors absorvedPfs = absorve (obsFormulas[i], pf); + if (absorvedPfs.empty() == false) { + if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) { + // just remove pf; + } else { + Util::addToVector (newPfs, absorvedPfs); + } + delete pf; + } else { + it = pfList.insertShattered (it, pf); + ++ it; + } + } + pfList.add (newPfs); + } + if (Globals::verbosity > 2 && obsFormulas.empty() == false) { + Util::printAsteriskLine(); + cout << "AFTER EVIDENCE ABSORVED" << endl; + for (size_t i = 0; i < obsFormulas.size(); i++) { + cout << " -> " << obsFormulas[i] << endl; + } + Util::printAsteriskLine(); + pfList.print(); + } +} + + + +Parfactors +LiftedOperations::countNormalize ( + Parfactor* g, + const LogVarSet& set) +{ + Parfactors normPfs; + if (set.empty()) { + normPfs.push_back (new Parfactor (*g)); + } else { + ConstraintTrees normCts = g->constr()->countNormalize (set); + for (size_t i = 0; i < normCts.size(); i++) { + normPfs.push_back (new Parfactor (g, normCts[i])); + } + } + return normPfs; +} + + + +Parfactor +LiftedOperations::calcGroundMultiplication (Parfactor pf) +{ + LogVarSet lvs = pf.constr()->logVarSet(); + lvs -= pf.constr()->singletons(); + Parfactors newPfs = {new Parfactor (pf)}; + for (size_t i = 0; i < lvs.size(); i++) { + Parfactors pfs = newPfs; + newPfs.clear(); + for (size_t j = 0; j < pfs.size(); j++) { + bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]); + if (countedLv) { + pfs[j]->fullExpand (lvs[i]); + newPfs.push_back (pfs[j]); + } else { + ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]); + for (size_t k = 0; k < cts.size(); k++) { + newPfs.push_back (new Parfactor (pfs[j], cts[k])); + } + delete pfs[j]; + } + } + } + ParfactorList pfList (newPfs); + Parfactors groundShatteredPfs (pfList.begin(),pfList.end()); + for (size_t i = 1; i < groundShatteredPfs.size(); i++) { + groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]); + } + return Parfactor (*groundShatteredPfs[0]); +} + + + +Parfactors +LiftedOperations::absorve ( + ObservedFormula& obsFormula, + Parfactor* g) +{ + Parfactors absorvedPfs; + const ProbFormulas& formulas = g->arguments(); + for (size_t i = 0; i < formulas.size(); i++) { + if (obsFormula.functor() == formulas[i].functor() && + obsFormula.arity() == formulas[i].arity()) { + + if (obsFormula.isAtom()) { + if (formulas.size() > 1) { + g->absorveEvidence (formulas[i], obsFormula.evidence()); + } else { + // hack to erase parfactor g + absorvedPfs.push_back (0); + } + break; + } + + g->constr()->moveToTop (formulas[i].logVars()); + std::pair res; + res = g->constr()->split ( + formulas[i].logVars(), + &(obsFormula.constr()), + obsFormula.constr().logVars()); + ConstraintTree* commCt = res.first; + ConstraintTree* exclCt = res.second; + + if (commCt->empty() == false) { + if (formulas.size() > 1) { + LogVarSet excl = g->exclusiveLogVars (i); + Parfactor tempPf (g, commCt); + Parfactors countNormPfs = LiftedOperations::countNormalize ( + &tempPf, excl); + for (size_t j = 0; j < countNormPfs.size(); j++) { + countNormPfs[j]->absorveEvidence ( + formulas[i], obsFormula.evidence()); + absorvedPfs.push_back (countNormPfs[j]); + } + } else { + delete commCt; + } + if (exclCt->empty() == false) { + absorvedPfs.push_back (new Parfactor (g, exclCt)); + } else { + delete exclCt; + } + if (absorvedPfs.empty()) { + // hack to erase parfactor g + absorvedPfs.push_back (0); + } + break; + } else { + delete commCt; + delete exclCt; + } + } + } + return absorvedPfs; +} + + diff --git a/packages/CLPBN/horus/LiftedOperations.h b/packages/CLPBN/horus/LiftedOperations.h new file mode 100644 index 000000000..ae6b6f0a9 --- /dev/null +++ b/packages/CLPBN/horus/LiftedOperations.h @@ -0,0 +1,24 @@ +#ifndef HORUS_LIFTEDOPERATIONS_H +#define HORUS_LIFTEDOPERATIONS_H + +#include "ParfactorList.h" + +class LiftedOperations +{ + public: + static void shatterAgainstQuery ( + ParfactorList& pfList, const Grounds& query); + + static void absorveEvidence ( + ParfactorList& pfList, ObservedFormulas& obsFormulas); + + static Parfactors countNormalize (Parfactor*, const LogVarSet&); + + static Parfactor calcGroundMultiplication (Parfactor pf); + + private: + + static Parfactors absorve (ObservedFormula&, Parfactor*); +}; + +#endif // HORUS_LIFTEDOPERATIONS_H diff --git a/packages/CLPBN/horus/LiftedVe.cpp b/packages/CLPBN/horus/LiftedVe.cpp index 2437906ce..0b5435295 100644 --- a/packages/CLPBN/horus/LiftedVe.cpp +++ b/packages/CLPBN/horus/LiftedVe.cpp @@ -2,6 +2,7 @@ #include #include "LiftedVe.h" +#include "LiftedOperations.h" #include "Histogram.h" #include "Util.h" @@ -221,7 +222,7 @@ SumOutOperator::apply (void) product->sumOutIndex (fIdx); pfList_.addShattered (product); } else { - Parfactors pfs = LiftedVe::countNormalize (product, excl); + Parfactors pfs = LiftedOperations::countNormalize (product, excl); for (size_t i = 0; i < pfs.size(); i++) { pfs[i]->sumOutIndex (fIdx); pfList_.add (pfs[i]); @@ -375,7 +376,7 @@ CountingOperator::apply (void) } else { Parfactor* pf = *pfIter_; pfList_.remove (pfIter_); - Parfactors pfs = LiftedVe::countNormalize (pf, X_); + Parfactors pfs = LiftedOperations::countNormalize (pf, X_); for (size_t i = 0; i < pfs.size(); i++) { unsigned condCount = pfs[i]->constr()->getConditionalCount (X_); bool cartProduct = pfs[i]->constr()->isCartesianProduct ( @@ -419,7 +420,7 @@ CountingOperator::toString (void) ss << "count convert " << X_ << " in " ; ss << (*pfIter_)->getLabel(); ss << " [cost=" << std::exp (getLogCost()) << "]" << endl; - Parfactors pfs = LiftedVe::countNormalize (*pfIter_, X_); + Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_); if ((*pfIter_)->constr()->isCountNormalized (X_) == false) { for (size_t i = 0; i < pfs.size(); i++) { ss << " ยบ " << pfs[i]->getLabel() << endl; @@ -655,102 +656,11 @@ LiftedVe::printSolverFlags (void) const -void -LiftedVe::absorveEvidence ( - ParfactorList& pfList, - ObservedFormulas& obsFormulas) -{ - for (size_t i = 0; i < obsFormulas.size(); i++) { - Parfactors newPfs; - ParfactorList::iterator it = pfList.begin(); - while (it != pfList.end()) { - Parfactor* pf = *it; - it = pfList.remove (it); - Parfactors absorvedPfs = absorve (obsFormulas[i], pf); - if (absorvedPfs.empty() == false) { - if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) { - // just remove pf; - } else { - Util::addToVector (newPfs, absorvedPfs); - } - delete pf; - } else { - it = pfList.insertShattered (it, pf); - ++ it; - } - } - pfList.add (newPfs); - } - if (Globals::verbosity > 2 && obsFormulas.empty() == false) { - Util::printAsteriskLine(); - cout << "AFTER EVIDENCE ABSORVED" << endl; - for (size_t i = 0; i < obsFormulas.size(); i++) { - cout << " -> " << obsFormulas[i] << endl; - } - Util::printAsteriskLine(); - pfList.print(); - } -} - - - -Parfactors -LiftedVe::countNormalize ( - Parfactor* g, - const LogVarSet& set) -{ - Parfactors normPfs; - if (set.empty()) { - normPfs.push_back (new Parfactor (*g)); - } else { - ConstraintTrees normCts = g->constr()->countNormalize (set); - for (size_t i = 0; i < normCts.size(); i++) { - normPfs.push_back (new Parfactor (g, normCts[i])); - } - } - return normPfs; -} - - - -Parfactor -LiftedVe::calcGroundMultiplication (Parfactor pf) -{ - LogVarSet lvs = pf.constr()->logVarSet(); - lvs -= pf.constr()->singletons(); - Parfactors newPfs = {new Parfactor (pf)}; - for (size_t i = 0; i < lvs.size(); i++) { - Parfactors pfs = newPfs; - newPfs.clear(); - for (size_t j = 0; j < pfs.size(); j++) { - bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]); - if (countedLv) { - pfs[j]->fullExpand (lvs[i]); - newPfs.push_back (pfs[j]); - } else { - ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]); - for (size_t k = 0; k < cts.size(); k++) { - newPfs.push_back (new Parfactor (pfs[j], cts[k])); - } - delete pfs[j]; - } - } - } - ParfactorList pfList (newPfs); - Parfactors groundShatteredPfs (pfList.begin(),pfList.end()); - for (size_t i = 1; i < groundShatteredPfs.size(); i++) { - groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]); - } - return Parfactor (*groundShatteredPfs[0]); -} - - - void LiftedVe::runSolver (const Grounds& query) { largestCost_ = std::log (0); - shatterAgainstQuery (query); + LiftedOperations::shatterAgainstQuery (pfList_, query); runWeakBayesBall (query); while (true) { if (Globals::verbosity > 2) { @@ -876,118 +786,3 @@ LiftedVe::runWeakBayesBall (const Grounds& query) } } - - -void -LiftedVe::shatterAgainstQuery (const Grounds& query) -{ - for (size_t i = 0; i < query.size(); i++) { - if (query[i].isAtom()) { - continue; - } - bool found = false; - Parfactors newPfs; - ParfactorList::iterator it = pfList_.begin(); - while (it != pfList_.end()) { - if ((*it)->containsGround (query[i])) { - found = true; - std::pair split; - LogVars queryLvs ( - (*it)->constr()->logVars().begin(), - (*it)->constr()->logVars().begin() + query[i].arity()); - split = (*it)->constr()->split (query[i].args()); - ConstraintTree* commCt = split.first; - ConstraintTree* exclCt = split.second; - newPfs.push_back (new Parfactor (*it, commCt)); - if (exclCt->empty() == false) { - newPfs.push_back (new Parfactor (*it, exclCt)); - } else { - delete exclCt; - } - it = pfList_.removeAndDelete (it); - } else { - ++ it; - } - } - if (found == false) { - cerr << "error: could not find a parfactor with ground " ; - cerr << "`" << query[i] << "'" << endl; - exit (0); - } - pfList_.add (newPfs); - } - if (Globals::verbosity > 2) { - Util::printAsteriskLine(); - cout << "SHATTERED AGAINST THE QUERY" << endl; - for (size_t i = 0; i < query.size(); i++) { - cout << " -> " << query[i] << endl; - } - Util::printAsteriskLine(); - pfList_.print(); - } -} - - - -Parfactors -LiftedVe::absorve ( - ObservedFormula& obsFormula, - Parfactor* g) -{ - Parfactors absorvedPfs; - const ProbFormulas& formulas = g->arguments(); - for (size_t i = 0; i < formulas.size(); i++) { - if (obsFormula.functor() == formulas[i].functor() && - obsFormula.arity() == formulas[i].arity()) { - - if (obsFormula.isAtom()) { - if (formulas.size() > 1) { - g->absorveEvidence (formulas[i], obsFormula.evidence()); - } else { - // hack to erase parfactor g - absorvedPfs.push_back (0); - } - break; - } - - g->constr()->moveToTop (formulas[i].logVars()); - std::pair res; - res = g->constr()->split ( - formulas[i].logVars(), - &(obsFormula.constr()), - obsFormula.constr().logVars()); - ConstraintTree* commCt = res.first; - ConstraintTree* exclCt = res.second; - - if (commCt->empty() == false) { - if (formulas.size() > 1) { - LogVarSet excl = g->exclusiveLogVars (i); - Parfactor tempPf (g, commCt); - Parfactors countNormPfs = countNormalize (&tempPf, excl); - for (size_t j = 0; j < countNormPfs.size(); j++) { - countNormPfs[j]->absorveEvidence ( - formulas[i], obsFormula.evidence()); - absorvedPfs.push_back (countNormPfs[j]); - } - } else { - delete commCt; - } - if (exclCt->empty() == false) { - absorvedPfs.push_back (new Parfactor (g, exclCt)); - } else { - delete exclCt; - } - if (absorvedPfs.empty()) { - // hack to erase parfactor g - absorvedPfs.push_back (0); - } - break; - } else { - delete commCt; - delete exclCt; - } - } - } - return absorvedPfs; -} - diff --git a/packages/CLPBN/horus/LiftedVe.h b/packages/CLPBN/horus/LiftedVe.h index e79ffc265..85adacc18 100644 --- a/packages/CLPBN/horus/LiftedVe.h +++ b/packages/CLPBN/horus/LiftedVe.h @@ -144,10 +144,6 @@ class LiftedVe static void absorveEvidence ( ParfactorList& pfList, ObservedFormulas& obsFormulas); - static Parfactors countNormalize (Parfactor*, const LogVarSet&); - - static Parfactor calcGroundMultiplication (Parfactor pf); - private: void runSolver (const Grounds&); diff --git a/packages/CLPBN/horus/Makefile.in b/packages/CLPBN/horus/Makefile.in index e4c5b5178..a87b3574e 100644 --- a/packages/CLPBN/horus/Makefile.in +++ b/packages/CLPBN/horus/Makefile.in @@ -59,6 +59,7 @@ HEADERS = \ $(srcdir)/LiftedBp.h \ $(srcdir)/LiftedCircuit.h \ $(srcdir)/LiftedKc.h \ + $(srcdir)/LiftedOperations.h \ $(srcdir)/LiftedUtils.h \ $(srcdir)/LiftedVe.h \ $(srcdir)/LiftedWCNF.h \ @@ -87,6 +88,7 @@ CPP_SOURCES = \ $(srcdir)/LiftedBp.cpp \ $(srcdir)/LiftedCircuit.cpp \ $(srcdir)/LiftedKc.cpp \ + $(srcdir)/LiftedOperations.cpp \ $(srcdir)/LiftedUtils.cpp \ $(srcdir)/LiftedVe.cpp \ $(srcdir)/LiftedWCNF.cpp \ @@ -113,6 +115,7 @@ OBJS = \ LiftedBp.o \ LiftedCircuit.o \ LiftedKc.o \ + LiftedOperations.o \ LiftedUtils.o \ LiftedVe.o \ LiftedWCNF.o \