move more code around

This commit is contained in:
Tiago Gomes 2012-11-10 00:18:20 +00:00
parent 07bcc89a76
commit 4776817603
5 changed files with 81 additions and 88 deletions

View File

@ -2,7 +2,6 @@
#include "WeightedBp.h"
#include "FactorGraph.h"
#include "LiftedOperations.h"
#include "LiftedVe.h"
LiftedBp::LiftedBp (const ParfactorList& pfList)
@ -190,12 +189,12 @@ LiftedBp::rangeOfGround (const Ground& gr)
Params
LiftedBp::getJointByConditioning (
const ParfactorList& pfList,
const Grounds& grounds)
const Grounds& query)
{
LiftedBp solver (pfList);
Params prevBeliefs = solver.solveQuery ({grounds[0]});
Grounds obsGrounds = {grounds[0]};
for (size_t i = 1; i < grounds.size(); i++) {
Params prevBeliefs = solver.solveQuery ({query[0]});
Grounds obsGrounds = {query[0]};
for (size_t i = 1; i < query.size(); i++) {
Params newBeliefs;
vector<ObservedFormula> obsFs;
Ranges obsRanges;
@ -210,16 +209,16 @@ LiftedBp::getJointByConditioning (
obsFs[j].setEvidence (indexer[j]);
}
ParfactorList tempPfList (pfList);
LiftedVe::absorveEvidence (tempPfList, obsFs);
LiftedOperations::absorveEvidence (tempPfList, obsFs);
LiftedBp solver (tempPfList);
Params beliefs = solver.solveQuery ({grounds[i]});
Params beliefs = solver.solveQuery ({query[i]});
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
unsigned range = rangeOfGround (grounds[i]);
unsigned range = rangeOfGround (query[i]);
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % range == 0) {
count ++;
@ -227,7 +226,7 @@ LiftedBp::getJointByConditioning (
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
obsGrounds.push_back (grounds[i]);
obsGrounds.push_back (query[i]);
}
return prevBeliefs;
}

View File

@ -54,6 +54,67 @@ LiftedOperations::shatterAgainstQuery (
void
LiftedOperations::runWeakBayesBall (
ParfactorList& pfList,
const Grounds& query)
{
queue<PrvGroup> todo; // groups to process
set<PrvGroup> done; // processed or in queue
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
PrvGroup group = (*it)->findGroup (query[i]);
if (group != numeric_limits<PrvGroup>::max()) {
todo.push (group);
done.insert (group);
break;
}
++ it;
}
}
set<Parfactor*> requiredPfs;
while (todo.empty() == false) {
PrvGroup group = todo.front();
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) {
vector<PrvGroup> groups = (*it)->getAllGroups();
for (size_t i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]);
done.insert (groups[i]);
}
}
requiredPfs.insert (*it);
}
++ it;
}
todo.pop();
}
ParfactorList::iterator it = pfList.begin();
bool foundNotRequired = false;
while (it != pfList.end()) {
if (Util::contains (requiredPfs, *it) == false) {
if (Globals::verbosity > 2) {
if (foundNotRequired == false) {
Util::printHeader ("PARFACTORS TO DISCARD");
foundNotRequired = true;
}
(*it)->print();
}
it = pfList.removeAndDelete (it);
} else {
++ it;
}
}
}
void
LiftedOperations::absorveEvidence (
ParfactorList& pfList,

View File

@ -9,6 +9,9 @@ class LiftedOperations
static void shatterAgainstQuery (
ParfactorList& pfList, const Grounds& query);
static void runWeakBayesBall (
ParfactorList& pfList, const Grounds&);
static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas);
@ -17,7 +20,6 @@ class LiftedOperations
static Parfactor calcGroundMultiplication (Parfactor pf);
private:
static Parfactors absorve (ObservedFormula&, Parfactor*);
};

View File

@ -661,7 +661,7 @@ LiftedVe::runSolver (const Grounds& query)
{
largestCost_ = std::log (0);
LiftedOperations::shatterAgainstQuery (pfList_, query);
runWeakBayesBall (query);
LiftedOperations::runWeakBayesBall (pfList_, query);
while (true) {
if (Globals::verbosity > 2) {
Util::printDashedLine();
@ -727,62 +727,3 @@ LiftedVe::getBestOperation (const Grounds& query)
return bestOp;
}
void
LiftedVe::runWeakBayesBall (const Grounds& query)
{
queue<PrvGroup> todo; // groups to process
set<PrvGroup> done; // processed or in queue
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
PrvGroup group = (*it)->findGroup (query[i]);
if (group != numeric_limits<PrvGroup>::max()) {
todo.push (group);
done.insert (group);
break;
}
++ it;
}
}
set<Parfactor*> requiredPfs;
while (todo.empty() == false) {
PrvGroup group = todo.front();
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) {
vector<PrvGroup> groups = (*it)->getAllGroups();
for (size_t i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]);
done.insert (groups[i]);
}
}
requiredPfs.insert (*it);
}
++ it;
}
todo.pop();
}
ParfactorList::iterator it = pfList_.begin();
bool foundNotRequired = false;
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false) {
if (Globals::verbosity > 2) {
if (foundNotRequired == false) {
Util::printHeader ("PARFACTORS TO DISCARD");
foundNotRequired = true;
}
(*it)->print();
}
it = pfList_.removeAndDelete (it);
} else {
++ it;
}
}
}

View File

@ -45,9 +45,9 @@ class ProductOperator : public LiftedOperator
private:
static bool validOp (Parfactor*, Parfactor*);
ParfactorList::iterator g1_;
ParfactorList::iterator g2_;
ParfactorList& pfList_;
ParfactorList::iterator g1_;
ParfactorList::iterator g2_;
ParfactorList& pfList_;
};
@ -125,9 +125,9 @@ class GroundOperator : public LiftedOperator
private:
vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void);
PrvGroup group_;
unsigned lvIndex_;
ParfactorList& pfList_;
PrvGroup group_;
unsigned lvIndex_;
ParfactorList& pfList_;
};
@ -141,23 +141,13 @@ class LiftedVe
void printSolverFlags (void) const;
static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas);
private:
void runSolver (const Grounds&);
LiftedOperator* getBestOperation (const Grounds&);
void runWeakBayesBall (const Grounds&);
void shatterAgainstQuery (const Grounds&);
static Parfactors absorve (ObservedFormula&, Parfactor*);
ParfactorList pfList_;
double largestCost_;
ParfactorList pfList_;
double largestCost_;
};
#endif // HORUS_LIFTEDVE_H