move more code around
This commit is contained in:
parent
07bcc89a76
commit
4776817603
@ -2,7 +2,6 @@
|
|||||||
#include "WeightedBp.h"
|
#include "WeightedBp.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "LiftedOperations.h"
|
#include "LiftedOperations.h"
|
||||||
#include "LiftedVe.h"
|
|
||||||
|
|
||||||
|
|
||||||
LiftedBp::LiftedBp (const ParfactorList& pfList)
|
LiftedBp::LiftedBp (const ParfactorList& pfList)
|
||||||
@ -190,12 +189,12 @@ LiftedBp::rangeOfGround (const Ground& gr)
|
|||||||
Params
|
Params
|
||||||
LiftedBp::getJointByConditioning (
|
LiftedBp::getJointByConditioning (
|
||||||
const ParfactorList& pfList,
|
const ParfactorList& pfList,
|
||||||
const Grounds& grounds)
|
const Grounds& query)
|
||||||
{
|
{
|
||||||
LiftedBp solver (pfList);
|
LiftedBp solver (pfList);
|
||||||
Params prevBeliefs = solver.solveQuery ({grounds[0]});
|
Params prevBeliefs = solver.solveQuery ({query[0]});
|
||||||
Grounds obsGrounds = {grounds[0]};
|
Grounds obsGrounds = {query[0]};
|
||||||
for (size_t i = 1; i < grounds.size(); i++) {
|
for (size_t i = 1; i < query.size(); i++) {
|
||||||
Params newBeliefs;
|
Params newBeliefs;
|
||||||
vector<ObservedFormula> obsFs;
|
vector<ObservedFormula> obsFs;
|
||||||
Ranges obsRanges;
|
Ranges obsRanges;
|
||||||
@ -210,16 +209,16 @@ LiftedBp::getJointByConditioning (
|
|||||||
obsFs[j].setEvidence (indexer[j]);
|
obsFs[j].setEvidence (indexer[j]);
|
||||||
}
|
}
|
||||||
ParfactorList tempPfList (pfList);
|
ParfactorList tempPfList (pfList);
|
||||||
LiftedVe::absorveEvidence (tempPfList, obsFs);
|
LiftedOperations::absorveEvidence (tempPfList, obsFs);
|
||||||
LiftedBp solver (tempPfList);
|
LiftedBp solver (tempPfList);
|
||||||
Params beliefs = solver.solveQuery ({grounds[i]});
|
Params beliefs = solver.solveQuery ({query[i]});
|
||||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||||
newBeliefs.push_back (beliefs[k]);
|
newBeliefs.push_back (beliefs[k]);
|
||||||
}
|
}
|
||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
int count = -1;
|
int count = -1;
|
||||||
unsigned range = rangeOfGround (grounds[i]);
|
unsigned range = rangeOfGround (query[i]);
|
||||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||||
if (j % range == 0) {
|
if (j % range == 0) {
|
||||||
count ++;
|
count ++;
|
||||||
@ -227,7 +226,7 @@ LiftedBp::getJointByConditioning (
|
|||||||
newBeliefs[j] *= prevBeliefs[count];
|
newBeliefs[j] *= prevBeliefs[count];
|
||||||
}
|
}
|
||||||
prevBeliefs = newBeliefs;
|
prevBeliefs = newBeliefs;
|
||||||
obsGrounds.push_back (grounds[i]);
|
obsGrounds.push_back (query[i]);
|
||||||
}
|
}
|
||||||
return prevBeliefs;
|
return prevBeliefs;
|
||||||
}
|
}
|
||||||
|
@ -54,6 +54,67 @@ LiftedOperations::shatterAgainstQuery (
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedOperations::runWeakBayesBall (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
|
{
|
||||||
|
queue<PrvGroup> todo; // groups to process
|
||||||
|
set<PrvGroup> done; // processed or in queue
|
||||||
|
for (size_t i = 0; i < query.size(); i++) {
|
||||||
|
ParfactorList::iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
PrvGroup group = (*it)->findGroup (query[i]);
|
||||||
|
if (group != numeric_limits<PrvGroup>::max()) {
|
||||||
|
todo.push (group);
|
||||||
|
done.insert (group);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
set<Parfactor*> requiredPfs;
|
||||||
|
while (todo.empty() == false) {
|
||||||
|
PrvGroup group = todo.front();
|
||||||
|
ParfactorList::iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
if (Util::contains (requiredPfs, *it) == false &&
|
||||||
|
(*it)->containsGroup (group)) {
|
||||||
|
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||||
|
for (size_t i = 0; i < groups.size(); i++) {
|
||||||
|
if (Util::contains (done, groups[i]) == false) {
|
||||||
|
todo.push (groups[i]);
|
||||||
|
done.insert (groups[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
requiredPfs.insert (*it);
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
todo.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
ParfactorList::iterator it = pfList.begin();
|
||||||
|
bool foundNotRequired = false;
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
if (Util::contains (requiredPfs, *it) == false) {
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
if (foundNotRequired == false) {
|
||||||
|
Util::printHeader ("PARFACTORS TO DISCARD");
|
||||||
|
foundNotRequired = true;
|
||||||
|
}
|
||||||
|
(*it)->print();
|
||||||
|
}
|
||||||
|
it = pfList.removeAndDelete (it);
|
||||||
|
} else {
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
LiftedOperations::absorveEvidence (
|
LiftedOperations::absorveEvidence (
|
||||||
ParfactorList& pfList,
|
ParfactorList& pfList,
|
||||||
|
@ -9,6 +9,9 @@ class LiftedOperations
|
|||||||
static void shatterAgainstQuery (
|
static void shatterAgainstQuery (
|
||||||
ParfactorList& pfList, const Grounds& query);
|
ParfactorList& pfList, const Grounds& query);
|
||||||
|
|
||||||
|
static void runWeakBayesBall (
|
||||||
|
ParfactorList& pfList, const Grounds&);
|
||||||
|
|
||||||
static void absorveEvidence (
|
static void absorveEvidence (
|
||||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||||
|
|
||||||
@ -17,7 +20,6 @@ class LiftedOperations
|
|||||||
static Parfactor calcGroundMultiplication (Parfactor pf);
|
static Parfactor calcGroundMultiplication (Parfactor pf);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -661,7 +661,7 @@ LiftedVe::runSolver (const Grounds& query)
|
|||||||
{
|
{
|
||||||
largestCost_ = std::log (0);
|
largestCost_ = std::log (0);
|
||||||
LiftedOperations::shatterAgainstQuery (pfList_, query);
|
LiftedOperations::shatterAgainstQuery (pfList_, query);
|
||||||
runWeakBayesBall (query);
|
LiftedOperations::runWeakBayesBall (pfList_, query);
|
||||||
while (true) {
|
while (true) {
|
||||||
if (Globals::verbosity > 2) {
|
if (Globals::verbosity > 2) {
|
||||||
Util::printDashedLine();
|
Util::printDashedLine();
|
||||||
@ -727,62 +727,3 @@ LiftedVe::getBestOperation (const Grounds& query)
|
|||||||
return bestOp;
|
return bestOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
LiftedVe::runWeakBayesBall (const Grounds& query)
|
|
||||||
{
|
|
||||||
queue<PrvGroup> todo; // groups to process
|
|
||||||
set<PrvGroup> done; // processed or in queue
|
|
||||||
for (size_t i = 0; i < query.size(); i++) {
|
|
||||||
ParfactorList::iterator it = pfList_.begin();
|
|
||||||
while (it != pfList_.end()) {
|
|
||||||
PrvGroup group = (*it)->findGroup (query[i]);
|
|
||||||
if (group != numeric_limits<PrvGroup>::max()) {
|
|
||||||
todo.push (group);
|
|
||||||
done.insert (group);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
++ it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
set<Parfactor*> requiredPfs;
|
|
||||||
while (todo.empty() == false) {
|
|
||||||
PrvGroup group = todo.front();
|
|
||||||
ParfactorList::iterator it = pfList_.begin();
|
|
||||||
while (it != pfList_.end()) {
|
|
||||||
if (Util::contains (requiredPfs, *it) == false &&
|
|
||||||
(*it)->containsGroup (group)) {
|
|
||||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
|
||||||
for (size_t i = 0; i < groups.size(); i++) {
|
|
||||||
if (Util::contains (done, groups[i]) == false) {
|
|
||||||
todo.push (groups[i]);
|
|
||||||
done.insert (groups[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
requiredPfs.insert (*it);
|
|
||||||
}
|
|
||||||
++ it;
|
|
||||||
}
|
|
||||||
todo.pop();
|
|
||||||
}
|
|
||||||
|
|
||||||
ParfactorList::iterator it = pfList_.begin();
|
|
||||||
bool foundNotRequired = false;
|
|
||||||
while (it != pfList_.end()) {
|
|
||||||
if (Util::contains (requiredPfs, *it) == false) {
|
|
||||||
if (Globals::verbosity > 2) {
|
|
||||||
if (foundNotRequired == false) {
|
|
||||||
Util::printHeader ("PARFACTORS TO DISCARD");
|
|
||||||
foundNotRequired = true;
|
|
||||||
}
|
|
||||||
(*it)->print();
|
|
||||||
}
|
|
||||||
it = pfList_.removeAndDelete (it);
|
|
||||||
} else {
|
|
||||||
++ it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -141,22 +141,12 @@ class LiftedVe
|
|||||||
|
|
||||||
void printSolverFlags (void) const;
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
static void absorveEvidence (
|
|
||||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runSolver (const Grounds&);
|
void runSolver (const Grounds&);
|
||||||
|
|
||||||
LiftedOperator* getBestOperation (const Grounds&);
|
LiftedOperator* getBestOperation (const Grounds&);
|
||||||
|
|
||||||
void runWeakBayesBall (const Grounds&);
|
|
||||||
|
|
||||||
void shatterAgainstQuery (const Grounds&);
|
|
||||||
|
|
||||||
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
|
||||||
|
|
||||||
ParfactorList pfList_;
|
ParfactorList pfList_;
|
||||||
|
|
||||||
double largestCost_;
|
double largestCost_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user