use PrvGroup

This commit is contained in:
Tiago Gomes 2012-05-24 23:38:44 +01:00
parent af6601837c
commit 18d4777d9a
8 changed files with 69 additions and 65 deletions

View File

@ -52,7 +52,7 @@ LiftedOperator::printValidOps (
vector<ParfactorList::iterator> vector<ParfactorList::iterator>
LiftedOperator::getParfactorsWithGroup ( LiftedOperator::getParfactorsWithGroup (
ParfactorList& pfList, unsigned group) ParfactorList& pfList, PrvGroup group)
{ {
vector<ParfactorList::iterator> iters; vector<ParfactorList::iterator> iters;
ParfactorList::iterator pflIt = pfList.begin(); ParfactorList::iterator pflIt = pfList.begin();
@ -144,10 +144,10 @@ ProductOperator::toString (void)
bool bool
ProductOperator::validOp (Parfactor* g1, Parfactor* g2) ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
{ {
TinySet<unsigned> g1_gs (g1->getAllGroups()); TinySet<PrvGroup> g1_gs (g1->getAllGroups());
TinySet<unsigned> g2_gs (g2->getAllGroups()); TinySet<PrvGroup> g2_gs (g2->getAllGroups());
if (g1_gs.contains (g2_gs) || g2_gs.contains (g1_gs)) { if (g1_gs.contains (g2_gs) || g2_gs.contains (g1_gs)) {
TinySet<unsigned> intersect = g1_gs & g2_gs; TinySet<PrvGroup> intersect = g1_gs & g2_gs;
for (size_t i = 0; i < intersect.size(); i++) { for (size_t i = 0; i < intersect.size(); i++) {
if (g1->nrFormulasWithGroup (intersect[i]) != 1 || if (g1->nrFormulasWithGroup (intersect[i]) != 1 ||
g2->nrFormulasWithGroup (intersect[i]) != 1) { g2->nrFormulasWithGroup (intersect[i]) != 1) {
@ -169,13 +169,13 @@ ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
double double
SumOutOperator::getLogCost (void) SumOutOperator::getLogCost (void)
{ {
TinySet<unsigned> groupSet; TinySet<PrvGroup> groupSet;
ParfactorList::const_iterator pfIter = pfList_.begin(); ParfactorList::const_iterator pfIter = pfList_.begin();
unsigned nrProdFactors = 0; unsigned nrProdFactors = 0;
while (pfIter != pfList_.end()) { while (pfIter != pfList_.end()) {
if ((*pfIter)->containsGroup (group_)) { if ((*pfIter)->containsGroup (group_)) {
vector<unsigned> groups = (*pfIter)->getAllGroups(); vector<PrvGroup> groups = (*pfIter)->getAllGroups();
groupSet |= TinySet<unsigned> (groups); groupSet |= TinySet<PrvGroup> (groups);
++ nrProdFactors; ++ nrProdFactors;
} }
++ pfIter; ++ pfIter;
@ -239,7 +239,7 @@ SumOutOperator::getValidOps (
const Grounds& query) const Grounds& query)
{ {
vector<SumOutOperator*> validOps; vector<SumOutOperator*> validOps;
set<unsigned> allGroups; set<PrvGroup> allGroups;
ParfactorList::const_iterator it = pfList.begin(); ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) { while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments(); const ProbFormulas& formulas = (*it)->arguments();
@ -248,7 +248,7 @@ SumOutOperator::getValidOps (
} }
++ it; ++ it;
} }
set<unsigned>::const_iterator groupIt = allGroups.begin(); set<PrvGroup>::const_iterator groupIt = allGroups.begin();
while (groupIt != allGroups.end()) { while (groupIt != allGroups.end()) {
if (validOp (*groupIt, pfList, query)) { if (validOp (*groupIt, pfList, query)) {
validOps.push_back (new SumOutOperator (*groupIt, pfList)); validOps.push_back (new SumOutOperator (*groupIt, pfList));
@ -279,7 +279,7 @@ SumOutOperator::toString (void)
bool bool
SumOutOperator::validOp ( SumOutOperator::validOp (
unsigned group, PrvGroup group,
ParfactorList& pfList, ParfactorList& pfList,
const Grounds& query) const Grounds& query)
{ {
@ -312,7 +312,7 @@ SumOutOperator::validOp (
bool bool
SumOutOperator::isToEliminate ( SumOutOperator::isToEliminate (
Parfactor* g, Parfactor* g,
unsigned group, PrvGroup group,
const Grounds& query) const Grounds& query)
{ {
size_t fIdx = g->indexOfGroup (group); size_t fIdx = g->indexOfGroup (group);
@ -345,7 +345,7 @@ CountingOperator::getLogCost (void)
for (size_t i = 0; i < counts.size(); i++) { for (size_t i = 0; i < counts.size(); i++) {
cost += size * HistogramSet::nrHistograms (counts[i], range); cost += size * HistogramSet::nrHistograms (counts[i], range);
} }
unsigned group = (*pfIter_)->argument (fIdx).group(); PrvGroup group = (*pfIter_)->argument (fIdx).group();
size_t lvIndex = Util::indexOf ( size_t lvIndex = Util::indexOf (
(*pfIter_)->argument (fIdx).logVars(), X_); (*pfIter_)->argument (fIdx).logVars(), X_);
assert (lvIndex != (*pfIter_)->argument (fIdx).logVars().size()); assert (lvIndex != (*pfIter_)->argument (fIdx).logVars().size());
@ -456,7 +456,7 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
double double
GroundOperator::getLogCost (void) GroundOperator::getLogCost (void)
{ {
vector<pair<unsigned, unsigned>> affectedFormulas; vector<pair<PrvGroup, unsigned>> affectedFormulas;
affectedFormulas = getAffectedFormulas(); affectedFormulas = getAffectedFormulas();
// cout << "affected formulas: " ; // cout << "affected formulas: " ;
// for (size_t i = 0; i < affectedFormulas.size(); i++) { // for (size_t i = 0; i < affectedFormulas.size(); i++) {
@ -542,7 +542,7 @@ vector<GroundOperator*>
GroundOperator::getValidOps (ParfactorList& pfList) GroundOperator::getValidOps (ParfactorList& pfList)
{ {
vector<GroundOperator*> validOps; vector<GroundOperator*> validOps;
set<unsigned> allGroups; set<PrvGroup> allGroups;
ParfactorList::const_iterator it = pfList.begin(); ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) { while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments(); const ProbFormulas& formulas = (*it)->arguments();
@ -593,15 +593,15 @@ GroundOperator::toString (void)
vector<pair<unsigned, unsigned>> vector<pair<PrvGroup, unsigned>>
GroundOperator::getAffectedFormulas (void) GroundOperator::getAffectedFormulas (void)
{ {
vector<pair<unsigned, unsigned>> affectedFormulas; vector<pair<PrvGroup, unsigned>> affectedFormulas;
affectedFormulas.push_back (make_pair (group_, lvIndex_)); affectedFormulas.push_back (make_pair (group_, lvIndex_));
queue<pair<unsigned, unsigned>> q; queue<pair<PrvGroup, unsigned>> q;
q.push (make_pair (group_, lvIndex_)); q.push (make_pair (group_, lvIndex_));
while (q.empty() == false) { while (q.empty() == false) {
pair<unsigned, unsigned> front = q.front(); pair<PrvGroup, unsigned> front = q.front();
ParfactorList::iterator pflIt = pfList_.begin(); ParfactorList::iterator pflIt = pfList_.begin();
while (pflIt != pfList_.end()) { while (pflIt != pfList_.end()) {
size_t idx = (*pflIt)->indexOfGroup (front.first); size_t idx = (*pflIt)->indexOfGroup (front.first);
@ -611,7 +611,7 @@ GroundOperator::getAffectedFormulas (void)
const ProbFormulas& fs = (*pflIt)->arguments(); const ProbFormulas& fs = (*pflIt)->arguments();
for (size_t i = 0; i < fs.size(); i++) { for (size_t i = 0; i < fs.size(); i++) {
if ((int)i != idx && fs[i].contains (X)) { if ((int)i != idx && fs[i].contains (X)) {
pair<unsigned, unsigned> pair = make_pair ( pair<PrvGroup, unsigned> pair = make_pair (
fs[i].group(), fs[i].indexOf (X)); fs[i].group(), fs[i].indexOf (X));
if (Util::contains (affectedFormulas, pair) == false) { if (Util::contains (affectedFormulas, pair) == false) {
q.push (pair); q.push (pair);
@ -830,13 +830,13 @@ FoveSolver::getBestOperation (const Grounds& query)
void void
FoveSolver::runWeakBayesBall (const Grounds& query) FoveSolver::runWeakBayesBall (const Grounds& query)
{ {
queue<unsigned> todo; // groups to process queue<PrvGroup> todo; // groups to process
set<unsigned> done; // processed or in queue set<PrvGroup> done; // processed or in queue
for (size_t i = 0; i < query.size(); i++) { for (size_t i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList_.begin(); ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) { while (it != pfList_.end()) {
int group = (*it)->findGroup (query[i]); PrvGroup group = (*it)->findGroup (query[i]);
if (group != -1) { if (group != numeric_limits<PrvGroup>::max()) {
todo.push (group); todo.push (group);
done.insert (group); done.insert (group);
break; break;
@ -847,12 +847,12 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
set<Parfactor*> requiredPfs; set<Parfactor*> requiredPfs;
while (todo.empty() == false) { while (todo.empty() == false) {
unsigned group = todo.front(); PrvGroup group = todo.front();
ParfactorList::iterator it = pfList_.begin(); ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) { while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false && if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) { (*it)->containsGroup (group)) {
vector<unsigned> groups = (*it)->getAllGroups(); vector<PrvGroup> groups = (*it)->getAllGroups();
for (size_t i = 0; i < groups.size(); i++) { for (size_t i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) { if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]); todo.push (groups[i]);

View File

@ -20,7 +20,7 @@ class LiftedOperator
static void printValidOps (ParfactorList&, const Grounds&); static void printValidOps (ParfactorList&, const Grounds&);
static vector<ParfactorList::iterator> getParfactorsWithGroup ( static vector<ParfactorList::iterator> getParfactorsWithGroup (
ParfactorList&, unsigned group); ParfactorList&, PrvGroup group);
}; };
@ -53,7 +53,7 @@ class ProductOperator : public LiftedOperator
class SumOutOperator : public LiftedOperator class SumOutOperator : public LiftedOperator
{ {
public: public:
SumOutOperator (unsigned group, ParfactorList& pfList) SumOutOperator (PrvGroup group, ParfactorList& pfList)
: group_(group), pfList_(pfList) { } : group_(group), pfList_(pfList) { }
double getLogCost (void); double getLogCost (void);
@ -66,11 +66,11 @@ class SumOutOperator : public LiftedOperator
string toString (void); string toString (void);
private: private:
static bool validOp (unsigned, ParfactorList&, const Grounds&); static bool validOp (PrvGroup, ParfactorList&, const Grounds&);
static bool isToEliminate (Parfactor*, unsigned, const Grounds&); static bool isToEliminate (Parfactor*, PrvGroup, const Grounds&);
unsigned group_; PrvGroup group_;
ParfactorList& pfList_; ParfactorList& pfList_;
}; };
@ -107,7 +107,7 @@ class GroundOperator : public LiftedOperator
{ {
public: public:
GroundOperator ( GroundOperator (
unsigned group, PrvGroup group,
unsigned lvIndex, unsigned lvIndex,
ParfactorList& pfList) ParfactorList& pfList)
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { } : group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
@ -121,9 +121,9 @@ class GroundOperator : public LiftedOperator
string toString (void); string toString (void);
private: private:
vector<pair<unsigned, unsigned>> getAffectedFormulas (void); vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void);
unsigned group_; PrvGroup group_;
unsigned lvIndex_; unsigned lvIndex_;
ParfactorList& pfList_; ParfactorList& pfList_;
}; };

View File

@ -428,10 +428,10 @@ Parfactor::applySubstitution (const Substitution& theta)
int PrvGroup
Parfactor::findGroup (const Ground& ground) const Parfactor::findGroup (const Ground& ground) const
{ {
int group = -1; PrvGroup group = numeric_limits<PrvGroup>::max();
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
if (args_[i].functor() == ground.functor() && if (args_[i].functor() == ground.functor() &&
args_[i].arity() == ground.arity()) { args_[i].arity() == ground.arity()) {
@ -450,13 +450,13 @@ Parfactor::findGroup (const Ground& ground) const
bool bool
Parfactor::containsGround (const Ground& ground) const Parfactor::containsGround (const Ground& ground) const
{ {
return findGroup (ground) != -1; return findGroup (ground) != numeric_limits<PrvGroup>::max();
} }
bool bool
Parfactor::containsGroup (unsigned group) const Parfactor::containsGroup (PrvGroup group) const
{ {
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
if (args_[i].group() == group) { if (args_[i].group() == group) {
@ -499,7 +499,7 @@ Parfactor::indexOfLogVar (LogVar X) const
int int
Parfactor::indexOfGroup (unsigned group) const Parfactor::indexOfGroup (PrvGroup group) const
{ {
size_t pos = args_.size(); size_t pos = args_.size();
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
@ -514,7 +514,7 @@ Parfactor::indexOfGroup (unsigned group) const
unsigned unsigned
Parfactor::nrFormulasWithGroup (unsigned group) const Parfactor::nrFormulasWithGroup (PrvGroup group) const
{ {
unsigned count = 0; unsigned count = 0;
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
@ -527,10 +527,10 @@ Parfactor::nrFormulasWithGroup (unsigned group) const
vector<unsigned> vector<PrvGroup>
Parfactor::getAllGroups (void) const Parfactor::getAllGroups (void) const
{ {
vector<unsigned> groups (args_.size()); vector<PrvGroup> groups (args_.size());
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
groups[i] = args_[i].group(); groups[i] = args_[i].group();
} }
@ -726,8 +726,11 @@ Parfactor::simplifyCountingFormulas (size_t fIdx)
void void
Parfactor::simplifyGrounds (void) Parfactor::simplifyGrounds (void)
{ {
if (args_.size() == 1) {
return;
}
LogVarSet singletons = constr_->singletons(); LogVarSet singletons = constr_->singletons();
for (long i = 0; i < (int)args_.size() - 1; i++) { for (long i = 0; i < (long)args_.size() - 1; i++) {
for (size_t j = i + 1; j < args_.size(); j++) { for (size_t j = i + 1; j < args_.size(); j++) {
if (args_[i].group() == args_[j].group() && if (args_[i].group() == args_[j].group() &&
singletons.contains (args_[i].logVarSet()) && singletons.contains (args_[i].logVarSet()) &&

View File

@ -18,7 +18,7 @@ class Parfactor : public TFactor<ProbFormula>
const ProbFormulas&, const ProbFormulas&,
const Params&, const Params&,
const Tuples&, const Tuples&,
unsigned); unsigned distId);
Parfactor (const Parfactor*, const Tuple&); Parfactor (const Parfactor*, const Tuple&);
@ -66,21 +66,21 @@ class Parfactor : public TFactor<ProbFormula>
void applySubstitution (const Substitution&); void applySubstitution (const Substitution&);
int findGroup (const Ground&) const; PrvGroup findGroup (const Ground&) const;
bool containsGround (const Ground&) const; bool containsGround (const Ground&) const;
bool containsGroup (unsigned) const; bool containsGroup (PrvGroup) const;
unsigned nrFormulas (LogVar) const; unsigned nrFormulas (LogVar) const;
int indexOfLogVar (LogVar) const; int indexOfLogVar (LogVar) const;
int indexOfGroup (unsigned) const; int indexOfGroup (PrvGroup) const;
unsigned nrFormulasWithGroup (unsigned) const; unsigned nrFormulasWithGroup (PrvGroup) const;
vector<unsigned> getAllGroups (void) const; vector<PrvGroup> getAllGroups (void) const;
void print (bool = false) const; void print (bool = false) const;

View File

@ -340,7 +340,7 @@ ParfactorList::shatterAgainstMySelf (
return { }; return { };
} }
unsigned newGroup = ProbFormula::getNewGroup(); PrvGroup newGroup = ProbFormula::getNewGroup();
Parfactors res1 = shatter (g, fIdx1, commCt1, exclCt1, newGroup); Parfactors res1 = shatter (g, fIdx1, commCt1, exclCt1, newGroup);
if (res1.empty()) { if (res1.empty()) {
res1.push_back (g); res1.push_back (g);
@ -488,7 +488,7 @@ ParfactorList::shatter (
return { }; return { };
} }
unsigned group; PrvGroup group;
if (exclCt1->empty()) { if (exclCt1->empty()) {
group = f1.group(); group = f1.group();
} else if (exclCt2->empty()) { } else if (exclCt2->empty()) {
@ -509,7 +509,7 @@ ParfactorList::shatter (
size_t fIdx, size_t fIdx,
ConstraintTree* commCt, ConstraintTree* commCt,
ConstraintTree* exclCt, ConstraintTree* exclCt,
unsigned commGroup) PrvGroup commGroup)
{ {
ProbFormula& f = g->argument (fIdx); ProbFormula& f = g->argument (fIdx);
if (exclCt->empty()) { if (exclCt->empty()) {
@ -556,7 +556,7 @@ ParfactorList::shatter (
void void
ParfactorList::updateGroups (unsigned oldGroup, unsigned newGroup) ParfactorList::updateGroups (PrvGroup oldGroup, PrvGroup newGroup)
{ {
for (ParfactorList::iterator it = pfList_.begin(); for (ParfactorList::iterator it = pfList_.begin();
it != pfList_.end(); it++) { it != pfList_.end(); it++) {

View File

@ -80,12 +80,12 @@ class ParfactorList
Parfactors shatter ( Parfactors shatter (
Parfactor*, Parfactor*,
unsigned, size_t,
ConstraintTree*, ConstraintTree*,
ConstraintTree*, ConstraintTree*,
unsigned); PrvGroup);
void updateGroups (unsigned group1, unsigned group2); void updateGroups (PrvGroup group1, PrvGroup group2);
bool proper ( bool proper (
const ProbFormula&, ConstraintTree, const ProbFormula&, ConstraintTree,

View File

@ -1,7 +1,7 @@
#include "ProbFormula.h" #include "ProbFormula.h"
int ProbFormula::freeGroup_ = 0; PrvGroup ProbFormula::freeGroup_ = 0;
@ -120,10 +120,11 @@ std::ostream& operator<< (ostream &os, const ProbFormula& f)
unsigned PrvGroup
ProbFormula::getNewGroup (void) ProbFormula::getNewGroup (void)
{ {
freeGroup_ ++; freeGroup_ ++;
assert (freeGroup_ != numeric_limits<PrvGroup>::max());
return freeGroup_; return freeGroup_;
} }

View File

@ -7,17 +7,17 @@
#include "LiftedUtils.h" #include "LiftedUtils.h"
#include "Horus.h" #include "Horus.h"
typedef unsigned PrvGroup; typedef unsigned long PrvGroup;
class ProbFormula class ProbFormula
{ {
public: public:
ProbFormula (Symbol f, const LogVars& lvs, unsigned range) ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
: functor_(f), logVars_(lvs), range_(range), : functor_(f), logVars_(lvs), range_(range),
countedLogVar_(), group_(Util::maxUnsigned()) { } countedLogVar_(), group_(numeric_limits<PrvGroup>::max()) { }
ProbFormula (Symbol f, unsigned r) ProbFormula (Symbol f, unsigned r)
: functor_(f), range_(r), group_(Util::maxUnsigned()) { } : functor_(f), range_(r), group_(numeric_limits<PrvGroup>::max()) { }
Symbol functor (void) const { return functor_; } Symbol functor (void) const { return functor_; }
@ -31,9 +31,9 @@ class ProbFormula
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); } LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
unsigned group (void) const { return group_; } PrvGroup group (void) const { return group_; }
void setGroup (unsigned g) { group_ = g; } void setGroup (PrvGroup g) { group_ = g; }
bool sameSkeletonAs (const ProbFormula&) const; bool sameSkeletonAs (const ProbFormula&) const;
@ -55,7 +55,7 @@ class ProbFormula
void rename (LogVar, LogVar); void rename (LogVar, LogVar);
static unsigned getNewGroup (void); static PrvGroup getNewGroup (void);
friend std::ostream& operator<< (ostream &os, const ProbFormula& f); friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
@ -66,8 +66,8 @@ class ProbFormula
LogVars logVars_; LogVars logVars_;
unsigned range_; unsigned range_;
LogVar countedLogVar_; LogVar countedLogVar_;
unsigned group_; PrvGroup group_;
static int freeGroup_; static PrvGroup freeGroup_;
}; };
typedef vector<ProbFormula> ProbFormulas; typedef vector<ProbFormula> ProbFormulas;