support setting flags in horus cli and fix a bug when ordering the variables

This commit is contained in:
Tiago Gomes 2012-04-21 17:14:19 +01:00
parent 085ebe1e96
commit 8c689665a0
11 changed files with 260 additions and 189 deletions

View File

@ -24,6 +24,12 @@ CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
Statistics::updateCompressingStatistics (nrGroundVars,
nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless);
}
// cout << "uncompressed factor graph:" << endl;
// cout << " " << fg.nrVarNodes() << " variables " << endl;
// cout << " " << fg.nrFacNodes() << " factors " << endl;
// cout << "compressed factor graph:" << endl;
// cout << " " << fg_->nrVarNodes() << " variables " << endl;
// cout << " " << fg_->nrFacNodes() << " factors " << endl;
// Util::printHeader ("Compressed Factor Graph");
// fg_->print();
// Util::printHeader ("Uncompressed Factor Graph");
@ -60,7 +66,7 @@ CbpSolver::printSolverFlags (void) const
ss << ",max_iter=" << BpOptions::maxIter;
ss << ",accuracy=" << BpOptions::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",order_vars=" << Util::toString (FactorGraph::orderVariables);
ss << ",order_vars=" << Util::toString (FactorGraph::orderVars);
ss << ",chkif=" <<
Util::toString (CFactorGraph::checkForIdenticalFactors);
ss << "]" ;
@ -69,7 +75,6 @@ CbpSolver::printSolverFlags (void) const
Params
CbpSolver::getPosterioriOf (VarId vid)
{

View File

@ -4,7 +4,7 @@
#include "ElimGraph.h"
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
ElimHeuristic ElimGraph::elimHeuristic = MIN_NEIGHBORS;
ElimGraph::ElimGraph (const vector<Factor*>& factors)
@ -174,7 +174,7 @@ ElimGraph::getLowestCostNode (void) const
for (unsigned i = 0; i < nodes_.size(); i++) {
if (marked_[i]) continue;
unsigned cost = 0;
switch (elimHeuristic_) {
switch (elimHeuristic) {
case MIN_NEIGHBORS:
cost = getNeighborsCost (nodes_[i]);
break;

View File

@ -47,15 +47,7 @@ class ElimGraph
static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
static void setEliminationHeuristic (ElimHeuristic h)
{
elimHeuristic_ = h;
}
static ElimHeuristic getEliminationHeuristic (void)
{
return elimHeuristic_;
}
static ElimHeuristic elimHeuristic;
private:
@ -86,7 +78,6 @@ class ElimGraph
vector<EgNode*> nodes_;
vector<bool> marked_;
unordered_map<VarId, EgNode*> varMap_;
static ElimHeuristic elimHeuristic_;
};
#endif // HORUS_ELIMGRAPH_H

View File

@ -13,7 +13,7 @@
#include "Util.h"
bool FactorGraph::orderVariables = false;
bool FactorGraph::orderVars = false;
FactorGraph::FactorGraph (const FactorGraph& fg)
@ -186,14 +186,17 @@ void
FactorGraph::addFactor (const Factor& factor)
{
FacNode* fn = new FacNode (factor);
if (orderVars) {
fn->factor().reorderAccordingVarIds();
}
addFacNode (fn);
const VarIds& vids = factor.arguments();
const VarIds& vids = fn->factor().arguments();
for (unsigned i = 0; i < vids.size(); i++) {
VarMap::const_iterator it = varMap_.find (vids[i]);
if (it != varMap_.end()) {
addEdge (it->second, fn);
} else {
VarNode* vn = new VarNode (vids[i], factor.range (i));
VarNode* vn = new VarNode (vids[i], fn->factor().range (i));
addVarNode (vn);
addEdge (vn, fn);
}
@ -261,6 +264,7 @@ FactorGraph::getStructure (void)
void
FactorGraph::print (void) const
{
/*
for (unsigned i = 0; i < varNodes_.size(); i++) {
cout << "var id = " << varNodes_[i]->varId() << endl;
cout << "label = " << varNodes_[i]->label() << endl;
@ -272,6 +276,7 @@ FactorGraph::print (void) const
}
cout << endl << endl;
}
*/
for (unsigned i = 0; i < facNodes_.size(); i++) {
facNodes_[i]->factor().print();
}

View File

@ -117,7 +117,7 @@ class FactorGraph
void exportToLibDaiFormat (const char*) const;
static bool orderVariables;
static bool orderVars;
private:
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);

View File

@ -10,43 +10,68 @@
using namespace std;
void processArguments (FactorGraph&, int, const char* []);
int readHorusFlags (int, const char* []);
void readFactorGraph (FactorGraph&, const char*);
VarIds readQueryAndEvidence (FactorGraph&, int, const char* [], int);
void runSolver (const FactorGraph&, const VarIds&);
const string USAGE = "usage: \
./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
const string USAGE = "usage: ./hcli [HORUS_FLAG=VALUE] \
NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE] ..." ;
int
main (int argc, const char* argv[])
{
if (argc <= 1) {
cerr << "error: no solver specified" << endl;
cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl;
exit (0);
}
if (argc <= 2) {
cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl;
exit (0);
}
string solver (argv[1]);
if (solver == "ve") {
Globals::infAlgorithm = InfAlgorithms::VE;
} else if (solver == "bp") {
Globals::infAlgorithm = InfAlgorithms::BP;
} else if (solver == "cbp") {
Globals::infAlgorithm = InfAlgorithms::CBP;
} else {
cerr << "error: unknow solver `" << solver << "'" << endl ;
cerr << USAGE << endl;
exit(0);
}
string fileName (argv[2]);
string extension = fileName.substr (
fileName.find_last_of ('.') + 1);
int idx = readHorusFlags (argc, argv);
FactorGraph fg;
readFactorGraph (fg, argv[idx]);
VarIds queryIds = readQueryAndEvidence (fg, argc, argv, idx + 1);
runSolver (fg, queryIds);
return 0;
}
int
readHorusFlags (int argc, const char* argv[])
{
int i = 1;
for (; i < argc; i++) {
const string& arg = argv[i];
size_t pos = arg.find ('=');
if (pos == std::string::npos) {
return i;
}
string leftArg = arg.substr (0, pos);
string rightArg = arg.substr (pos + 1);
if (leftArg.empty()) {
cerr << "error: missing left argument" << endl;
cerr << USAGE << endl;
exit (0);
}
if (rightArg.empty()) {
cerr << "error: missing right argument" << endl;
cerr << USAGE << endl;
exit (0);
}
Util::setHorusFlag (leftArg, rightArg);
}
return i + 1;
}
void
readFactorGraph (FactorGraph& fg, const char* s)
{
string fileName (s);
string extension = fileName.substr (fileName.find_last_of ('.') + 1);
if (extension == "uai") {
fg.readFromUaiFormat (fileName.c_str());
} else if (extension == "fg") {
@ -56,90 +81,78 @@ main (int argc, const char* argv[])
cerr << "in a UAI or libDAI file" << endl;
exit (0);
}
processArguments (fg, argc, argv);
return 0;
}
void
processArguments (FactorGraph& fg, int argc, const char* argv[])
VarIds
readQueryAndEvidence (
FactorGraph& fg,
int argc,
const char* argv[],
int start)
{
VarIds queryIds;
for (int i = 3; i < argc; i++) {
for (int i = start; i < argc; i++) {
const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) {
if (!Util::isInteger (arg)) {
if (Util::isInteger (arg) == false) {
cerr << "error: `" << arg << "' " ;
cerr << "is not a valid variable id" ;
cerr << endl;
exit (0);
}
VarId vid;
stringstream ss;
ss << arg;
ss >> vid;
VarNode* queryVar = fg.getVarNode (vid);
if (queryVar) {
queryIds.push_back (vid);
} else {
cerr << "error: there isn't a variable with " ;
cerr << "`" << vid << "' as id" ;
cerr << endl;
exit (0);
}
} else {
size_t pos = arg.find ('=');
if (arg.substr (0, pos).empty()) {
cerr << "error: missing left argument" << endl;
cerr << USAGE << endl;
exit (0);
}
if (arg.substr (pos + 1).empty()) {
cerr << "error: missing right argument" << endl;
cerr << USAGE << endl;
exit (0);
}
if (!Util::isInteger (arg.substr (0, pos))) {
cerr << "error: `" << arg.substr (0, pos) << "' " ;
cerr << "is not a variable id" ;
cerr << endl;
exit (0);
}
VarId vid;
stringstream ss;
ss << arg.substr (0, pos);
ss >> vid;
VarNode* var = fg.getVarNode (vid);
if (var) {
if (!Util::isInteger (arg.substr (pos + 1))) {
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
cerr << "is not a state index" ;
cerr << endl;
exit (0);
}
int stateIndex;
stringstream ss;
ss << arg.substr (pos + 1);
ss >> stateIndex;
if (var->isValidState (stateIndex)) {
var->setEvidence (stateIndex);
} else {
cerr << "error: `" << stateIndex << "' " ;
cerr << "is not a valid state index for variable " ;
cerr << "`" << var->varId() << "'" ;
cerr << endl;
exit (0);
}
} else {
cerr << "error: there isn't a variable with " ;
cerr << "`" << vid << "' as id" ;
cerr << endl;
VarId vid = Util::stringToUnsigned (arg);
VarNode* queryVar = fg.getVarNode (vid);
if (queryVar == false) {
cerr << "error: unknow variable with id " ;
cerr << "`" << vid << "'" << endl;
exit (0);
}
queryIds.push_back (vid);
} else {
size_t pos = arg.find ('=');
string leftArg = arg.substr (0, pos);
string rightArg = arg.substr (pos + 1);
if (leftArg.empty()) {
cerr << "error: missing left argument" << endl;
cerr << USAGE << endl;
exit (0);
}
if (Util::isInteger (leftArg) == false) {
cerr << "error: `" << leftArg << "' " ;
cerr << "is not a variable id" << endl ;
exit (0);
continue;
}
VarId vid = Util::stringToUnsigned (leftArg);
VarNode* observedVar = fg.getVarNode (vid);
if (observedVar == false) {
cerr << "error: unknow variable with id " ;
cerr << "`" << vid << "'" << endl;
exit (0);
}
if (rightArg.empty()) {
cerr << "error: missing right argument" << endl;
cerr << USAGE << endl;
exit (0);
}
if (Util::isInteger (rightArg) == false) {
cerr << "error: `" << rightArg << "' " ;
cerr << "is not a state index" << endl ;
exit (0);
}
unsigned stateIdx = Util::stringToUnsigned (rightArg);
if (observedVar->isValidState (stateIdx) == false) {
cerr << "error: `" << stateIdx << "' " ;
cerr << "is not a valid state index for variable with id " ;
cerr << "`" << vid << "'" << endl;
exit (0);
}
observedVar->setEvidence (stateIdx);
}
}
runSolver (fg, queryIds);
return queryIds;
}
@ -161,6 +174,8 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
default:
assert (false);
}
solver->printSolverFlags();
cout << endl;
if (queryIds.size() == 0) {
solver->printAllPosterioris();
} else {

View File

@ -67,12 +67,12 @@ int createLiftedNetwork (void)
if (Constants::DEBUG > 2) {
Util::printHeader ("INITIAL PARFACTORS");
for (unsigned i = 0; i < parfactors.size(); i++) {
parfactors[i]->print();
parfactors[i]->print();
}
}
ParfactorList* pfList = new ParfactorList (parfactors);
if (Constants::DEBUG >= 2) {
Util::printHeader ("SHATTERED PARFACTORS");
pfList->print();
@ -402,10 +402,10 @@ void runBpSolver (
*fg, VarIds (vids.begin(),vids.end()));
}
if (Globals::infAlgorithm == InfAlgorithms::BP) {
solver = new BpSolver (*mfg);
solver = new BpSolver (*fg); // FIXME
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
CFactorGraph::checkForIdenticalFactors = false;
solver = new CbpSolver (*mfg);
solver = new CbpSolver (*fg); // FIXME
} else {
cerr << "error: unknow solver" << endl;
abort();
@ -507,80 +507,19 @@ int
setHorusFlag (void)
{
string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
if (key == "inf_alg") {
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
if ( value == "ve") {
Globals::infAlgorithm = InfAlgorithms::VE;
} else if (value == "bp") {
Globals::infAlgorithm = InfAlgorithms::BP;
} else if (value == "cbp") {
Globals::infAlgorithm = InfAlgorithms::CBP;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
return FALSE;
}
} else if (key == "elim_heuristic") {
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
if ( value == "min_neighbors") {
ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_NEIGHBORS);
} else if (value == "min_weight") {
ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_WEIGHT);
} else if (value == "min_fill") {
ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_FILL);
} else if (value == "weighted_min_fill") {
ElimGraph::setEliminationHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL);
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
return FALSE;
}
} else if (key == "schedule") {
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
if ( value == "seq_fixed") {
BpOptions::schedule = BpOptions::Schedule::SEQ_FIXED;
} else if (value == "seq_random") {
BpOptions::schedule = BpOptions::Schedule::SEQ_RANDOM;
} else if (value == "parallel") {
BpOptions::schedule = BpOptions::Schedule::PARALLEL;
} else if (value == "max_residual") {
BpOptions::schedule = BpOptions::Schedule::MAX_RESIDUAL;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
return FALSE;
}
} else if (key == "accuracy") {
BpOptions::accuracy = (double) YAP_FloatOfTerm (YAP_ARG2);
string value;
if (key == "accuracy") {
stringstream ss;
ss << (float) YAP_FloatOfTerm (YAP_ARG2);
ss >> value;
} else if (key == "max_iter") {
BpOptions::maxIter = (int) YAP_IntOfTerm (YAP_ARG2);
} else if (key == "use_logarithms") {
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
if ( value == "true") {
Globals::logDomain = true;
} else if (value == "false") {
Globals::logDomain = false;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
return FALSE;
}
} else if (key == "order_variables") {
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
if ( value == "true") {
FactorGraph::orderVariables = true;
} else if (value == "false") {
FactorGraph::orderVariables = false;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
return FALSE;
}
stringstream ss;
ss << (int) YAP_IntOfTerm (YAP_ARG2);
ss >> value;
} else {
cerr << "warning: invalid key `" << key << "'" << endl;
return FALSE;
value = ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
}
return TRUE;
return Util::setHorusFlag (key, value);
}

View File

@ -5,6 +5,7 @@
#include "Util.h"
#include "Indexer.h"
#include "ElimGraph.h"
namespace Globals {
@ -48,6 +49,34 @@ toString (const bool& b)
unsigned
stringToUnsigned (string str)
{
int val;
stringstream ss;
ss << str;
ss >> val;
if (val < 0) {
cerr << "error: the value tried to read is negative" << endl;
abort();
}
return static_cast<unsigned> (val);
}
double
stringToDouble (string str)
{
double val;
stringstream ss;
ss << str;
ss >> val;
return val;
}
void
toLog (Params& v)
{
@ -194,6 +223,87 @@ getStateLines (const Vars& vars)
bool
setHorusFlag (string key, string value)
{
bool returnVal = true;
if (key == "inf_alg") {
if ( value == "ve") {
Globals::infAlgorithm = InfAlgorithms::VE;
} else if (value == "bp") {
Globals::infAlgorithm = InfAlgorithms::BP;
} else if (value == "cbp") {
Globals::infAlgorithm = InfAlgorithms::CBP;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
returnVal = false;
}
} else if (key == "elim_heuristic") {
if ( value == "min_neighbors") {
ElimGraph::elimHeuristic = ElimHeuristic::MIN_NEIGHBORS;
} else if (value == "min_weight") {
ElimGraph::elimHeuristic = ElimHeuristic::MIN_WEIGHT;
} else if (value == "min_fill") {
ElimGraph::elimHeuristic = ElimHeuristic::MIN_FILL;
} else if (value == "weighted_min_fill") {
ElimGraph::elimHeuristic = ElimHeuristic::WEIGHTED_MIN_FILL;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
returnVal = false;
}
} else if (key == "schedule") {
if ( value == "seq_fixed") {
BpOptions::schedule = BpOptions::Schedule::SEQ_FIXED;
} else if (value == "seq_random") {
BpOptions::schedule = BpOptions::Schedule::SEQ_RANDOM;
} else if (value == "parallel") {
BpOptions::schedule = BpOptions::Schedule::PARALLEL;
} else if (value == "max_residual") {
BpOptions::schedule = BpOptions::Schedule::MAX_RESIDUAL;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
returnVal = false;
}
} else if (key == "accuracy") {
stringstream ss;
ss << value;
ss >> BpOptions::accuracy;
} else if (key == "max_iter") {
stringstream ss;
ss << value;
ss >> BpOptions::maxIter;
} else if (key == "use_logarithms") {
if ( value == "true") {
Globals::logDomain = true;
} else if (value == "false") {
Globals::logDomain = false;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
returnVal = false;
}
} else if (key == "order_vars") {
if ( value == "true") {
FactorGraph::orderVars = true;
} else if (value == "false") {
FactorGraph::orderVars = false;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
returnVal = false;
}
} else {
cerr << "warning: invalid key `" << key << "'" << endl;
returnVal = false;
}
return returnVal;
}
void
printHeader (string header, std::ostream& os)
{

View File

@ -38,6 +38,10 @@ template <typename T> std::string toString (const T&);
template <> std::string toString (const bool&);
unsigned stringToUnsigned (string);
double stringToDouble (string);
void toLog (Params&);
void fromLog (Params&);
@ -52,6 +56,8 @@ void add (Params&, const Params&);
void add (Params&, const Params&, unsigned);
unsigned maxUnsigned (void);
double factorial (unsigned);
double logFactorial (unsigned);
@ -68,6 +74,8 @@ string parametersToString (const Params&, unsigned = Constants::PRECISION);
vector<string> getStateLines (const Vars&);
bool setHorusFlag (string key, string value);
void printHeader (string, std::ostream& os = std::cout);
void printSubHeader (string, std::ostream& os = std::cout);
@ -76,8 +84,6 @@ void printAsteriskLine (std::ostream& os = std::cout);
void printDashedLine (std::ostream& os = std::cout);
unsigned maxUnsigned (void);
};

View File

@ -38,7 +38,7 @@ VarElimSolver::printSolverFlags (void) const
stringstream ss;
ss << "variable elimination [" ;
ss << "elim_heuristic=" ;
ElimHeuristic eh = ElimGraph::getEliminationHeuristic();
ElimHeuristic eh = ElimGraph::elimHeuristic;
switch (eh) {
case MIN_NEIGHBORS: ss << "min_neighbors"; break;
case MIN_WEIGHT: ss << "min_weight"; break;

View File

@ -59,8 +59,8 @@ set_solver(S) :- throw(error('unknow solver ', S)).
:- set_horus_flag(max_iter, 1000).
:- set_horus_flag(order_variables, false).
%:- set_horus_flag(order_variables, true).
:- set_horus_flag(order_vars, false).
%:- set_horus_flag(order_vars, true).
:- set_horus_flag(use_logarithms, false).
% :- set_horus_flag(use_logarithms, true).