diff --git a/packages/CLPBN/clpbn/bp/ElimGraph.cpp b/packages/CLPBN/clpbn/bp/ElimGraph.cpp index d468e5204..155a74eb3 100644 --- a/packages/CLPBN/clpbn/bp/ElimGraph.cpp +++ b/packages/CLPBN/clpbn/bp/ElimGraph.cpp @@ -54,16 +54,20 @@ VarIds ElimGraph::getEliminatingOrder (const VarIds& exclude) { VarIds elimOrder; - marked_.resize (nodes_.size(), false); - for (unsigned i = 0; i < exclude.size(); i++) { - assert (getEgNode (exclude[i])); - EgNode* node = getEgNode (exclude[i]); - marked_[*node] = true; + unmarked_.reserve (nodes_.size()); + for (unsigned i = 0; i < nodes_.size(); i++) { + if (Util::contains (exclude, nodes_[i]->varId()) == false) { + unmarked_.insert (nodes_[i]); + } } unsigned nVarsToEliminate = nodes_.size() - exclude.size(); for (unsigned i = 0; i < nVarsToEliminate; i++) { EgNode* node = getLowestCostNode(); - marked_[*node] = true; + unmarked_.remove (node); + const EGNeighs& neighs = node->neighbors(); + for (unsigned j = 0; j < neighs.size(); j++) { + neighs[j]->removeNeighbor (node); + } elimOrder.push_back (node->varId()); connectAllNeighbors (node); } @@ -77,7 +81,7 @@ ElimGraph::print (void) const { for (unsigned i = 0; i < nodes_.size(); i++) { cout << "node " << nodes_[i]->label() << " neighs:" ; - vector neighs = nodes_[i]->neighbors(); + EGNeighs neighs = nodes_[i]->neighbors(); for (unsigned j = 0; j < neighs.size(); j++) { cout << " " << neighs[j]->label(); } @@ -120,7 +124,7 @@ ElimGraph::exportToGraphViz ( } for (unsigned i = 0; i < nodes_.size(); i++) { - vector neighs = nodes_[i]->neighbors(); + EGNeighs neighs = nodes_[i]->neighbors(); for (unsigned j = 0; j < neighs.size(); j++) { out << '"' << nodes_[i]->label() << '"' << " -- " ; out << '"' << neighs[j]->label() << '"' << endl; @@ -171,29 +175,29 @@ ElimGraph::getLowestCostNode (void) const { EgNode* bestNode = 0; unsigned minCost = std::numeric_limits::max(); - for (unsigned i = 0; i < nodes_.size(); i++) { - if (marked_[i]) continue; - unsigned cost = 0; - switch (elimHeuristic) { - case MIN_NEIGHBORS: - cost = getNeighborsCost (nodes_[i]); - break; - case MIN_WEIGHT: - cost = getWeightCost (nodes_[i]); - break; - case MIN_FILL: - cost = getFillCost (nodes_[i]); - break; - case WEIGHTED_MIN_FILL: - cost = getWeightedFillCost (nodes_[i]); - break; - default: - assert (false); - } - if (cost < minCost) { - bestNode = nodes_[i]; - minCost = cost; - } + unsigned cost = 0; + EGNeighs::const_iterator it; + switch (elimHeuristic) { + case MIN_NEIGHBORS: { + for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) { + cost = getNeighborsCost (*it); + if (cost < minCost) { + bestNode = *it; + minCost = cost; + } + }} + break; + case MIN_WEIGHT: + //cost = getWeightCost (unmarked_[i]); + break; + case MIN_FILL: + //cost = getFillCost (unmarked_[i]); + break; + case WEIGHTED_MIN_FILL: + //cost = getWeightedFillCost (unmarked_[i]); + break; + default: + assert (false); } assert (bestNode); return bestNode; @@ -201,88 +205,14 @@ ElimGraph::getLowestCostNode (void) const -unsigned -ElimGraph::getNeighborsCost (const EgNode* n) const -{ - unsigned cost = 0; - const vector& neighs = n->neighbors(); - for (unsigned i = 0; i < neighs.size(); i++) { - if (marked_[*neighs[i]] == false) { - cost ++; - } - } - return cost; -} - - - -unsigned -ElimGraph::getWeightCost (const EgNode* n) const -{ - unsigned cost = 1; - const vector& neighs = n->neighbors(); - for (unsigned i = 0; i < neighs.size(); i++) { - if (marked_[*neighs[i]] == false) { - cost *= neighs[i]->range(); - } - } - return cost; -} - - - -unsigned -ElimGraph::getFillCost (const EgNode* n) const -{ - unsigned cost = 0; - const vector& neighs = n->neighbors(); - if (neighs.size() > 0) { - for (unsigned i = 0; i < neighs.size() - 1; i++) { - if (marked_[*neighs[i]] == true) continue; - for (unsigned j = i+1; j < neighs.size(); j++) { - if (marked_[*neighs[j]] == true) continue; - if (!neighbors (neighs[i], neighs[j])) { - cost ++; - } - } - } - } - return cost; -} - - - -unsigned -ElimGraph::getWeightedFillCost (const EgNode* n) const -{ - unsigned cost = 0; - const vector& neighs = n->neighbors(); - if (neighs.size() > 0) { - for (unsigned i = 0; i < neighs.size() - 1; i++) { - if (marked_[*neighs[i]] == true) continue; - for (unsigned j = i+1; j < neighs.size(); j++) { - if (marked_[*neighs[j]] == true) continue; - if (!neighbors (neighs[i], neighs[j])) { - cost += neighs[i]->range() * neighs[j]->range(); - } - } - } - } - return cost; -} - - - void ElimGraph::connectAllNeighbors (const EgNode* n) { - const vector& neighs = n->neighbors(); + const EGNeighs& neighs = n->neighbors(); if (neighs.size() > 0) { for (unsigned i = 0; i < neighs.size() - 1; i++) { - if (marked_[*neighs[i]] == true) continue; for (unsigned j = i+1; j < neighs.size(); j++) { - if (marked_[*neighs[j]] == true) continue; - if (!neighbors (neighs[i], neighs[j])) { + if ( ! neighbors (neighs[i], neighs[j])) { addEdge (neighs[i], neighs[j]); } } @@ -290,17 +220,3 @@ ElimGraph::connectAllNeighbors (const EgNode* n) } } - - -bool -ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const -{ - const vector& neighs = n1->neighbors(); - for (unsigned i = 0; i < neighs.size(); i++) { - if (neighs[i] == n2) { - return true; - } - } - return false; -} - diff --git a/packages/CLPBN/clpbn/bp/ElimGraph.h b/packages/CLPBN/clpbn/bp/ElimGraph.h index 8ffba954d..5dca758da 100644 --- a/packages/CLPBN/clpbn/bp/ElimGraph.h +++ b/packages/CLPBN/clpbn/bp/ElimGraph.h @@ -4,8 +4,10 @@ #include "unordered_map" #include "FactorGraph.h" +#include "TinySet.h" #include "Horus.h" + using namespace std; enum ElimHeuristic @@ -17,17 +19,26 @@ enum ElimHeuristic }; +class EgNode; + +typedef TinySet EGNeighs; + + class EgNode : public Var { public: EgNode (VarId vid, unsigned range) : Var (vid, range) { } - void addNeighbor (EgNode* n) { neighs_.push_back (n); } + void addNeighbor (EgNode* n) { neighs_.insert (n); } - const vector& neighbors (void) const { return neighs_; } + void removeNeighbor (EgNode* n) { neighs_.remove (n); } + + bool isNeighbor (EgNode* n) const { return neighs_.contains (n); } + + const EGNeighs& neighbors (void) const { return neighs_; } private: - vector neighs_; + EGNeighs neighs_; }; @@ -58,25 +69,68 @@ class ElimGraph n2->addNeighbor (n1); } + unsigned getNeighborsCost (const EgNode* n) const + { + return n->neighbors().size(); + } + + unsigned getWeightCost (const EgNode* n) const + { + unsigned cost = 1; + const EGNeighs& neighs = n->neighbors(); + for (unsigned i = 0; i < neighs.size(); i++) { + cost *= neighs[i]->range(); + } + return cost; + } + + unsigned getFillCost (const EgNode* n) const + { + unsigned cost = 0; + const EGNeighs& neighs = n->neighbors(); + if (neighs.size() > 0) { + for (unsigned i = 0; i < neighs.size() - 1; i++) { + for (unsigned j = i + 1; j < neighs.size(); j++) { + if ( ! neighbors (neighs[i], neighs[j])) { + cost ++; + } + } + } + } + return cost; + } + + unsigned getWeightedFillCost (const EgNode* n) const + { + unsigned cost = 0; + const EGNeighs& neighs = n->neighbors(); + if (neighs.size() > 0) { + for (unsigned i = 0; i < neighs.size() - 1; i++) { + for (unsigned j = i+1; j < neighs.size(); j++) { + if ( ! neighbors (neighs[i], neighs[j])) { + cost += neighs[i]->range() * neighs[j]->range(); + } + } + } + } + return cost; + } + + bool neighbors (EgNode* n1, EgNode* n2) const + { + return n1->isNeighbor (n2); + } + void addNode (EgNode*); EgNode* getEgNode (VarId) const; + EgNode* getLowestCostNode (void) const; - unsigned getNeighborsCost (const EgNode*) const; - - unsigned getWeightCost (const EgNode*) const; - - unsigned getFillCost (const EgNode*) const; - - unsigned getWeightedFillCost (const EgNode*) const; - void connectAllNeighbors (const EgNode*); - bool neighbors (const EgNode*, const EgNode*) const; - vector nodes_; - vector marked_; + TinySet unmarked_; unordered_map varMap_; }; diff --git a/packages/CLPBN/clpbn/bp/TODO b/packages/CLPBN/clpbn/bp/TODO index 60eb32466..e2c0a5284 100644 --- a/packages/CLPBN/clpbn/bp/TODO +++ b/packages/CLPBN/clpbn/bp/TODO @@ -5,4 +5,8 @@ - Consider using hashs instead of vectors of colors to calculate the groups in counting bp - use more psize_t instead of unsigned for looping through params +- Find a way to decrease the time required to find an + elimination order for variable elimination +- Add a sequential elimination heuristic +