use PrvGroup

This commit is contained in:
Tiago Gomes 2012-05-24 23:38:44 +01:00
parent af6601837c
commit 18d4777d9a
8 changed files with 69 additions and 65 deletions

View File

@ -52,7 +52,7 @@ LiftedOperator::printValidOps (
vector<ParfactorList::iterator>
LiftedOperator::getParfactorsWithGroup (
ParfactorList& pfList, unsigned group)
ParfactorList& pfList, PrvGroup group)
{
vector<ParfactorList::iterator> iters;
ParfactorList::iterator pflIt = pfList.begin();
@ -144,10 +144,10 @@ ProductOperator::toString (void)
bool
ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
{
TinySet<unsigned> g1_gs (g1->getAllGroups());
TinySet<unsigned> g2_gs (g2->getAllGroups());
TinySet<PrvGroup> g1_gs (g1->getAllGroups());
TinySet<PrvGroup> g2_gs (g2->getAllGroups());
if (g1_gs.contains (g2_gs) || g2_gs.contains (g1_gs)) {
TinySet<unsigned> intersect = g1_gs & g2_gs;
TinySet<PrvGroup> intersect = g1_gs & g2_gs;
for (size_t i = 0; i < intersect.size(); i++) {
if (g1->nrFormulasWithGroup (intersect[i]) != 1 ||
g2->nrFormulasWithGroup (intersect[i]) != 1) {
@ -169,13 +169,13 @@ ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
double
SumOutOperator::getLogCost (void)
{
TinySet<unsigned> groupSet;
TinySet<PrvGroup> groupSet;
ParfactorList::const_iterator pfIter = pfList_.begin();
unsigned nrProdFactors = 0;
while (pfIter != pfList_.end()) {
if ((*pfIter)->containsGroup (group_)) {
vector<unsigned> groups = (*pfIter)->getAllGroups();
groupSet |= TinySet<unsigned> (groups);
vector<PrvGroup> groups = (*pfIter)->getAllGroups();
groupSet |= TinySet<PrvGroup> (groups);
++ nrProdFactors;
}
++ pfIter;
@ -239,7 +239,7 @@ SumOutOperator::getValidOps (
const Grounds& query)
{
vector<SumOutOperator*> validOps;
set<unsigned> allGroups;
set<PrvGroup> allGroups;
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments();
@ -248,7 +248,7 @@ SumOutOperator::getValidOps (
}
++ it;
}
set<unsigned>::const_iterator groupIt = allGroups.begin();
set<PrvGroup>::const_iterator groupIt = allGroups.begin();
while (groupIt != allGroups.end()) {
if (validOp (*groupIt, pfList, query)) {
validOps.push_back (new SumOutOperator (*groupIt, pfList));
@ -279,7 +279,7 @@ SumOutOperator::toString (void)
bool
SumOutOperator::validOp (
unsigned group,
PrvGroup group,
ParfactorList& pfList,
const Grounds& query)
{
@ -312,7 +312,7 @@ SumOutOperator::validOp (
bool
SumOutOperator::isToEliminate (
Parfactor* g,
unsigned group,
PrvGroup group,
const Grounds& query)
{
size_t fIdx = g->indexOfGroup (group);
@ -345,7 +345,7 @@ CountingOperator::getLogCost (void)
for (size_t i = 0; i < counts.size(); i++) {
cost += size * HistogramSet::nrHistograms (counts[i], range);
}
unsigned group = (*pfIter_)->argument (fIdx).group();
PrvGroup group = (*pfIter_)->argument (fIdx).group();
size_t lvIndex = Util::indexOf (
(*pfIter_)->argument (fIdx).logVars(), X_);
assert (lvIndex != (*pfIter_)->argument (fIdx).logVars().size());
@ -456,7 +456,7 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
double
GroundOperator::getLogCost (void)
{
vector<pair<unsigned, unsigned>> affectedFormulas;
vector<pair<PrvGroup, unsigned>> affectedFormulas;
affectedFormulas = getAffectedFormulas();
// cout << "affected formulas: " ;
// for (size_t i = 0; i < affectedFormulas.size(); i++) {
@ -542,7 +542,7 @@ vector<GroundOperator*>
GroundOperator::getValidOps (ParfactorList& pfList)
{
vector<GroundOperator*> validOps;
set<unsigned> allGroups;
set<PrvGroup> allGroups;
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments();
@ -593,15 +593,15 @@ GroundOperator::toString (void)
vector<pair<unsigned, unsigned>>
vector<pair<PrvGroup, unsigned>>
GroundOperator::getAffectedFormulas (void)
{
vector<pair<unsigned, unsigned>> affectedFormulas;
vector<pair<PrvGroup, unsigned>> affectedFormulas;
affectedFormulas.push_back (make_pair (group_, lvIndex_));
queue<pair<unsigned, unsigned>> q;
queue<pair<PrvGroup, unsigned>> q;
q.push (make_pair (group_, lvIndex_));
while (q.empty() == false) {
pair<unsigned, unsigned> front = q.front();
pair<PrvGroup, unsigned> front = q.front();
ParfactorList::iterator pflIt = pfList_.begin();
while (pflIt != pfList_.end()) {
size_t idx = (*pflIt)->indexOfGroup (front.first);
@ -611,7 +611,7 @@ GroundOperator::getAffectedFormulas (void)
const ProbFormulas& fs = (*pflIt)->arguments();
for (size_t i = 0; i < fs.size(); i++) {
if ((int)i != idx && fs[i].contains (X)) {
pair<unsigned, unsigned> pair = make_pair (
pair<PrvGroup, unsigned> pair = make_pair (
fs[i].group(), fs[i].indexOf (X));
if (Util::contains (affectedFormulas, pair) == false) {
q.push (pair);
@ -830,13 +830,13 @@ FoveSolver::getBestOperation (const Grounds& query)
void
FoveSolver::runWeakBayesBall (const Grounds& query)
{
queue<unsigned> todo; // groups to process
set<unsigned> done; // processed or in queue
queue<PrvGroup> todo; // groups to process
set<PrvGroup> done; // processed or in queue
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
int group = (*it)->findGroup (query[i]);
if (group != -1) {
PrvGroup group = (*it)->findGroup (query[i]);
if (group != numeric_limits<PrvGroup>::max()) {
todo.push (group);
done.insert (group);
break;
@ -847,12 +847,12 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
set<Parfactor*> requiredPfs;
while (todo.empty() == false) {
unsigned group = todo.front();
PrvGroup group = todo.front();
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) {
vector<unsigned> groups = (*it)->getAllGroups();
vector<PrvGroup> groups = (*it)->getAllGroups();
for (size_t i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]);

View File

@ -20,7 +20,7 @@ class LiftedOperator
static void printValidOps (ParfactorList&, const Grounds&);
static vector<ParfactorList::iterator> getParfactorsWithGroup (
ParfactorList&, unsigned group);
ParfactorList&, PrvGroup group);
};
@ -53,7 +53,7 @@ class ProductOperator : public LiftedOperator
class SumOutOperator : public LiftedOperator
{
public:
SumOutOperator (unsigned group, ParfactorList& pfList)
SumOutOperator (PrvGroup group, ParfactorList& pfList)
: group_(group), pfList_(pfList) { }
double getLogCost (void);
@ -66,11 +66,11 @@ class SumOutOperator : public LiftedOperator
string toString (void);
private:
static bool validOp (unsigned, ParfactorList&, const Grounds&);
static bool validOp (PrvGroup, ParfactorList&, const Grounds&);
static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
static bool isToEliminate (Parfactor*, PrvGroup, const Grounds&);
unsigned group_;
PrvGroup group_;
ParfactorList& pfList_;
};
@ -107,7 +107,7 @@ class GroundOperator : public LiftedOperator
{
public:
GroundOperator (
unsigned group,
PrvGroup group,
unsigned lvIndex,
ParfactorList& pfList)
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
@ -121,9 +121,9 @@ class GroundOperator : public LiftedOperator
string toString (void);
private:
vector<pair<unsigned, unsigned>> getAffectedFormulas (void);
vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void);
unsigned group_;
PrvGroup group_;
unsigned lvIndex_;
ParfactorList& pfList_;
};

View File

@ -428,10 +428,10 @@ Parfactor::applySubstitution (const Substitution& theta)
int
PrvGroup
Parfactor::findGroup (const Ground& ground) const
{
int group = -1;
PrvGroup group = numeric_limits<PrvGroup>::max();
for (size_t i = 0; i < args_.size(); i++) {
if (args_[i].functor() == ground.functor() &&
args_[i].arity() == ground.arity()) {
@ -450,13 +450,13 @@ Parfactor::findGroup (const Ground& ground) const
bool
Parfactor::containsGround (const Ground& ground) const
{
return findGroup (ground) != -1;
return findGroup (ground) != numeric_limits<PrvGroup>::max();
}
bool
Parfactor::containsGroup (unsigned group) const
Parfactor::containsGroup (PrvGroup group) const
{
for (size_t i = 0; i < args_.size(); i++) {
if (args_[i].group() == group) {
@ -499,7 +499,7 @@ Parfactor::indexOfLogVar (LogVar X) const
int
Parfactor::indexOfGroup (unsigned group) const
Parfactor::indexOfGroup (PrvGroup group) const
{
size_t pos = args_.size();
for (size_t i = 0; i < args_.size(); i++) {
@ -514,7 +514,7 @@ Parfactor::indexOfGroup (unsigned group) const
unsigned
Parfactor::nrFormulasWithGroup (unsigned group) const
Parfactor::nrFormulasWithGroup (PrvGroup group) const
{
unsigned count = 0;
for (size_t i = 0; i < args_.size(); i++) {
@ -527,10 +527,10 @@ Parfactor::nrFormulasWithGroup (unsigned group) const
vector<unsigned>
vector<PrvGroup>
Parfactor::getAllGroups (void) const
{
vector<unsigned> groups (args_.size());
vector<PrvGroup> groups (args_.size());
for (size_t i = 0; i < args_.size(); i++) {
groups[i] = args_[i].group();
}
@ -726,8 +726,11 @@ Parfactor::simplifyCountingFormulas (size_t fIdx)
void
Parfactor::simplifyGrounds (void)
{
if (args_.size() == 1) {
return;
}
LogVarSet singletons = constr_->singletons();
for (long i = 0; i < (int)args_.size() - 1; i++) {
for (long i = 0; i < (long)args_.size() - 1; i++) {
for (size_t j = i + 1; j < args_.size(); j++) {
if (args_[i].group() == args_[j].group() &&
singletons.contains (args_[i].logVarSet()) &&

View File

@ -18,7 +18,7 @@ class Parfactor : public TFactor<ProbFormula>
const ProbFormulas&,
const Params&,
const Tuples&,
unsigned);
unsigned distId);
Parfactor (const Parfactor*, const Tuple&);
@ -66,21 +66,21 @@ class Parfactor : public TFactor<ProbFormula>
void applySubstitution (const Substitution&);
int findGroup (const Ground&) const;
PrvGroup findGroup (const Ground&) const;
bool containsGround (const Ground&) const;
bool containsGroup (unsigned) const;
bool containsGroup (PrvGroup) const;
unsigned nrFormulas (LogVar) const;
int indexOfLogVar (LogVar) const;
int indexOfGroup (unsigned) const;
int indexOfGroup (PrvGroup) const;
unsigned nrFormulasWithGroup (unsigned) const;
unsigned nrFormulasWithGroup (PrvGroup) const;
vector<unsigned> getAllGroups (void) const;
vector<PrvGroup> getAllGroups (void) const;
void print (bool = false) const;

View File

@ -340,7 +340,7 @@ ParfactorList::shatterAgainstMySelf (
return { };
}
unsigned newGroup = ProbFormula::getNewGroup();
PrvGroup newGroup = ProbFormula::getNewGroup();
Parfactors res1 = shatter (g, fIdx1, commCt1, exclCt1, newGroup);
if (res1.empty()) {
res1.push_back (g);
@ -488,7 +488,7 @@ ParfactorList::shatter (
return { };
}
unsigned group;
PrvGroup group;
if (exclCt1->empty()) {
group = f1.group();
} else if (exclCt2->empty()) {
@ -509,7 +509,7 @@ ParfactorList::shatter (
size_t fIdx,
ConstraintTree* commCt,
ConstraintTree* exclCt,
unsigned commGroup)
PrvGroup commGroup)
{
ProbFormula& f = g->argument (fIdx);
if (exclCt->empty()) {
@ -556,7 +556,7 @@ ParfactorList::shatter (
void
ParfactorList::updateGroups (unsigned oldGroup, unsigned newGroup)
ParfactorList::updateGroups (PrvGroup oldGroup, PrvGroup newGroup)
{
for (ParfactorList::iterator it = pfList_.begin();
it != pfList_.end(); it++) {

View File

@ -80,12 +80,12 @@ class ParfactorList
Parfactors shatter (
Parfactor*,
unsigned,
size_t,
ConstraintTree*,
ConstraintTree*,
unsigned);
PrvGroup);
void updateGroups (unsigned group1, unsigned group2);
void updateGroups (PrvGroup group1, PrvGroup group2);
bool proper (
const ProbFormula&, ConstraintTree,

View File

@ -1,7 +1,7 @@
#include "ProbFormula.h"
int ProbFormula::freeGroup_ = 0;
PrvGroup ProbFormula::freeGroup_ = 0;
@ -120,10 +120,11 @@ std::ostream& operator<< (ostream &os, const ProbFormula& f)
unsigned
PrvGroup
ProbFormula::getNewGroup (void)
{
freeGroup_ ++;
assert (freeGroup_ != numeric_limits<PrvGroup>::max());
return freeGroup_;
}

View File

@ -7,17 +7,17 @@
#include "LiftedUtils.h"
#include "Horus.h"
typedef unsigned PrvGroup;
typedef unsigned long PrvGroup;
class ProbFormula
{
public:
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
: functor_(f), logVars_(lvs), range_(range),
countedLogVar_(), group_(Util::maxUnsigned()) { }
countedLogVar_(), group_(numeric_limits<PrvGroup>::max()) { }
ProbFormula (Symbol f, unsigned r)
: functor_(f), range_(r), group_(Util::maxUnsigned()) { }
: functor_(f), range_(r), group_(numeric_limits<PrvGroup>::max()) { }
Symbol functor (void) const { return functor_; }
@ -31,9 +31,9 @@ class ProbFormula
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
unsigned group (void) const { return group_; }
PrvGroup group (void) const { return group_; }
void setGroup (unsigned g) { group_ = g; }
void setGroup (PrvGroup g) { group_ = g; }
bool sameSkeletonAs (const ProbFormula&) const;
@ -55,7 +55,7 @@ class ProbFormula
void rename (LogVar, LogVar);
static unsigned getNewGroup (void);
static PrvGroup getNewGroup (void);
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
@ -66,8 +66,8 @@ class ProbFormula
LogVars logVars_;
unsigned range_;
LogVar countedLogVar_;
unsigned group_;
static int freeGroup_;
PrvGroup group_;
static PrvGroup freeGroup_;
};
typedef vector<ProbFormula> ProbFormulas;