This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/HorusCli.cpp
Tiago Gomes ef4ebb4d7f Use camel case for constants and enumerators.
All capitals case should be reserved for macros and besides there is no big need to emphasize constness in general.
2013-02-13 18:54:15 +00:00

210 lines
5.8 KiB
C++

#include <cassert>
#include <string>
#include <iostream>
#include "FactorGraph.h"
#include "VarElim.h"
#include "BeliefProp.h"
#include "CountingBp.h"
int readHorusFlags (int, const char* []);
void readFactorGraph (Horus::FactorGraph&, const char*);
Horus::VarIds readQueryAndEvidence (
Horus::FactorGraph&, int, const char* [], int);
void runSolver (const Horus::FactorGraph&, const Horus::VarIds&);
const std::string usage = "usage: ./hcli [solver=hve|bp|cbp] \
[<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ;
int
main (int argc, const char* argv[])
{
if (argc <= 1) {
std::cerr << "Error: no probabilistic graphical model was given." ;
std::cerr << std::endl << usage << std::endl;
exit (EXIT_FAILURE);
}
int idx = readHorusFlags (argc, argv);
Horus::FactorGraph fg;
readFactorGraph (fg, argv[idx]);
Horus::VarIds queryIds
= readQueryAndEvidence (fg, argc, argv, idx + 1);
if (Horus::FactorGraph::exportToLibDai()) {
fg.exportToLibDai ("model.fg");
}
if (Horus::FactorGraph::exportToUai()) {
fg.exportToUai ("model.uai");
}
if (Horus::FactorGraph::exportGraphViz()) {
fg.exportToGraphViz ("model.dot");
}
if (Horus::FactorGraph::printFactorGraph()) {
fg.print();
}
if (Horus::Globals::verbosity > 0) {
std::cout << "factor graph contains " ;
std::cout << fg.nrVarNodes() << " variables and " ;
std::cout << fg.nrFacNodes() << " factors " << std::endl;
}
runSolver (fg, queryIds);
return 0;
}
int
readHorusFlags (int argc, const char* argv[])
{
int i = 1;
for (; i < argc; i++) {
const std::string& arg = argv[i];
size_t pos = arg.find ('=');
if (pos == std::string::npos) {
return i;
}
std::string leftArg = arg.substr (0, pos);
std::string rightArg = arg.substr (pos + 1);
if (leftArg.empty()) {
std::cerr << "Error: missing left argument." << std::endl;
std::cerr << usage << std::endl;
exit (EXIT_FAILURE);
}
if (rightArg.empty()) {
std::cerr << "Error: missing right argument." << std::endl;
std::cerr << usage << std::endl;
exit (EXIT_FAILURE);
}
Horus::Util::setHorusFlag (leftArg, rightArg);
}
return i + 1;
}
void
readFactorGraph (Horus::FactorGraph& fg, const char* s)
{
std::string fileName (s);
std::string extension = fileName.substr (fileName.find_last_of ('.') + 1);
if (extension == "uai") {
fg.readFromUaiFormat (fileName.c_str());
} else if (extension == "fg") {
fg.readFromLibDaiFormat (fileName.c_str());
} else {
std::cerr << "Error: the probabilistic graphical model must be " ;
std::cerr << "defined either in a UAI or libDAI file." << std::endl;
exit (EXIT_FAILURE);
}
}
Horus::VarIds
readQueryAndEvidence (
Horus::FactorGraph& fg,
int argc,
const char* argv[],
int start)
{
Horus::VarIds queryIds;
for (int i = start; i < argc; i++) {
const std::string& arg = argv[i];
if (arg.find ('=') == std::string::npos) {
if (Horus::Util::isInteger (arg) == false) {
std::cerr << "Error: `" << arg << "' " ;
std::cerr << "is not a variable id." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
Horus::VarId vid = Horus::Util::stringToUnsigned (arg);
Horus::VarNode* queryVar = fg.getVarNode (vid);
if (queryVar == false) {
std::cerr << "Error: unknow variable with id " ;
std::cerr << "`" << vid << "'." << std::endl;
exit (EXIT_FAILURE);
}
queryIds.push_back (vid);
} else {
size_t pos = arg.find ('=');
std::string leftArg = arg.substr (0, pos);
std::string rightArg = arg.substr (pos + 1);
if (leftArg.empty()) {
std::cerr << "Error: missing left argument." << std::endl;
std::cerr << usage << std::endl;
exit (EXIT_FAILURE);
}
if (Horus::Util::isInteger (leftArg) == false) {
std::cerr << "Error: `" << leftArg << "' " ;
std::cerr << "is not a variable id." << std::endl;
exit (EXIT_FAILURE);
}
Horus::VarId vid = Horus::Util::stringToUnsigned (leftArg);
Horus::VarNode* observedVar = fg.getVarNode (vid);
if (observedVar == false) {
std::cerr << "Error: unknow variable with id " ;
std::cerr << "`" << vid << "'." << std::endl;
exit (EXIT_FAILURE);
}
if (rightArg.empty()) {
std::cerr << "Error: missing right argument." << std::endl;
std::cerr << usage << std::endl;
exit (EXIT_FAILURE);
}
if (Horus::Util::isInteger (rightArg) == false) {
std::cerr << "Error: `" << rightArg << "' " ;
std::cerr << "is not a state index." << std::endl;
exit (EXIT_FAILURE);
}
unsigned stateIdx = Horus::Util::stringToUnsigned (rightArg);
if (observedVar->isValidState (stateIdx) == false) {
std::cerr << "Error: `" << stateIdx << "' " ;
std::cerr << "is not a valid state index for variable with id " ;
std::cerr << "`" << vid << "'." << std::endl;
exit (EXIT_FAILURE);
}
observedVar->setEvidence (stateIdx);
}
}
return queryIds;
}
void
runSolver (
const Horus::FactorGraph& fg,
const Horus::VarIds& queryIds)
{
Horus::GroundSolver* solver = 0;
switch (Horus::Globals::groundSolver) {
case Horus::GroundSolverType::veSolver:
solver = new Horus::VarElim (fg);
break;
case Horus::GroundSolverType::bpSolver:
solver = new Horus::BeliefProp (fg);
break;
case Horus::GroundSolverType::CbpSolver:
solver = new Horus::CountingBp (fg);
break;
default:
assert (false);
}
if (Horus::Globals::verbosity > 0) {
solver->printSolverFlags();
std::cout << std::endl;
}
if (queryIds.empty()) {
solver->printAllPosterioris();
} else {
solver->printAnswer (queryIds);
}
delete solver;
}