#include <cassert>

#include <string>
#include <iostream>

#include "FactorGraph.h"
#include "VarElim.h"
#include "BeliefProp.h"
#include "CountingBp.h"


namespace {

int readHorusFlags (int, const char* []);

void readFactorGraph (Horus::FactorGraph&, const char*);

Horus::VarIds readQueryAndEvidence (
    Horus::FactorGraph&, int, const char* [], int);

void runSolver (const Horus::FactorGraph&, const Horus::VarIds&);

const std::string usage = "usage: ./hcli [solver=hve|bp|cbp] \
[<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ;

}



int
main (int argc, const char* argv[])
{
  if (argc <= 1) {
    std::cerr << "Error: no probabilistic graphical model was given." ;
    std::cerr << std::endl << usage << std::endl;
    exit (EXIT_FAILURE);
  }
  int idx = readHorusFlags (argc, argv);
  Horus::FactorGraph fg;
  readFactorGraph (fg, argv[idx]);
  Horus::VarIds queryIds
      = readQueryAndEvidence (fg, argc, argv, idx + 1);
  if (Horus::FactorGraph::exportToLibDai()) {
    fg.exportToLibDai ("model.fg");
  }
  if (Horus::FactorGraph::exportToUai()) {
    fg.exportToUai ("model.uai");
  }
  if (Horus::FactorGraph::exportGraphViz()) {
    fg.exportToGraphViz ("model.dot");
  }
  if (Horus::FactorGraph::printFactorGraph()) {
    fg.print();
  }
  if (Horus::Globals::verbosity > 0) {
    std::cout << "factor graph contains " ;
    std::cout << fg.nrVarNodes() << " variables and " ;
    std::cout << fg.nrFacNodes() << " factors " << std::endl;
  }
  runSolver (fg, queryIds);
  return 0;
}



namespace {

int
readHorusFlags (int argc, const char* argv[])
{
  int i = 1;
  for (; i < argc; i++) {
    const std::string& arg = argv[i];
    size_t pos = arg.find ('=');
    if (pos == std::string::npos) {
      return i;
    }
    std::string leftArg  = arg.substr (0, pos);
    std::string rightArg = arg.substr (pos + 1);
    if (leftArg.empty()) {
      std::cerr << "Error: missing left argument." << std::endl;
      std::cerr << usage << std::endl;
      exit (EXIT_FAILURE);
    }
    if (rightArg.empty()) {
      std::cerr << "Error: missing right argument." << std::endl;
      std::cerr << usage << std::endl;
      exit (EXIT_FAILURE);
    }
    Horus::Util::setHorusFlag (leftArg, rightArg);
  }
  return i + 1;
}



void
readFactorGraph (Horus::FactorGraph& fg, const char* s)
{
  std::string fileName (s);
  std::string extension = fileName.substr (fileName.find_last_of ('.') + 1);
  if (extension == "uai") {
    fg = Horus::FactorGraph::readFromUaiFormat (fileName.c_str());
  } else if (extension == "fg") {
    fg = Horus::FactorGraph::readFromLibDaiFormat (fileName.c_str());
  } else {
    std::cerr << "Error: the probabilistic graphical model must be " ;
    std::cerr << "defined either in a UAI or libDAI file." << std::endl;
    exit (EXIT_FAILURE);
  }
}



Horus::VarIds
readQueryAndEvidence (
    Horus::FactorGraph& fg,
    int argc,
    const char* argv[],
    int start)
{
  Horus::VarIds queryIds;
  for (int i = start; i < argc; i++) {
    const std::string& arg = argv[i];
    if (arg.find ('=') == std::string::npos) {
      if (Horus::Util::isInteger (arg) == false) {
        std::cerr << "Error: `" << arg << "' " ;
        std::cerr << "is not a variable id." ;
        std::cerr << std::endl;
        exit (EXIT_FAILURE);
      }
      Horus::VarId vid = Horus::Util::stringToUnsigned (arg);
      Horus::VarNode* queryVar = fg.getVarNode (vid);
      if (queryVar == nullptr) {
        std::cerr << "Error: unknow variable with id " ;
        std::cerr << "`" << vid << "'."  << std::endl;
        exit (EXIT_FAILURE);
      }
      queryIds.push_back (vid);
    } else {
      size_t pos = arg.find ('=');
      std::string leftArg  = arg.substr (0, pos);
      std::string rightArg = arg.substr (pos + 1);
      if (leftArg.empty()) {
        std::cerr << "Error: missing left argument." << std::endl;
        std::cerr << usage << std::endl;
        exit (EXIT_FAILURE);
      }
      if (Horus::Util::isInteger (leftArg) == false) {
        std::cerr << "Error: `" << leftArg << "' " ;
        std::cerr << "is not a variable id." << std::endl;
        exit (EXIT_FAILURE);
      }
      Horus::VarId vid = Horus::Util::stringToUnsigned (leftArg);
      Horus::VarNode* observedVar = fg.getVarNode (vid);
      if (observedVar == nullptr) {
        std::cerr << "Error: unknow variable with id " ;
        std::cerr << "`" << vid << "'."  << std::endl;
        exit (EXIT_FAILURE);
      }
      if (rightArg.empty()) {
        std::cerr << "Error: missing right argument." << std::endl;
        std::cerr << usage << std::endl;
        exit (EXIT_FAILURE);
      }
      if (Horus::Util::isInteger (rightArg) == false) {
        std::cerr << "Error: `" << rightArg << "' " ;
        std::cerr << "is not a state index." << std::endl;
        exit (EXIT_FAILURE);
      }
      unsigned stateIdx = Horus::Util::stringToUnsigned (rightArg);
      if (observedVar->isValidState (stateIdx) == false) {
        std::cerr << "Error: `" << stateIdx << "' " ;
        std::cerr << "is not a valid state index for variable with id " ;
        std::cerr << "`" << vid << "'."  << std::endl;
        exit (EXIT_FAILURE);
      }
      observedVar->setEvidence (stateIdx);
    }
  }
  return queryIds;
}



void
runSolver (
    const Horus::FactorGraph& fg,
    const Horus::VarIds& queryIds)
{
  Horus::GroundSolver* solver = 0;
  switch (Horus::Globals::groundSolver) {
    case Horus::GroundSolverType::veSolver:
      solver = new Horus::VarElim (fg);
      break;
    case Horus::GroundSolverType::bpSolver:
      solver = new Horus::BeliefProp (fg);
      break;
    case Horus::GroundSolverType::CbpSolver:
      solver = new Horus::CountingBp (fg);
      break;
    default:
      assert (false);
  }
  if (Horus::Globals::verbosity > 0) {
    solver->printSolverFlags();
    std::cout << std::endl;
  }
  if (queryIds.empty()) {
    solver->printAllPosterioris();
  } else {
    solver->printAnswer (queryIds);
  }
  delete solver;
}

}