Improve solver flags

This commit is contained in:
Tiago Gomes 2012-12-27 23:21:32 +00:00
parent d36b63ece3
commit 7d9af75c35
7 changed files with 49 additions and 40 deletions

View File

@ -34,15 +34,23 @@ warning :-
set_horus_flag(K,V) :- cpp_set_horus_flag(K,V). set_horus_flag(K,V) :- cpp_set_horus_flag(K,V).
:- cpp_set_horus_flag(schedule, seq_fixed). :- cpp_set_horus_flag(verbosity, 0).
%:- cpp_set_horus_flag(schedule, seq_random).
%:- cpp_set_horus_flag(schedule, parallel).
%:- cpp_set_horus_flag(schedule, max_residual).
:- cpp_set_horus_flag(accuracy, 0.0001).
:- cpp_set_horus_flag(max_iter, 1000).
:- cpp_set_horus_flag(use_logarithms, false). :- cpp_set_horus_flag(use_logarithms, false).
% :- cpp_set_horus_flag(use_logarithms, true). %:- cpp_set_horus_flag(use_logarithms, true).
%:- cpp_set_horus_flag(hve_elim_heuristic, sequential).
%:- cpp_set_horus_flag(hve_elim_heuristic, min_neighbors).
%:- cpp_set_horus_flag(hve_elim_heuristic, min_weight).
%:- cpp_set_horus_flag(hve_elim_heuristic, min_fill).
:- cpp_set_horus_flag(hve_elim_heuristic, weighted_min_fill).
:- cpp_set_horus_flag(bp_msg_schedule, seq_fixed).
%:- cpp_set_horus_flag(bp_msg_schedule, seq_random).
%:- cpp_set_horus_flag(bp_msg_schedule, parallel).
%:- cpp_set_horus_flag(bp_msg_schedule, max_residual).
:- cpp_set_horus_flag(bp_accuracy, 0.0001).
:- cpp_set_horus_flag(bp_max_iter, 1000).

View File

