| 
									
										
										
										
											2013-02-07 20:09:10 +00:00
										 |  |  | #include <cassert>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <vector>
 | 
					
						
							|  |  |  | #include <string>
 | 
					
						
							|  |  |  | #include <iostream>
 | 
					
						
							| 
									
										
										
										
											2013-02-16 16:42:08 +00:00
										 |  |  | #include <iomanip>
 | 
					
						
							| 
									
										
										
										
											2013-02-07 20:09:10 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-11-14 23:17:39 +00:00
										 |  |  | #include "GroundSolver.h"
 | 
					
						
							| 
									
										
										
										
											2012-12-27 22:28:19 +00:00
										 |  |  | #include "VarElim.h"
 | 
					
						
							| 
									
										
										
										
											2012-11-14 23:17:39 +00:00
										 |  |  | #include "BeliefProp.h"
 | 
					
						
							|  |  |  | #include "CountingBp.h"
 | 
					
						
							| 
									
										
										
										
											2013-03-04 17:58:32 +00:00
										 |  |  | #include "Indexer.h"
 | 
					
						
							| 
									
										
										
										
											2012-12-27 12:54:58 +00:00
										 |  |  | #include "Util.h"
 | 
					
						
							| 
									
										
										
										
											2012-11-14 23:17:39 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  | namespace Horus { | 
					
						
							| 
									
										
										
										
											2013-02-07 23:53:13 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-11-14 23:17:39 +00:00
										 |  |  | void | 
					
						
							|  |  |  | GroundSolver::printAnswer (const VarIds& vids) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   Vars   unobservedVars; | 
					
						
							|  |  |  |   VarIds unobservedVids; | 
					
						
							|  |  |  |   for (size_t i = 0; i < vids.size(); i++) { | 
					
						
							|  |  |  |     VarNode* vn = fg.getVarNode (vids[i]); | 
					
						
							|  |  |  |     if (vn->hasEvidence() == false) { | 
					
						
							|  |  |  |       unobservedVars.push_back (vn); | 
					
						
							|  |  |  |       unobservedVids.push_back (vids[i]); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   if (unobservedVids.empty() == false) { | 
					
						
							|  |  |  |     Params res = solveQuery (unobservedVids); | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |     std::vector<std::string> stateLines = | 
					
						
							| 
									
										
										
										
											2013-02-08 21:01:53 +00:00
										 |  |  |         Util::getStateLines (unobservedVars); | 
					
						
							| 
									
										
										
										
											2012-11-14 23:17:39 +00:00
										 |  |  |     for (size_t i = 0; i < res.size(); i++) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |       std::cout << "P(" << stateLines[i] << ") = " ; | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  |       std::cout << std::setprecision (Constants::precision) << res[i]; | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |       std::cout << std::endl; | 
					
						
							| 
									
										
										
										
											2012-11-14 23:17:39 +00:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |     std::cout << std::endl; | 
					
						
							| 
									
										
										
										
											2012-11-14 23:17:39 +00:00
										 |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  | GroundSolver::printAllPosterioris() | 
					
						
							| 
									
										
										
										
											2012-11-14 23:17:39 +00:00
										 |  |  | { | 
					
						
							|  |  |  |   VarNodes vars = fg.varNodes(); | 
					
						
							|  |  |  |   std::sort (vars.begin(), vars.end(), sortByVarId()); | 
					
						
							|  |  |  |   for (size_t i = 0; i < vars.size(); i++) { | 
					
						
							|  |  |  |     printAnswer ({vars[i]->varId()}); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Params | 
					
						
							|  |  |  | GroundSolver::getJointByConditioning ( | 
					
						
							|  |  |  |     GroundSolverType solverType, | 
					
						
							|  |  |  |     FactorGraph fg, | 
					
						
							| 
									
										
										
										
											2012-12-27 12:54:58 +00:00
										 |  |  |     const VarIds& jointVarIds) | 
					
						
							| 
									
										
										
										
											2012-11-14 23:17:39 +00:00
										 |  |  | { | 
					
						
							|  |  |  |   VarNodes jointVars; | 
					
						
							|  |  |  |   for (size_t i = 0; i < jointVarIds.size(); i++) { | 
					
						
							|  |  |  |     assert (fg.getVarNode (jointVarIds[i])); | 
					
						
							|  |  |  |     jointVars.push_back (fg.getVarNode (jointVarIds[i])); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GroundSolver* solver = 0; | 
					
						
							|  |  |  |   switch (solverType) { | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  |     case GroundSolverType::bpSolver:  solver = new BeliefProp (fg); break; | 
					
						
							|  |  |  |     case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break; | 
					
						
							|  |  |  |     case GroundSolverType::veSolver:  solver = new VarElim (fg);    break; | 
					
						
							| 
									
										
										
										
											2012-11-14 23:17:39 +00:00
										 |  |  |   } | 
					
						
							|  |  |  |   Params prevBeliefs = solver->solveQuery ({jointVarIds[0]}); | 
					
						
							|  |  |  |   VarIds observedVids = {jointVars[0]->varId()}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   for (size_t i = 1; i < jointVarIds.size(); i++) { | 
					
						
							|  |  |  |     assert (jointVars[i]->hasEvidence() == false); | 
					
						
							|  |  |  |     Params newBeliefs; | 
					
						
							|  |  |  |     Vars observedVars; | 
					
						
							|  |  |  |     Ranges observedRanges; | 
					
						
							|  |  |  |     for (size_t j = 0; j < observedVids.size(); j++) { | 
					
						
							|  |  |  |       observedVars.push_back (fg.getVarNode (observedVids[j])); | 
					
						
							|  |  |  |       observedRanges.push_back (observedVars.back()->range()); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     Indexer indexer (observedRanges, false); | 
					
						
							|  |  |  |     while (indexer.valid()) { | 
					
						
							|  |  |  |       for (size_t j = 0; j < observedVars.size(); j++) { | 
					
						
							|  |  |  |         observedVars[j]->setEvidence (indexer[j]); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       delete solver; | 
					
						
							|  |  |  |       switch (solverType) { | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  |         case GroundSolverType::bpSolver:  solver = new BeliefProp (fg); break; | 
					
						
							|  |  |  |         case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break; | 
					
						
							|  |  |  |         case GroundSolverType::veSolver:  solver = new VarElim (fg);    break; | 
					
						
							| 
									
										
										
										
											2012-11-14 23:17:39 +00:00
										 |  |  |       } | 
					
						
							|  |  |  |       Params beliefs = solver->solveQuery ({jointVarIds[i]}); | 
					
						
							|  |  |  |       for (size_t k = 0; k < beliefs.size(); k++) { | 
					
						
							|  |  |  |         newBeliefs.push_back (beliefs[k]); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       ++ indexer; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int count = -1; | 
					
						
							|  |  |  |     for (size_t j = 0; j < newBeliefs.size(); j++) { | 
					
						
							|  |  |  |       if (j % jointVars[i]->range() == 0) { | 
					
						
							|  |  |  |         count ++; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       newBeliefs[j] *= prevBeliefs[count]; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     prevBeliefs = newBeliefs; | 
					
						
							|  |  |  |     observedVids.push_back (jointVars[i]->varId()); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   delete solver; | 
					
						
							|  |  |  |   return prevBeliefs; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  | }  // namespace Horus
 | 
					
						
							| 
									
										
										
										
											2013-02-07 23:53:13 +00:00
										 |  |  | 
 |