This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/ElimGraph.cpp

249 lines
5.7 KiB
C++
Raw Normal View History

2012-05-23 14:56:01 +01:00
#include <limits>
#include <fstream>
#include "ElimGraph.h"
ElimHeuristic ElimGraph::elimHeuristic = MIN_NEIGHBORS;
ElimGraph::ElimGraph (const vector<Factor*>& factors)
{
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < factors.size(); i++) {
2012-05-23 14:56:01 +01:00
if (factors[i] == 0) { // if contained just one var with evidence
continue;
}
const VarIds& vids = factors[i]->arguments();
2012-05-24 22:55:20 +01:00
for (size_t j = 0; j < vids.size() - 1; j++) {
2012-05-23 14:56:01 +01:00
EgNode* n1 = getEgNode (vids[j]);
if (n1 == 0) {
n1 = new EgNode (vids[j], factors[i]->range (j));
addNode (n1);
}
2012-05-24 22:55:20 +01:00
for (size_t k = j + 1; k < vids.size(); k++) {
2012-05-23 14:56:01 +01:00
EgNode* n2 = getEgNode (vids[k]);
if (n2 == 0) {
n2 = new EgNode (vids[k], factors[i]->range (k));
addNode (n2);
}
if (neighbors (n1, n2) == false) {
addEdge (n1, n2);
}
}
}
if (vids.size() == 1) {
if (getEgNode (vids[0]) == 0) {
addNode (new EgNode (vids[0], factors[i]->range (0)));
}
}
}
}
ElimGraph::~ElimGraph (void)
{
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < nodes_.size(); i++) {
2012-12-17 18:39:42 +00:00
delete nodes_[i];
2012-05-23 14:56:01 +01:00
}
}
VarIds
ElimGraph::getEliminatingOrder (const VarIds& exclude)
{
VarIds elimOrder;
unmarked_.reserve (nodes_.size());
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < nodes_.size(); i++) {
2012-05-23 14:56:01 +01:00
if (Util::contains (exclude, nodes_[i]->varId()) == false) {
unmarked_.insert (nodes_[i]);
}
}
2012-05-24 22:55:20 +01:00
size_t nrVarsToEliminate = nodes_.size() - exclude.size();
for (size_t i = 0; i < nrVarsToEliminate; i++) {
2012-05-23 14:56:01 +01:00
EgNode* node = getLowestCostNode();
unmarked_.remove (node);
const EGNeighs& neighs = node->neighbors();
2012-05-24 22:55:20 +01:00
for (size_t j = 0; j < neighs.size(); j++) {
2012-05-23 14:56:01 +01:00
neighs[j]->removeNeighbor (node);
}
elimOrder.push_back (node->varId());
connectAllNeighbors (node);
}
return elimOrder;
}
void
ElimGraph::print (void) const
{
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < nodes_.size(); i++) {
2012-05-23 14:56:01 +01:00
cout << "node " << nodes_[i]->label() << " neighs:" ;
EGNeighs neighs = nodes_[i]->neighbors();
2012-05-24 22:55:20 +01:00
for (size_t j = 0; j < neighs.size(); j++) {
2012-05-23 14:56:01 +01:00
cout << " " << neighs[j]->label();
}
cout << endl;
}
}
void
ElimGraph::exportToGraphViz (
const char* fileName,
bool showNeighborless,
const VarIds& highlightVarIds) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "Markov::exportToDotFile()" << endl;
abort();
}
out << "strict graph {" << endl;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < nodes_.size(); i++) {
2012-05-23 14:56:01 +01:00
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
out << '"' << nodes_[i]->label() << '"' << endl;
}
}
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < highlightVarIds.size(); i++) {
2012-05-23 14:56:01 +01:00
EgNode* node =getEgNode (highlightVarIds[i]);
if (node) {
out << '"' << node->label() << '"' ;
out << " [shape=box3d]" << endl;
} else {
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
abort();
}
}
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < nodes_.size(); i++) {
2012-05-23 14:56:01 +01:00
EGNeighs neighs = nodes_[i]->neighbors();
2012-05-24 22:55:20 +01:00
for (size_t j = 0; j < neighs.size(); j++) {
2012-05-23 14:56:01 +01:00
out << '"' << nodes_[i]->label() << '"' << " -- " ;
out << '"' << neighs[j]->label() << '"' << endl;
}
}
out << "}" << endl;
out.close();
}
VarIds
ElimGraph::getEliminationOrder (
2012-05-29 17:12:57 +01:00
const Factors& factors,
2012-05-23 14:56:01 +01:00
VarIds excludedVids)
{
2012-05-29 17:12:57 +01:00
if (elimHeuristic == ElimHeuristic::SEQUENTIAL) {
VarIds allVids;
Factors::const_iterator first = factors.begin();
Factors::const_iterator end = factors.end();
for (; first != end; ++first) {
Util::addToVector (allVids, (*first)->arguments());
}
TinySet<VarId> elimOrder (allVids);
2012-05-29 17:19:49 +01:00
elimOrder -= TinySet<VarId> (excludedVids);
2012-05-29 17:12:57 +01:00
return elimOrder.elements();
}
2012-05-23 14:56:01 +01:00
ElimGraph graph (factors);
return graph.getEliminatingOrder (excludedVids);
}
void
ElimGraph::addNode (EgNode* n)
{
nodes_.push_back (n);
n->setIndex (nodes_.size() - 1);
varMap_.insert (make_pair (n->varId(), n));
}
EgNode*
ElimGraph::getEgNode (VarId vid) const
{
unordered_map<VarId, EgNode*>::const_iterator it;
it = varMap_.find (vid);
return (it != varMap_.end()) ? it->second : 0;
}
EgNode*
ElimGraph::getLowestCostNode (void) const
{
EgNode* bestNode = 0;
unsigned minCost = std::numeric_limits<unsigned>::max();
EGNeighs::const_iterator it;
switch (elimHeuristic) {
case MIN_NEIGHBORS: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
2012-07-02 22:53:44 +01:00
unsigned cost = getNeighborsCost (*it);
2012-05-23 14:56:01 +01:00
if (cost < minCost) {
bestNode = *it;
minCost = cost;
}
}}
break;
2012-07-02 22:53:44 +01:00
case MIN_WEIGHT: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getWeightCost (*it);
if (cost < minCost) {
bestNode = *it;
minCost = cost;
}
}}
2012-05-23 14:56:01 +01:00
break;
2012-07-02 22:53:44 +01:00
case MIN_FILL: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getFillCost (*it);
if (cost < minCost) {
bestNode = *it;
minCost = cost;
}
}}
2012-05-23 14:56:01 +01:00
break;
2012-07-02 22:53:44 +01:00
case WEIGHTED_MIN_FILL: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getWeightedFillCost (*it);
if (cost < minCost) {
bestNode = *it;
minCost = cost;
}
}}
2012-05-23 14:56:01 +01:00
break;
default:
assert (false);
}
assert (bestNode);
return bestNode;
}
void
ElimGraph::connectAllNeighbors (const EgNode* n)
{
const EGNeighs& neighs = n->neighbors();
if (neighs.size() > 0) {
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < neighs.size() - 1; i++) {
for (size_t j = i + 1; j < neighs.size(); j++) {
2012-05-23 14:56:01 +01:00
if ( ! neighbors (neighs[i], neighs[j])) {
addEdge (neighs[i], neighs[j]);
}
}
}
}
}