factor out some lifted operations in a new class
This commit is contained in:
parent
8ab622e0aa
commit
07bcc89a76
@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include "ParfactorList.h"
|
#include "ParfactorList.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
|
#include "LiftedOperations.h"
|
||||||
#include "LiftedVe.h"
|
#include "LiftedVe.h"
|
||||||
#include "VarElim.h"
|
#include "VarElim.h"
|
||||||
#include "LiftedBp.h"
|
#include "LiftedBp.h"
|
||||||
@ -281,7 +282,7 @@ runLiftedSolver (void)
|
|||||||
YAP_Term taskList = YAP_ARG2;
|
YAP_Term taskList = YAP_ARG2;
|
||||||
vector<Params> results;
|
vector<Params> results;
|
||||||
ParfactorList pfListCopy (*network->first);
|
ParfactorList pfListCopy (*network->first);
|
||||||
LiftedVe::absorveEvidence (pfListCopy, *network->second);
|
LiftedOperations::absorveEvidence (pfListCopy, *network->second);
|
||||||
while (taskList != YAP_TermNil()) {
|
while (taskList != YAP_TermNil()) {
|
||||||
Grounds queryVars;
|
Grounds queryVars;
|
||||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#include "LiftedBp.h"
|
#include "LiftedBp.h"
|
||||||
#include "WeightedBp.h"
|
#include "WeightedBp.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
|
#include "LiftedOperations.h"
|
||||||
#include "LiftedVe.h"
|
#include "LiftedVe.h"
|
||||||
|
|
||||||
|
|
||||||
@ -101,7 +102,7 @@ LiftedBp::iterate (void)
|
|||||||
for (size_t i = 0; i < args.size(); i++) {
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
||||||
if ((*it)->constr()->isCountNormalized (lvs) == false) {
|
if ((*it)->constr()->isCountNormalized (lvs) == false) {
|
||||||
Parfactors pfs = LiftedVe::countNormalize (*it, lvs);
|
Parfactors pfs = LiftedOperations::countNormalize (*it, lvs);
|
||||||
it = pfList_.removeAndDelete (it);
|
it = pfList_.removeAndDelete (it);
|
||||||
pfList_.add (pfs);
|
pfList_.add (pfs);
|
||||||
return false;
|
return false;
|
||||||
|
@ -39,3 +39,4 @@ class LiftedBp
|
|||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_LIFTEDBP_H
|
#endif // HORUS_LIFTEDBP_H
|
||||||
|
|
||||||
|
211
packages/CLPBN/horus/LiftedOperations.cpp
Normal file
211
packages/CLPBN/horus/LiftedOperations.cpp
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
#include "LiftedOperations.h"
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedOperations::shatterAgainstQuery (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < query.size(); i++) {
|
||||||
|
if (query[i].isAtom()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
bool found = false;
|
||||||
|
Parfactors newPfs;
|
||||||
|
ParfactorList::iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
if ((*it)->containsGround (query[i])) {
|
||||||
|
found = true;
|
||||||
|
std::pair<ConstraintTree*, ConstraintTree*> split;
|
||||||
|
LogVars queryLvs (
|
||||||
|
(*it)->constr()->logVars().begin(),
|
||||||
|
(*it)->constr()->logVars().begin() + query[i].arity());
|
||||||
|
split = (*it)->constr()->split (query[i].args());
|
||||||
|
ConstraintTree* commCt = split.first;
|
||||||
|
ConstraintTree* exclCt = split.second;
|
||||||
|
newPfs.push_back (new Parfactor (*it, commCt));
|
||||||
|
if (exclCt->empty() == false) {
|
||||||
|
newPfs.push_back (new Parfactor (*it, exclCt));
|
||||||
|
} else {
|
||||||
|
delete exclCt;
|
||||||
|
}
|
||||||
|
it = pfList.removeAndDelete (it);
|
||||||
|
} else {
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (found == false) {
|
||||||
|
cerr << "error: could not find a parfactor with ground " ;
|
||||||
|
cerr << "`" << query[i] << "'" << endl;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
|
pfList.add (newPfs);
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
Util::printAsteriskLine();
|
||||||
|
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||||
|
for (size_t i = 0; i < query.size(); i++) {
|
||||||
|
cout << " -> " << query[i] << endl;
|
||||||
|
}
|
||||||
|
Util::printAsteriskLine();
|
||||||
|
pfList.print();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedOperations::absorveEvidence (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
ObservedFormulas& obsFormulas)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||||
|
Parfactors newPfs;
|
||||||
|
ParfactorList::iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
Parfactor* pf = *it;
|
||||||
|
it = pfList.remove (it);
|
||||||
|
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
||||||
|
if (absorvedPfs.empty() == false) {
|
||||||
|
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
|
||||||
|
// just remove pf;
|
||||||
|
} else {
|
||||||
|
Util::addToVector (newPfs, absorvedPfs);
|
||||||
|
}
|
||||||
|
delete pf;
|
||||||
|
} else {
|
||||||
|
it = pfList.insertShattered (it, pf);
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pfList.add (newPfs);
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
||||||
|
Util::printAsteriskLine();
|
||||||
|
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||||
|
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||||
|
cout << " -> " << obsFormulas[i] << endl;
|
||||||
|
}
|
||||||
|
Util::printAsteriskLine();
|
||||||
|
pfList.print();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactors
|
||||||
|
LiftedOperations::countNormalize (
|
||||||
|
Parfactor* g,
|
||||||
|
const LogVarSet& set)
|
||||||
|
{
|
||||||
|
Parfactors normPfs;
|
||||||
|
if (set.empty()) {
|
||||||
|
normPfs.push_back (new Parfactor (*g));
|
||||||
|
} else {
|
||||||
|
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
||||||
|
for (size_t i = 0; i < normCts.size(); i++) {
|
||||||
|
normPfs.push_back (new Parfactor (g, normCts[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return normPfs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactor
|
||||||
|
LiftedOperations::calcGroundMultiplication (Parfactor pf)
|
||||||
|
{
|
||||||
|
LogVarSet lvs = pf.constr()->logVarSet();
|
||||||
|
lvs -= pf.constr()->singletons();
|
||||||
|
Parfactors newPfs = {new Parfactor (pf)};
|
||||||
|
for (size_t i = 0; i < lvs.size(); i++) {
|
||||||
|
Parfactors pfs = newPfs;
|
||||||
|
newPfs.clear();
|
||||||
|
for (size_t j = 0; j < pfs.size(); j++) {
|
||||||
|
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
|
||||||
|
if (countedLv) {
|
||||||
|
pfs[j]->fullExpand (lvs[i]);
|
||||||
|
newPfs.push_back (pfs[j]);
|
||||||
|
} else {
|
||||||
|
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
|
||||||
|
for (size_t k = 0; k < cts.size(); k++) {
|
||||||
|
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
|
||||||
|
}
|
||||||
|
delete pfs[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ParfactorList pfList (newPfs);
|
||||||
|
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
|
||||||
|
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
|
||||||
|
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
|
||||||
|
}
|
||||||
|
return Parfactor (*groundShatteredPfs[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactors
|
||||||
|
LiftedOperations::absorve (
|
||||||
|
ObservedFormula& obsFormula,
|
||||||
|
Parfactor* g)
|
||||||
|
{
|
||||||
|
Parfactors absorvedPfs;
|
||||||
|
const ProbFormulas& formulas = g->arguments();
|
||||||
|
for (size_t i = 0; i < formulas.size(); i++) {
|
||||||
|
if (obsFormula.functor() == formulas[i].functor() &&
|
||||||
|
obsFormula.arity() == formulas[i].arity()) {
|
||||||
|
|
||||||
|
if (obsFormula.isAtom()) {
|
||||||
|
if (formulas.size() > 1) {
|
||||||
|
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
||||||
|
} else {
|
||||||
|
// hack to erase parfactor g
|
||||||
|
absorvedPfs.push_back (0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
g->constr()->moveToTop (formulas[i].logVars());
|
||||||
|
std::pair<ConstraintTree*, ConstraintTree*> res;
|
||||||
|
res = g->constr()->split (
|
||||||
|
formulas[i].logVars(),
|
||||||
|
&(obsFormula.constr()),
|
||||||
|
obsFormula.constr().logVars());
|
||||||
|
ConstraintTree* commCt = res.first;
|
||||||
|
ConstraintTree* exclCt = res.second;
|
||||||
|
|
||||||
|
if (commCt->empty() == false) {
|
||||||
|
if (formulas.size() > 1) {
|
||||||
|
LogVarSet excl = g->exclusiveLogVars (i);
|
||||||
|
Parfactor tempPf (g, commCt);
|
||||||
|
Parfactors countNormPfs = LiftedOperations::countNormalize (
|
||||||
|
&tempPf, excl);
|
||||||
|
for (size_t j = 0; j < countNormPfs.size(); j++) {
|
||||||
|
countNormPfs[j]->absorveEvidence (
|
||||||
|
formulas[i], obsFormula.evidence());
|
||||||
|
absorvedPfs.push_back (countNormPfs[j]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
delete commCt;
|
||||||
|
}
|
||||||
|
if (exclCt->empty() == false) {
|
||||||
|
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
||||||
|
} else {
|
||||||
|
delete exclCt;
|
||||||
|
}
|
||||||
|
if (absorvedPfs.empty()) {
|
||||||
|
// hack to erase parfactor g
|
||||||
|
absorvedPfs.push_back (0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
delete commCt;
|
||||||
|
delete exclCt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return absorvedPfs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
24
packages/CLPBN/horus/LiftedOperations.h
Normal file
24
packages/CLPBN/horus/LiftedOperations.h
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
#ifndef HORUS_LIFTEDOPERATIONS_H
|
||||||
|
#define HORUS_LIFTEDOPERATIONS_H
|
||||||
|
|
||||||
|
#include "ParfactorList.h"
|
||||||
|
|
||||||
|
class LiftedOperations
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static void shatterAgainstQuery (
|
||||||
|
ParfactorList& pfList, const Grounds& query);
|
||||||
|
|
||||||
|
static void absorveEvidence (
|
||||||
|
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||||
|
|
||||||
|
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||||
|
|
||||||
|
static Parfactor calcGroundMultiplication (Parfactor pf);
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_LIFTEDOPERATIONS_H
|
@ -2,6 +2,7 @@
|
|||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
#include "LiftedVe.h"
|
#include "LiftedVe.h"
|
||||||
|
#include "LiftedOperations.h"
|
||||||
#include "Histogram.h"
|
#include "Histogram.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
@ -221,7 +222,7 @@ SumOutOperator::apply (void)
|
|||||||
product->sumOutIndex (fIdx);
|
product->sumOutIndex (fIdx);
|
||||||
pfList_.addShattered (product);
|
pfList_.addShattered (product);
|
||||||
} else {
|
} else {
|
||||||
Parfactors pfs = LiftedVe::countNormalize (product, excl);
|
Parfactors pfs = LiftedOperations::countNormalize (product, excl);
|
||||||
for (size_t i = 0; i < pfs.size(); i++) {
|
for (size_t i = 0; i < pfs.size(); i++) {
|
||||||
pfs[i]->sumOutIndex (fIdx);
|
pfs[i]->sumOutIndex (fIdx);
|
||||||
pfList_.add (pfs[i]);
|
pfList_.add (pfs[i]);
|
||||||
@ -375,7 +376,7 @@ CountingOperator::apply (void)
|
|||||||
} else {
|
} else {
|
||||||
Parfactor* pf = *pfIter_;
|
Parfactor* pf = *pfIter_;
|
||||||
pfList_.remove (pfIter_);
|
pfList_.remove (pfIter_);
|
||||||
Parfactors pfs = LiftedVe::countNormalize (pf, X_);
|
Parfactors pfs = LiftedOperations::countNormalize (pf, X_);
|
||||||
for (size_t i = 0; i < pfs.size(); i++) {
|
for (size_t i = 0; i < pfs.size(); i++) {
|
||||||
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
||||||
bool cartProduct = pfs[i]->constr()->isCartesianProduct (
|
bool cartProduct = pfs[i]->constr()->isCartesianProduct (
|
||||||
@ -419,7 +420,7 @@ CountingOperator::toString (void)
|
|||||||
ss << "count convert " << X_ << " in " ;
|
ss << "count convert " << X_ << " in " ;
|
||||||
ss << (*pfIter_)->getLabel();
|
ss << (*pfIter_)->getLabel();
|
||||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||||
Parfactors pfs = LiftedVe::countNormalize (*pfIter_, X_);
|
Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_);
|
||||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||||
for (size_t i = 0; i < pfs.size(); i++) {
|
for (size_t i = 0; i < pfs.size(); i++) {
|
||||||
ss << " º " << pfs[i]->getLabel() << endl;
|
ss << " º " << pfs[i]->getLabel() << endl;
|
||||||
@ -655,102 +656,11 @@ LiftedVe::printSolverFlags (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
LiftedVe::absorveEvidence (
|
|
||||||
ParfactorList& pfList,
|
|
||||||
ObservedFormulas& obsFormulas)
|
|
||||||
{
|
|
||||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
|
||||||
Parfactors newPfs;
|
|
||||||
ParfactorList::iterator it = pfList.begin();
|
|
||||||
while (it != pfList.end()) {
|
|
||||||
Parfactor* pf = *it;
|
|
||||||
it = pfList.remove (it);
|
|
||||||
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
|
||||||
if (absorvedPfs.empty() == false) {
|
|
||||||
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
|
|
||||||
// just remove pf;
|
|
||||||
} else {
|
|
||||||
Util::addToVector (newPfs, absorvedPfs);
|
|
||||||
}
|
|
||||||
delete pf;
|
|
||||||
} else {
|
|
||||||
it = pfList.insertShattered (it, pf);
|
|
||||||
++ it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pfList.add (newPfs);
|
|
||||||
}
|
|
||||||
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
|
||||||
Util::printAsteriskLine();
|
|
||||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
|
||||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
|
||||||
cout << " -> " << obsFormulas[i] << endl;
|
|
||||||
}
|
|
||||||
Util::printAsteriskLine();
|
|
||||||
pfList.print();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactors
|
|
||||||
LiftedVe::countNormalize (
|
|
||||||
Parfactor* g,
|
|
||||||
const LogVarSet& set)
|
|
||||||
{
|
|
||||||
Parfactors normPfs;
|
|
||||||
if (set.empty()) {
|
|
||||||
normPfs.push_back (new Parfactor (*g));
|
|
||||||
} else {
|
|
||||||
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
|
||||||
for (size_t i = 0; i < normCts.size(); i++) {
|
|
||||||
normPfs.push_back (new Parfactor (g, normCts[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return normPfs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor
|
|
||||||
LiftedVe::calcGroundMultiplication (Parfactor pf)
|
|
||||||
{
|
|
||||||
LogVarSet lvs = pf.constr()->logVarSet();
|
|
||||||
lvs -= pf.constr()->singletons();
|
|
||||||
Parfactors newPfs = {new Parfactor (pf)};
|
|
||||||
for (size_t i = 0; i < lvs.size(); i++) {
|
|
||||||
Parfactors pfs = newPfs;
|
|
||||||
newPfs.clear();
|
|
||||||
for (size_t j = 0; j < pfs.size(); j++) {
|
|
||||||
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
|
|
||||||
if (countedLv) {
|
|
||||||
pfs[j]->fullExpand (lvs[i]);
|
|
||||||
newPfs.push_back (pfs[j]);
|
|
||||||
} else {
|
|
||||||
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
|
|
||||||
for (size_t k = 0; k < cts.size(); k++) {
|
|
||||||
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
|
|
||||||
}
|
|
||||||
delete pfs[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ParfactorList pfList (newPfs);
|
|
||||||
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
|
|
||||||
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
|
|
||||||
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
|
|
||||||
}
|
|
||||||
return Parfactor (*groundShatteredPfs[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
LiftedVe::runSolver (const Grounds& query)
|
LiftedVe::runSolver (const Grounds& query)
|
||||||
{
|
{
|
||||||
largestCost_ = std::log (0);
|
largestCost_ = std::log (0);
|
||||||
shatterAgainstQuery (query);
|
LiftedOperations::shatterAgainstQuery (pfList_, query);
|
||||||
runWeakBayesBall (query);
|
runWeakBayesBall (query);
|
||||||
while (true) {
|
while (true) {
|
||||||
if (Globals::verbosity > 2) {
|
if (Globals::verbosity > 2) {
|
||||||
@ -876,118 +786,3 @@ LiftedVe::runWeakBayesBall (const Grounds& query)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
LiftedVe::shatterAgainstQuery (const Grounds& query)
|
|
||||||
{
|
|
||||||
for (size_t i = 0; i < query.size(); i++) {
|
|
||||||
if (query[i].isAtom()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
bool found = false;
|
|
||||||
Parfactors newPfs;
|
|
||||||
ParfactorList::iterator it = pfList_.begin();
|
|
||||||
while (it != pfList_.end()) {
|
|
||||||
if ((*it)->containsGround (query[i])) {
|
|
||||||
found = true;
|
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> split;
|
|
||||||
LogVars queryLvs (
|
|
||||||
(*it)->constr()->logVars().begin(),
|
|
||||||
(*it)->constr()->logVars().begin() + query[i].arity());
|
|
||||||
split = (*it)->constr()->split (query[i].args());
|
|
||||||
ConstraintTree* commCt = split.first;
|
|
||||||
ConstraintTree* exclCt = split.second;
|
|
||||||
newPfs.push_back (new Parfactor (*it, commCt));
|
|
||||||
if (exclCt->empty() == false) {
|
|
||||||
newPfs.push_back (new Parfactor (*it, exclCt));
|
|
||||||
} else {
|
|
||||||
delete exclCt;
|
|
||||||
}
|
|
||||||
it = pfList_.removeAndDelete (it);
|
|
||||||
} else {
|
|
||||||
++ it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (found == false) {
|
|
||||||
cerr << "error: could not find a parfactor with ground " ;
|
|
||||||
cerr << "`" << query[i] << "'" << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
pfList_.add (newPfs);
|
|
||||||
}
|
|
||||||
if (Globals::verbosity > 2) {
|
|
||||||
Util::printAsteriskLine();
|
|
||||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
|
||||||
for (size_t i = 0; i < query.size(); i++) {
|
|
||||||
cout << " -> " << query[i] << endl;
|
|
||||||
}
|
|
||||||
Util::printAsteriskLine();
|
|
||||||
pfList_.print();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactors
|
|
||||||
LiftedVe::absorve (
|
|
||||||
ObservedFormula& obsFormula,
|
|
||||||
Parfactor* g)
|
|
||||||
{
|
|
||||||
Parfactors absorvedPfs;
|
|
||||||
const ProbFormulas& formulas = g->arguments();
|
|
||||||
for (size_t i = 0; i < formulas.size(); i++) {
|
|
||||||
if (obsFormula.functor() == formulas[i].functor() &&
|
|
||||||
obsFormula.arity() == formulas[i].arity()) {
|
|
||||||
|
|
||||||
if (obsFormula.isAtom()) {
|
|
||||||
if (formulas.size() > 1) {
|
|
||||||
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
|
||||||
} else {
|
|
||||||
// hack to erase parfactor g
|
|
||||||
absorvedPfs.push_back (0);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
g->constr()->moveToTop (formulas[i].logVars());
|
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> res;
|
|
||||||
res = g->constr()->split (
|
|
||||||
formulas[i].logVars(),
|
|
||||||
&(obsFormula.constr()),
|
|
||||||
obsFormula.constr().logVars());
|
|
||||||
ConstraintTree* commCt = res.first;
|
|
||||||
ConstraintTree* exclCt = res.second;
|
|
||||||
|
|
||||||
if (commCt->empty() == false) {
|
|
||||||
if (formulas.size() > 1) {
|
|
||||||
LogVarSet excl = g->exclusiveLogVars (i);
|
|
||||||
Parfactor tempPf (g, commCt);
|
|
||||||
Parfactors countNormPfs = countNormalize (&tempPf, excl);
|
|
||||||
for (size_t j = 0; j < countNormPfs.size(); j++) {
|
|
||||||
countNormPfs[j]->absorveEvidence (
|
|
||||||
formulas[i], obsFormula.evidence());
|
|
||||||
absorvedPfs.push_back (countNormPfs[j]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
delete commCt;
|
|
||||||
}
|
|
||||||
if (exclCt->empty() == false) {
|
|
||||||
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
|
||||||
} else {
|
|
||||||
delete exclCt;
|
|
||||||
}
|
|
||||||
if (absorvedPfs.empty()) {
|
|
||||||
// hack to erase parfactor g
|
|
||||||
absorvedPfs.push_back (0);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
delete commCt;
|
|
||||||
delete exclCt;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return absorvedPfs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -144,10 +144,6 @@ class LiftedVe
|
|||||||
static void absorveEvidence (
|
static void absorveEvidence (
|
||||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||||
|
|
||||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
|
||||||
|
|
||||||
static Parfactor calcGroundMultiplication (Parfactor pf);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runSolver (const Grounds&);
|
void runSolver (const Grounds&);
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ HEADERS = \
|
|||||||
$(srcdir)/LiftedBp.h \
|
$(srcdir)/LiftedBp.h \
|
||||||
$(srcdir)/LiftedCircuit.h \
|
$(srcdir)/LiftedCircuit.h \
|
||||||
$(srcdir)/LiftedKc.h \
|
$(srcdir)/LiftedKc.h \
|
||||||
|
$(srcdir)/LiftedOperations.h \
|
||||||
$(srcdir)/LiftedUtils.h \
|
$(srcdir)/LiftedUtils.h \
|
||||||
$(srcdir)/LiftedVe.h \
|
$(srcdir)/LiftedVe.h \
|
||||||
$(srcdir)/LiftedWCNF.h \
|
$(srcdir)/LiftedWCNF.h \
|
||||||
@ -87,6 +88,7 @@ CPP_SOURCES = \
|
|||||||
$(srcdir)/LiftedBp.cpp \
|
$(srcdir)/LiftedBp.cpp \
|
||||||
$(srcdir)/LiftedCircuit.cpp \
|
$(srcdir)/LiftedCircuit.cpp \
|
||||||
$(srcdir)/LiftedKc.cpp \
|
$(srcdir)/LiftedKc.cpp \
|
||||||
|
$(srcdir)/LiftedOperations.cpp \
|
||||||
$(srcdir)/LiftedUtils.cpp \
|
$(srcdir)/LiftedUtils.cpp \
|
||||||
$(srcdir)/LiftedVe.cpp \
|
$(srcdir)/LiftedVe.cpp \
|
||||||
$(srcdir)/LiftedWCNF.cpp \
|
$(srcdir)/LiftedWCNF.cpp \
|
||||||
@ -113,6 +115,7 @@ OBJS = \
|
|||||||
LiftedBp.o \
|
LiftedBp.o \
|
||||||
LiftedCircuit.o \
|
LiftedCircuit.o \
|
||||||
LiftedKc.o \
|
LiftedKc.o \
|
||||||
|
LiftedOperations.o \
|
||||||
LiftedUtils.o \
|
LiftedUtils.o \
|
||||||
LiftedVe.o \
|
LiftedVe.o \
|
||||||
LiftedWCNF.o \
|
LiftedWCNF.o \
|
||||||
|
Reference in New Issue
Block a user