From 911b241ad663a911af52babcf5d702c5239b4350 Mon Sep 17 00:00:00 2001 From: Tiago Gomes Date: Tue, 3 Apr 2012 11:58:21 +0100 Subject: [PATCH] fix align of logical variables --- packages/CLPBN/clpbn/bp/ConstraintTree.cpp | 4 ++++ packages/CLPBN/clpbn/bp/Factor.cpp | 1 - packages/CLPBN/clpbn/bp/Factor.h | 6 ++++++ packages/CLPBN/clpbn/bp/FoveSolver.cpp | 2 +- packages/CLPBN/clpbn/bp/HorusYap.cpp | 21 ++++++++++----------- packages/CLPBN/clpbn/bp/LiftedUtils.cpp | 19 +++++++++++++++++++ packages/CLPBN/clpbn/bp/LiftedUtils.h | 3 +++ packages/CLPBN/clpbn/bp/Parfactor.cpp | 15 ++------------- packages/CLPBN/clpbn/bp/Parfactor.h | 4 +--- packages/CLPBN/clpbn/bp/ParfactorList.cpp | 21 ++++++++++++--------- packages/CLPBN/clpbn/bp/ParfactorList.h | 4 ++-- 11 files changed, 60 insertions(+), 40 deletions(-) diff --git a/packages/CLPBN/clpbn/bp/ConstraintTree.cpp b/packages/CLPBN/clpbn/bp/ConstraintTree.cpp index cf83863df..51f06789c 100644 --- a/packages/CLPBN/clpbn/bp/ConstraintTree.cpp +++ b/packages/CLPBN/clpbn/bp/ConstraintTree.cpp @@ -347,6 +347,10 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new) void ConstraintTree::applySubstitution (const Substitution& theta) { + LogVars discardedLvs = theta.getDiscardedLogVars(); + for (unsigned i = 0; i < discardedLvs.size(); i++) { + remove(discardedLvs[i]); + } for (unsigned i = 0; i < logVars_.size(); i++) { logVars_[i] = theta.newNameFor (logVars_[i]); } diff --git a/packages/CLPBN/clpbn/bp/Factor.cpp b/packages/CLPBN/clpbn/bp/Factor.cpp index 1980c2ade..ee548a47f 100644 --- a/packages/CLPBN/clpbn/bp/Factor.cpp +++ b/packages/CLPBN/clpbn/bp/Factor.cpp @@ -245,7 +245,6 @@ Factor::multiply (Factor& g) return; } TFactor::multiply (g); - cout << "Factor mult called" << endl; } diff --git a/packages/CLPBN/clpbn/bp/Factor.h b/packages/CLPBN/clpbn/bp/Factor.h index 1956eb04a..04a11fdae 100644 --- a/packages/CLPBN/clpbn/bp/Factor.h +++ b/packages/CLPBN/clpbn/bp/Factor.h @@ -62,6 +62,12 @@ class TFactor return args_[idx]; } + T& argument (unsigned idx) + { + assert (idx < args_.size()); + return args_[idx]; + } + unsigned range (unsigned idx) const { assert (idx < ranges_.size()); diff --git a/packages/CLPBN/clpbn/bp/FoveSolver.cpp b/packages/CLPBN/clpbn/bp/FoveSolver.cpp index 854348967..a205596c7 100644 --- a/packages/CLPBN/clpbn/bp/FoveSolver.cpp +++ b/packages/CLPBN/clpbn/bp/FoveSolver.cpp @@ -554,7 +554,7 @@ FoveSolver::runWeakBayesBall (const Grounds& query) for (unsigned i = 0; i < query.size(); i++) { ParfactorList::iterator it = pfList_.begin(); while (it != pfList_.end()) { - int group = (*it)->groupWithGround (query[i]); + int group = (*it)->findGroup (query[i]); if (group != -1) { todo.push (group); done.insert (group); diff --git a/packages/CLPBN/clpbn/bp/HorusYap.cpp b/packages/CLPBN/clpbn/bp/HorusYap.cpp index 6e2f4fd67..83aae4fe3 100644 --- a/packages/CLPBN/clpbn/bp/HorusYap.cpp +++ b/packages/CLPBN/clpbn/bp/HorusYap.cpp @@ -452,7 +452,6 @@ setBayesNetParams (void) int setExtraVarsInfo (void) { - // BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1); GraphicalModel::clearVariablesInformation(); YAP_Term varsInfoL = YAP_ARG2; while (varsInfoL != YAP_TermNil()) { @@ -583,15 +582,15 @@ freeParfactors (void) extern "C" void init_predicates (void) { - YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3); - YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2); - YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3); - YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3); - YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2); - YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2); - YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2); - YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2); - YAP_UserCPredicate ("free_parfactors", freeParfactors, 1); - YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1); + YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3); + YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2); + YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3); + YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3); + YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2); + YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2); + YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2); + YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2); + YAP_UserCPredicate ("free_parfactors", freeParfactors, 1); + YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1); } diff --git a/packages/CLPBN/clpbn/bp/LiftedUtils.cpp b/packages/CLPBN/clpbn/bp/LiftedUtils.cpp index 7ab70fd81..6df19b85d 100644 --- a/packages/CLPBN/clpbn/bp/LiftedUtils.cpp +++ b/packages/CLPBN/clpbn/bp/LiftedUtils.cpp @@ -95,6 +95,25 @@ ostream& operator<< (ostream &os, const Ground& gr) +LogVars +Substitution::getDiscardedLogVars (void) const +{ + LogVars discardedLvs; + set doneLvs; + unordered_map::const_iterator it; + it = subs_.begin(); + while (it != subs_.end()) { + if (Util::contains (doneLvs, it->second)) { + discardedLvs.push_back (it->first); + } else { + doneLvs.insert (it->second); + } + it ++; + } + return discardedLvs; +} + + ostream& operator<< (ostream &os, const Substitution& theta) { diff --git a/packages/CLPBN/clpbn/bp/LiftedUtils.h b/packages/CLPBN/clpbn/bp/LiftedUtils.h index 38c540779..a698ca2f4 100644 --- a/packages/CLPBN/clpbn/bp/LiftedUtils.h +++ b/packages/CLPBN/clpbn/bp/LiftedUtils.h @@ -141,10 +141,13 @@ class Substitution return subs_.find (X)->second; } + LogVars getDiscardedLogVars (void) const; + friend ostream& operator<< (ostream &os, const Substitution& theta); private: unordered_map subs_; + }; diff --git a/packages/CLPBN/clpbn/bp/Parfactor.cpp b/packages/CLPBN/clpbn/bp/Parfactor.cpp index 96c5b8f0e..e1e550481 100644 --- a/packages/CLPBN/clpbn/bp/Parfactor.cpp +++ b/packages/CLPBN/clpbn/bp/Parfactor.cpp @@ -191,7 +191,6 @@ Parfactor::multiply (Parfactor& g) { alignAndExponentiate (this, &g); TFactor::multiply (g); - cout << "calling lifted mult" << endl; constr_->join (g.constr(), true); } @@ -377,15 +376,6 @@ Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence) -void -Parfactor::setFormulaGroup (const ProbFormula& f, int group) -{ - assert (indexOf (f) != -1); - args_[indexOf (f)].setGroup (group); -} - - - void Parfactor::setNewGroups (void) { @@ -415,7 +405,7 @@ Parfactor::applySubstitution (const Substitution& theta) int -Parfactor::groupWithGround (const Ground& ground) const +Parfactor::findGroup (const Ground& ground) const { int group = -1; for (unsigned i = 0; i < args_.size(); i++) { @@ -436,7 +426,7 @@ Parfactor::groupWithGround (const Ground& ground) const bool Parfactor::containsGround (const Ground& ground) const { - return groupWithGround (ground) != -1; + return findGroup (ground) != -1; } @@ -670,7 +660,6 @@ Parfactor::align ( LogVar freeLogVar = 0; Substitution theta1; Substitution theta2; - const LogVarSet& allLvs1 = g1->logVarSet(); for (unsigned i = 0; i < allLvs1.size(); i++) { theta1.add (allLvs1[i], freeLogVar); diff --git a/packages/CLPBN/clpbn/bp/Parfactor.h b/packages/CLPBN/clpbn/bp/Parfactor.h index c26707aeb..4c206e209 100644 --- a/packages/CLPBN/clpbn/bp/Parfactor.h +++ b/packages/CLPBN/clpbn/bp/Parfactor.h @@ -60,13 +60,11 @@ class Parfactor : public TFactor void absorveEvidence (const ProbFormula&, unsigned); - void setFormulaGroup (const ProbFormula&, int); - void setNewGroups (void); void applySubstitution (const Substitution&); - int groupWithGround (const Ground&) const; + int findGroup (const Ground&) const; bool containsGround (const Ground&) const; diff --git a/packages/CLPBN/clpbn/bp/ParfactorList.cpp b/packages/CLPBN/clpbn/bp/ParfactorList.cpp index abdb29c00..249126a99 100644 --- a/packages/CLPBN/clpbn/bp/ParfactorList.cpp +++ b/packages/CLPBN/clpbn/bp/ParfactorList.cpp @@ -197,8 +197,8 @@ ParfactorList::shatter (Parfactor* g1, Parfactor* g2) for (unsigned i = 0; i < formulas1.size(); i++) { for (unsigned j = 0; j < formulas2.size(); j++) { if (formulas1[i].sameSkeletonAs (formulas2[j])) { - std::pair res - = shatter (formulas1[i], g1, formulas2[j], g2); + std::pair res; + res = shatter (i, g1, j, g2); if (res.first.empty() == false || res.second.empty() == false) { return res; @@ -213,9 +213,11 @@ ParfactorList::shatter (Parfactor* g1, Parfactor* g2) std::pair ParfactorList::shatter ( - ProbFormula& f1, Parfactor* g1, - ProbFormula& f2, Parfactor* g2) + unsigned fIdx1, Parfactor* g1, + unsigned fIdx2, Parfactor* g2) { + ProbFormula& f1 = g1->argument (fIdx1); + ProbFormula& f2 = g2->argument (fIdx2); // cout << endl; // Util::printDashLine(); // cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl; @@ -299,8 +301,8 @@ ParfactorList::shatter ( } else { group = ProbFormula::getNewGroup(); } - Parfactors res1 = shatter (g1, f1, commCt1, exclCt1, group); - Parfactors res2 = shatter (g2, f2, commCt2, exclCt2, group); + Parfactors res1 = shatter (g1, fIdx1, commCt1, exclCt1, group); + Parfactors res2 = shatter (g2, fIdx2, commCt2, exclCt2, group); return make_pair (res1, res2); } @@ -309,15 +311,16 @@ ParfactorList::shatter ( Parfactors ParfactorList::shatter ( Parfactor* g, - const ProbFormula& f, + unsigned fIdx, ConstraintTree* commCt, ConstraintTree* exclCt, unsigned commGroup) { + ProbFormula& f = g->argument (fIdx); if (exclCt->empty()) { delete commCt; delete exclCt; - g->setFormulaGroup (f, commGroup); + f.setGroup (commGroup); return { }; } @@ -346,7 +349,7 @@ ParfactorList::shatter ( } else { Parfactor* newPf = new Parfactor (g, commCt); newPf->setNewGroups(); - newPf->setFormulaGroup (f, commGroup); + newPf->argument (fIdx).setGroup (commGroup); result.push_back (newPf); newPf = new Parfactor (g, exclCt); newPf->setNewGroups(); diff --git a/packages/CLPBN/clpbn/bp/ParfactorList.h b/packages/CLPBN/clpbn/bp/ParfactorList.h index e6350a90c..0d4992cfe 100644 --- a/packages/CLPBN/clpbn/bp/ParfactorList.h +++ b/packages/CLPBN/clpbn/bp/ParfactorList.h @@ -67,11 +67,11 @@ class ParfactorList Parfactor*, Parfactor*); std::pair shatter ( - ProbFormula&, Parfactor*, ProbFormula&, Parfactor*); + unsigned, Parfactor*, unsigned, Parfactor*); Parfactors shatter ( Parfactor*, - const ProbFormula&, + unsigned, ConstraintTree*, ConstraintTree*, unsigned);