This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/HorusCli.cpp

219 lines
5.9 KiB
C++
Raw Normal View History

2013-02-07 20:09:10 +00:00
#include <cassert>
#include <string>
#include <iostream>
2012-05-23 14:56:01 +01:00
#include "FactorGraph.h"
#include "VarElim.h"
#include "BeliefProp.h"
#include "CountingBp.h"
2012-05-23 14:56:01 +01:00
namespace {
2012-05-23 14:56:01 +01:00
int readHorusFlags (int, const char* []);
2013-02-07 13:37:15 +00:00
void readFactorGraph (Horus::FactorGraph&, const char*);
2013-02-07 13:37:15 +00:00
Horus::VarIds readQueryAndEvidence (
Horus::FactorGraph&, int, const char* [], int);
2012-05-23 14:56:01 +01:00
void runSolver (const Horus::FactorGraph&, const Horus::VarIds&);
2012-05-23 14:56:01 +01:00
const std::string usage = "usage: ./hcli [solver=hve|bp|cbp] \
[<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ;
2012-05-23 14:56:01 +01:00
}
2012-05-23 14:56:01 +01:00
int
main (int argc, const char* argv[])
{
if (argc <= 1) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: no probabilistic graphical model was given." ;
std::cerr << std::endl << usage << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
int idx = readHorusFlags (argc, argv);
Horus::FactorGraph fg;
2012-05-23 14:56:01 +01:00
readFactorGraph (fg, argv[idx]);
Horus::VarIds queryIds
2013-02-07 23:53:13 +00:00
= readQueryAndEvidence (fg, argc, argv, idx + 1);
if (Horus::FactorGraph::exportToLibDai()) {
fg.exportToLibDai ("model.fg");
}
if (Horus::FactorGraph::exportToUai()) {
fg.exportToUai ("model.uai");
}
if (Horus::FactorGraph::exportGraphViz()) {
fg.exportToGraphViz ("model.dot");
}
if (Horus::FactorGraph::printFactorGraph()) {
fg.print();
}
if (Horus::Globals::verbosity > 0) {
2013-02-07 13:37:15 +00:00
std::cout << "factor graph contains " ;
std::cout << fg.nrVarNodes() << " variables and " ;
std::cout << fg.nrFacNodes() << " factors " << std::endl;
}
2012-05-23 14:56:01 +01:00
runSolver (fg, queryIds);
return 0;
}
namespace {
2012-05-23 14:56:01 +01:00
int
readHorusFlags (int argc, const char* argv[])
{
int i = 1;
for (; i < argc; i++) {
2013-02-07 13:37:15 +00:00
const std::string& arg = argv[i];
2012-05-23 14:56:01 +01:00
size_t pos = arg.find ('=');
if (pos == std::string::npos) {
return i;
}
2013-02-07 13:37:15 +00:00
std::string leftArg = arg.substr (0, pos);
std::string rightArg = arg.substr (pos + 1);
2012-05-23 14:56:01 +01:00
if (leftArg.empty()) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: missing left argument." << std::endl;
std::cerr << usage << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
if (rightArg.empty()) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: missing right argument." << std::endl;
std::cerr << usage << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
Horus::Util::setHorusFlag (leftArg, rightArg);
2012-05-23 14:56:01 +01:00
}
return i + 1;
}
void
readFactorGraph (Horus::FactorGraph& fg, const char* s)
2012-05-23 14:56:01 +01:00
{
2013-02-07 13:37:15 +00:00
std::string fileName (s);
std::string extension = fileName.substr (fileName.find_last_of ('.') + 1);
2012-05-23 14:56:01 +01:00
if (extension == "uai") {
fg = Horus::FactorGraph::readFromUaiFormat (fileName.c_str());
2012-05-23 14:56:01 +01:00
} else if (extension == "fg") {
fg = Horus::FactorGraph::readFromLibDaiFormat (fileName.c_str());
2012-05-23 14:56:01 +01:00
} else {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: the probabilistic graphical model must be " ;
std::cerr << "defined either in a UAI or libDAI file." << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
}
Horus::VarIds
2012-05-23 14:56:01 +01:00
readQueryAndEvidence (
Horus::FactorGraph& fg,
2012-05-23 14:56:01 +01:00
int argc,
const char* argv[],
int start)
{
Horus::VarIds queryIds;
2012-05-23 14:56:01 +01:00
for (int i = start; i < argc; i++) {
2013-02-07 13:37:15 +00:00
const std::string& arg = argv[i];
2012-05-23 14:56:01 +01:00
if (arg.find ('=') == std::string::npos) {
if (Horus::Util::isInteger (arg) == false) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: `" << arg << "' " ;
std::cerr << "is not a variable id." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
Horus::VarId vid = Horus::Util::stringToUnsigned (arg);
Horus::VarNode* queryVar = fg.getVarNode (vid);
2014-03-14 23:03:22 +00:00
if (queryVar == nullptr) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: unknow variable with id " ;
std::cerr << "`" << vid << "'." << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
queryIds.push_back (vid);
} else {
size_t pos = arg.find ('=');
2013-02-07 13:37:15 +00:00
std::string leftArg = arg.substr (0, pos);
std::string rightArg = arg.substr (pos + 1);
2012-05-23 14:56:01 +01:00
if (leftArg.empty()) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: missing left argument." << std::endl;
std::cerr << usage << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
if (Horus::Util::isInteger (leftArg) == false) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: `" << leftArg << "' " ;
std::cerr << "is not a variable id." << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
Horus::VarId vid = Horus::Util::stringToUnsigned (leftArg);
Horus::VarNode* observedVar = fg.getVarNode (vid);
2014-03-14 23:03:22 +00:00
if (observedVar == nullptr) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: unknow variable with id " ;
std::cerr << "`" << vid << "'." << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
if (rightArg.empty()) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: missing right argument." << std::endl;
std::cerr << usage << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
if (Horus::Util::isInteger (rightArg) == false) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: `" << rightArg << "' " ;
std::cerr << "is not a state index." << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
unsigned stateIdx = Horus::Util::stringToUnsigned (rightArg);
2012-05-23 14:56:01 +01:00
if (observedVar->isValidState (stateIdx) == false) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: `" << stateIdx << "' " ;
std::cerr << "is not a valid state index for variable with id " ;
std::cerr << "`" << vid << "'." << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
observedVar->setEvidence (stateIdx);
}
}
return queryIds;
}
void
2013-02-07 23:53:13 +00:00
runSolver (
const Horus::FactorGraph& fg,
const Horus::VarIds& queryIds)
2012-05-23 14:56:01 +01:00
{
Horus::GroundSolver* solver = 0;
switch (Horus::Globals::groundSolver) {
case Horus::GroundSolverType::veSolver:
solver = new Horus::VarElim (fg);
2012-05-23 14:56:01 +01:00
break;
case Horus::GroundSolverType::bpSolver:
solver = new Horus::BeliefProp (fg);
2012-05-23 14:56:01 +01:00
break;
case Horus::GroundSolverType::CbpSolver:
solver = new Horus::CountingBp (fg);
2012-05-23 14:56:01 +01:00
break;
default:
assert (false);
}
if (Horus::Globals::verbosity > 0) {
2012-05-23 14:56:01 +01:00
solver->printSolverFlags();
2013-02-07 13:37:15 +00:00
std::cout << std::endl;
2012-05-23 14:56:01 +01:00
}
2012-05-28 18:26:15 +01:00
if (queryIds.empty()) {
2012-05-23 14:56:01 +01:00
solver->printAllPosterioris();
} else {
solver->printAnswer (queryIds);
}
delete solver;
}
}