factor out some lifted operations in a new class

This commit is contained in:
Tiago Gomes 2012-11-09 23:52:35 +00:00
parent 8ab622e0aa
commit 07bcc89a76
8 changed files with 248 additions and 216 deletions

View File

@ -9,6 +9,7 @@
#include "ParfactorList.h" #include "ParfactorList.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "LiftedOperations.h"
#include "LiftedVe.h" #include "LiftedVe.h"
#include "VarElim.h" #include "VarElim.h"
#include "LiftedBp.h" #include "LiftedBp.h"
@ -281,7 +282,7 @@ runLiftedSolver (void)
YAP_Term taskList = YAP_ARG2; YAP_Term taskList = YAP_ARG2;
vector<Params> results; vector<Params> results;
ParfactorList pfListCopy (*network->first); ParfactorList pfListCopy (*network->first);
LiftedVe::absorveEvidence (pfListCopy, *network->second); LiftedOperations::absorveEvidence (pfListCopy, *network->second);
while (taskList != YAP_TermNil()) { while (taskList != YAP_TermNil()) {
Grounds queryVars; Grounds queryVars;
YAP_Term jointList = YAP_HeadOfTerm (taskList); YAP_Term jointList = YAP_HeadOfTerm (taskList);

View File

@ -1,6 +1,7 @@
#include "LiftedBp.h" #include "LiftedBp.h"
#include "WeightedBp.h" #include "WeightedBp.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "LiftedOperations.h"
#include "LiftedVe.h" #include "LiftedVe.h"
@ -101,7 +102,7 @@ LiftedBp::iterate (void)
for (size_t i = 0; i < args.size(); i++) { for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars(); LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
if ((*it)->constr()->isCountNormalized (lvs) == false) { if ((*it)->constr()->isCountNormalized (lvs) == false) {
Parfactors pfs = LiftedVe::countNormalize (*it, lvs); Parfactors pfs = LiftedOperations::countNormalize (*it, lvs);
it = pfList_.removeAndDelete (it); it = pfList_.removeAndDelete (it);
pfList_.add (pfs); pfList_.add (pfs);
return false; return false;

View File

@ -39,3 +39,4 @@ class LiftedBp
}; };
#endif // HORUS_LIFTEDBP_H #endif // HORUS_LIFTEDBP_H

View File

@ -0,0 +1,211 @@
#include "LiftedOperations.h"
void
LiftedOperations::shatterAgainstQuery (
ParfactorList& pfList,
const Grounds& query)
{
for (size_t i = 0; i < query.size(); i++) {
if (query[i].isAtom()) {
continue;
}
bool found = false;
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
if ((*it)->containsGround (query[i])) {
found = true;
std::pair<ConstraintTree*, ConstraintTree*> split;
LogVars queryLvs (
(*it)->constr()->logVars().begin(),
(*it)->constr()->logVars().begin() + query[i].arity());
split = (*it)->constr()->split (query[i].args());
ConstraintTree* commCt = split.first;
ConstraintTree* exclCt = split.second;
newPfs.push_back (new Parfactor (*it, commCt));
if (exclCt->empty() == false) {
newPfs.push_back (new Parfactor (*it, exclCt));
} else {
delete exclCt;
}
it = pfList.removeAndDelete (it);
} else {
++ it;
}
}
if (found == false) {
cerr << "error: could not find a parfactor with ground " ;
cerr << "`" << query[i] << "'" << endl;
exit (0);
}
pfList.add (newPfs);
}
if (Globals::verbosity > 2) {
Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl;
for (size_t i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl;
}
Util::printAsteriskLine();
pfList.print();
}
}
void
LiftedOperations::absorveEvidence (
ParfactorList& pfList,
ObservedFormulas& obsFormulas)
{
for (size_t i = 0; i < obsFormulas.size(); i++) {
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
Parfactor* pf = *it;
it = pfList.remove (it);
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
if (absorvedPfs.empty() == false) {
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
// just remove pf;
} else {
Util::addToVector (newPfs, absorvedPfs);
}
delete pf;
} else {
it = pfList.insertShattered (it, pf);
++ it;
}
}
pfList.add (newPfs);
}
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl;
for (size_t i = 0; i < obsFormulas.size(); i++) {
cout << " -> " << obsFormulas[i] << endl;
}
Util::printAsteriskLine();
pfList.print();
}
}
Parfactors
LiftedOperations::countNormalize (
Parfactor* g,
const LogVarSet& set)
{
Parfactors normPfs;
if (set.empty()) {
normPfs.push_back (new Parfactor (*g));
} else {
ConstraintTrees normCts = g->constr()->countNormalize (set);
for (size_t i = 0; i < normCts.size(); i++) {
normPfs.push_back (new Parfactor (g, normCts[i]));
}
}
return normPfs;
}
Parfactor
LiftedOperations::calcGroundMultiplication (Parfactor pf)
{
LogVarSet lvs = pf.constr()->logVarSet();
lvs -= pf.constr()->singletons();
Parfactors newPfs = {new Parfactor (pf)};
for (size_t i = 0; i < lvs.size(); i++) {
Parfactors pfs = newPfs;
newPfs.clear();
for (size_t j = 0; j < pfs.size(); j++) {
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
if (countedLv) {
pfs[j]->fullExpand (lvs[i]);
newPfs.push_back (pfs[j]);
} else {
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
for (size_t k = 0; k < cts.size(); k++) {
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
}
delete pfs[j];
}
}
}
ParfactorList pfList (newPfs);
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
}
return Parfactor (*groundShatteredPfs[0]);
}
Parfactors
LiftedOperations::absorve (
ObservedFormula& obsFormula,
Parfactor* g)
{
Parfactors absorvedPfs;
const ProbFormulas& formulas = g->arguments();
for (size_t i = 0; i < formulas.size(); i++) {
if (obsFormula.functor() == formulas[i].functor() &&
obsFormula.arity() == formulas[i].arity()) {
if (obsFormula.isAtom()) {
if (formulas.size() > 1) {
g->absorveEvidence (formulas[i], obsFormula.evidence());
} else {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
}
g->constr()->moveToTop (formulas[i].logVars());
std::pair<ConstraintTree*, ConstraintTree*> res;
res = g->constr()->split (
formulas[i].logVars(),
&(obsFormula.constr()),
obsFormula.constr().logVars());
ConstraintTree* commCt = res.first;
ConstraintTree* exclCt = res.second;
if (commCt->empty() == false) {
if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i);
Parfactor tempPf (g, commCt);
Parfactors countNormPfs = LiftedOperations::countNormalize (
&tempPf, excl);
for (size_t j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (
formulas[i], obsFormula.evidence());
absorvedPfs.push_back (countNormPfs[j]);
}
} else {
delete commCt;
}
if (exclCt->empty() == false) {
absorvedPfs.push_back (new Parfactor (g, exclCt));
} else {
delete exclCt;
}
if (absorvedPfs.empty()) {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
} else {
delete commCt;
delete exclCt;
}
}
}
return absorvedPfs;
}

