219 lines
5.9 KiB
C++
219 lines
5.9 KiB
C++
#include <cassert>
|
|
|
|
#include <string>
|
|
#include <iostream>
|
|
|
|
#include "FactorGraph.h"
|
|
#include "VarElim.h"
|
|
#include "BeliefProp.h"
|
|
#include "CountingBp.h"
|
|
|
|
|
|
namespace {
|
|
|
|
int readHorusFlags (int, const char* []);
|
|
|
|
void readFactorGraph (Horus::FactorGraph&, const char*);
|
|
|
|
Horus::VarIds readQueryAndEvidence (
|
|
Horus::FactorGraph&, int, const char* [], int);
|
|
|
|
void runSolver (const Horus::FactorGraph&, const Horus::VarIds&);
|
|
|
|
const std::string usage = "usage: ./hcli [solver=hve|bp|cbp] \
|
|
[<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ;
|
|
|
|
}
|
|
|
|
|
|
|
|
int
|
|
main (int argc, const char* argv[])
|
|
{
|
|
if (argc <= 1) {
|
|
std::cerr << "Error: no probabilistic graphical model was given." ;
|
|
std::cerr << std::endl << usage << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
int idx = readHorusFlags (argc, argv);
|
|
Horus::FactorGraph fg;
|
|
readFactorGraph (fg, argv[idx]);
|
|
Horus::VarIds queryIds
|
|
= readQueryAndEvidence (fg, argc, argv, idx + 1);
|
|
if (Horus::FactorGraph::exportToLibDai()) {
|
|
fg.exportToLibDai ("model.fg");
|
|
}
|
|
if (Horus::FactorGraph::exportToUai()) {
|
|
fg.exportToUai ("model.uai");
|
|
}
|
|
if (Horus::FactorGraph::exportGraphViz()) {
|
|
fg.exportToGraphViz ("model.dot");
|
|
}
|
|
if (Horus::FactorGraph::printFactorGraph()) {
|
|
fg.print();
|
|
}
|
|
if (Horus::Globals::verbosity > 0) {
|
|
std::cout << "factor graph contains " ;
|
|
std::cout << fg.nrVarNodes() << " variables and " ;
|
|
std::cout << fg.nrFacNodes() << " factors " << std::endl;
|
|
}
|
|
runSolver (fg, queryIds);
|
|
return 0;
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
int
|
|
readHorusFlags (int argc, const char* argv[])
|
|
{
|
|
int i = 1;
|
|
for (; i < argc; i++) {
|
|
const std::string& arg = argv[i];
|
|
size_t pos = arg.find ('=');
|
|
if (pos == std::string::npos) {
|
|
return i;
|
|
}
|
|
std::string leftArg = arg.substr (0, pos);
|
|
std::string rightArg = arg.substr (pos + 1);
|
|
if (leftArg.empty()) {
|
|
std::cerr << "Error: missing left argument." << std::endl;
|
|
std::cerr << usage << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
if (rightArg.empty()) {
|
|
std::cerr << "Error: missing right argument." << std::endl;
|
|
std::cerr << usage << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
Horus::Util::setHorusFlag (leftArg, rightArg);
|
|
}
|
|
return i + 1;
|
|
}
|
|
|
|
|
|
|
|
void
|
|
readFactorGraph (Horus::FactorGraph& fg, const char* s)
|
|
{
|
|
std::string fileName (s);
|
|
std::string extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
|
if (extension == "uai") {
|
|
fg = Horus::FactorGraph::readFromUaiFormat (fileName.c_str());
|
|
} else if (extension == "fg") {
|
|
fg = Horus::FactorGraph::readFromLibDaiFormat (fileName.c_str());
|
|
} else {
|
|
std::cerr << "Error: the probabilistic graphical model must be " ;
|
|
std::cerr << "defined either in a UAI or libDAI file." << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
Horus::VarIds
|
|
readQueryAndEvidence (
|
|
Horus::FactorGraph& fg,
|
|
int argc,
|
|
const char* argv[],
|
|
int start)
|
|
{
|
|
Horus::VarIds queryIds;
|
|
for (int i = start; i < argc; i++) {
|
|
const std::string& arg = argv[i];
|
|
if (arg.find ('=') == std::string::npos) {
|
|
if (Horus::Util::isInteger (arg) == false) {
|
|
std::cerr << "Error: `" << arg << "' " ;
|
|
std::cerr << "is not a variable id." ;
|
|
std::cerr << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
Horus::VarId vid = Horus::Util::stringToUnsigned (arg);
|
|
Horus::VarNode* queryVar = fg.getVarNode (vid);
|
|
if (queryVar == nullptr) {
|
|
std::cerr << "Error: unknow variable with id " ;
|
|
std::cerr << "`" << vid << "'." << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
queryIds.push_back (vid);
|
|
} else {
|
|
size_t pos = arg.find ('=');
|
|
std::string leftArg = arg.substr (0, pos);
|
|
std::string rightArg = arg.substr (pos + 1);
|
|
if (leftArg.empty()) {
|
|
std::cerr << "Error: missing left argument." << std::endl;
|
|
std::cerr << usage << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
if (Horus::Util::isInteger (leftArg) == false) {
|
|
std::cerr << "Error: `" << leftArg << "' " ;
|
|
std::cerr << "is not a variable id." << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
Horus::VarId vid = Horus::Util::stringToUnsigned (leftArg);
|
|
Horus::VarNode* observedVar = fg.getVarNode (vid);
|
|
if (observedVar == nullptr) {
|
|
std::cerr << "Error: unknow variable with id " ;
|
|
std::cerr << "`" << vid << "'." << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
if (rightArg.empty()) {
|
|
std::cerr << "Error: missing right argument." << std::endl;
|
|
std::cerr << usage << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
if (Horus::Util::isInteger (rightArg) == false) {
|
|
std::cerr << "Error: `" << rightArg << "' " ;
|
|
std::cerr << "is not a state index." << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
unsigned stateIdx = Horus::Util::stringToUnsigned (rightArg);
|
|
if (observedVar->isValidState (stateIdx) == false) {
|
|
std::cerr << "Error: `" << stateIdx << "' " ;
|
|
std::cerr << "is not a valid state index for variable with id " ;
|
|
std::cerr << "`" << vid << "'." << std::endl;
|
|
exit (EXIT_FAILURE);
|
|
}
|
|
observedVar->setEvidence (stateIdx);
|
|
}
|
|
}
|
|
return queryIds;
|
|
}
|
|
|
|
|
|
|
|
void
|
|
runSolver (
|
|
const Horus::FactorGraph& fg,
|
|
const Horus::VarIds& queryIds)
|
|
{
|
|
Horus::GroundSolver* solver = 0;
|
|
switch (Horus::Globals::groundSolver) {
|
|
case Horus::GroundSolverType::veSolver:
|
|
solver = new Horus::VarElim (fg);
|
|
break;
|
|
case Horus::GroundSolverType::bpSolver:
|
|
solver = new Horus::BeliefProp (fg);
|
|
break;
|
|
case Horus::GroundSolverType::CbpSolver:
|
|
solver = new Horus::CountingBp (fg);
|
|
break;
|
|
default:
|
|
assert (false);
|
|
}
|
|
if (Horus::Globals::verbosity > 0) {
|
|
solver->printSolverFlags();
|
|
std::cout << std::endl;
|
|
}
|
|
if (queryIds.empty()) {
|
|
solver->printAllPosterioris();
|
|
} else {
|
|
solver->printAnswer (queryIds);
|
|
}
|
|
delete solver;
|
|
}
|
|
|
|
}
|
|
|