move more code around
This commit is contained in:
parent
07bcc89a76
commit
4776817603
@ -2,7 +2,6 @@
|
||||
#include "WeightedBp.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "LiftedOperations.h"
|
||||
#include "LiftedVe.h"
|
||||
|
||||
|
||||
LiftedBp::LiftedBp (const ParfactorList& pfList)
|
||||
@ -190,12 +189,12 @@ LiftedBp::rangeOfGround (const Ground& gr)
|
||||
Params
|
||||
LiftedBp::getJointByConditioning (
|
||||
const ParfactorList& pfList,
|
||||
const Grounds& grounds)
|
||||
const Grounds& query)
|
||||
{
|
||||
LiftedBp solver (pfList);
|
||||
Params prevBeliefs = solver.solveQuery ({grounds[0]});
|
||||
Grounds obsGrounds = {grounds[0]};
|
||||
for (size_t i = 1; i < grounds.size(); i++) {
|
||||
Params prevBeliefs = solver.solveQuery ({query[0]});
|
||||
Grounds obsGrounds = {query[0]};
|
||||
for (size_t i = 1; i < query.size(); i++) {
|
||||
Params newBeliefs;
|
||||
vector<ObservedFormula> obsFs;
|
||||
Ranges obsRanges;
|
||||
@ -210,16 +209,16 @@ LiftedBp::getJointByConditioning (
|
||||
obsFs[j].setEvidence (indexer[j]);
|
||||
}
|
||||
ParfactorList tempPfList (pfList);
|
||||
LiftedVe::absorveEvidence (tempPfList, obsFs);
|
||||
LiftedOperations::absorveEvidence (tempPfList, obsFs);
|
||||
LiftedBp solver (tempPfList);
|
||||
Params beliefs = solver.solveQuery ({grounds[i]});
|
||||
Params beliefs = solver.solveQuery ({query[i]});
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
int count = -1;
|
||||
unsigned range = rangeOfGround (grounds[i]);
|
||||
unsigned range = rangeOfGround (query[i]);
|
||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % range == 0) {
|
||||
count ++;
|
||||
@ -227,7 +226,7 @@ LiftedBp::getJointByConditioning (
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
obsGrounds.push_back (grounds[i]);
|
||||
obsGrounds.push_back (query[i]);
|
||||
}
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
@ -54,6 +54,67 @@ LiftedOperations::shatterAgainstQuery (
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedOperations::runWeakBayesBall (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
queue<PrvGroup> todo; // groups to process
|
||||
set<PrvGroup> done; // processed or in queue
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
PrvGroup group = (*it)->findGroup (query[i]);
|
||||
if (group != numeric_limits<PrvGroup>::max()) {
|
||||
todo.push (group);
|
||||
done.insert (group);
|
||||
break;
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
set<Parfactor*> requiredPfs;
|
||||
while (todo.empty() == false) {
|
||||
PrvGroup group = todo.front();
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false &&
|
||||
(*it)->containsGroup (group)) {
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
if (Util::contains (done, groups[i]) == false) {
|
||||
todo.push (groups[i]);
|
||||
done.insert (groups[i]);
|
||||
}
|
||||
}
|
||||
requiredPfs.insert (*it);
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
todo.pop();
|
||||
}
|
||||
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
bool foundNotRequired = false;
|
||||
while (it != pfList.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false) {
|
||||
if (Globals::verbosity > 2) {
|
||||
if (foundNotRequired == false) {
|
||||
Util::printHeader ("PARFACTORS TO DISCARD");
|
||||
foundNotRequired = true;
|
||||
}
|
||||
(*it)->print();
|
||||
}
|
||||
it = pfList.removeAndDelete (it);
|
||||
} else {
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedOperations::absorveEvidence (
|
||||
ParfactorList& pfList,
|
||||
|
@ -9,6 +9,9 @@ class LiftedOperations
|
||||
static void shatterAgainstQuery (
|
||||
ParfactorList& pfList, const Grounds& query);
|
||||
|
||||
static void runWeakBayesBall (
|
||||
ParfactorList& pfList, const Grounds&);
|
||||
|
||||
static void absorveEvidence (
|
||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||
|
||||
@ -17,7 +20,6 @@ class LiftedOperations
|
||||
static Parfactor calcGroundMultiplication (Parfactor pf);
|
||||
|
||||
private:
|
||||
|
||||
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||
};
|
||||
|
||||
|
@ -661,7 +661,7 @@ LiftedVe::runSolver (const Grounds& query)
|
||||
{
|
||||
largestCost_ = std::log (0);
|
||||
LiftedOperations::shatterAgainstQuery (pfList_, query);
|
||||
runWeakBayesBall (query);
|
||||
LiftedOperations::runWeakBayesBall (pfList_, query);
|
||||
while (true) {
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printDashedLine();
|
||||
@ -727,62 +727,3 @@ LiftedVe::getBestOperation (const Grounds& query)
|
||||
return bestOp;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedVe::runWeakBayesBall (const Grounds& query)
|
||||
{
|
||||
queue<PrvGroup> todo; // groups to process
|
||||
set<PrvGroup> done; // processed or in queue
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
PrvGroup group = (*it)->findGroup (query[i]);
|
||||
if (group != numeric_limits<PrvGroup>::max()) {
|
||||
todo.push (group);
|
||||
done.insert (group);
|
||||
break;
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
set<Parfactor*> requiredPfs;
|
||||
while (todo.empty() == false) {
|
||||
PrvGroup group = todo.front();
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false &&
|
||||
(*it)->containsGroup (group)) {
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
if (Util::contains (done, groups[i]) == false) {
|
||||
todo.push (groups[i]);
|
||||
done.insert (groups[i]);
|
||||
}
|
||||
}
|
||||
requiredPfs.insert (*it);
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
todo.pop();
|
||||
}
|
||||
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
bool foundNotRequired = false;
|
||||
while (it != pfList_.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false) {
|
||||
if (Globals::verbosity > 2) {
|
||||
if (foundNotRequired == false) {
|
||||
Util::printHeader ("PARFACTORS TO DISCARD");
|
||||
foundNotRequired = true;
|
||||
}
|
||||
(*it)->print();
|
||||
}
|
||||
it = pfList_.removeAndDelete (it);
|
||||
} else {
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,9 +45,9 @@ class ProductOperator : public LiftedOperator
|
||||
private:
|
||||
static bool validOp (Parfactor*, Parfactor*);
|
||||
|
||||
ParfactorList::iterator g1_;
|
||||
ParfactorList::iterator g2_;
|
||||
ParfactorList& pfList_;
|
||||
ParfactorList::iterator g1_;
|
||||
ParfactorList::iterator g2_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
@ -125,9 +125,9 @@ class GroundOperator : public LiftedOperator
|
||||
private:
|
||||
vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void);
|
||||
|
||||
PrvGroup group_;
|
||||
unsigned lvIndex_;
|
||||
ParfactorList& pfList_;
|
||||
PrvGroup group_;
|
||||
unsigned lvIndex_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
@ -141,23 +141,13 @@ class LiftedVe
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
|
||||
static void absorveEvidence (
|
||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||
|
||||
private:
|
||||
void runSolver (const Grounds&);
|
||||
|
||||
LiftedOperator* getBestOperation (const Grounds&);
|
||||
|
||||
void runWeakBayesBall (const Grounds&);
|
||||
|
||||
void shatterAgainstQuery (const Grounds&);
|
||||
|
||||
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||
|
||||
ParfactorList pfList_;
|
||||
|
||||
double largestCost_;
|
||||
ParfactorList pfList_;
|
||||
double largestCost_;
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDVE_H
|
||||
|
Reference in New Issue
Block a user