View File

@ -0,0 +1,24 @@
#ifndef HORUS_LIFTEDOPERATIONS_H
#define HORUS_LIFTEDOPERATIONS_H
#include "ParfactorList.h"
class LiftedOperations
{
public:
static void shatterAgainstQuery (
ParfactorList& pfList, const Grounds& query);
static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas);
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
static Parfactor calcGroundMultiplication (Parfactor pf);
private:
static Parfactors absorve (ObservedFormula&, Parfactor*);
};
#endif // HORUS_LIFTEDOPERATIONS_H

View File

@ -2,6 +2,7 @@
#include <set> #include <set>
#include "LiftedVe.h" #include "LiftedVe.h"
#include "LiftedOperations.h"
#include "Histogram.h" #include "Histogram.h"
#include "Util.h" #include "Util.h"
@ -221,7 +222,7 @@ SumOutOperator::apply (void)
product->sumOutIndex (fIdx); product->sumOutIndex (fIdx);
pfList_.addShattered (product); pfList_.addShattered (product);
} else { } else {
Parfactors pfs = LiftedVe::countNormalize (product, excl); Parfactors pfs = LiftedOperations::countNormalize (product, excl);
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
pfs[i]->sumOutIndex (fIdx); pfs[i]->sumOutIndex (fIdx);
pfList_.add (pfs[i]); pfList_.add (pfs[i]);
@ -375,7 +376,7 @@ CountingOperator::apply (void)
} else { } else {
Parfactor* pf = *pfIter_; Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_); pfList_.remove (pfIter_);
Parfactors pfs = LiftedVe::countNormalize (pf, X_); Parfactors pfs = LiftedOperations::countNormalize (pf, X_);
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_); unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
bool cartProduct = pfs[i]->constr()->isCartesianProduct ( bool cartProduct = pfs[i]->constr()->isCartesianProduct (
@ -419,7 +420,7 @@ CountingOperator::toString (void)
ss << "count convert " << X_ << " in " ; ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel(); ss << (*pfIter_)->getLabel();
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl; ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
Parfactors pfs = LiftedVe::countNormalize (*pfIter_, X_); Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) { if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
ss << " º " << pfs[i]->getLabel() << endl; ss << " º " << pfs[i]->getLabel() << endl;
@ -655,102 +656,11 @@ LiftedVe::printSolverFlags (void) const
void
LiftedVe::absorveEvidence (
ParfactorList& pfList,
ObservedFormulas& obsFormulas)
{
for (size_t i = 0; i < obsFormulas.size(); i++) {
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
Parfactor* pf = *it;
it = pfList.remove (it);
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
if (absorvedPfs.empty() == false) {
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
// just remove pf;
} else {
Util::addToVector (newPfs, absorvedPfs);
}
delete pf;
} else {
it = pfList.insertShattered (it, pf);
++ it;
}
}
pfList.add (newPfs);
}
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl;
for (size_t i = 0; i < obsFormulas.size(); i++) {
cout << " -> " << obsFormulas[i] << endl;
}
Util::printAsteriskLine();
pfList.print();
}
}
Parfactors
LiftedVe::countNormalize (
Parfactor* g,
const LogVarSet& set)
{
Parfactors normPfs;
if (set.empty()) {
normPfs.push_back (new Parfactor (*g));
} else {
ConstraintTrees normCts = g->constr()->countNormalize (set);
for (size_t i = 0; i < normCts.size(); i++) {
normPfs.push_back (new Parfactor (g, normCts[i]));
}
}
return normPfs;
}
Parfactor
LiftedVe::calcGroundMultiplication (Parfactor pf)
{
LogVarSet lvs = pf.constr()->logVarSet();
lvs -= pf.constr()->singletons();
Parfactors newPfs = {new Parfactor (pf)};
for (size_t i = 0; i < lvs.size(); i++) {
Parfactors pfs = newPfs;
newPfs.clear();
for (size_t j = 0; j < pfs.size(); j++) {
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
if (countedLv) {
pfs[j]->fullExpand (lvs[i]);
newPfs.push_back (pfs[j]);
} else {
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
for (size_t k = 0; k < cts.size(); k++) {
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
}
delete pfs[j];
}
}
}
ParfactorList pfList (newPfs);
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
}
return Parfactor (*groundShatteredPfs[0]);
}
void void
LiftedVe::runSolver (const Grounds& query) LiftedVe::runSolver (const Grounds& query)
{ {
largestCost_ = std::log (0); largestCost_ = std::log (0);
shatterAgainstQuery (query); LiftedOperations::shatterAgainstQuery (pfList_, query);
runWeakBayesBall (query); runWeakBayesBall (query);
while (true) { while (true) {
if (Globals::verbosity > 2) { if (Globals::verbosity > 2) {
@ -876,118 +786,3 @@ LiftedVe::runWeakBayesBall (const Grounds& query)
} }
} }
void
LiftedVe::shatterAgainstQuery (const Grounds& query)
{
for (size_t i = 0; i < query.size(); i++) {
if (query[i].isAtom()) {
continue;
}
bool found = false;
Parfactors newPfs;
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if ((*it)->containsGround (query[i])) {
found = true;
std::pair<ConstraintTree*, ConstraintTree*> split;
LogVars queryLvs (
(*it)->constr()->logVars().begin(),
(*it)->constr()->logVars().begin() + query[i].arity());
split = (*it)->constr()->split (query[i].args());
ConstraintTree* commCt = split.first;
ConstraintTree* exclCt = split.second;
newPfs.push_back (new Parfactor (*it, commCt));
if (exclCt->empty() == false) {
newPfs.push_back (new Parfactor (*it, exclCt));
} else {
delete exclCt;
}
it = pfList_.removeAndDelete (it);
} else {
++ it;
}
}
if (found == false) {
cerr << "error: could not find a parfactor with ground " ;
cerr << "`" << query[i] << "'" << endl;
exit (0);
}
pfList_.add (newPfs);
}
if (Globals::verbosity > 2) {
Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl;
for (size_t i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl;
}
Util::printAsteriskLine();
pfList_.print();
}
}
Parfactors
LiftedVe::absorve (
ObservedFormula& obsFormula,
Parfactor* g)
{
Parfactors absorvedPfs;
const ProbFormulas& formulas = g->arguments();
for (size_t i = 0; i < formulas.size(); i++) {
if (obsFormula.functor() == formulas[i].functor() &&
obsFormula.arity() == formulas[i].arity()) {
if (obsFormula.isAtom()) {
if (formulas.size() > 1) {
g->absorveEvidence (formulas[i], obsFormula.evidence());
} else {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
}
g->constr()->moveToTop (formulas[i].logVars());
std::pair<ConstraintTree*, ConstraintTree*> res;
res = g->constr()->split (
formulas[i].logVars(),
&(obsFormula.constr()),
obsFormula.constr().logVars());
ConstraintTree* commCt = res.first;
ConstraintTree* exclCt = res.second;
if (commCt->empty() == false) {
if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i);
Parfactor tempPf (g, commCt);
Parfactors countNormPfs = countNormalize (&tempPf, excl);
for (size_t j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (
formulas[i], obsFormula.evidence());
absorvedPfs.push_back (countNormPfs[j]);
}
} else {
delete commCt;
}
if (exclCt->empty() == false) {
absorvedPfs.push_back (new Parfactor (g, exclCt));
} else {
delete exclCt;
}
if (absorvedPfs.empty()) {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
} else {
delete commCt;
delete exclCt;
}
}
}
return absorvedPfs;
}

View File

@ -144,10 +144,6 @@ class LiftedVe
static void absorveEvidence ( static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas); ParfactorList& pfList, ObservedFormulas& obsFormulas);
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
static Parfactor calcGroundMultiplication (Parfactor pf);
private: private:
void runSolver (const Grounds&); void runSolver (const Grounds&);

