From 7d9af75c352d6cddbc7678fa166dc9bc9c306a2c Mon Sep 17 00:00:00 2001 From: Tiago Gomes Date: Thu, 27 Dec 2012 23:21:32 +0000 Subject: [PATCH] Improve solver flags --- packages/CLPBN/clpbn/horus.yap | 26 ++++++++++++------- packages/CLPBN/horus/BeliefProp.cpp | 2 +- packages/CLPBN/horus/CountingBp.cpp | 9 +++---- packages/CLPBN/horus/CountingBp.h | 4 ++- packages/CLPBN/horus/HorusYap.cpp | 6 ++--- packages/CLPBN/horus/LiftedBp.cpp | 2 +- packages/CLPBN/horus/Util.cpp | 40 ++++++++++++++--------------- 7 files changed, 49 insertions(+), 40 deletions(-) diff --git a/packages/CLPBN/clpbn/horus.yap b/packages/CLPBN/clpbn/horus.yap index 344f11d86..976d481b9 100644 --- a/packages/CLPBN/clpbn/horus.yap +++ b/packages/CLPBN/clpbn/horus.yap @@ -34,15 +34,23 @@ warning :- set_horus_flag(K,V) :- cpp_set_horus_flag(K,V). -:- cpp_set_horus_flag(schedule, seq_fixed). -%:- cpp_set_horus_flag(schedule, seq_random). -%:- cpp_set_horus_flag(schedule, parallel). -%:- cpp_set_horus_flag(schedule, max_residual). - -:- cpp_set_horus_flag(accuracy, 0.0001). - -:- cpp_set_horus_flag(max_iter, 1000). +:- cpp_set_horus_flag(verbosity, 0). :- cpp_set_horus_flag(use_logarithms, false). -% :- cpp_set_horus_flag(use_logarithms, true). +%:- cpp_set_horus_flag(use_logarithms, true). + +%:- cpp_set_horus_flag(hve_elim_heuristic, sequential). +%:- cpp_set_horus_flag(hve_elim_heuristic, min_neighbors). +%:- cpp_set_horus_flag(hve_elim_heuristic, min_weight). +%:- cpp_set_horus_flag(hve_elim_heuristic, min_fill). +:- cpp_set_horus_flag(hve_elim_heuristic, weighted_min_fill). + +:- cpp_set_horus_flag(bp_msg_schedule, seq_fixed). +%:- cpp_set_horus_flag(bp_msg_schedule, seq_random). +%:- cpp_set_horus_flag(bp_msg_schedule, parallel). +%:- cpp_set_horus_flag(bp_msg_schedule, max_residual). + +:- cpp_set_horus_flag(bp_accuracy, 0.0001). + +:- cpp_set_horus_flag(bp_max_iter, 1000). diff --git a/packages/CLPBN/horus/BeliefProp.cpp b/packages/CLPBN/horus/BeliefProp.cpp index bf8f30a79..d009cd7a9 100644 --- a/packages/CLPBN/horus/BeliefProp.cpp +++ b/packages/CLPBN/horus/BeliefProp.cpp @@ -52,7 +52,7 @@ BeliefProp::printSolverFlags (void) const { stringstream ss; ss << "belief propagation [" ; - ss << "schedule=" ; + ss << "msg_schedule=" ; switch (schedule_) { case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; diff --git a/packages/CLPBN/horus/CountingBp.cpp b/packages/CLPBN/horus/CountingBp.cpp index 006bf99fd..39b47eab3 100644 --- a/packages/CLPBN/horus/CountingBp.cpp +++ b/packages/CLPBN/horus/CountingBp.cpp @@ -2,7 +2,7 @@ #include "WeightedBp.h" -bool CountingBp::checkForIdenticalFactors = true; +bool CountingBp::fif_ = true; CountingBp::CountingBp (const FactorGraph& fg) @@ -36,7 +36,7 @@ CountingBp::printSolverFlags (void) const { stringstream ss; ss << "counting bp [" ; - ss << "schedule=" ; + ss << "msg_schedule=" ; switch (WeightedBp::msgSchedule()) { case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; @@ -46,7 +46,7 @@ CountingBp::printSolverFlags (void) const ss << ",max_iter=" << WeightedBp::maxIterations(); ss << ",accuracy=" << WeightedBp::accuracy(); ss << ",log_domain=" << Util::toString (Globals::logDomain); - ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors); + ss << ",fif=" << Util::toString (CountingBp::fif_); ss << "]" ; cout << ss.str() << endl; } @@ -93,8 +93,7 @@ void CountingBp::findIdenticalFactors() { const FacNodes& facNodes = fg.facNodes(); - if (checkForIdenticalFactors == false || - facNodes.size() == 1) { + if (fif_ == false || facNodes.size() == 1) { return; } for (size_t i = 0; i < facNodes.size(); i++) { diff --git a/packages/CLPBN/horus/CountingBp.h b/packages/CLPBN/horus/CountingBp.h index 4f674e687..605fa8b22 100644 --- a/packages/CLPBN/horus/CountingBp.h +++ b/packages/CLPBN/horus/CountingBp.h @@ -116,7 +116,7 @@ class CountingBp : public GroundSolver Params solveQuery (VarIds); - static bool checkForIdenticalFactors; + static void setFindIdenticalFactorsFlag (bool fif) { fif_ = fif; } private: Color getNewColor (void) @@ -179,6 +179,8 @@ class CountingBp : public GroundSolver const FactorGraph* compressedFg_; WeightedBp* solver_; + static bool fif_; + DISALLOW_COPY_AND_ASSIGN (CountingBp); }; diff --git a/packages/CLPBN/horus/HorusYap.cpp b/packages/CLPBN/horus/HorusYap.cpp index 77e900bb0..dbd210412 100644 --- a/packages/CLPBN/horus/HorusYap.cpp +++ b/packages/CLPBN/horus/HorusYap.cpp @@ -200,7 +200,7 @@ runGroundSolver (void) } GroundSolver* solver = 0; - CountingBp::checkForIdenticalFactors = false; + CountingBp::setFindIdenticalFactorsFlag (false); switch (Globals::groundSolver) { case GroundSolverType::VE: solver = new VarElim (*mfg); break; case GroundSolverType::BP: solver = new BeliefProp (*mfg); break; @@ -320,11 +320,11 @@ setHorusFlag (void) stringstream ss; ss << (int) YAP_IntOfTerm (YAP_ARG2); ss >> value; - } else if (key == "accuracy") { + } else if (key == "bp_accuracy") { stringstream ss; ss << (float) YAP_FloatOfTerm (YAP_ARG2); ss >> value; - } else if (key == "max_iter") { + } else if (key == "bp_max_iter") { stringstream ss; ss << (int) YAP_IntOfTerm (YAP_ARG2); ss >> value; diff --git a/packages/CLPBN/horus/LiftedBp.cpp b/packages/CLPBN/horus/LiftedBp.cpp index 18f056f8a..b748cc9e1 100644 --- a/packages/CLPBN/horus/LiftedBp.cpp +++ b/packages/CLPBN/horus/LiftedBp.cpp @@ -62,7 +62,7 @@ LiftedBp::printSolverFlags (void) const { stringstream ss; ss << "lifted bp [" ; - ss << "schedule=" ; + ss << "msg_schedule=" ; switch (WeightedBp::msgSchedule()) { case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; diff --git a/packages/CLPBN/horus/Util.cpp b/packages/CLPBN/horus/Util.cpp index 9fad10705..4d682a1b0 100644 --- a/packages/CLPBN/horus/Util.cpp +++ b/packages/CLPBN/horus/Util.cpp @@ -192,11 +192,7 @@ bool setHorusFlag (string key, string value) { bool returnVal = true; - if (key == "verbosity") { - stringstream ss; - ss << value; - ss >> Globals::verbosity; - } else if (key == "lifted_solver") { + if ( key == "lifted_solver") { if ( value == "lve") { Globals::liftedSolver = LiftedSolverType::LVE; } else if (value == "lbp") { @@ -209,7 +205,7 @@ setHorusFlag (string key, string value) returnVal = false; } } else if (key == "ground_solver") { - if ( value == "ve") { + if ( value == "ve" || value == "hve") { Globals::groundSolver = GroundSolverType::VE; } else if (value == "bp") { Globals::groundSolver = GroundSolverType::BP; @@ -220,7 +216,21 @@ setHorusFlag (string key, string value) cerr << "for `" << key << "'" << endl; returnVal = false; } - } else if (key == "elim_heuristic") { + } else if (key == "verbosity") { + stringstream ss; + ss << value; + ss >> Globals::verbosity; + } else if (key == "use_logarithms") { + if ( value == "true") { + Globals::logDomain = true; + } else if (value == "false") { + Globals::logDomain = false; + } else { + cerr << "warning: invalid value `" << value << "' " ; + cerr << "for `" << key << "'" << endl; + returnVal = false; + } + } else if (key == "ve_elim_heuristic" || key == "hve_elim_heuristic") { if ( value == "sequential") { ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL); } else if (value == "min_neighbors") { @@ -236,7 +246,7 @@ setHorusFlag (string key, string value) cerr << "for `" << key << "'" << endl; returnVal = false; } - } else if (key == "schedule") { + } else if (key == "bp_msg_schedule") { if ( value == "seq_fixed") { BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED); } else if (value == "seq_random") { @@ -250,28 +260,18 @@ setHorusFlag (string key, string value) cerr << "for `" << key << "'" << endl; returnVal = false; } - } else if (key == "accuracy") { + } else if (key == "bp_accuracy") { stringstream ss; double acc; ss << value; ss >> acc; BeliefProp::setAccuracy (acc); - } else if (key == "max_iter") { + } else if (key == "bp_max_iter") { stringstream ss; unsigned mi; ss << value; ss >> mi; BeliefProp::setMaxIterations (mi); - } else if (key == "use_logarithms") { - if ( value == "true") { - Globals::logDomain = true; - } else if (value == "false") { - Globals::logDomain = false; - } else { - cerr << "warning: invalid value `" << value << "' " ; - cerr << "for `" << key << "'" << endl; - returnVal = false; - } } else { cerr << "warning: invalid key `" << key << "'" << endl; returnVal = false;