diff --git a/packages/CLPBN/clpbn/bp/ConstraintTree.cpp b/packages/CLPBN/clpbn/bp/ConstraintTree.cpp index 3711d63a6..d854abb1b 100644 --- a/packages/CLPBN/clpbn/bp/ConstraintTree.cpp +++ b/packages/CLPBN/clpbn/bp/ConstraintTree.cpp @@ -359,10 +359,6 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new) void ConstraintTree::applySubstitution (const Substitution& theta) { - LogVars discardedLvs = theta.getDiscardedLogVars(); - for (unsigned i = 0; i < discardedLvs.size(); i++) { - remove(discardedLvs[i]); - } for (unsigned i = 0; i < logVars_.size(); i++) { logVars_[i] = theta.newNameFor (logVars_[i]); } diff --git a/packages/CLPBN/clpbn/bp/FoveSolver.cpp b/packages/CLPBN/clpbn/bp/FoveSolver.cpp index 97712035a..09da088e8 100644 --- a/packages/CLPBN/clpbn/bp/FoveSolver.cpp +++ b/packages/CLPBN/clpbn/bp/FoveSolver.cpp @@ -13,17 +13,23 @@ LiftedOperator::getValidOps ( const Grounds& query) { vector validOps; - vector sumOutOps; - vector countOps; - vector groundOps; + vector multOps; - sumOutOps = SumOutOperator::getValidOps (pfList, query); - countOps = CountingOperator::getValidOps (pfList); - groundOps = GroundOperator::getValidOps (pfList); + multOps = ProductOperator::getValidOps (pfList); + validOps.insert (validOps.end(), multOps.begin(), multOps.end()); + + if (Globals::verbosity > 1 || multOps.empty()) { + vector sumOutOps; + vector countOps; + vector groundOps; + sumOutOps = SumOutOperator::getValidOps (pfList, query); + countOps = CountingOperator::getValidOps (pfList); + groundOps = GroundOperator::getValidOps (pfList); + validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end()); + validOps.insert (validOps.end(), countOps.begin(), countOps.end()); + validOps.insert (validOps.end(), groundOps.begin(), groundOps.end()); + } - validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end()); - validOps.insert (validOps.end(), countOps.begin(), countOps.end()); - validOps.insert (validOps.end(), groundOps.begin(), groundOps.end()); return validOps; } @@ -44,17 +50,9 @@ LiftedOperator::printValidOps ( -vector -LiftedOperator::getAllGroupss (ParfactorList& ) -{ - return { }; -} - - - vector LiftedOperator::getParfactorsWithGroup ( - ParfactorList& pfList, unsigned group) + ParfactorList& pfList, unsigned group) { vector iters; ParfactorList::iterator pflIt = pfList.begin(); @@ -69,6 +67,105 @@ LiftedOperator::getParfactorsWithGroup ( +double +ProductOperator::getLogCost (void) +{ + return std::log (0.0); +} + + + +void +ProductOperator::apply (void) +{ + Parfactor* g1 = *g1_; + Parfactor* g2 = *g2_; + g1->multiply (*g2); + pfList_.remove (g1_); + pfList_.removeAndDelete (g2_); + pfList_.addShattered (g1); +} + + + +vector +ProductOperator::getValidOps (ParfactorList& pfList) +{ + vector validOps; + ParfactorList::iterator it1 = pfList.begin(); + ParfactorList::iterator penultimate = -- pfList.end(); + set pfs; + while (it1 != penultimate) { + if (Util::contains (pfs, *it1)) { + ++ it1; + continue; + } + ParfactorList::iterator it2 = it1; + ++ it2; + while (it2 != pfList.end()) { + if (Util::contains (pfs, *it2)) { + ++ it2; + continue; + } else { + if (validOp (*it1, *it2)) { + pfs.insert (*it1); + pfs.insert (*it2); + validOps.push_back (new ProductOperator ( + it1, it2, pfList)); + if (Globals::verbosity < 2) { + return validOps; + } + break; + } + } + ++ it2; + } + ++ it1; + } + return validOps; +} + + + +string +ProductOperator::toString (void) +{ + stringstream ss; + ss << "just multiplicate " ; + ss << (*g1_)->getAllGroups(); + ss << " x " ; + ss << (*g2_)->getAllGroups(); + ss << " [cost=" << std::exp (getLogCost()) << "]" << endl; + return ss.str(); +} + + + +bool +ProductOperator::validOp (Parfactor* g1, Parfactor* g2) +{ + TinySet g1_gs (g1->getAllGroups()); + TinySet g2_gs (g2->getAllGroups()); + if (g1_gs.contains (g2_gs) || g2_gs.contains (g1_gs)) { + TinySet intersect = g1_gs & g2_gs; + for (unsigned i = 0; i < intersect.size(); i++) { + if (g1->nrFormulasWithGroup (intersect[i]) != 1 || + g2->nrFormulasWithGroup (intersect[i]) != 1) { + return false; + } + int idx1 = g1->indexOfGroup (intersect[i]); + int idx2 = g2->indexOfGroup (intersect[i]); + if (g1->range (idx1) != g2->range (idx2)) { + return false; + } + } + return Parfactor::canMultiply (g1, g2); + } + return false; +} + + + double SumOutOperator::getLogCost (void) { @@ -84,7 +181,8 @@ SumOutOperator::getLogCost (void) ++ pfIter; } if (nrProdFactors == 1) { - return std::log (0.0); // best possible case + // best possible case + return std::log (0.0); } double cost = 1.0; for (unsigned i = 0; i < groupSet.size(); i++) { @@ -190,38 +288,20 @@ SumOutOperator::validOp ( if (isToEliminate (*pfIters[0], group, query) == false) { return false; } - - unordered_map groupToRange; + int range = -1; for (unsigned i = 0; i < pfIters.size(); i++) { - const ProbFormulas& formulas = (*pfIters[i])->arguments(); - int fIdx = -1; - for (unsigned j = 0; j < formulas.size(); j++) { - if (formulas[j].group() == group) { - if (fIdx != -1) { - // only summout a group of rand vars if they don't - // appear in another position on the factor - return false; - } else { - fIdx = j; - } - } + if ((*pfIters[i])->nrFormulasWithGroup (group) > 1) { + return false; } + int fIdx = (*pfIters[i])->indexOfGroup (group); if ((*pfIters[i])->argument (fIdx).contains ( (*pfIters[i])->elimLogVars()) == false) { return false; } - vector ranges = (*pfIters[i])->ranges(); - vector groups = (*pfIters[i])->getAllGroups(); - for (unsigned i = 0; i < groups.size(); i++) { - unordered_map::iterator it; - it = groupToRange.find (groups[i]); - if (it == groupToRange.end()) { - groupToRange.insert (make_pair (groups[i], ranges[i])); - } else { - if (it->second != ranges[i]) { - return false; - } - } + if (range == -1) { + range = (*pfIters[i])->range (fIdx); + } else if ((int)(*pfIters[i])->range (fIdx) != range) { + return false; } } return true; @@ -265,8 +345,23 @@ CountingOperator::getLogCost (void) for (unsigned i = 0; i < counts.size(); i++) { cost += size * HistogramSet::nrHistograms (counts[i], range); } - if ((*pfIter_)->nrArguments() == 1) { - cost *= 3; // avoid counting conversion in the beginning + unsigned group = (*pfIter_)->argument (fIdx).group(); + int lvIndex = Util::vectorIndex ( + (*pfIter_)->argument (fIdx).logVars(), X_); + assert (lvIndex != -1); + ParfactorList::iterator pfIter = pfList_.begin(); + while (pfIter != pfList_.end()) { + if (pfIter != pfIter_) { + int fIdx2 = (*pfIter)->indexOfGroup (group); + if (fIdx2 != -1) { + LogVar Y = ((*pfIter)->argument (fIdx2).logVars()[lvIndex]); + if ((*pfIter)->canCountConvert (Y) == false) { + // the real cost should be the cost of grounding Y + cost *= 10.0; + } + } + } + ++ pfIter; } return std::log (cost); } @@ -308,6 +403,7 @@ CountingOperator::getValidOps (ParfactorList& pfList) if (validOp (*it, candidates[i])) { validOps.push_back (new CountingOperator ( it, candidates[i], pfList)); + } else { } } ++ it; @@ -350,12 +446,7 @@ CountingOperator::validOp (Parfactor* g, LogVar X) } bool countNormalized = g->constr()->isCountNormalized (X); if (countNormalized) { - unsigned condCount = g->constr()->getConditionalCount (X); - bool cartProduct = g->constr()->isCarteesianProduct ( - g->countedLogVars() | X); - if (condCount == 1 || cartProduct == false) { - return false; - } + return g->canCountConvert (X); } return true; } @@ -438,6 +529,11 @@ GroundOperator::apply (void) } delete pf; } + ParfactorList::iterator pflIt = pfList_.begin(); + while (pflIt != pfList_.end()) { + (*pflIt)->simplifyGrounds(); + ++ pflIt; + } } @@ -625,6 +721,39 @@ FoveSolver::countNormalize ( +Parfactor +FoveSolver::calcGroundMultiplication (Parfactor pf) +{ + LogVarSet lvs = pf.constr()->logVarSet(); + lvs -= pf.constr()->singletons(); + Parfactors newPfs = {new Parfactor (pf)}; + for (unsigned i = 0; i < lvs.size(); i++) { + Parfactors pfs = newPfs; + newPfs.clear(); + for (unsigned j = 0; j < pfs.size(); j++) { + bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]); + if (countedLv) { + pfs[j]->fullExpand (lvs[i]); + newPfs.push_back (pfs[j]); + } else { + ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]); + for (unsigned k = 0; k < cts.size(); k++) { + newPfs.push_back (new Parfactor (pfs[j], cts[k])); + } + delete pfs[j]; + } + } + } + ParfactorList pfList (newPfs); + Parfactors groundShatteredPfs (pfList.begin(),pfList.end()); + for (unsigned i = 1; i < groundShatteredPfs.size(); i++) { + groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]); + } + return Parfactor (*groundShatteredPfs[0]); +} + + + void FoveSolver::runSolver (const Grounds& query) { @@ -665,6 +794,7 @@ FoveSolver::runSolver (const Grounds& query) cout << "largest cost = " << std::exp (largestCost_) << endl; cout << endl; } + (*pfList_.begin())->simplifyGrounds(); (*pfList_.begin())->reorderAccordingGrounds (query); } @@ -769,8 +899,11 @@ FoveSolver::shatterAgainstQuery (const Grounds& query) while (it != pfList_.end()) { if ((*it)->containsGround (query[i])) { found = true; - std::pair split = - (*it)->constr()->split (query[i].args(), query[i].arity()); + std::pair split; + LogVars queryLvs ( + (*it)->constr()->logVars().begin(), + (*it)->constr()->logVars().begin() + query[i].arity()); + split = (*it)->constr()->split (query[i].args()); ConstraintTree* commCt = split.first; ConstraintTree* exclCt = split.second; newPfs.push_back (new Parfactor (*it, commCt)); @@ -826,8 +959,11 @@ FoveSolver::absorve ( } g->constr()->moveToTop (formulas[i].logVars()); - std::pair res - = g->constr()->split (&(obsFormula.constr()), formulas[i].arity()); + std::pair res; + res = g->constr()->split ( + formulas[i].logVars(), + &(obsFormula.constr()), + obsFormula.constr().logVars()); ConstraintTree* commCt = res.first; ConstraintTree* exclCt = res.second; diff --git a/packages/CLPBN/clpbn/bp/FoveSolver.h b/packages/CLPBN/clpbn/bp/FoveSolver.h index ef52fae7c..c0c1cfc71 100644 --- a/packages/CLPBN/clpbn/bp/FoveSolver.h +++ b/packages/CLPBN/clpbn/bp/FoveSolver.h @@ -19,14 +19,37 @@ class LiftedOperator static void printValidOps (ParfactorList&, const Grounds&); - static vector getAllGroupss (ParfactorList&); - static vector getParfactorsWithGroup ( ParfactorList&, unsigned group); }; +class ProductOperator : public LiftedOperator +{ + public: + ProductOperator ( + ParfactorList::iterator g1, ParfactorList::iterator g2, + ParfactorList& pfList) : g1_(g1), g2_(g2), pfList_(pfList) { } + + double getLogCost (void); + + void apply (void); + + static vector getValidOps (ParfactorList&); + + string toString (void); + + private: + static bool validOp (Parfactor*, Parfactor*); + + ParfactorList::iterator g1_; + ParfactorList::iterator g2_; + ParfactorList& pfList_; +}; + + + class SumOutOperator : public LiftedOperator { public: @@ -123,6 +146,8 @@ class FoveSolver static Parfactors countNormalize (Parfactor*, const LogVarSet&); + static Parfactor calcGroundMultiplication (Parfactor pf); + private: void runSolver (const Grounds&); diff --git a/packages/CLPBN/clpbn/bp/HorusYap.cpp b/packages/CLPBN/clpbn/bp/HorusYap.cpp index 876d0c063..9c108ffd6 100644 --- a/packages/CLPBN/clpbn/bp/HorusYap.cpp +++ b/packages/CLPBN/clpbn/bp/HorusYap.cpp @@ -68,6 +68,7 @@ int createLiftedNetwork (void) Util::printHeader ("INITIAL PARFACTORS"); for (unsigned i = 0; i < parfactors.size(); i++) { parfactors[i]->print(); + cout << endl; } } diff --git a/packages/CLPBN/clpbn/bp/LiftedUtils.h b/packages/CLPBN/clpbn/bp/LiftedUtils.h index 181e98789..d89fe477b 100644 --- a/packages/CLPBN/clpbn/bp/LiftedUtils.h +++ b/packages/CLPBN/clpbn/bp/LiftedUtils.h @@ -137,14 +137,20 @@ class Substitution LogVar newNameFor (LogVar X) const { - assert (Util::contains (subs_, X)); - return subs_.find (X)->second; + unordered_map::const_iterator it; + it = subs_.find (X); + if (it != subs_.end()) { + return subs_.find (X)->second; + } + return X; } bool containsReplacementFor (LogVar X) const { return Util::contains (subs_, X); } + + unsigned nrReplacements (void) const { return subs_.size(); } LogVars getDiscardedLogVars (void) const; diff --git a/packages/CLPBN/clpbn/bp/Parfactor.cpp b/packages/CLPBN/clpbn/bp/Parfactor.cpp index 7dca9e670..c5ad13d14 100644 --- a/packages/CLPBN/clpbn/bp/Parfactor.cpp +++ b/packages/CLPBN/clpbn/bp/Parfactor.cpp @@ -193,6 +193,32 @@ Parfactor::multiply (Parfactor& g) alignAndExponentiate (this, &g); TFactor::multiply (g); constr_->join (g.constr(), true); + simplifyGrounds(); + assert (constr_->isCarteesianProduct (countedLogVars())); +} + + + +bool +Parfactor::canCountConvert (LogVar X) +{ + if (nrFormulas (X) != 1) { + return false; + } + int fIdx = indexOfLogVar (X); + if (args_[fIdx].isCounting()) { + return false; + } + if (constr_->isCountNormalized (X) == false) { + return false; + } + if (constr_->getConditionalCount (X) == 1) { + return false; + } + if (constr_->isCarteesianProduct (countedLogVars() | X) == false) { + return false; + } + return true; } @@ -201,11 +227,10 @@ void Parfactor::countConvert (LogVar X) { int fIdx = indexOfLogVar (X); - assert (fIdx != -1); assert (constr_->isCountNormalized (X)); assert (constr_->getConditionalCount (X) > 1); - assert (constr_->isCarteesianProduct (countedLogVars() | X)); - + assert (canCountConvert (X)); + unsigned N = constr_->getConditionalCount (X); unsigned R = ranges_[fIdx]; unsigned H = HistogramSet::nrHistograms (N, R); @@ -245,6 +270,7 @@ Parfactor::countConvert (LogVar X) ++ mapIndexer; } args_[fIdx].setCountedLogVar (X); + simplifyCountingFormulas (fIdx); } @@ -307,10 +333,9 @@ Parfactor::fullExpand (LogVar X) unsigned N = constr_->getConditionalCount (X); unsigned R = args_[fIdx].range(); - vector originHists = HistogramSet::getHistograms (N, R); vector expandHists = HistogramSet::getHistograms (1, R); - + assert (ranges_[fIdx] == originHists.size()); vector sumIndexes; sumIndexes.reserve (N * R); @@ -496,6 +521,20 @@ Parfactor::indexOfGroup (unsigned group) const +unsigned +Parfactor::nrFormulasWithGroup (unsigned group) const +{ + unsigned count = 0; + for (unsigned i = 0; i < args_.size(); i++) { + if (args_[i].group() == group) { + count ++; + } + } + return count; +} + + + vector Parfactor::getAllGroups (void) const { @@ -672,22 +711,119 @@ Parfactor::expandPotential ( void -Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2) +Parfactor::simplifyCountingFormulas (int fIdx) { - LogVars X_1, X_2; - const ProbFormulas& formulas1 = g1->arguments(); - const ProbFormulas& formulas2 = g2->arguments(); - for (unsigned i = 0; i < formulas1.size(); i++) { - for (unsigned j = 0; j < formulas2.size(); j++) { - if (formulas1[i].group() == formulas2[j].group()) { - Util::addToVector (X_1, formulas1[i].logVars()); - Util::addToVector (X_2, formulas2[j].logVars()); + // check if we can simplify the parfactor + for (unsigned i = 0; i < args_.size(); i++) { + if ((int)i != fIdx && + args_[i].isCounting() && + args_[i].group() == args_[fIdx].group()) { + // if they only differ in the name of the counting log var + if ((args_[i].logVarSet() - args_[i].countedLogVar()) == + (args_[fIdx].logVarSet()) - args_[fIdx].countedLogVar() && + ranges_[i] == ranges_[fIdx]) { + simplifyParfactor (fIdx, i); + break; } } } - LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1); - LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2); - // counting log vars were already raised on counting conversion +} + + + +void +Parfactor::simplifyGrounds (void) +{ + LogVarSet singletons = constr_->singletons(); + for (int i = 0; i < (int)args_.size() - 1; i++) { + for (unsigned j = i + 1; j < args_.size(); j++) { + if (args_[i].group() == args_[j].group() && + singletons.contains (args_[i].logVarSet()) && + singletons.contains (args_[j].logVarSet())) { + simplifyParfactor (i, j); + i --; + break; + } + } + } +} + + + +bool +Parfactor::canMultiply (Parfactor* g1, Parfactor* g2) +{ + std::pair res = getAlignLogVars (g1, g2); + LogVarSet Xs_1 (res.first); + LogVarSet Xs_2 (res.second); + LogVarSet Y_1 = g1->logVarSet() - Xs_1; + LogVarSet Y_2 = g2->logVarSet() - Xs_2; + Y_1 -= g1->countedLogVars(); + Y_2 -= g2->countedLogVars(); + return g1->constr()->isCountNormalized (Y_1) && + g2->constr()->isCountNormalized (Y_2); +} + + + +void +Parfactor::simplifyParfactor (unsigned fIdx1, unsigned fIdx2) +{ + Params copy = params_; + params_.clear(); + StatesIndexer indexer (ranges_); + while (indexer.valid()) { + if (indexer[fIdx1] == indexer[fIdx2]) { + params_.push_back (copy[indexer]); + } + ++ indexer; + } + for (unsigned i = 0; i < args_[fIdx2].logVars().size(); i++) { + if (nrFormulas (args_[fIdx2].logVars()[i]) == 1) { + constr_->remove ({ args_[fIdx2].logVars()[i] }); + } + } + args_.erase (args_.begin() + fIdx2); + ranges_.erase (ranges_.begin() + fIdx2); +} + + + +std::pair +Parfactor::getAlignLogVars (Parfactor* g1, Parfactor* g2) +{ + g1->simplifyGrounds(); + g2->simplifyGrounds(); + LogVars Xs_1, Xs_2; + TinySet matchedI; + TinySet matchedJ; + ProbFormulas& formulas1 = g1->arguments(); + ProbFormulas& formulas2 = g2->arguments(); + for (unsigned i = 0; i < formulas1.size(); i++) { + for (unsigned j = 0; j < formulas2.size(); j++) { + if (formulas1[i].group() == formulas2[j].group() && + g1->range (i) == g2->range (j) && + matchedI.contains (i) == false && + matchedJ.contains (j) == false) { + Util::addToVector (Xs_1, formulas1[i].logVars()); + Util::addToVector (Xs_2, formulas2[j].logVars()); + matchedI.insert (i); + matchedJ.insert (j); + } + } + } + return make_pair (Xs_1, Xs_2); +} + + + +void +Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2) +{ + alignLogicalVars (g1, g2); + LogVarSet comm = g1->logVarSet() & g2->logVarSet(); + LogVarSet Y_1 = g1->logVarSet() - comm; + LogVarSet Y_2 = g2->logVarSet() - comm; Y_1 -= g1->countedLogVars(); Y_2 -= g2->countedLogVars(); assert (g1->constr()->isCountNormalized (Y_1)); @@ -696,30 +832,27 @@ Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2) unsigned condCount2 = g2->constr()->getConditionalCount (Y_2); LogAware::pow (g1->params(), 1.0 / condCount2); LogAware::pow (g2->params(), 1.0 / condCount1); - //cout << "g1::::::::::::::::::::::::" << endl; - //g1->print(); - //cout << "g2::::::::::::::::::::::::" << endl; - //g2->print(); - // the alignment should be done in the end or else X_1 and X_2 - // will refer to the old log var names on the code above - align (g1, X_1, g2, X_2); } void -Parfactor::align ( - Parfactor* g1, const LogVars& alignLvs1, - Parfactor* g2, const LogVars& alignLvs2) +Parfactor::alignLogicalVars (Parfactor* g1, Parfactor* g2) { - LogVar freeLogVar = 0; + std::pair res = getAlignLogVars (g1, g2); + const LogVars& alignLvs1 = res.first; + const LogVars& alignLvs2 = res.second; + // cout << "ALIGNING :::::::::::::::::" << endl; + // g1->print(); + // cout << "AND" << endl; + // g2->print(); + // cout << "-> align lvs1 = " << alignLvs1 << endl; + // cout << "-> align lvs2 = " << alignLvs2 << endl; + LogVar freeLogVar (0); Substitution theta1, theta2; for (unsigned i = 0; i < alignLvs1.size(); i++) { bool b1 = theta1.containsReplacementFor (alignLvs1[i]); bool b2 = theta2.containsReplacementFor (alignLvs2[i]); - // handle this type of situation: - // g1 = p(X), q(X) ; X in {x1} - // g2 = p(X), q(Y) ; X in {x1}, Y in {x1} if (b1 == false && b2 == false) { theta1.add (alignLvs1[i], freeLogVar); theta2.add (alignLvs2[i], freeLogVar); @@ -730,6 +863,7 @@ Parfactor::align ( theta2.add (alignLvs2[i], theta1.newNameFor (alignLvs1[i])); } } + const LogVarSet& allLvs1 = g1->logVarSet(); for (unsigned i = 0; i < allLvs1.size(); i++) { if (theta1.containsReplacementFor (allLvs1[i]) == false) { @@ -744,25 +878,33 @@ Parfactor::align ( ++ freeLogVar; } } - //cout << "theta1: " << theta1 << endl; - //cout << "theta2: " << theta2 << endl; + + // handle this type of situation: + // g1 = p(X), q(X) ; X in {(p1),(p2)} + // g2 = p(X), q(Y) ; (X,Y) in {(p1,p2),(p2,p1)} LogVars discardedLvs1 = theta1.getDiscardedLogVars(); for (unsigned i = 0; i < discardedLvs1.size(); i++) { - unsigned condCount = g1->constr()->getConditionalCount (discardedLvs1[i]); - cout << "discarding g1" << discardedLvs1[i]; - cout << " cc = " << condCount << endl; - LogAware::pow (g1->params(), condCount); + if (g1->constr()->isSingleton (discardedLvs1[i]) && + g1->nrFormulas (discardedLvs1[i]) == 1) { + g1->constr()->remove (discardedLvs1[i]); + } else { + LogVar X_new = ++ g1->constr()->logVarSet().back(); + theta1.rename (discardedLvs1[i], X_new); + } } LogVars discardedLvs2 = theta2.getDiscardedLogVars(); for (unsigned i = 0; i < discardedLvs2.size(); i++) { - unsigned condCount = g2->constr()->getConditionalCount (discardedLvs2[i]); - cout << "discarding g2" << discardedLvs2[i]; - cout << " cc = " << condCount << endl; - //if (condCount != 1) { - // theta2.rename (discardedLvs2[i], freeLogVar); - //} - LogAware::pow (g2->params(), condCount); + if (g2->constr()->isSingleton (discardedLvs2[i]) && + g2->nrFormulas (discardedLvs2[i]) == 1) { + g2->constr()->remove (discardedLvs2[i]); + } else { + LogVar X_new = ++ g2->constr()->logVarSet().back(); + theta2.rename (discardedLvs2[i], X_new); + } } + + // cout << "theta1: " << theta1 << endl; + // cout << "theta2: " << theta2 << endl; g1->applySubstitution (theta1); g2->applySubstitution (theta2); } diff --git a/packages/CLPBN/clpbn/bp/Parfactor.h b/packages/CLPBN/clpbn/bp/Parfactor.h index f2bd5257d..1a55e9c55 100644 --- a/packages/CLPBN/clpbn/bp/Parfactor.h +++ b/packages/CLPBN/clpbn/bp/Parfactor.h @@ -50,6 +50,8 @@ class Parfactor : public TFactor void multiply (Parfactor&); + bool canCountConvert (LogVar X); + void countConvert (LogVar); void expand (LogVar, LogVar, LogVar); @@ -76,6 +78,8 @@ class Parfactor : public TFactor int indexOfGroup (unsigned) const; + unsigned nrFormulasWithGroup (unsigned) const; + vector getAllGroups (void) const; void print (bool = false) const; @@ -86,16 +90,28 @@ class Parfactor : public TFactor string getLabel (void) const; + void simplifyGrounds (void); + + static bool canMultiply (Parfactor*, Parfactor*); + private: + + void simplifyCountingFormulas (int fIdx); + + void simplifyParfactor (unsigned fIdx1, unsigned fIdx2); + + static std::pair getAlignLogVars ( + Parfactor* g1, Parfactor* g2); + void expandPotential (int fIdx, unsigned newRange, const vector& sumIndexes); static void alignAndExponentiate (Parfactor*, Parfactor*); - static void align ( - Parfactor*, const LogVars&, Parfactor*, const LogVars&); + static void alignLogicalVars (Parfactor*, Parfactor*); ConstraintTree* constr_; + }; diff --git a/packages/CLPBN/clpbn/bp/TODO b/packages/CLPBN/clpbn/bp/TODO index a6bbb5930..e247b6de5 100644 --- a/packages/CLPBN/clpbn/bp/TODO +++ b/packages/CLPBN/clpbn/bp/TODO @@ -1,4 +1,10 @@ -TODO - - add a way to calculate combinations and factorials with large numbers - - refactor sumOut in parfactor -> is really ugly code - - Indexer: start receiving ranges as constant reference +- Refactor sum out in factor +- Add a way to sum out several vars at the same time +- Receive ranges as a constant reference in Indexer +- Merge TinySet and SortedVector classes +- Check if evidence remains in the compressed factor graph +- Consider using hashs instead of vectors of colors to calculate the groups in + counting bp +- use more psize_t instead of unsigned for looping through params +- use more Util::abort and Util::vectorIndex +- LogVar should not cast to int diff --git a/packages/CLPBN/clpbn/bp/examples/fail.yap b/packages/CLPBN/clpbn/bp/examples/fail.yap deleted file mode 100644 index aa5cfc5db..000000000 --- a/packages/CLPBN/clpbn/bp/examples/fail.yap +++ /dev/null @@ -1,21 +0,0 @@ - -:- use_module(library(pfl)). - -:- set_pfl_flag(solver,fove). -%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve). -%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp). -%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,cbp). - - -t(ann). -t(dave). - -% p(ann,t). - -markov p(X)::[t,f] ; [0.1, 0.3] ; [t(X)]. - -% use standard Prolog queries: provide evidence first. - -?- p(ann,t), p(ann,X). -% ?- p(ann,X). -