323 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
		
		
			
		
	
	
			323 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| 
								 | 
							
								#include <limits>
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								#include "ElimGraph.h"
							 | 
						||
| 
								 | 
							
								#include "BayesNet.h"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								ElimGraph::ElimGraph (const BayesNet& bayesNet)
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  const BnNodeSet& bnNodes = bayesNet.getBayesNodes();
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < bnNodes.size(); i++) {
							 | 
						||
| 
								 | 
							
								    if (bnNodes[i]->hasEvidence() == false) {
							 | 
						||
| 
								 | 
							
								      addNode (new EgNode (bnNodes[i]));
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < bnNodes.size(); i++) {
							 | 
						||
| 
								 | 
							
								    if (bnNodes[i]->hasEvidence() == false) {
							 | 
						||
| 
								 | 
							
								      EgNode* n = getEgNode (bnNodes[i]->varId());
							 | 
						||
| 
								 | 
							
								      const BnNodeSet& childs = bnNodes[i]->getChilds();
							 | 
						||
| 
								 | 
							
								      for (unsigned j = 0; j < childs.size(); j++) {
							 | 
						||
| 
								 | 
							
								        if (childs[j]->hasEvidence() == false) {
							 | 
						||
| 
								 | 
							
								          addEdge (n, getEgNode (childs[j]->varId()));
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								      }
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < bnNodes.size(); i++) {
							 | 
						||
| 
								 | 
							
								    vector<EgNode*> neighs;
							 | 
						||
| 
								 | 
							
								    const vector<BayesNode*>& parents = bnNodes[i]->getParents();
							 | 
						||
| 
								 | 
							
								    for (unsigned i = 0; i < parents.size(); i++) {
							 | 
						||
| 
								 | 
							
								      if (parents[i]->hasEvidence() == false) {
							 | 
						||
| 
								 | 
							
								        neighs.push_back (getEgNode (parents[i]->varId()));
							 | 
						||
| 
								 | 
							
								      }
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    if (neighs.size() > 0) {
							 | 
						||
| 
								 | 
							
								      for (unsigned i = 0; i < neighs.size() - 1; i++) {
							 | 
						||
| 
								 | 
							
								        for (unsigned j = i+1; j < neighs.size(); j++) {
							 | 
						||
| 
								 | 
							
								          if (!neighbors (neighs[i], neighs[j])) {
							 | 
						||
| 
								 | 
							
								            addEdge (neighs[i], neighs[j]);
							 | 
						||
| 
								 | 
							
								          }
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								      }
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  setIndexes();
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								ElimGraph::~ElimGraph (void)
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < nodes_.size(); i++) {
							 | 
						||
| 
								 | 
							
								    delete nodes_[i];    
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void
							 | 
						||
| 
								 | 
							
								ElimGraph::addNode (EgNode* n)
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  nodes_.push_back (n);
							 | 
						||
| 
								 | 
							
								  vid2nodes_.insert (make_pair (n->varId(), n));
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								EgNode*
							 | 
						||
| 
								 | 
							
								ElimGraph::getEgNode (VarId vid) const
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  unordered_map<VarId,EgNode*>::const_iterator it = vid2nodes_.find (vid);
							 | 
						||
| 
								 | 
							
								  if (it == vid2nodes_.end()) {
							 | 
						||
| 
								 | 
							
								    return 0;
							 | 
						||
| 
								 | 
							
								  } else {
							 | 
						||
| 
								 | 
							
								    return it->second;
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								VarIdSet
							 | 
						||
| 
								 | 
							
								ElimGraph::getEliminatingOrder (const VarIdSet& exclude)
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  VarIdSet elimOrder;
							 | 
						||
| 
								 | 
							
								  marked_.resize (nodes_.size(), false);
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < exclude.size(); i++) {
							 | 
						||
| 
								 | 
							
								    EgNode* node = getEgNode (exclude[i]);
							 | 
						||
| 
								 | 
							
								    assert (node);
							 | 
						||
| 
								 | 
							
								    marked_[*node] = true;
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  unsigned nVarsToEliminate = nodes_.size() - exclude.size();
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < nVarsToEliminate; i++) {
							 | 
						||
| 
								 | 
							
								    EgNode* node = getLowestCostNode();
							 | 
						||
| 
								 | 
							
								    marked_[*node] = true;
							 | 
						||
| 
								 | 
							
								    elimOrder.push_back (node->varId());
							 | 
						||
| 
								 | 
							
								    connectAllNeighbors (node);
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								  return elimOrder;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								EgNode*
							 | 
						||
| 
								 | 
							
								ElimGraph::getLowestCostNode (void) const
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  EgNode* bestNode = 0;
							 | 
						||
| 
								 | 
							
								  unsigned minCost = std::numeric_limits<unsigned>::max();
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < nodes_.size(); i++) {
							 | 
						||
| 
								 | 
							
								    if (marked_[i]) continue;
							 | 
						||
| 
								 | 
							
								    unsigned cost = 0;
							 | 
						||
| 
								 | 
							
								    switch (elimHeuristic_) {
							 | 
						||
| 
								 | 
							
								      case MIN_NEIGHBORS:
							 | 
						||
| 
								 | 
							
								        cost = getNeighborsCost (nodes_[i]);
							 | 
						||
| 
								 | 
							
								        break;
							 | 
						||
| 
								 | 
							
								      case MIN_WEIGHT:
							 | 
						||
| 
								 | 
							
								        cost = getWeightCost (nodes_[i]);
							 | 
						||
| 
								 | 
							
								        break;
							 | 
						||
| 
								 | 
							
								      case MIN_FILL:
							 | 
						||
| 
								 | 
							
								        cost = getFillCost (nodes_[i]);
							 | 
						||
| 
								 | 
							
								        break;
							 | 
						||
| 
								 | 
							
								      case WEIGHTED_MIN_FILL:
							 | 
						||
| 
								 | 
							
								        cost = getWeightedFillCost (nodes_[i]);
							 | 
						||
| 
								 | 
							
								        break;
							 | 
						||
| 
								 | 
							
								      default:
							 | 
						||
| 
								 | 
							
								        assert (false);
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    if (cost < minCost) {
							 | 
						||
| 
								 | 
							
								      bestNode = nodes_[i];
							 | 
						||
| 
								 | 
							
								      minCost  = cost;
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								  assert (bestNode);
							 | 
						||
| 
								 | 
							
								  return bestNode;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								unsigned
							 | 
						||
| 
								 | 
							
								ElimGraph::getNeighborsCost (const EgNode* n) const
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  unsigned cost = 0;
							 | 
						||
| 
								 | 
							
								  const vector<EgNode*>& neighs = n->neighbors();
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < neighs.size(); i++) {
							 | 
						||
| 
								 | 
							
								    if (marked_[*neighs[i]] == false) {
							 | 
						||
| 
								 | 
							
								      cost ++;
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								  return cost;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								unsigned
							 | 
						||
| 
								 | 
							
								ElimGraph::getWeightCost (const EgNode* n) const
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  unsigned cost = 1;
							 | 
						||
| 
								 | 
							
								  const vector<EgNode*>& neighs = n->neighbors();
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < neighs.size(); i++) {
							 | 
						||
| 
								 | 
							
								    if (marked_[*neighs[i]] == false) {
							 | 
						||
| 
								 | 
							
								      cost *= neighs[i]->nrStates();
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								  return cost;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								unsigned
							 | 
						||
| 
								 | 
							
								ElimGraph::getFillCost (const EgNode* n) const
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  unsigned cost = 0;
							 | 
						||
| 
								 | 
							
								  const vector<EgNode*>& neighs = n->neighbors();
							 | 
						||
| 
								 | 
							
								  if (neighs.size() > 0) {
							 | 
						||
| 
								 | 
							
								    for (unsigned i = 0; i < neighs.size() - 1; i++) {
							 | 
						||
| 
								 | 
							
								      if (marked_[*neighs[i]] == true) continue;
							 | 
						||
| 
								 | 
							
								      for (unsigned j = i+1; j < neighs.size(); j++) {
							 | 
						||
| 
								 | 
							
								        if (marked_[*neighs[j]] == true) continue;
							 | 
						||
| 
								 | 
							
								        if (!neighbors (neighs[i], neighs[j])) {
							 | 
						||
| 
								 | 
							
								          cost ++;
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								      }
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								  return cost;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								unsigned
							 | 
						||
| 
								 | 
							
								ElimGraph::getWeightedFillCost (const EgNode* n) const
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  unsigned cost = 0;
							 | 
						||
| 
								 | 
							
								  const vector<EgNode*>& neighs = n->neighbors();
							 | 
						||
| 
								 | 
							
								  if (neighs.size() > 0) {
							 | 
						||
| 
								 | 
							
								    for (unsigned i = 0; i < neighs.size() - 1; i++) {
							 | 
						||
| 
								 | 
							
								      if (marked_[*neighs[i]] == true) continue;
							 | 
						||
| 
								 | 
							
								      for (unsigned j = i+1; j < neighs.size(); j++) {
							 | 
						||
| 
								 | 
							
								        if (marked_[*neighs[j]] == true) continue;
							 | 
						||
| 
								 | 
							
								        if (!neighbors (neighs[i], neighs[j])) {
							 | 
						||
| 
								 | 
							
								          cost += neighs[i]->nrStates() * neighs[j]->nrStates();
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								      }
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								  return cost;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void
							 | 
						||
| 
								 | 
							
								ElimGraph::connectAllNeighbors (const EgNode* n)
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  const vector<EgNode*>& neighs = n->neighbors();
							 | 
						||
| 
								 | 
							
								  if (neighs.size() > 0) {
							 | 
						||
| 
								 | 
							
								    for (unsigned i = 0; i < neighs.size() - 1; i++) {
							 | 
						||
| 
								 | 
							
								      if (marked_[*neighs[i]] == true) continue;
							 | 
						||
| 
								 | 
							
								      for (unsigned j = i+1; j < neighs.size(); j++) {
							 | 
						||
| 
								 | 
							
								        if (marked_[*neighs[j]] == true) continue;
							 | 
						||
| 
								 | 
							
								        if (!neighbors (neighs[i], neighs[j])) {
							 | 
						||
| 
								 | 
							
								          addEdge (neighs[i], neighs[j]);
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								      }
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								bool
							 | 
						||
| 
								 | 
							
								ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  const vector<EgNode*>& neighs = n1->neighbors();
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < neighs.size(); i++) {
							 | 
						||
| 
								 | 
							
								    if (neighs[i] == n2) {
							 | 
						||
| 
								 | 
							
								      return true;
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								  return false;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void
							 | 
						||
| 
								 | 
							
								ElimGraph::setIndexes (void)
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < nodes_.size(); i++) {
							 | 
						||
| 
								 | 
							
								    nodes_[i]->setIndex (i);
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void
							 | 
						||
| 
								 | 
							
								ElimGraph::printGraphicalModel (void) const
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < nodes_.size(); i++) {
							 | 
						||
| 
								 | 
							
								    cout << "node " << nodes_[i]->label() << " neighs:" ;
							 | 
						||
| 
								 | 
							
								    vector<EgNode*> neighs = nodes_[i]->neighbors();
							 | 
						||
| 
								 | 
							
								    for (unsigned j = 0; j < neighs.size(); j++) {
							 | 
						||
| 
								 | 
							
								      cout << "  " << neighs[j]->label();
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    cout << endl;
							 | 
						||
| 
								 | 
							
								  }  
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void
							 | 
						||
| 
								 | 
							
								ElimGraph::exportToGraphViz (const char* fileName,
							 | 
						||
| 
								 | 
							
								                             bool showNeighborless,
							 | 
						||
| 
								 | 
							
								                             const VarIdSet& highlightVarIds) const
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  ofstream out (fileName);
							 | 
						||
| 
								 | 
							
								  if (!out.is_open()) {
							 | 
						||
| 
								 | 
							
								    cerr << "error: cannot open file to write at " ;
							 | 
						||
| 
								 | 
							
								    cerr << "Markov::exportToDotFile()" << endl;
							 | 
						||
| 
								 | 
							
								    abort();
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  out << "strict graph {" << endl;
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < nodes_.size(); i++) {
							 | 
						||
| 
								 | 
							
								    if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
							 | 
						||
| 
								 | 
							
								      out << '"' << nodes_[i]->label() << '"' ;
							 | 
						||
| 
								 | 
							
								      if (nodes_[i]->hasEvidence()) {
							 | 
						||
| 
								 | 
							
								        out << " [style=filled, fillcolor=yellow]" << endl;
							 | 
						||
| 
								 | 
							
								      } else {
							 | 
						||
| 
								 | 
							
								        out << endl;
							 | 
						||
| 
								 | 
							
								      }
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < highlightVarIds.size(); i++) {
							 | 
						||
| 
								 | 
							
								    EgNode* node =getEgNode (highlightVarIds[i]);
							 | 
						||
| 
								 | 
							
								    if (node) {
							 | 
						||
| 
								 | 
							
								      out << '"' << node->label() << '"' ;
							 | 
						||
| 
								 | 
							
								      out << " [shape=box3d]" << endl;
							 | 
						||
| 
								 | 
							
								    } else {
							 | 
						||
| 
								 | 
							
								      cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
							 | 
						||
| 
								 | 
							
								      abort();
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  for (unsigned i = 0; i < nodes_.size(); i++) {
							 | 
						||
| 
								 | 
							
								    vector<EgNode*> neighs = nodes_[i]->neighbors();
							 | 
						||
| 
								 | 
							
								    for (unsigned j = 0; j < neighs.size(); j++) {
							 | 
						||
| 
								 | 
							
								      out << '"' << nodes_[i]->label() << '"' << " -- " ;
							 | 
						||
| 
								 | 
							
								      out << '"' << neighs[j]->label() << '"' << endl;
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  out << "}" << endl;
							 | 
						||
| 
								 | 
							
								  out.close();
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 |