use PrvGroup
This commit is contained in:
@@ -52,7 +52,7 @@ LiftedOperator::printValidOps (
|
||||
|
||||
vector<ParfactorList::iterator>
|
||||
LiftedOperator::getParfactorsWithGroup (
|
||||
ParfactorList& pfList, unsigned group)
|
||||
ParfactorList& pfList, PrvGroup group)
|
||||
{
|
||||
vector<ParfactorList::iterator> iters;
|
||||
ParfactorList::iterator pflIt = pfList.begin();
|
||||
@@ -144,10 +144,10 @@ ProductOperator::toString (void)
|
||||
bool
|
||||
ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
|
||||
{
|
||||
TinySet<unsigned> g1_gs (g1->getAllGroups());
|
||||
TinySet<unsigned> g2_gs (g2->getAllGroups());
|
||||
TinySet<PrvGroup> g1_gs (g1->getAllGroups());
|
||||
TinySet<PrvGroup> g2_gs (g2->getAllGroups());
|
||||
if (g1_gs.contains (g2_gs) || g2_gs.contains (g1_gs)) {
|
||||
TinySet<unsigned> intersect = g1_gs & g2_gs;
|
||||
TinySet<PrvGroup> intersect = g1_gs & g2_gs;
|
||||
for (size_t i = 0; i < intersect.size(); i++) {
|
||||
if (g1->nrFormulasWithGroup (intersect[i]) != 1 ||
|
||||
g2->nrFormulasWithGroup (intersect[i]) != 1) {
|
||||
@@ -169,13 +169,13 @@ ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
|
||||
double
|
||||
SumOutOperator::getLogCost (void)
|
||||
{
|
||||
TinySet<unsigned> groupSet;
|
||||
TinySet<PrvGroup> groupSet;
|
||||
ParfactorList::const_iterator pfIter = pfList_.begin();
|
||||
unsigned nrProdFactors = 0;
|
||||
while (pfIter != pfList_.end()) {
|
||||
if ((*pfIter)->containsGroup (group_)) {
|
||||
vector<unsigned> groups = (*pfIter)->getAllGroups();
|
||||
groupSet |= TinySet<unsigned> (groups);
|
||||
vector<PrvGroup> groups = (*pfIter)->getAllGroups();
|
||||
groupSet |= TinySet<PrvGroup> (groups);
|
||||
++ nrProdFactors;
|
||||
}
|
||||
++ pfIter;
|
||||
@@ -239,7 +239,7 @@ SumOutOperator::getValidOps (
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<SumOutOperator*> validOps;
|
||||
set<unsigned> allGroups;
|
||||
set<PrvGroup> allGroups;
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
@@ -248,7 +248,7 @@ SumOutOperator::getValidOps (
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
set<unsigned>::const_iterator groupIt = allGroups.begin();
|
||||
set<PrvGroup>::const_iterator groupIt = allGroups.begin();
|
||||
while (groupIt != allGroups.end()) {
|
||||
if (validOp (*groupIt, pfList, query)) {
|
||||
validOps.push_back (new SumOutOperator (*groupIt, pfList));
|
||||
@@ -279,7 +279,7 @@ SumOutOperator::toString (void)
|
||||
|
||||
bool
|
||||
SumOutOperator::validOp (
|
||||
unsigned group,
|
||||
PrvGroup group,
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
@@ -312,7 +312,7 @@ SumOutOperator::validOp (
|
||||
bool
|
||||
SumOutOperator::isToEliminate (
|
||||
Parfactor* g,
|
||||
unsigned group,
|
||||
PrvGroup group,
|
||||
const Grounds& query)
|
||||
{
|
||||
size_t fIdx = g->indexOfGroup (group);
|
||||
@@ -345,7 +345,7 @@ CountingOperator::getLogCost (void)
|
||||
for (size_t i = 0; i < counts.size(); i++) {
|
||||
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
||||
}
|
||||
unsigned group = (*pfIter_)->argument (fIdx).group();
|
||||
PrvGroup group = (*pfIter_)->argument (fIdx).group();
|
||||
size_t lvIndex = Util::indexOf (
|
||||
(*pfIter_)->argument (fIdx).logVars(), X_);
|
||||
assert (lvIndex != (*pfIter_)->argument (fIdx).logVars().size());
|
||||
@@ -456,7 +456,7 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
||||
double
|
||||
GroundOperator::getLogCost (void)
|
||||
{
|
||||
vector<pair<unsigned, unsigned>> affectedFormulas;
|
||||
vector<pair<PrvGroup, unsigned>> affectedFormulas;
|
||||
affectedFormulas = getAffectedFormulas();
|
||||
// cout << "affected formulas: " ;
|
||||
// for (size_t i = 0; i < affectedFormulas.size(); i++) {
|
||||
@@ -542,7 +542,7 @@ vector<GroundOperator*>
|
||||
GroundOperator::getValidOps (ParfactorList& pfList)
|
||||
{
|
||||
vector<GroundOperator*> validOps;
|
||||
set<unsigned> allGroups;
|
||||
set<PrvGroup> allGroups;
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
@@ -593,15 +593,15 @@ GroundOperator::toString (void)
|
||||
|
||||
|
||||
|
||||
vector<pair<unsigned, unsigned>>
|
||||
vector<pair<PrvGroup, unsigned>>
|
||||
GroundOperator::getAffectedFormulas (void)
|
||||
{
|
||||
vector<pair<unsigned, unsigned>> affectedFormulas;
|
||||
vector<pair<PrvGroup, unsigned>> affectedFormulas;
|
||||
affectedFormulas.push_back (make_pair (group_, lvIndex_));
|
||||
queue<pair<unsigned, unsigned>> q;
|
||||
queue<pair<PrvGroup, unsigned>> q;
|
||||
q.push (make_pair (group_, lvIndex_));
|
||||
while (q.empty() == false) {
|
||||
pair<unsigned, unsigned> front = q.front();
|
||||
pair<PrvGroup, unsigned> front = q.front();
|
||||
ParfactorList::iterator pflIt = pfList_.begin();
|
||||
while (pflIt != pfList_.end()) {
|
||||
size_t idx = (*pflIt)->indexOfGroup (front.first);
|
||||
@@ -611,7 +611,7 @@ GroundOperator::getAffectedFormulas (void)
|
||||
const ProbFormulas& fs = (*pflIt)->arguments();
|
||||
for (size_t i = 0; i < fs.size(); i++) {
|
||||
if ((int)i != idx && fs[i].contains (X)) {
|
||||
pair<unsigned, unsigned> pair = make_pair (
|
||||
pair<PrvGroup, unsigned> pair = make_pair (
|
||||
fs[i].group(), fs[i].indexOf (X));
|
||||
if (Util::contains (affectedFormulas, pair) == false) {
|
||||
q.push (pair);
|
||||
@@ -830,13 +830,13 @@ FoveSolver::getBestOperation (const Grounds& query)
|
||||
void
|
||||
FoveSolver::runWeakBayesBall (const Grounds& query)
|
||||
{
|
||||
queue<unsigned> todo; // groups to process
|
||||
set<unsigned> done; // processed or in queue
|
||||
queue<PrvGroup> todo; // groups to process
|
||||
set<PrvGroup> done; // processed or in queue
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
int group = (*it)->findGroup (query[i]);
|
||||
if (group != -1) {
|
||||
PrvGroup group = (*it)->findGroup (query[i]);
|
||||
if (group != numeric_limits<PrvGroup>::max()) {
|
||||
todo.push (group);
|
||||
done.insert (group);
|
||||
break;
|
||||
@@ -847,12 +847,12 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
|
||||
|
||||
set<Parfactor*> requiredPfs;
|
||||
while (todo.empty() == false) {
|
||||
unsigned group = todo.front();
|
||||
PrvGroup group = todo.front();
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false &&
|
||||
(*it)->containsGroup (group)) {
|
||||
vector<unsigned> groups = (*it)->getAllGroups();
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
if (Util::contains (done, groups[i]) == false) {
|
||||
todo.push (groups[i]);
|
||||
|
Reference in New Issue
Block a user