| 
									
										
										
										
											2014-05-28 00:09:36 +01:00
										 |  |  | #if __ANDROID__
 | 
					
						
							|  |  |  | #define assert(P)
 | 
					
						
							|  |  |  | #else
 | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | #include <cassert>
 | 
					
						
							| 
									
										
										
										
											2014-05-28 00:09:36 +01:00
										 |  |  | #endif
 | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | #include "BayesBall.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  | namespace Horus { | 
					
						
							| 
									
										
										
										
											2013-02-07 23:53:13 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-06 00:24:02 +00:00
										 |  |  | BayesBall::BayesBall (FactorGraph& fg) | 
					
						
							|  |  |  |     : fg_(fg) , dag_(fg.getStructure()) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   dag_.clear(); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | FactorGraph* | 
					
						
							|  |  |  | BayesBall::getMinimalFactorGraph (FactorGraph& fg, VarIds vids) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   BayesBall bb (fg); | 
					
						
							|  |  |  |   return bb.getMinimalFactorGraph (vids); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | FactorGraph* | 
					
						
							|  |  |  | BayesBall::getMinimalFactorGraph (const VarIds& queryIds) | 
					
						
							|  |  |  | { | 
					
						
							| 
									
										
										
										
											2012-05-31 22:42:38 +01:00
										 |  |  |   assert (fg_.bayesianFactors()); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   Scheduling scheduling; | 
					
						
							| 
									
										
										
										
											2012-05-24 22:55:20 +01:00
										 |  |  |   for (size_t i = 0; i < queryIds.size(); i++) { | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     assert (dag_.getNode (queryIds[i])); | 
					
						
							| 
									
										
										
										
											2012-06-19 14:32:12 +01:00
										 |  |  |     BBNode* n = dag_.getNode (queryIds[i]); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     scheduling.push (ScheduleInfo (n, false, true)); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   while (!scheduling.empty()) { | 
					
						
							|  |  |  |     ScheduleInfo& sch = scheduling.front(); | 
					
						
							| 
									
										
										
										
											2012-06-19 14:32:12 +01:00
										 |  |  |     BBNode* n = sch.node; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     n->setAsVisited(); | 
					
						
							|  |  |  |     if (n->hasEvidence() == false && sch.visitedFromChild) { | 
					
						
							| 
									
										
										
										
											2013-02-16 16:17:14 +00:00
										 |  |  |       if (n->isMarkedAbove() == false) { | 
					
						
							|  |  |  |         n->markAbove(); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |         scheduleParents (n, scheduling); | 
					
						
							|  |  |  |       } | 
					
						
							| 
									
										
										
										
											2013-02-16 16:17:14 +00:00
										 |  |  |       if (n->isMarkedBelow() == false) { | 
					
						
							|  |  |  |         n->markBelow(); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |         scheduleChilds (n, scheduling); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (sch.visitedFromParent) { | 
					
						
							| 
									
										
										
										
											2013-02-16 16:17:14 +00:00
										 |  |  |       if (n->hasEvidence() && n->isMarkedAbove() == false) { | 
					
						
							|  |  |  |         n->markAbove(); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |         scheduleParents (n, scheduling); | 
					
						
							|  |  |  |       } | 
					
						
							| 
									
										
										
										
											2013-02-16 16:17:14 +00:00
										 |  |  |       if (n->hasEvidence() == false && n->isMarkedBelow() == false) { | 
					
						
							|  |  |  |         n->markBelow(); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |         scheduleChilds (n, scheduling); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     scheduling.pop(); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   FactorGraph* fg = new FactorGraph(); | 
					
						
							|  |  |  |   constructGraph (fg); | 
					
						
							|  |  |  |   return fg; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void | 
					
						
							|  |  |  | BayesBall::constructGraph (FactorGraph* fg) const | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   const FacNodes& facNodes = fg_.facNodes(); | 
					
						
							| 
									
										
										
										
											2012-05-24 22:55:20 +01:00
										 |  |  |   for (size_t i = 0; i < facNodes.size(); i++) { | 
					
						
							| 
									
										
										
										
											2012-06-19 14:32:12 +01:00
										 |  |  |     const BBNode* n = dag_.getNode ( | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |         facNodes[i]->factor().argument (0)); | 
					
						
							| 
									
										
										
										
											2013-02-16 16:17:14 +00:00
										 |  |  |     if (n->isMarkedAbove()) { | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       fg->addFactor (facNodes[i]->factor()); | 
					
						
							|  |  |  |     } else if (n->hasEvidence() && n->isVisited()) { | 
					
						
							|  |  |  |       VarIds varIds = { facNodes[i]->factor().argument (0) }; | 
					
						
							|  |  |  |       Ranges ranges = { facNodes[i]->factor().range (0) }; | 
					
						
							| 
									
										
										
										
											2013-02-08 21:01:53 +00:00
										 |  |  |       Params params (ranges[0], LogAware::noEvidence()); | 
					
						
							|  |  |  |       params[n->getEvidence()] = LogAware::withEvidence(); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       fg->addFactor (Factor (varIds, ranges, params)); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   const VarNodes& varNodes = fg_.varNodes(); | 
					
						
							| 
									
										
										
										
											2012-05-24 22:55:20 +01:00
										 |  |  |   for (size_t i = 0; i < varNodes.size(); i++) { | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     if (varNodes[i]->hasEvidence()) { | 
					
						
							|  |  |  |       VarNode* vn = fg->getVarNode (varNodes[i]->varId()); | 
					
						
							|  |  |  |       if (vn) { | 
					
						
							|  |  |  |         vn->setEvidence (varNodes[i]->getEvidence()); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  | }  // namespace Horus
 | 
					
						
							| 
									
										
										
										
											2013-02-07 23:53:13 +00:00
										 |  |  | 
 |