use PrvGroup
This commit is contained in:
parent
af6601837c
commit
18d4777d9a
@ -52,7 +52,7 @@ LiftedOperator::printValidOps (
|
||||
|
||||
vector<ParfactorList::iterator>
|
||||
LiftedOperator::getParfactorsWithGroup (
|
||||
ParfactorList& pfList, unsigned group)
|
||||
ParfactorList& pfList, PrvGroup group)
|
||||
{
|
||||
vector<ParfactorList::iterator> iters;
|
||||
ParfactorList::iterator pflIt = pfList.begin();
|
||||
@ -144,10 +144,10 @@ ProductOperator::toString (void)
|
||||
bool
|
||||
ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
|
||||
{
|
||||
TinySet<unsigned> g1_gs (g1->getAllGroups());
|
||||
TinySet<unsigned> g2_gs (g2->getAllGroups());
|
||||
TinySet<PrvGroup> g1_gs (g1->getAllGroups());
|
||||
TinySet<PrvGroup> g2_gs (g2->getAllGroups());
|
||||
if (g1_gs.contains (g2_gs) || g2_gs.contains (g1_gs)) {
|
||||
TinySet<unsigned> intersect = g1_gs & g2_gs;
|
||||
TinySet<PrvGroup> intersect = g1_gs & g2_gs;
|
||||
for (size_t i = 0; i < intersect.size(); i++) {
|
||||
if (g1->nrFormulasWithGroup (intersect[i]) != 1 ||
|
||||
g2->nrFormulasWithGroup (intersect[i]) != 1) {
|
||||
@ -169,13 +169,13 @@ ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
|
||||
double
|
||||
SumOutOperator::getLogCost (void)
|
||||
{
|
||||
TinySet<unsigned> groupSet;
|
||||
TinySet<PrvGroup> groupSet;
|
||||
ParfactorList::const_iterator pfIter = pfList_.begin();
|
||||
unsigned nrProdFactors = 0;
|
||||
while (pfIter != pfList_.end()) {
|
||||
if ((*pfIter)->containsGroup (group_)) {
|
||||
vector<unsigned> groups = (*pfIter)->getAllGroups();
|
||||
groupSet |= TinySet<unsigned> (groups);
|
||||
vector<PrvGroup> groups = (*pfIter)->getAllGroups();
|
||||
groupSet |= TinySet<PrvGroup> (groups);
|
||||
++ nrProdFactors;
|
||||
}
|
||||
++ pfIter;
|
||||
@ -239,7 +239,7 @@ SumOutOperator::getValidOps (
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<SumOutOperator*> validOps;
|
||||
set<unsigned> allGroups;
|
||||
set<PrvGroup> allGroups;
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
@ -248,7 +248,7 @@ SumOutOperator::getValidOps (
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
set<unsigned>::const_iterator groupIt = allGroups.begin();
|
||||
set<PrvGroup>::const_iterator groupIt = allGroups.begin();
|
||||
while (groupIt != allGroups.end()) {
|
||||
if (validOp (*groupIt, pfList, query)) {
|
||||
validOps.push_back (new SumOutOperator (*groupIt, pfList));
|
||||
@ -279,7 +279,7 @@ SumOutOperator::toString (void)
|
||||
|
||||
bool
|
||||
SumOutOperator::validOp (
|
||||
unsigned group,
|
||||
PrvGroup group,
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
@ -312,7 +312,7 @@ SumOutOperator::validOp (
|
||||
bool
|
||||
SumOutOperator::isToEliminate (
|
||||
Parfactor* g,
|
||||
unsigned group,
|
||||
PrvGroup group,
|
||||
const Grounds& query)
|
||||
{
|
||||
size_t fIdx = g->indexOfGroup (group);
|
||||
@ -345,7 +345,7 @@ CountingOperator::getLogCost (void)
|
||||
for (size_t i = 0; i < counts.size(); i++) {
|
||||
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
||||
}
|
||||
unsigned group = (*pfIter_)->argument (fIdx).group();
|
||||
PrvGroup group = (*pfIter_)->argument (fIdx).group();
|
||||
size_t lvIndex = Util::indexOf (
|
||||
(*pfIter_)->argument (fIdx).logVars(), X_);
|
||||
assert (lvIndex != (*pfIter_)->argument (fIdx).logVars().size());
|
||||
@ -456,7 +456,7 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
||||
double
|
||||
GroundOperator::getLogCost (void)
|
||||
{
|
||||
vector<pair<unsigned, unsigned>> affectedFormulas;
|
||||
vector<pair<PrvGroup, unsigned>> affectedFormulas;
|
||||
affectedFormulas = getAffectedFormulas();
|
||||
// cout << "affected formulas: " ;
|
||||
// for (size_t i = 0; i < affectedFormulas.size(); i++) {
|
||||
@ -542,7 +542,7 @@ vector<GroundOperator*>
|
||||
GroundOperator::getValidOps (ParfactorList& pfList)
|
||||
{
|
||||
vector<GroundOperator*> validOps;
|
||||
set<unsigned> allGroups;
|
||||
set<PrvGroup> allGroups;
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
@ -593,15 +593,15 @@ GroundOperator::toString (void)
|
||||
|
||||
|
||||
|
||||
vector<pair<unsigned, unsigned>>
|
||||
vector<pair<PrvGroup, unsigned>>
|
||||
GroundOperator::getAffectedFormulas (void)
|
||||
{
|
||||
vector<pair<unsigned, unsigned>> affectedFormulas;
|
||||
vector<pair<PrvGroup, unsigned>> affectedFormulas;
|
||||
affectedFormulas.push_back (make_pair (group_, lvIndex_));
|
||||
queue<pair<unsigned, unsigned>> q;
|
||||
queue<pair<PrvGroup, unsigned>> q;
|
||||
q.push (make_pair (group_, lvIndex_));
|
||||
while (q.empty() == false) {
|
||||
pair<unsigned, unsigned> front = q.front();
|
||||
pair<PrvGroup, unsigned> front = q.front();
|
||||
ParfactorList::iterator pflIt = pfList_.begin();
|
||||
while (pflIt != pfList_.end()) {
|
||||
size_t idx = (*pflIt)->indexOfGroup (front.first);
|
||||
@ -611,7 +611,7 @@ GroundOperator::getAffectedFormulas (void)
|
||||
const ProbFormulas& fs = (*pflIt)->arguments();
|
||||
for (size_t i = 0; i < fs.size(); i++) {
|
||||
if ((int)i != idx && fs[i].contains (X)) {
|
||||
pair<unsigned, unsigned> pair = make_pair (
|
||||
pair<PrvGroup, unsigned> pair = make_pair (
|
||||
fs[i].group(), fs[i].indexOf (X));
|
||||
if (Util::contains (affectedFormulas, pair) == false) {
|
||||
q.push (pair);
|
||||
@ -830,13 +830,13 @@ FoveSolver::getBestOperation (const Grounds& query)
|
||||
void
|
||||
FoveSolver::runWeakBayesBall (const Grounds& query)
|
||||
{
|
||||
queue<unsigned> todo; // groups to process
|
||||
set<unsigned> done; // processed or in queue
|
||||
queue<PrvGroup> todo; // groups to process
|
||||
set<PrvGroup> done; // processed or in queue
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
int group = (*it)->findGroup (query[i]);
|
||||
if (group != -1) {
|
||||
PrvGroup group = (*it)->findGroup (query[i]);
|
||||
if (group != numeric_limits<PrvGroup>::max()) {
|
||||
todo.push (group);
|
||||
done.insert (group);
|
||||
break;
|
||||
@ -847,12 +847,12 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
|
||||
|
||||
set<Parfactor*> requiredPfs;
|
||||
while (todo.empty() == false) {
|
||||
unsigned group = todo.front();
|
||||
PrvGroup group = todo.front();
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false &&
|
||||
(*it)->containsGroup (group)) {
|
||||
vector<unsigned> groups = (*it)->getAllGroups();
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
if (Util::contains (done, groups[i]) == false) {
|
||||
todo.push (groups[i]);
|
||||
|
@ -20,7 +20,7 @@ class LiftedOperator
|
||||
static void printValidOps (ParfactorList&, const Grounds&);
|
||||
|
||||
static vector<ParfactorList::iterator> getParfactorsWithGroup (
|
||||
ParfactorList&, unsigned group);
|
||||
ParfactorList&, PrvGroup group);
|
||||
};
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ class ProductOperator : public LiftedOperator
|
||||
class SumOutOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
SumOutOperator (unsigned group, ParfactorList& pfList)
|
||||
SumOutOperator (PrvGroup group, ParfactorList& pfList)
|
||||
: group_(group), pfList_(pfList) { }
|
||||
|
||||
double getLogCost (void);
|
||||
@ -66,11 +66,11 @@ class SumOutOperator : public LiftedOperator
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (unsigned, ParfactorList&, const Grounds&);
|
||||
static bool validOp (PrvGroup, ParfactorList&, const Grounds&);
|
||||
|
||||
static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
|
||||
static bool isToEliminate (Parfactor*, PrvGroup, const Grounds&);
|
||||
|
||||
unsigned group_;
|
||||
PrvGroup group_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
@ -107,7 +107,7 @@ class GroundOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
GroundOperator (
|
||||
unsigned group,
|
||||
PrvGroup group,
|
||||
unsigned lvIndex,
|
||||
ParfactorList& pfList)
|
||||
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
|
||||
@ -121,9 +121,9 @@ class GroundOperator : public LiftedOperator
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
vector<pair<unsigned, unsigned>> getAffectedFormulas (void);
|
||||
vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void);
|
||||
|
||||
unsigned group_;
|
||||
PrvGroup group_;
|
||||
unsigned lvIndex_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
@ -428,10 +428,10 @@ Parfactor::applySubstitution (const Substitution& theta)
|
||||
|
||||
|
||||
|
||||
int
|
||||
PrvGroup
|
||||
Parfactor::findGroup (const Ground& ground) const
|
||||
{
|
||||
int group = -1;
|
||||
PrvGroup group = numeric_limits<PrvGroup>::max();
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].functor() == ground.functor() &&
|
||||
args_[i].arity() == ground.arity()) {
|
||||
@ -450,13 +450,13 @@ Parfactor::findGroup (const Ground& ground) const
|
||||
bool
|
||||
Parfactor::containsGround (const Ground& ground) const
|
||||
{
|
||||
return findGroup (ground) != -1;
|
||||
return findGroup (ground) != numeric_limits<PrvGroup>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGroup (unsigned group) const
|
||||
Parfactor::containsGroup (PrvGroup group) const
|
||||
{
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].group() == group) {
|
||||
@ -499,7 +499,7 @@ Parfactor::indexOfLogVar (LogVar X) const
|
||||
|
||||
|
||||
int
|
||||
Parfactor::indexOfGroup (unsigned group) const
|
||||
Parfactor::indexOfGroup (PrvGroup group) const
|
||||
{
|
||||
size_t pos = args_.size();
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
@ -514,7 +514,7 @@ Parfactor::indexOfGroup (unsigned group) const
|
||||
|
||||
|
||||
unsigned
|
||||
Parfactor::nrFormulasWithGroup (unsigned group) const
|
||||
Parfactor::nrFormulasWithGroup (PrvGroup group) const
|
||||
{
|
||||
unsigned count = 0;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
@ -527,10 +527,10 @@ Parfactor::nrFormulasWithGroup (unsigned group) const
|
||||
|
||||
|
||||
|
||||
vector<unsigned>
|
||||
vector<PrvGroup>
|
||||
Parfactor::getAllGroups (void) const
|
||||
{
|
||||
vector<unsigned> groups (args_.size());
|
||||
vector<PrvGroup> groups (args_.size());
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
groups[i] = args_[i].group();
|
||||
}
|
||||
@ -726,8 +726,11 @@ Parfactor::simplifyCountingFormulas (size_t fIdx)
|
||||
void
|
||||
Parfactor::simplifyGrounds (void)
|
||||
{
|
||||
if (args_.size() == 1) {
|
||||
return;
|
||||
}
|
||||
LogVarSet singletons = constr_->singletons();
|
||||
for (long i = 0; i < (int)args_.size() - 1; i++) {
|
||||
for (long i = 0; i < (long)args_.size() - 1; i++) {
|
||||
for (size_t j = i + 1; j < args_.size(); j++) {
|
||||
if (args_[i].group() == args_[j].group() &&
|
||||
singletons.contains (args_[i].logVarSet()) &&
|
||||
|
@ -18,7 +18,7 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
const ProbFormulas&,
|
||||
const Params&,
|
||||
const Tuples&,
|
||||
unsigned);
|
||||
unsigned distId);
|
||||
|
||||
Parfactor (const Parfactor*, const Tuple&);
|
||||
|
||||
@ -66,21 +66,21 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
void applySubstitution (const Substitution&);
|
||||
|
||||
int findGroup (const Ground&) const;
|
||||
PrvGroup findGroup (const Ground&) const;
|
||||
|
||||
bool containsGround (const Ground&) const;
|
||||
|
||||
bool containsGroup (unsigned) const;
|
||||
bool containsGroup (PrvGroup) const;
|
||||
|
||||
unsigned nrFormulas (LogVar) const;
|
||||
|
||||
int indexOfLogVar (LogVar) const;
|
||||
|
||||
int indexOfGroup (unsigned) const;
|
||||
int indexOfGroup (PrvGroup) const;
|
||||
|
||||
unsigned nrFormulasWithGroup (unsigned) const;
|
||||
unsigned nrFormulasWithGroup (PrvGroup) const;
|
||||
|
||||
vector<unsigned> getAllGroups (void) const;
|
||||
vector<PrvGroup> getAllGroups (void) const;
|
||||
|
||||
void print (bool = false) const;
|
||||
|
||||
|
@ -340,7 +340,7 @@ ParfactorList::shatterAgainstMySelf (
|
||||
return { };
|
||||
}
|
||||
|
||||
unsigned newGroup = ProbFormula::getNewGroup();
|
||||
PrvGroup newGroup = ProbFormula::getNewGroup();
|
||||
Parfactors res1 = shatter (g, fIdx1, commCt1, exclCt1, newGroup);
|
||||
if (res1.empty()) {
|
||||
res1.push_back (g);
|
||||
@ -488,7 +488,7 @@ ParfactorList::shatter (
|
||||
return { };
|
||||
}
|
||||
|
||||
unsigned group;
|
||||
PrvGroup group;
|
||||
if (exclCt1->empty()) {
|
||||
group = f1.group();
|
||||
} else if (exclCt2->empty()) {
|
||||
@ -509,7 +509,7 @@ ParfactorList::shatter (
|
||||
size_t fIdx,
|
||||
ConstraintTree* commCt,
|
||||
ConstraintTree* exclCt,
|
||||
unsigned commGroup)
|
||||
PrvGroup commGroup)
|
||||
{
|
||||
ProbFormula& f = g->argument (fIdx);
|
||||
if (exclCt->empty()) {
|
||||
@ -556,7 +556,7 @@ ParfactorList::shatter (
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::updateGroups (unsigned oldGroup, unsigned newGroup)
|
||||
ParfactorList::updateGroups (PrvGroup oldGroup, PrvGroup newGroup)
|
||||
{
|
||||
for (ParfactorList::iterator it = pfList_.begin();
|
||||
it != pfList_.end(); it++) {
|
||||
|
@ -80,12 +80,12 @@ class ParfactorList
|
||||
|
||||
Parfactors shatter (
|
||||
Parfactor*,
|
||||
unsigned,
|
||||
size_t,
|
||||
ConstraintTree*,
|
||||
ConstraintTree*,
|
||||
unsigned);
|
||||
PrvGroup);
|
||||
|
||||
void updateGroups (unsigned group1, unsigned group2);
|
||||
void updateGroups (PrvGroup group1, PrvGroup group2);
|
||||
|
||||
bool proper (
|
||||
const ProbFormula&, ConstraintTree,
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include "ProbFormula.h"
|
||||
|
||||
|
||||
int ProbFormula::freeGroup_ = 0;
|
||||
PrvGroup ProbFormula::freeGroup_ = 0;
|
||||
|
||||
|
||||
|
||||
@ -120,10 +120,11 @@ std::ostream& operator<< (ostream &os, const ProbFormula& f)
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
PrvGroup
|
||||
ProbFormula::getNewGroup (void)
|
||||
{
|
||||
freeGroup_ ++;
|
||||
assert (freeGroup_ != numeric_limits<PrvGroup>::max());
|
||||
return freeGroup_;
|
||||
}
|
||||
|
||||
|
@ -7,17 +7,17 @@
|
||||
#include "LiftedUtils.h"
|
||||
#include "Horus.h"
|
||||
|
||||
typedef unsigned PrvGroup;
|
||||
typedef unsigned long PrvGroup;
|
||||
|
||||
class ProbFormula
|
||||
{
|
||||
public:
|
||||
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
|
||||
: functor_(f), logVars_(lvs), range_(range),
|
||||
countedLogVar_(), group_(Util::maxUnsigned()) { }
|
||||
countedLogVar_(), group_(numeric_limits<PrvGroup>::max()) { }
|
||||
|
||||
ProbFormula (Symbol f, unsigned r)
|
||||
: functor_(f), range_(r), group_(Util::maxUnsigned()) { }
|
||||
: functor_(f), range_(r), group_(numeric_limits<PrvGroup>::max()) { }
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
|
||||
@ -31,9 +31,9 @@ class ProbFormula
|
||||
|
||||
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||
|
||||
unsigned group (void) const { return group_; }
|
||||
PrvGroup group (void) const { return group_; }
|
||||
|
||||
void setGroup (unsigned g) { group_ = g; }
|
||||
void setGroup (PrvGroup g) { group_ = g; }
|
||||
|
||||
bool sameSkeletonAs (const ProbFormula&) const;
|
||||
|
||||
@ -55,7 +55,7 @@ class ProbFormula
|
||||
|
||||
void rename (LogVar, LogVar);
|
||||
|
||||
static unsigned getNewGroup (void);
|
||||
static PrvGroup getNewGroup (void);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
|
||||
|
||||
@ -66,8 +66,8 @@ class ProbFormula
|
||||
LogVars logVars_;
|
||||
unsigned range_;
|
||||
LogVar countedLogVar_;
|
||||
unsigned group_;
|
||||
static int freeGroup_;
|
||||
PrvGroup group_;
|
||||
static PrvGroup freeGroup_;
|
||||
};
|
||||
|
||||
typedef vector<ProbFormula> ProbFormulas;
|
||||
|
Reference in New Issue
Block a user