@ -52,7 +52,7 @@ BeliefProp::printSolverFlags (void) const
{ {
stringstream ss; stringstream ss;
ss << "belief propagation [" ; ss << "belief propagation [" ;
ss << "schedule=" ; ss << "msg_schedule=" ;
switch (schedule_) { switch (schedule_) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;

View File

@ -2,7 +2,7 @@
#include "WeightedBp.h" #include "WeightedBp.h"
bool CountingBp::checkForIdenticalFactors = true; bool CountingBp::fif_ = true;
CountingBp::CountingBp (const FactorGraph& fg) CountingBp::CountingBp (const FactorGraph& fg)
@ -36,7 +36,7 @@ CountingBp::printSolverFlags (void) const
{ {
stringstream ss; stringstream ss;
ss << "counting bp [" ; ss << "counting bp [" ;
ss << "schedule=" ; ss << "msg_schedule=" ;
switch (WeightedBp::msgSchedule()) { switch (WeightedBp::msgSchedule()) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
@ -46,7 +46,7 @@ CountingBp::printSolverFlags (void) const
ss << ",max_iter=" << WeightedBp::maxIterations(); ss << ",max_iter=" << WeightedBp::maxIterations();
ss << ",accuracy=" << WeightedBp::accuracy(); ss << ",accuracy=" << WeightedBp::accuracy();
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",chkif=" << Util::toString (CountingBp::checkForIdenticalFactors); ss << ",fif=" << Util::toString (CountingBp::fif_);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; cout << ss.str() << endl;
} }
@ -93,8 +93,7 @@ void
CountingBp::findIdenticalFactors() CountingBp::findIdenticalFactors()
{ {
const FacNodes& facNodes = fg.facNodes(); const FacNodes& facNodes = fg.facNodes();
if (checkForIdenticalFactors == false || if (fif_ == false || facNodes.size() == 1) {
facNodes.size() == 1) {
return; return;
} }
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {

View File

@ -116,7 +116,7 @@ class CountingBp : public GroundSolver
Params solveQuery (VarIds); Params solveQuery (VarIds);
static bool checkForIdenticalFactors; static void setFindIdenticalFactorsFlag (bool fif) { fif_ = fif; }
private: private:
Color getNewColor (void) Color getNewColor (void)
@ -179,6 +179,8 @@ class CountingBp : public GroundSolver
const FactorGraph* compressedFg_; const FactorGraph* compressedFg_;
WeightedBp* solver_; WeightedBp* solver_;
static bool fif_;
DISALLOW_COPY_AND_ASSIGN (CountingBp); DISALLOW_COPY_AND_ASSIGN (CountingBp);
}; };

View File

@ -200,7 +200,7 @@ runGroundSolver (void)
} }
GroundSolver* solver = 0; GroundSolver* solver = 0;
CountingBp::checkForIdenticalFactors = false; CountingBp::setFindIdenticalFactorsFlag (false);
switch (Globals::groundSolver) { switch (Globals::groundSolver) {
case GroundSolverType::VE: solver = new VarElim (*mfg); break; case GroundSolverType::VE: solver = new VarElim (*mfg); break;
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break; case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
@ -320,11 +320,11 @@ setHorusFlag (void)
stringstream ss; stringstream ss;
ss << (int) YAP_IntOfTerm (YAP_ARG2); ss << (int) YAP_IntOfTerm (YAP_ARG2);
ss >> value; ss >> value;
} else if (key == "accuracy") { } else if (key == "bp_accuracy") {
stringstream ss; stringstream ss;
ss << (float) YAP_FloatOfTerm (YAP_ARG2); ss << (float) YAP_FloatOfTerm (YAP_ARG2);
ss >> value; ss >> value;
} else if (key == "max_iter") { } else if (key == "bp_max_iter") {
stringstream ss; stringstream ss;
ss << (int) YAP_IntOfTerm (YAP_ARG2); ss << (int) YAP_IntOfTerm (YAP_ARG2);
ss >> value; ss >> value;

View File

@ -62,7 +62,7 @@ LiftedBp::printSolverFlags (void) const
{ {
stringstream ss; stringstream ss;
ss << "lifted bp [" ; ss << "lifted bp [" ;
ss << "schedule=" ; ss << "msg_schedule=" ;
switch (WeightedBp::msgSchedule()) { switch (WeightedBp::msgSchedule()) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;

View File

@ -192,11 +192,7 @@ bool
setHorusFlag (string key, string value) setHorusFlag (string key, string value)
{ {
bool returnVal = true; bool returnVal = true;
if (key == "verbosity") { if ( key == "lifted_solver") {
stringstream ss;
ss << value;
ss >> Globals::verbosity;
} else if (key == "lifted_solver") {
if ( value == "lve") { if ( value == "lve") {
Globals::liftedSolver = LiftedSolverType::LVE; Globals::liftedSolver = LiftedSolverType::LVE;
} else if (value == "lbp") { } else if (value == "lbp") {
@ -209,7 +205,7 @@ setHorusFlag (string key, string value)
returnVal = false; returnVal = false;
} }
} else if (key == "ground_solver") { } else if (key == "ground_solver") {
if ( value == "ve") { if ( value == "ve" || value == "hve") {
Globals::groundSolver = GroundSolverType::VE; Globals::groundSolver = GroundSolverType::VE;
} else if (value == "bp") { } else if (value == "bp") {
Globals::groundSolver = GroundSolverType::BP; Globals::groundSolver = GroundSolverType::BP;
@ -220,7 +216,21 @@ setHorusFlag (string key, string value)
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;
returnVal = false; returnVal = false;
} }
} else if (key == "elim_heuristic") { } else if (key == "verbosity") {
stringstream ss;
ss << value;
ss >> Globals::verbosity;
} else if (key == "use_logarithms") {
if ( value == "true") {
Globals::logDomain = true;
} else if (value == "false") {
Globals::logDomain = false;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
returnVal = false;
}
} else if (key == "ve_elim_heuristic" || key == "hve_elim_heuristic") {
if ( value == "sequential") { if ( value == "sequential") {
ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL); ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL);
} else if (value == "min_neighbors") { } else if (value == "min_neighbors") {
@ -236,7 +246,7 @@ setHorusFlag (string key, string value)
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;
returnVal = false; returnVal = false;
} }
} else if (key == "schedule") { } else if (key == "bp_msg_schedule") {
if ( value == "seq_fixed") { if ( value == "seq_fixed") {
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED); BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED);
} else if (value == "seq_random") { } else if (value == "seq_random") {
@ -250,28 +260,18 @@ setHorusFlag (string key, string value)
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;
returnVal = false; returnVal = false;
} }
} else if (key == "accuracy") { } else if (key == "bp_accuracy") {
stringstream ss; stringstream ss;
double acc; double acc;
ss << value; ss << value;
ss >> acc; ss >> acc;
BeliefProp::setAccuracy (acc); BeliefProp::setAccuracy (acc);
} else if (key == "max_iter") { } else if (key == "bp_max_iter") {
stringstream ss; stringstream ss;
unsigned mi; unsigned mi;
ss << value; ss << value;
ss >> mi; ss >> mi;
BeliefProp::setMaxIterations (mi); BeliefProp::setMaxIterations (mi);
} else if (key == "use_logarithms") {
if ( value == "true") {
Globals::logDomain = true;
} else if (value == "false") {
Globals::logDomain = false;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
returnVal = false;
}
} else { } else {
cerr << "warning: invalid key `" << key << "'" << endl; cerr << "warning: invalid key `" << key << "'" << endl;
returnVal = false; returnVal = false;