2011-05-17 12:00:33 +01:00
|
|
|
#include <cstdlib>
|
2011-07-22 21:33:30 +01:00
|
|
|
|
|
|
|
#include <iostream>
|
2011-05-17 12:00:33 +01:00
|
|
|
#include <sstream>
|
|
|
|
|
|
|
|
#include "FactorGraph.h"
|
2011-12-12 15:29:51 +00:00
|
|
|
#include "VarElimSolver.h"
|
2012-04-05 23:00:48 +01:00
|
|
|
#include "BpSolver.h"
|
2011-12-12 15:29:51 +00:00
|
|
|
#include "CbpSolver.h"
|
|
|
|
|
2011-05-17 12:00:33 +01:00
|
|
|
using namespace std;
|
|
|
|
|
2011-12-12 15:29:51 +00:00
|
|
|
void processArguments (FactorGraph&, int, const char* []);
|
2012-04-05 23:00:48 +01:00
|
|
|
void runSolver (Solver*, const VarIds&);
|
2011-05-17 12:00:33 +01:00
|
|
|
|
|
|
|
const string USAGE = "usage: \
|
|
|
|
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
|
|
|
|
|
|
|
|
|
|
|
int
|
|
|
|
main (int argc, const char* argv[])
|
2011-12-12 15:29:51 +00:00
|
|
|
{
|
2011-05-17 12:00:33 +01:00
|
|
|
if (!argv[1]) {
|
|
|
|
cerr << "error: no graphical model specified" << endl;
|
|
|
|
cerr << USAGE << endl;
|
|
|
|
exit (0);
|
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
const string& fileName = argv[1];
|
2012-04-05 19:34:37 +01:00
|
|
|
const string& extension = fileName.substr (
|
|
|
|
fileName.find_last_of ('.') + 1);
|
|
|
|
FactorGraph fg;
|
|
|
|
if (extension == "uai") {
|
2011-12-12 15:29:51 +00:00
|
|
|
fg.readFromUaiFormat (argv[1]);
|
|
|
|
processArguments (fg, argc, argv);
|
|
|
|
} else if (extension == "fg") {
|
|
|
|
fg.readFromLibDaiFormat (argv[1]);
|
|
|
|
processArguments (fg, argc, argv);
|
2011-05-17 12:00:33 +01:00
|
|
|
} else {
|
|
|
|
cerr << "error: the graphical model must be defined either " ;
|
2012-04-05 19:34:37 +01:00
|
|
|
cerr << "in a UAI or libDAI file" << endl;
|
2011-05-17 12:00:33 +01:00
|
|
|
exit (0);
|
|
|
|
}
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
2011-12-12 15:29:51 +00:00
|
|
|
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2012-04-05 23:00:48 +01:00
|
|
|
VarIds queryIds;
|
2011-05-17 12:00:33 +01:00
|
|
|
for (int i = 2; i < argc; i++) {
|
2011-07-22 21:33:30 +01:00
|
|
|
const string& arg = argv[i];
|
2011-05-17 12:00:33 +01:00
|
|
|
if (arg.find ('=') == std::string::npos) {
|
|
|
|
if (!Util::isInteger (arg)) {
|
|
|
|
cerr << "error: `" << arg << "' " ;
|
|
|
|
cerr << "is not a valid variable id" ;
|
|
|
|
cerr << endl;
|
|
|
|
exit (0);
|
|
|
|
}
|
2011-12-12 15:29:51 +00:00
|
|
|
VarId vid;
|
2011-05-17 12:00:33 +01:00
|
|
|
stringstream ss;
|
|
|
|
ss << arg;
|
2011-07-22 21:33:30 +01:00
|
|
|
ss >> vid;
|
2012-04-05 23:00:48 +01:00
|
|
|
VarNode* queryVar = fg.getVarNode (vid);
|
2011-05-17 12:00:33 +01:00
|
|
|
if (queryVar) {
|
2012-04-05 23:00:48 +01:00
|
|
|
queryIds.push_back (vid);
|
2011-05-17 12:00:33 +01:00
|
|
|
} else {
|
|
|
|
cerr << "error: there isn't a variable with " ;
|
2011-07-22 21:33:30 +01:00
|
|
|
cerr << "`" << vid << "' as id" ;
|
2011-05-17 12:00:33 +01:00
|
|
|
cerr << endl;
|
|
|
|
exit (0);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
size_t pos = arg.find ('=');
|
|
|
|
if (arg.substr (0, pos).empty()) {
|
|
|
|
cerr << "error: missing left argument" << endl;
|
|
|
|
cerr << USAGE << endl;
|
|
|
|
exit (0);
|
|
|
|
}
|
|
|
|
if (arg.substr (pos + 1).empty()) {
|
|
|
|
cerr << "error: missing right argument" << endl;
|
|
|
|
cerr << USAGE << endl;
|
|
|
|
exit (0);
|
|
|
|
}
|
|
|
|
if (!Util::isInteger (arg.substr (0, pos))) {
|
|
|
|
cerr << "error: `" << arg.substr (0, pos) << "' " ;
|
|
|
|
cerr << "is not a variable id" ;
|
|
|
|
cerr << endl;
|
|
|
|
exit (0);
|
|
|
|
}
|
2011-12-12 15:29:51 +00:00
|
|
|
VarId vid;
|
2011-05-17 12:00:33 +01:00
|
|
|
stringstream ss;
|
|
|
|
ss << arg.substr (0, pos);
|
2011-07-22 21:33:30 +01:00
|
|
|
ss >> vid;
|
2012-04-05 23:00:48 +01:00
|
|
|
VarNode* var = fg.getVarNode (vid);
|
2011-05-17 12:00:33 +01:00
|
|
|
if (var) {
|
|
|
|
if (!Util::isInteger (arg.substr (pos + 1))) {
|
|
|
|
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
|
|
|
cerr << "is not a state index" ;
|
|
|
|
cerr << endl;
|
|
|
|
exit (0);
|
|
|
|
}
|
|
|
|
int stateIndex;
|
|
|
|
stringstream ss;
|
|
|
|
ss << arg.substr (pos + 1);
|
|
|
|
ss >> stateIndex;
|
2011-12-12 15:29:51 +00:00
|
|
|
if (var->isValidState (stateIndex)) {
|
2011-05-17 12:00:33 +01:00
|
|
|
var->setEvidence (stateIndex);
|
|
|
|
} else {
|
|
|
|
cerr << "error: `" << stateIndex << "' " ;
|
|
|
|
cerr << "is not a valid state index for variable " ;
|
2011-12-12 15:29:51 +00:00
|
|
|
cerr << "`" << var->varId() << "'" ;
|
2011-05-17 12:00:33 +01:00
|
|
|
cerr << endl;
|
|
|
|
exit (0);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
cerr << "error: there isn't a variable with " ;
|
2011-07-22 21:33:30 +01:00
|
|
|
cerr << "`" << vid << "' as id" ;
|
2011-05-17 12:00:33 +01:00
|
|
|
cerr << endl;
|
|
|
|
exit (0);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2011-12-12 15:29:51 +00:00
|
|
|
Solver* solver = 0;
|
2012-03-31 23:27:37 +01:00
|
|
|
switch (Globals::infAlgorithm) {
|
2011-12-12 15:29:51 +00:00
|
|
|
case InfAlgorithms::VE:
|
|
|
|
solver = new VarElimSolver (fg);
|
|
|
|
break;
|
2012-04-05 23:00:48 +01:00
|
|
|
case InfAlgorithms::BP:
|
|
|
|
solver = new BpSolver (fg);
|
2011-12-12 15:29:51 +00:00
|
|
|
break;
|
|
|
|
case InfAlgorithms::CBP:
|
|
|
|
solver = new CbpSolver (fg);
|
|
|
|
break;
|
|
|
|
default:
|
|
|
|
assert (false);
|
|
|
|
}
|
2012-04-05 23:00:48 +01:00
|
|
|
runSolver (solver, queryIds);
|
2011-07-22 21:33:30 +01:00
|
|
|
}
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
|
|
|
|
|
|
|
|
void
|
2012-04-05 23:00:48 +01:00
|
|
|
runSolver (Solver* solver, const VarIds& queryIds)
|
2011-07-22 21:33:30 +01:00
|
|
|
{
|
2012-04-05 23:00:48 +01:00
|
|
|
if (queryIds.size() == 0) {
|
2011-07-22 21:33:30 +01:00
|
|
|
solver->runSolver();
|
|
|
|
solver->printAllPosterioris();
|
2012-04-05 23:00:48 +01:00
|
|
|
} else if (queryIds.size() == 1) {
|
2011-07-22 21:33:30 +01:00
|
|
|
solver->runSolver();
|
2012-04-05 23:00:48 +01:00
|
|
|
solver->printPosterioriOf (queryIds[0]);
|
2011-05-17 12:00:33 +01:00
|
|
|
} else {
|
2011-12-12 15:29:51 +00:00
|
|
|
solver->runSolver();
|
2012-04-05 23:00:48 +01:00
|
|
|
solver->printJointDistributionOf (queryIds);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
delete solver;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|