View File

@ -59,6 +59,7 @@ HEADERS = \
$(srcdir)/LiftedBp.h \ $(srcdir)/LiftedBp.h \
$(srcdir)/LiftedCircuit.h \ $(srcdir)/LiftedCircuit.h \
$(srcdir)/LiftedKc.h \ $(srcdir)/LiftedKc.h \
$(srcdir)/LiftedOperations.h \
$(srcdir)/LiftedUtils.h \ $(srcdir)/LiftedUtils.h \
$(srcdir)/LiftedVe.h \ $(srcdir)/LiftedVe.h \
$(srcdir)/LiftedWCNF.h \ $(srcdir)/LiftedWCNF.h \
@ -87,6 +88,7 @@ CPP_SOURCES = \
$(srcdir)/LiftedBp.cpp \ $(srcdir)/LiftedBp.cpp \
$(srcdir)/LiftedCircuit.cpp \ $(srcdir)/LiftedCircuit.cpp \
$(srcdir)/LiftedKc.cpp \ $(srcdir)/LiftedKc.cpp \
$(srcdir)/LiftedOperations.cpp \
$(srcdir)/LiftedUtils.cpp \ $(srcdir)/LiftedUtils.cpp \
$(srcdir)/LiftedVe.cpp \ $(srcdir)/LiftedVe.cpp \
$(srcdir)/LiftedWCNF.cpp \ $(srcdir)/LiftedWCNF.cpp \
@ -113,6 +115,7 @@ OBJS = \
LiftedBp.o \ LiftedBp.o \
LiftedCircuit.o \ LiftedCircuit.o \
LiftedKc.o \ LiftedKc.o \
LiftedOperations.o \
LiftedUtils.o \ LiftedUtils.o \
LiftedVe.o \ LiftedVe.o \
LiftedWCNF.o \ LiftedWCNF.o \