| 
									
										
										
										
											2013-02-07 20:09:10 +00:00
										 |  |  | #include <cassert>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <iostream>
 | 
					
						
							|  |  |  | #include <sstream>
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-06-12 16:29:57 +01:00
										 |  |  | #include "CountingBp.h"
 | 
					
						
							|  |  |  | #include "WeightedBp.h"
 | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  | namespace Horus { | 
					
						
							| 
									
										
										
										
											2013-02-07 23:53:13 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-20 12:42:05 +00:00
										 |  |  | class VarCluster { | 
					
						
							|  |  |  |   public: | 
					
						
							|  |  |  |     VarCluster (const VarNodes& vs) : members_(vs) { } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  |     const VarNode* first() const { return members_.front(); } | 
					
						
							| 
									
										
										
										
											2013-02-20 12:42:05 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  |     const VarNodes& members() const { return members_; } | 
					
						
							| 
									
										
										
										
											2013-02-20 12:42:05 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  |     VarNode* representative() const { return repr_; } | 
					
						
							| 
									
										
										
										
											2013-02-20 12:42:05 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     void setRepresentative (VarNode* vn) { repr_ = vn; } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   private: | 
					
						
							|  |  |  |     VarNodes  members_; | 
					
						
							|  |  |  |     VarNode*  repr_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     DISALLOW_COPY_AND_ASSIGN (VarCluster); | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class FacCluster { | 
					
						
							|  |  |  |   private: | 
					
						
							|  |  |  |     typedef std::vector<VarCluster*> VarClusters; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   public: | 
					
						
							|  |  |  |     FacCluster (const FacNodes& fcs, const VarClusters& vcs) | 
					
						
							|  |  |  |         : members_(fcs), varClusters_(vcs) { } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  |     const FacNode* first() const { return members_.front(); } | 
					
						
							| 
									
										
										
										
											2013-02-20 12:42:05 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  |     const FacNodes& members() const { return members_; } | 
					
						
							| 
									
										
										
										
											2013-02-20 12:42:05 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  |     FacNode* representative() const { return repr_; } | 
					
						
							| 
									
										
										
										
											2013-02-20 12:42:05 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     void setRepresentative (FacNode* fn) { repr_ = fn; } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  |     VarClusters& varClusters() { return varClusters_; } | 
					
						
							| 
									
										
										
										
											2013-02-20 12:42:05 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     FacNodes     members_; | 
					
						
							|  |  |  |     FacNode*     repr_; | 
					
						
							|  |  |  |     VarClusters  varClusters_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     DISALLOW_COPY_AND_ASSIGN (FacCluster); | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-12-27 23:21:32 +00:00
										 |  |  | bool CountingBp::fif_ = true; | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-06-12 16:29:57 +01:00
										 |  |  | CountingBp::CountingBp (const FactorGraph& fg) | 
					
						
							| 
									
										
										
										
											2012-11-14 21:55:51 +00:00
										 |  |  |     : GroundSolver (fg), freeColor_(0) | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | { | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   findIdenticalFactors(); | 
					
						
							|  |  |  |   setInitialColors(); | 
					
						
							|  |  |  |   createGroups(); | 
					
						
							|  |  |  |   compressedFg_ = getCompressedFactorGraph(); | 
					
						
							| 
									
										
										
										
											2012-06-12 16:29:57 +01:00
										 |  |  |   solver_ = new WeightedBp (*compressedFg_, getWeights()); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  | CountingBp::~CountingBp() | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | { | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   delete solver_; | 
					
						
							|  |  |  |   delete compressedFg_; | 
					
						
							|  |  |  |   for (size_t i = 0; i  < varClusters_.size(); i++) { | 
					
						
							|  |  |  |     delete varClusters_[i]; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   for (size_t i = 0; i  < facClusters_.size(); i++) { | 
					
						
							|  |  |  |     delete facClusters_[i]; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  | CountingBp::printSolverFlags() const | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |   std::stringstream ss; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   ss << "counting bp [" ; | 
					
						
							| 
									
										
										
										
											2013-01-08 17:06:40 +00:00
										 |  |  |   ss << "bp_msg_schedule=" ; | 
					
						
							| 
									
										
										
										
											2013-02-20 23:34:03 +00:00
										 |  |  |   typedef WeightedBp::MsgSchedule MsgSchedule; | 
					
						
							| 
									
										
										
										
											2012-12-27 15:44:40 +00:00
										 |  |  |   switch (WeightedBp::msgSchedule()) { | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  |     case MsgSchedule::seqFixedSch:    ss << "seq_fixed";    break; | 
					
						
							|  |  |  |     case MsgSchedule::seqRandomSch:   ss << "seq_random";   break; | 
					
						
							|  |  |  |     case MsgSchedule::parallelSch:    ss << "parallel";     break; | 
					
						
							|  |  |  |     case MsgSchedule::maxResidualSch: ss << "max_residual"; break; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2013-01-08 17:06:40 +00:00
										 |  |  |   ss << ",bp_max_iter=" << WeightedBp::maxIterations(); | 
					
						
							|  |  |  |   ss << ",bp_accuracy=" << WeightedBp::accuracy(); | 
					
						
							| 
									
										
										
										
											2013-02-08 21:01:53 +00:00
										 |  |  |   ss << ",log_domain=" << Util::toString (Globals::logDomain); | 
					
						
							|  |  |  |   ss << ",fif=" << Util::toString (CountingBp::fif_); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   ss << "]" ; | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |   std::cout << ss.str() << std::endl; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  | Params | 
					
						
							| 
									
										
										
										
											2012-06-12 16:29:57 +01:00
										 |  |  | CountingBp::solveQuery (VarIds queryVids) | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  | { | 
					
						
							|  |  |  |   assert (queryVids.empty() == false); | 
					
						
							| 
									
										
										
										
											2012-05-31 23:06:53 +01:00
										 |  |  |   Params res; | 
					
						
							|  |  |  |   if (queryVids.size() == 1) { | 
					
						
							|  |  |  |     res = solver_->getPosterioriOf (getRepresentative (queryVids[0])); | 
					
						
							|  |  |  |   } else { | 
					
						
							| 
									
										
										
										
											2012-06-01 19:29:23 +01:00
										 |  |  |     VarNode* vn = fg.getVarNode (queryVids[0]); | 
					
						
							|  |  |  |     const FacNodes& facNodes = vn->neighbors(); | 
					
						
							|  |  |  |     size_t idx = facNodes.size(); | 
					
						
							|  |  |  |     for (size_t i = 0; i < facNodes.size(); i++) { | 
					
						
							|  |  |  |       if (facNodes[i]->factor().contains (queryVids)) { | 
					
						
							|  |  |  |         idx = i; | 
					
						
							|  |  |  |         break; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (idx == facNodes.size()) { | 
					
						
							| 
									
										
										
										
											2012-11-14 21:55:51 +00:00
										 |  |  |       res = GroundSolver::getJointByConditioning ( | 
					
						
							| 
									
										
										
										
											2013-02-13 18:54:15 +00:00
										 |  |  |           GroundSolverType::CbpSolver, fg, queryVids); | 
					
						
							| 
									
										
										
										
											2012-06-13 12:47:41 +01:00
										 |  |  |     } else { | 
					
						
							|  |  |  |       VarIds reprArgs; | 
					
						
							|  |  |  |       for (size_t i = 0; i < queryVids.size(); i++) { | 
					
						
							|  |  |  |         reprArgs.push_back (getRepresentative (queryVids[i])); | 
					
						
							|  |  |  |       } | 
					
						
							| 
									
										
										
										
											2012-09-18 17:24:22 +01:00
										 |  |  |       FacNode* reprFac = getRepresentative (facNodes[idx]); | 
					
						
							| 
									
										
										
										
											2012-12-27 12:54:58 +00:00
										 |  |  |       assert (reprFac); | 
					
						
							| 
									
										
										
										
											2012-09-18 17:24:22 +01:00
										 |  |  |       res = solver_->getFactorJoint (reprFac, reprArgs); | 
					
						
							| 
									
										
										
										
											2012-05-31 23:06:53 +01:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2012-05-31 23:06:53 +01:00
										 |  |  |   return res; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void | 
					
						
							| 
									
										
										
										
											2012-06-12 16:29:57 +01:00
										 |  |  | CountingBp::findIdenticalFactors() | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | { | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   const FacNodes& facNodes = fg.facNodes(); | 
					
						
							| 
									
										
										
										
											2012-12-27 23:21:32 +00:00
										 |  |  |   if (fif_ == false || facNodes.size() == 1) { | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     return; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   for (size_t i = 0; i < facNodes.size(); i++) { | 
					
						
							| 
									
										
										
										
											2013-02-08 21:01:53 +00:00
										 |  |  |     facNodes[i]->factor().setDistId (Util::maxUnsigned()); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   unsigned groupCount = 1; | 
					
						
							|  |  |  |   for (size_t i = 0; i < facNodes.size() - 1; i++) { | 
					
						
							|  |  |  |     Factor& f1 = facNodes[i]->factor(); | 
					
						
							| 
									
										
										
										
											2013-02-08 21:01:53 +00:00
										 |  |  |     if (f1.distId() != Util::maxUnsigned()) { | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |       continue; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     f1.setDistId (groupCount); | 
					
						
							|  |  |  |     for (size_t j = i + 1; j < facNodes.size(); j++) { | 
					
						
							|  |  |  |       Factor& f2 = facNodes[j]->factor(); | 
					
						
							| 
									
										
										
										
											2013-02-08 21:01:53 +00:00
										 |  |  |       if (f2.distId() != Util::maxUnsigned()) { | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |         continue; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       if (f1.size()   == f2.size()   && | 
					
						
							|  |  |  |           f1.ranges() == f2.ranges() && | 
					
						
							|  |  |  |           f1.params() == f2.params()) { | 
					
						
							|  |  |  |         f2.setDistId (groupCount); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     groupCount ++; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  | CountingBp::setInitialColors() | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | { | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   varColors_.resize (fg.nrVarNodes()); | 
					
						
							|  |  |  |   facColors_.resize (fg.nrFacNodes()); | 
					
						
							|  |  |  |   // create the initial variable colors
 | 
					
						
							|  |  |  |   VarColorMap colorMap; | 
					
						
							|  |  |  |   const VarNodes& varNodes = fg.varNodes(); | 
					
						
							|  |  |  |   for (size_t i = 0; i < varNodes.size(); i++) { | 
					
						
							|  |  |  |     unsigned range = varNodes[i]->range(); | 
					
						
							|  |  |  |     VarColorMap::iterator it = colorMap.find (range); | 
					
						
							|  |  |  |     if (it == colorMap.end()) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |       it = colorMap.insert (std::make_pair ( | 
					
						
							| 
									
										
										
										
											2012-12-20 23:19:10 +00:00
										 |  |  |           range, Colors (range + 1, -1))).first; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     unsigned idx = varNodes[i]->hasEvidence() | 
					
						
							|  |  |  |                  ? varNodes[i]->getEvidence() | 
					
						
							|  |  |  |                  : range; | 
					
						
							|  |  |  |     Colors& stateColors = it->second; | 
					
						
							|  |  |  |     if (stateColors[idx] == -1) { | 
					
						
							|  |  |  |       stateColors[idx] = getNewColor(); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     setColor (varNodes[i], stateColors[idx]); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   const FacNodes& facNodes = fg.facNodes(); | 
					
						
							|  |  |  |   // create the initial factor colors
 | 
					
						
							|  |  |  |   DistColorMap distColors; | 
					
						
							|  |  |  |   for (size_t i = 0; i < facNodes.size(); i++) { | 
					
						
							|  |  |  |     unsigned distId = facNodes[i]->factor().distId(); | 
					
						
							|  |  |  |     DistColorMap::iterator it = distColors.find (distId); | 
					
						
							|  |  |  |     if (it == distColors.end()) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |       it = distColors.insert (std::make_pair ( | 
					
						
							|  |  |  |           distId, getNewColor())).first; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     setColor (facNodes[i], it->second); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | void | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  | CountingBp::createGroups() | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  | { | 
					
						
							|  |  |  |   VarSignMap varGroups; | 
					
						
							|  |  |  |   FacSignMap facGroups; | 
					
						
							|  |  |  |   unsigned nIters = 0; | 
					
						
							|  |  |  |   bool groupsHaveChanged = true; | 
					
						
							|  |  |  |   const VarNodes& varNodes = fg.varNodes(); | 
					
						
							|  |  |  |   const FacNodes& facNodes = fg.facNodes(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   while (groupsHaveChanged || nIters == 1) { | 
					
						
							|  |  |  |     nIters ++; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // set a new color to the variables with the same signature
 | 
					
						
							|  |  |  |     size_t prevVarGroupsSize = varGroups.size(); | 
					
						
							|  |  |  |     varGroups.clear(); | 
					
						
							|  |  |  |     for (size_t i = 0; i < varNodes.size(); i++) { | 
					
						
							| 
									
										
										
										
											2013-02-16 18:58:22 +00:00
										 |  |  |       VarSignature signature = getSignature (varNodes[i]); | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |       VarSignMap::iterator it = varGroups.find (signature); | 
					
						
							|  |  |  |       if (it == varGroups.end()) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |         it = varGroups.insert (std::make_pair ( | 
					
						
							|  |  |  |             signature, VarNodes())).first; | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |       } | 
					
						
							|  |  |  |       it->second.push_back (varNodes[i]); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     for (VarSignMap::iterator it = varGroups.begin(); | 
					
						
							|  |  |  |         it != varGroups.end(); ++it) { | 
					
						
							|  |  |  |       Color newColor = getNewColor(); | 
					
						
							|  |  |  |       VarNodes& groupMembers = it->second; | 
					
						
							|  |  |  |       for (size_t i = 0; i < groupMembers.size(); i++) { | 
					
						
							|  |  |  |         setColor (groupMembers[i], newColor); | 
					
						
							|  |  |  |       } | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     size_t prevFactorGroupsSize = facGroups.size(); | 
					
						
							|  |  |  |     facGroups.clear(); | 
					
						
							|  |  |  |     // set a new color to the factors with the same signature
 | 
					
						
							|  |  |  |     for (size_t i = 0; i < facNodes.size(); i++) { | 
					
						
							| 
									
										
										
										
											2013-02-16 18:58:22 +00:00
										 |  |  |       FacSignature signature = getSignature (facNodes[i]); | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |       FacSignMap::iterator it = facGroups.find (signature); | 
					
						
							|  |  |  |       if (it == facGroups.end()) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |         it = facGroups.insert (std::make_pair ( | 
					
						
							|  |  |  |             signature, FacNodes())).first; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |       it->second.push_back (facNodes[i]); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     for (FacSignMap::iterator it = facGroups.begin(); | 
					
						
							|  |  |  |         it != facGroups.end(); ++it) { | 
					
						
							|  |  |  |       Color newColor = getNewColor(); | 
					
						
							|  |  |  |       FacNodes& groupMembers = it->second; | 
					
						
							|  |  |  |       for (size_t i = 0; i < groupMembers.size(); i++) { | 
					
						
							|  |  |  |         setColor (groupMembers[i], newColor); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     groupsHaveChanged = prevVarGroupsSize != varGroups.size() | 
					
						
							|  |  |  |         || prevFactorGroupsSize != facGroups.size(); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   // printGroups (varGroups, facGroups);
 | 
					
						
							|  |  |  |   createClusters (varGroups, facGroups); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void | 
					
						
							| 
									
										
										
										
											2012-06-12 16:29:57 +01:00
										 |  |  | CountingBp::createClusters ( | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     const VarSignMap& varGroups, | 
					
						
							|  |  |  |     const FacSignMap& facGroups) | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | { | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   varClusters_.reserve (varGroups.size()); | 
					
						
							|  |  |  |   for (VarSignMap::const_iterator it = varGroups.begin(); | 
					
						
							|  |  |  |        it != varGroups.end(); ++it) { | 
					
						
							|  |  |  |     const VarNodes& groupVars = it->second; | 
					
						
							|  |  |  |     VarCluster* vc = new VarCluster (groupVars); | 
					
						
							|  |  |  |     for (size_t i = 0; i < groupVars.size(); i++) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |       varClusterMap_.insert (std::make_pair ( | 
					
						
							|  |  |  |           groupVars[i]->varId(), vc)); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     varClusters_.push_back (vc); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |   facClusters_.reserve (facGroups.size()); | 
					
						
							|  |  |  |   for (FacSignMap::const_iterator it = facGroups.begin(); | 
					
						
							|  |  |  |        it != facGroups.end(); ++it) { | 
					
						
							|  |  |  |     FacNode* groupFactor = it->second[0]; | 
					
						
							|  |  |  |     const VarNodes& neighs = groupFactor->neighbors(); | 
					
						
							|  |  |  |     VarClusters varClusters; | 
					
						
							|  |  |  |     varClusters.reserve (neighs.size()); | 
					
						
							|  |  |  |     for (size_t i = 0; i < neighs.size(); i++) { | 
					
						
							|  |  |  |       VarId vid = neighs[i]->varId(); | 
					
						
							| 
									
										
										
										
											2012-06-30 19:25:29 +01:00
										 |  |  |       varClusters.push_back (varClusterMap_.find (vid)->second); | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     } | 
					
						
							|  |  |  |     facClusters_.push_back (new FacCluster (it->second, varClusters)); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-20 12:42:05 +00:00
										 |  |  | CountingBp::VarSignature | 
					
						
							| 
									
										
										
										
											2012-06-12 16:29:57 +01:00
										 |  |  | CountingBp::getSignature (const VarNode* varNode) | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  | { | 
					
						
							|  |  |  |   VarSignature sign; | 
					
						
							| 
									
										
										
										
											2013-02-16 18:58:22 +00:00
										 |  |  |   const FacNodes& neighs = varNode->neighbors(); | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   sign.reserve (neighs.size() + 1); | 
					
						
							|  |  |  |   for (size_t i = 0; i < neighs.size(); i++) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |     sign.push_back (std::make_pair ( | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |         getColor (neighs[i]), | 
					
						
							|  |  |  |         neighs[i]->factor().indexOf (varNode->varId()))); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   std::sort (sign.begin(), sign.end()); | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |   sign.push_back (std::make_pair (getColor (varNode), 0)); | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   return sign; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-20 12:42:05 +00:00
										 |  |  | CountingBp::FacSignature | 
					
						
							| 
									
										
										
										
											2012-06-12 16:29:57 +01:00
										 |  |  | CountingBp::getSignature (const FacNode* facNode) | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  | { | 
					
						
							|  |  |  |   FacSignature sign; | 
					
						
							| 
									
										
										
										
											2013-02-16 18:58:22 +00:00
										 |  |  |   const VarNodes& neighs = facNode->neighbors(); | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   sign.reserve (neighs.size() + 1); | 
					
						
							|  |  |  |   for (size_t i = 0; i < neighs.size(); i++) { | 
					
						
							|  |  |  |     sign.push_back (getColor (neighs[i])); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   sign.push_back (getColor (facNode)); | 
					
						
							|  |  |  |   return sign; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-06-13 12:47:41 +01:00
										 |  |  | VarId | 
					
						
							|  |  |  | CountingBp::getRepresentative (VarId vid) | 
					
						
							|  |  |  | { | 
					
						
							| 
									
										
										
										
											2013-02-08 21:01:53 +00:00
										 |  |  |   assert (Util::contains (varClusterMap_, vid)); | 
					
						
							| 
									
										
										
										
											2012-06-30 19:25:29 +01:00
										 |  |  |   VarCluster* vc = varClusterMap_.find (vid)->second; | 
					
						
							| 
									
										
										
										
											2012-06-13 12:47:41 +01:00
										 |  |  |   return vc->representative()->varId(); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | FacNode* | 
					
						
							|  |  |  | CountingBp::getRepresentative (FacNode* fn) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   for (size_t i = 0; i < facClusters_.size(); i++) { | 
					
						
							| 
									
										
										
										
											2013-02-08 21:01:53 +00:00
										 |  |  |     if (Util::contains (facClusters_[i]->members(), fn)) { | 
					
						
							| 
									
										
										
										
											2012-06-13 12:47:41 +01:00
										 |  |  |       return facClusters_[i]->representative(); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   return 0; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  | FactorGraph* | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  | CountingBp::getCompressedFactorGraph() | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | { | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   FactorGraph* fg = new FactorGraph(); | 
					
						
							|  |  |  |   for (size_t i = 0; i < varClusters_.size(); i++) { | 
					
						
							|  |  |  |     VarNode* newVar = new VarNode (varClusters_[i]->first()); | 
					
						
							|  |  |  |     varClusters_[i]->setRepresentative (newVar); | 
					
						
							|  |  |  |     fg->addVarNode (newVar); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   for (size_t i = 0; i < facClusters_.size(); i++) { | 
					
						
							|  |  |  |     Vars vars; | 
					
						
							|  |  |  |     const VarClusters& clusters = facClusters_[i]->varClusters(); | 
					
						
							|  |  |  |     for (size_t j = 0; j < clusters.size(); j++) { | 
					
						
							|  |  |  |       vars.push_back (clusters[j]->representative()); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     const Factor& groundFac = facClusters_[i]->first()->factor(); | 
					
						
							|  |  |  |     FacNode* fn = new FacNode (Factor ( | 
					
						
							|  |  |  |         vars, groundFac.params(), groundFac.distId())); | 
					
						
							|  |  |  |     facClusters_[i]->setRepresentative (fn); | 
					
						
							|  |  |  |     fg->addFacNode (fn); | 
					
						
							|  |  |  |     for (size_t j = 0; j < vars.size(); j++) { | 
					
						
							|  |  |  |       fg->addEdge (static_cast<VarNode*> (vars[j]), fn); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   return fg; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  | std::vector<std::vector<unsigned>> | 
					
						
							| 
									
										
										
										
											2013-02-28 19:45:37 +00:00
										 |  |  | CountingBp::getWeights() const | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  | { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |   std::vector<std::vector<unsigned>> weights; | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   weights.reserve (facClusters_.size()); | 
					
						
							|  |  |  |   for (size_t i = 0; i < facClusters_.size(); i++) { | 
					
						
							|  |  |  |     const VarClusters& neighs = facClusters_[i]->varClusters(); | 
					
						
							|  |  |  |     weights.push_back ({ }); | 
					
						
							|  |  |  |     weights.back().reserve (neighs.size()); | 
					
						
							|  |  |  |     for (size_t j = 0; j < neighs.size(); j++) { | 
					
						
							|  |  |  |       weights.back().push_back (getWeight ( | 
					
						
							|  |  |  |           facClusters_[i], neighs[j], j)); | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   return weights; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | unsigned | 
					
						
							| 
									
										
										
										
											2012-06-12 16:29:57 +01:00
										 |  |  | CountingBp::getWeight ( | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     const FacCluster* fc, | 
					
						
							|  |  |  |     const VarCluster* vc, | 
					
						
							|  |  |  |     size_t index) const | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   unsigned weight = 0; | 
					
						
							|  |  |  |   VarId reprVid = vc->representative()->varId(); | 
					
						
							|  |  |  |   VarNode* groundVar = fg.getVarNode (reprVid); | 
					
						
							|  |  |  |   const FacNodes& neighs = groundVar->neighbors(); | 
					
						
							|  |  |  |   for (size_t i = 0; i < neighs.size(); i++) { | 
					
						
							|  |  |  |     FacNodes::const_iterator it; | 
					
						
							|  |  |  |     it = std::find (fc->members().begin(), fc->members().end(), neighs[i]); | 
					
						
							|  |  |  |     if (it != fc->members().end() && | 
					
						
							|  |  |  |         (*it)->factor().indexOf (reprVid) == index) { | 
					
						
							|  |  |  |       weight ++; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   return weight; | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void | 
					
						
							| 
									
										
										
										
											2012-06-12 16:29:57 +01:00
										 |  |  | CountingBp::printGroups ( | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     const VarSignMap& varGroups, | 
					
						
							|  |  |  |     const FacSignMap& facGroups) const | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  | { | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   unsigned count = 1; | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |   std::cout << "variable groups:" << std::endl; | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   for (VarSignMap::const_iterator it = varGroups.begin(); | 
					
						
							|  |  |  |       it != varGroups.end(); ++it) { | 
					
						
							|  |  |  |     const VarNodes& groupMembers = it->second; | 
					
						
							|  |  |  |     if (groupMembers.size() > 0) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |       std::cout << count << ": " ; | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |       for (size_t i = 0; i < groupMembers.size(); i++) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |         std::cout << groupMembers[i]->label() << " " ; | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |       } | 
					
						
							|  |  |  |       count ++; | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |       std::cout << std::endl; | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   count = 1; | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |   std::cout << std::endl << "factor groups:" << std::endl; | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |   for (FacSignMap::const_iterator it = facGroups.begin(); | 
					
						
							|  |  |  |       it != facGroups.end(); ++it) { | 
					
						
							|  |  |  |     const FacNodes& groupMembers = it->second; | 
					
						
							|  |  |  |     if (groupMembers.size() > 0) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |       std::cout << ++count << ": " ; | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |       for (size_t i = 0; i < groupMembers.size(); i++) { | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |         std::cout << groupMembers[i]->getLabel() << " " ; | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |       } | 
					
						
							|  |  |  |       count ++; | 
					
						
							| 
									
										
										
										
											2013-02-07 13:37:15 +00:00
										 |  |  |       std::cout << std::endl; | 
					
						
							| 
									
										
										
										
											2012-05-31 21:12:46 +01:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2012-05-23 14:56:01 +01:00
										 |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-02-08 21:12:46 +00:00
										 |  |  | }  // namespace Horus
 | 
					
						
							| 
									
										
										
										
											2013-02-07 23:53:13 +00:00
										 |  |  | 
 |