| 
									
										
										
										
											2013-02-07 20:09:10 +00:00
										 |  |  | #include <cassert>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <string>
 | 
					
						
							|  |  |  | #include <iostream>
 | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | #include "FactorGraph.h"
 | 
					
						
							| 
									
										
										
										
											2012-06-12 16:29:57 +01:00
										 |  |  | #include "VarElim.h"
 | 
					
						
							|  |  |  | #include "BeliefProp.h"
 | 
					
						
							|  |  |  | #include "CountingBp.h"
 | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-20 11:52:42 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-20 00:05:45 +00:00
										 |  |  | namespace { | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | int readHorusFlags (int, const char* []); | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  | void readFactorGraph (Horus::FactorGraph&, const char*); | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  | Horus::VarIds readQueryAndEvidence ( | 
					
						
							|  |  |  |     Horus::FactorGraph&, int, const char* [], int); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  | void runSolver (const Horus::FactorGraph&, const Horus::VarIds&); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  | const std::string usage = "usage: ./hcli [solver=hve|bp|cbp] \
 | 
					
						
							| 
									
										
										
										
											2013-01-10 22:59:12 +00:00
										 |  |  | [<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-20 00:05:45 +00:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | int | 
					
						
							|  |  |  | main (int argc, const char* argv[]) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   if (argc <= 1) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |     std::cerr << "Error: no probabilistic graphical model was given." ; | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  |     std::cerr << std::endl << usage << std::endl; | 
					
						
							| 
									
										
										
										
											2012-12-20 17:37:59 +00:00
										 |  |  |     exit (EXIT_FAILURE); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							|  |  |  |   int idx = readHorusFlags (argc, argv); | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |   Horus::FactorGraph fg; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   readFactorGraph (fg, argv[idx]); | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |   Horus::VarIds queryIds | 
					
						
							| 
									
										
										
										
											2013-02-07 23:53:13 +00:00
										 |  |  |       = readQueryAndEvidence (fg, argc, argv, idx + 1); | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |   if (Horus::FactorGraph::exportToLibDai()) { | 
					
						
							| 
									
										
										
										
											2013-01-08 17:01:03 +00:00
										 |  |  |     fg.exportToLibDai ("model.fg"); | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |   if (Horus::FactorGraph::exportToUai()) { | 
					
						
							| 
									
										
										
										
											2013-01-08 17:01:03 +00:00
										 |  |  |     fg.exportToUai ("model.uai"); | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |   if (Horus::FactorGraph::exportGraphViz()) { | 
					
						
							| 
									
										
										
										
											2013-01-08 17:01:03 +00:00
										 |  |  |     fg.exportToGraphViz ("model.dot"); | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |   if (Horus::FactorGraph::printFactorGraph()) { | 
					
						
							| 
									
										
										
										
											2013-01-08 17:01:03 +00:00
										 |  |  |     fg.print(); | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |   if (Horus::Globals::verbosity > 0) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |     std::cout << "factor graph contains " ; | 
					
						
							|  |  |  |     std::cout << fg.nrVarNodes() << " variables and " ; | 
					
						
							|  |  |  |     std::cout << fg.nrFacNodes() << " factors " << std::endl; | 
					
						
							| 
									
										
										
										
											2013-01-08 17:01:03 +00:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   runSolver (fg, queryIds); | 
					
						
							|  |  |  |   return 0; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-20 00:05:45 +00:00
										 |  |  | namespace { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | int | 
					
						
							|  |  |  | readHorusFlags (int argc, const char* argv[]) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   int i = 1; | 
					
						
							|  |  |  |   for (; i < argc; i++) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |     const std::string& arg = argv[i]; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     size_t pos = arg.find ('='); | 
					
						
							|  |  |  |     if (pos == std::string::npos) { | 
					
						
							|  |  |  |       return i; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |     std::string leftArg  = arg.substr (0, pos); | 
					
						
							|  |  |  |     std::string rightArg = arg.substr (pos + 1); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     if (leftArg.empty()) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |       std::cerr << "Error: missing left argument." << std::endl; | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  |       std::cerr << usage << std::endl; | 
					
						
							| 
									
										
										
										
											2012-12-20 17:37:59 +00:00
										 |  |  |       exit (EXIT_FAILURE); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     } | 
					
						
							|  |  |  |     if (rightArg.empty()) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |       std::cerr << "Error: missing right argument." << std::endl; | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  |       std::cerr << usage << std::endl; | 
					
						
							| 
									
										
										
										
											2012-12-20 17:37:59 +00:00
										 |  |  |       exit (EXIT_FAILURE); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |     Horus::Util::setHorusFlag (leftArg, rightArg); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							|  |  |  |   return i + 1; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  | readFactorGraph (Horus::FactorGraph& fg, const char* s) | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |   std::string fileName (s); | 
					
						
							|  |  |  |   std::string extension = fileName.substr (fileName.find_last_of ('.') + 1); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   if (extension == "uai") { | 
					
						
							| 
									
										
										
										
											2013-03-14 16:57:34 +00:00
										 |  |  |     fg = Horus::FactorGraph::readFromUaiFormat (fileName.c_str()); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } else if (extension == "fg") { | 
					
						
							| 
									
										
										
										
											2013-03-14 16:57:34 +00:00
										 |  |  |     fg = Horus::FactorGraph::readFromLibDaiFormat (fileName.c_str()); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } else { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |     std::cerr << "Error: the probabilistic graphical model must be " ; | 
					
						
							|  |  |  |     std::cerr << "defined either in a UAI or libDAI file." << std::endl; | 
					
						
							| 
									
										
										
										
											2012-12-20 17:37:59 +00:00
										 |  |  |     exit (EXIT_FAILURE); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  | Horus::VarIds | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | readQueryAndEvidence ( | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |     Horus::FactorGraph& fg, | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     int argc, | 
					
						
							|  |  |  |     const char* argv[], | 
					
						
							|  |  |  |     int start) | 
					
						
							|  |  |  | { | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |   Horus::VarIds queryIds; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   for (int i = start; i < argc; i++) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |     const std::string& arg = argv[i]; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     if (arg.find ('=') == std::string::npos) { | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |       if (Horus::Util::isInteger (arg) == false) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |         std::cerr << "Error: `" << arg << "' " ; | 
					
						
							|  |  |  |         std::cerr << "is not a variable id." ; | 
					
						
							|  |  |  |         std::cerr << std::endl; | 
					
						
							| 
									
										
										
										
											2012-12-20 17:37:59 +00:00
										 |  |  |         exit (EXIT_FAILURE); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       } | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |       Horus::VarId vid = Horus::Util::stringToUnsigned (arg); | 
					
						
							|  |  |  |       Horus::VarNode* queryVar = fg.getVarNode (vid); | 
					
						
							| 
									
										
										
										
											2014-03-14 23:03:22 +00:00
										 |  |  |       if (queryVar == nullptr) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |         std::cerr << "Error: unknow variable with id " ; | 
					
						
							|  |  |  |         std::cerr << "`" << vid << "'."  << std::endl; | 
					
						
							| 
									
										
										
										
											2012-12-20 17:37:59 +00:00
										 |  |  |         exit (EXIT_FAILURE); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       } | 
					
						
							|  |  |  |       queryIds.push_back (vid); | 
					
						
							|  |  |  |     } else { | 
					
						
							|  |  |  |       size_t pos = arg.find ('='); | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |       std::string leftArg  = arg.substr (0, pos); | 
					
						
							|  |  |  |       std::string rightArg = arg.substr (pos + 1); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       if (leftArg.empty()) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |         std::cerr << "Error: missing left argument." << std::endl; | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  |         std::cerr << usage << std::endl; | 
					
						
							| 
									
										
										
										
											2012-12-20 17:37:59 +00:00
										 |  |  |         exit (EXIT_FAILURE); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       } | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |       if (Horus::Util::isInteger (leftArg) == false) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |         std::cerr << "Error: `" << leftArg << "' " ; | 
					
						
							|  |  |  |         std::cerr << "is not a variable id." << std::endl; | 
					
						
							| 
									
										
										
										
											2012-12-20 17:37:59 +00:00
										 |  |  |         exit (EXIT_FAILURE); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       } | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |       Horus::VarId vid = Horus::Util::stringToUnsigned (leftArg); | 
					
						
							|  |  |  |       Horus::VarNode* observedVar = fg.getVarNode (vid); | 
					
						
							| 
									
										
										
										
											2014-03-14 23:03:22 +00:00
										 |  |  |       if (observedVar == nullptr) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |         std::cerr << "Error: unknow variable with id " ; | 
					
						
							|  |  |  |         std::cerr << "`" << vid << "'."  << std::endl; | 
					
						
							| 
									
										
										
										
											2012-12-20 17:37:59 +00:00
										 |  |  |         exit (EXIT_FAILURE); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       } | 
					
						
							|  |  |  |       if (rightArg.empty()) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |         std::cerr << "Error: missing right argument." << std::endl; | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  |         std::cerr << usage << std::endl; | 
					
						
							| 
									
										
										
										
											2012-12-20 17:37:59 +00:00
										 |  |  |         exit (EXIT_FAILURE); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       } | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |       if (Horus::Util::isInteger (rightArg) == false) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |         std::cerr << "Error: `" << rightArg << "' " ; | 
					
						
							|  |  |  |         std::cerr << "is not a state index." << std::endl; | 
					
						
							| 
									
										
										
										
											2012-12-20 17:37:59 +00:00
										 |  |  |         exit (EXIT_FAILURE); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       } | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |       unsigned stateIdx = Horus::Util::stringToUnsigned (rightArg); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       if (observedVar->isValidState (stateIdx) == false) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |         std::cerr << "Error: `" << stateIdx << "' " ; | 
					
						
							|  |  |  |         std::cerr << "is not a valid state index for variable with id " ; | 
					
						
							|  |  |  |         std::cerr << "`" << vid << "'."  << std::endl; | 
					
						
							| 
									
										
										
										
											2012-12-20 17:37:59 +00:00
										 |  |  |         exit (EXIT_FAILURE); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       } | 
					
						
							|  |  |  |       observedVar->setEvidence (stateIdx); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   return queryIds; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void | 
					
						
							| 
									
										
										
										
											2013-02-07 23:53:13 +00:00
										 |  |  | runSolver ( | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |     const Horus::FactorGraph& fg, | 
					
						
							|  |  |  |     const Horus::VarIds& queryIds) | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | { | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |   Horus::GroundSolver* solver = 0; | 
					
						
							|  |  |  |   switch (Horus::Globals::groundSolver) { | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  |     case Horus::GroundSolverType::veSolver: | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |       solver = new Horus::VarElim (fg); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       break; | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  |     case Horus::GroundSolverType::bpSolver: | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |       solver = new Horus::BeliefProp (fg); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       break; | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  |     case Horus::GroundSolverType::CbpSolver: | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |       solver = new Horus::CountingBp (fg); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       break; | 
					
						
							|  |  |  |     default: | 
					
						
							|  |  |  |       assert (false); | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  |   if (Horus::Globals::verbosity > 0) { | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     solver->printSolverFlags(); | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |     std::cout << std::endl; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2012-05-28 18:26:15 +01:00
										 |  |  |   if (queryIds.empty()) { | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     solver->printAllPosterioris(); | 
					
						
							|  |  |  |   } else { | 
					
						
							|  |  |  |     solver->printAnswer (queryIds); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   delete solver; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-20 00:05:45 +00:00
										 |  |  | } | 
					
						
							|  |  |  | 
 |