This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/HorusCli.cpp

188 lines
4.8 KiB
C++
Raw Normal View History

2012-05-23 14:56:01 +01:00
#include <cstdlib>
#include <iostream>
#include <sstream>
#include "FactorGraph.h"
#include "VarElim.h"
#include "BeliefProp.h"
#include "CountingBp.h"
2012-05-23 14:56:01 +01:00
using namespace std;
int readHorusFlags (int, const char* []);
void readFactorGraph (FactorGraph&, const char*);
VarIds readQueryAndEvidence (FactorGraph&, int, const char* [], int);
void runSolver (const FactorGraph&, const VarIds&);
const string USAGE = "usage: ./hcli [HORUS_FLAG=VALUE] \
2012-12-20 18:07:50 +00:00
MODEL_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE] ..." ;
2012-05-23 14:56:01 +01:00
int
main (int argc, const char* argv[])
{
if (argc <= 1) {
2012-12-20 18:07:50 +00:00
cerr << "Error: no probabilistic graphical model was given." << endl;
2012-05-23 14:56:01 +01:00
cerr << USAGE << endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
int idx = readHorusFlags (argc, argv);
FactorGraph fg;
readFactorGraph (fg, argv[idx]);
VarIds queryIds = readQueryAndEvidence (fg, argc, argv, idx + 1);
runSolver (fg, queryIds);
return 0;
}
int
readHorusFlags (int argc, const char* argv[])
{
int i = 1;
for (; i < argc; i++) {
const string& arg = argv[i];
size_t pos = arg.find ('=');
if (pos == std::string::npos) {
return i;
}
string leftArg = arg.substr (0, pos);
string rightArg = arg.substr (pos + 1);
if (leftArg.empty()) {
2012-12-20 18:07:50 +00:00
cerr << "Error: missing left argument." << endl;
2012-05-23 14:56:01 +01:00
cerr << USAGE << endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
if (rightArg.empty()) {
2012-12-20 18:07:50 +00:00
cerr << "Error: missing right argument." << endl;
2012-05-23 14:56:01 +01:00
cerr << USAGE << endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
Util::setHorusFlag (leftArg, rightArg);
}
return i + 1;
}
void
readFactorGraph (FactorGraph& fg, const char* s)
{
string fileName (s);
string extension = fileName.substr (fileName.find_last_of ('.') + 1);
if (extension == "uai") {
fg.readFromUaiFormat (fileName.c_str());
} else if (extension == "fg") {
fg.readFromLibDaiFormat (fileName.c_str());
} else {
2012-12-20 18:07:50 +00:00
cerr << "Error: the probabilistic graphical model must be " ;
cerr << "defined either in a UAI or libDAI file." << endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
}
VarIds
readQueryAndEvidence (
FactorGraph& fg,
int argc,
const char* argv[],
int start)
{
VarIds queryIds;
for (int i = start; i < argc; i++) {
const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) {
if (Util::isInteger (arg) == false) {
2012-12-20 18:07:50 +00:00
cerr << "Error: `" << arg << "' " ;
cerr << "is not a variable id." ;
2012-05-23 14:56:01 +01:00
cerr << endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
VarId vid = Util::stringToUnsigned (arg);
VarNode* queryVar = fg.getVarNode (vid);
if (queryVar == false) {
2012-12-20 18:07:50 +00:00
cerr << "Error: unknow variable with id " ;
cerr << "`" << vid << "'." << endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
queryIds.push_back (vid);
} else {
size_t pos = arg.find ('=');
string leftArg = arg.substr (0, pos);
string rightArg = arg.substr (pos + 1);
if (leftArg.empty()) {
2012-12-20 18:07:50 +00:00
cerr << "Error: missing left argument." << endl;
2012-05-23 14:56:01 +01:00
cerr << USAGE << endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
if (Util::isInteger (leftArg) == false) {
2012-12-20 18:07:50 +00:00
cerr << "Error: `" << leftArg << "' " ;
cerr << "is not a variable id." << endl ;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
VarId vid = Util::stringToUnsigned (leftArg);
VarNode* observedVar = fg.getVarNode (vid);
if (observedVar == false) {
2012-12-20 18:07:50 +00:00
cerr << "Error: unknow variable with id " ;
cerr << "`" << vid << "'." << endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
if (rightArg.empty()) {
2012-12-20 18:07:50 +00:00
cerr << "Error: missing right argument." << endl;
2012-05-23 14:56:01 +01:00
cerr << USAGE << endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
if (Util::isInteger (rightArg) == false) {
2012-12-20 18:07:50 +00:00
cerr << "Error: `" << rightArg << "' " ;
cerr << "is not a state index." << endl ;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
unsigned stateIdx = Util::stringToUnsigned (rightArg);
if (observedVar->isValidState (stateIdx) == false) {
2012-12-20 18:07:50 +00:00
cerr << "Error: `" << stateIdx << "' " ;
2012-05-23 14:56:01 +01:00
cerr << "is not a valid state index for variable with id " ;
2012-12-20 18:07:50 +00:00
cerr << "`" << vid << "'." << endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
observedVar->setEvidence (stateIdx);
}
}
return queryIds;
}
void
runSolver (const FactorGraph& fg, const VarIds& queryIds)
{
2012-11-14 21:55:51 +00:00
GroundSolver* solver = 0;
switch (Globals::groundSolver) {
2012-11-14 21:55:51 +00:00
case GroundSolverType::VE:
solver = new VarElim (fg);
2012-05-23 14:56:01 +01:00
break;
2012-11-14 21:55:51 +00:00
case GroundSolverType::BP:
solver = new BeliefProp (fg);
2012-05-23 14:56:01 +01:00
break;
2012-11-14 21:55:51 +00:00
case GroundSolverType::CBP:
solver = new CountingBp (fg);
2012-05-23 14:56:01 +01:00
break;
default:
assert (false);
}
if (Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
}
2012-05-28 18:26:15 +01:00
if (queryIds.empty()) {
2012-05-23 14:56:01 +01:00
solver->printAllPosterioris();
} else {
solver->printAnswer (queryIds);
}
delete solver;
}