Add support to markov networks
This commit is contained in:
@@ -3,52 +3,37 @@
|
||||
#include <fstream>
|
||||
|
||||
#include "ElimGraph.h"
|
||||
#include "BayesNet.h"
|
||||
|
||||
|
||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
||||
|
||||
|
||||
ElimGraph::ElimGraph (const BayesNet& bayesNet)
|
||||
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||
{
|
||||
const BnNodeSet& bnNodes = bayesNet.getBayesNodes();
|
||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
||||
if (bnNodes[i]->hasEvidence() == false) {
|
||||
addNode (new EgNode (bnNodes[i]));
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
||||
if (bnNodes[i]->hasEvidence() == false) {
|
||||
EgNode* n = getEgNode (bnNodes[i]->varId());
|
||||
const BnNodeSet& childs = bnNodes[i]->getChilds();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
if (childs[j]->hasEvidence() == false) {
|
||||
addEdge (n, getEgNode (childs[j]->varId()));
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
const VarIds& vids = factors[i]->arguments();
|
||||
for (unsigned j = 0; j < vids.size() - 1; j++) {
|
||||
EgNode* n1 = getEgNode (vids[j]);
|
||||
if (n1 == 0) {
|
||||
n1 = new EgNode (vids[j], factors[i]->range (j));
|
||||
addNode (n1);
|
||||
}
|
||||
for (unsigned k = j + 1; k < vids.size(); k++) {
|
||||
EgNode* n2 = getEgNode (vids[k]);
|
||||
if (n2 == 0) {
|
||||
n2 = new EgNode (vids[k], factors[i]->range (k));
|
||||
addNode (n2);
|
||||
}
|
||||
if (neighbors (n1, n2) == false) {
|
||||
addEdge (n1, n2);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (vids.size() == 1) {
|
||||
if (getEgNode (vids[0]) == 0) {
|
||||
addNode (new EgNode (vids[0], factors[i]->range (0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
||||
vector<EgNode*> neighs;
|
||||
const vector<BayesNode*>& parents = bnNodes[i]->getParents();
|
||||
for (unsigned i = 0; i < parents.size(); i++) {
|
||||
if (parents[i]->hasEvidence() == false) {
|
||||
neighs.push_back (getEgNode (parents[i]->varId()));
|
||||
}
|
||||
}
|
||||
if (neighs.size() > 0) {
|
||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||
if (!neighbors (neighs[i], neighs[j])) {
|
||||
addEdge (neighs[i], neighs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
@@ -63,40 +48,16 @@ ElimGraph::~ElimGraph (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::addNode (EgNode* n)
|
||||
{
|
||||
nodes_.push_back (n);
|
||||
varMap_.insert (make_pair (n->varId(), n));
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getEgNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId,EgNode*>::const_iterator it =varMap_.find (vid);
|
||||
if (it ==varMap_.end()) {
|
||||
return 0;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarIds
|
||||
ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
||||
{
|
||||
VarIds elimOrder;
|
||||
marked_.resize (nodes_.size(), false);
|
||||
|
||||
for (unsigned i = 0; i < exclude.size(); i++) {
|
||||
assert (getEgNode (exclude[i]));
|
||||
EgNode* node = getEgNode (exclude[i]);
|
||||
assert (node);
|
||||
marked_[*node] = true;
|
||||
}
|
||||
|
||||
unsigned nVarsToEliminate = nodes_.size() - exclude.size();
|
||||
for (unsigned i = 0; i < nVarsToEliminate; i++) {
|
||||
EgNode* node = getLowestCostNode();
|
||||
@@ -109,6 +70,99 @@ ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::print (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
cout << " " << neighs[j]->label();
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::exportToGraphViz (
|
||||
const char* fileName,
|
||||
bool showNeighborless,
|
||||
const VarIds& highlightVarIds) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Markov::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "strict graph {" << endl;
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
||||
out << '"' << nodes_[i]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
||||
EgNode* node =getEgNode (highlightVarIds[i]);
|
||||
if (node) {
|
||||
out << '"' << node->label() << '"' ;
|
||||
out << " [shape=box3d]" << endl;
|
||||
} else {
|
||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
||||
out << '"' << neighs[j]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarIds
|
||||
ElimGraph::getEliminationOrder (
|
||||
const vector<Factor*> factors,
|
||||
VarIds excludedVids)
|
||||
{
|
||||
ElimGraph graph (factors);
|
||||
// graph.print();
|
||||
// graph.exportToGraphViz ("_egg.dot");
|
||||
return graph.getEliminatingOrder (excludedVids);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::addNode (EgNode* n)
|
||||
{
|
||||
nodes_.push_back (n);
|
||||
varMap_.insert (make_pair (n->varId(), n));
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getEgNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId, EgNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return (it != varMap_.end()) ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getLowestCostNode (void) const
|
||||
{
|
||||
@@ -166,7 +220,7 @@ ElimGraph::getWeightCost (const EgNode* n) const
|
||||
const vector<EgNode*>& neighs = n->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
if (marked_[*neighs[i]] == false) {
|
||||
cost *= neighs[i]->nrStates();
|
||||
cost *= neighs[i]->range();
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
@@ -206,7 +260,7 @@ ElimGraph::getWeightedFillCost (const EgNode* n) const
|
||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||
if (marked_[*neighs[j]] == true) continue;
|
||||
if (!neighbors (neighs[i], neighs[j])) {
|
||||
cost += neighs[i]->nrStates() * neighs[j]->nrStates();
|
||||
cost += neighs[i]->range() * neighs[j]->range();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -257,68 +311,3 @@ ElimGraph::setIndexes (void)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::printGraphicalModel (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
cout << " " << neighs[j]->label();
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::exportToGraphViz (const char* fileName,
|
||||
bool showNeighborless,
|
||||
const VarIds& highlightVarIds) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Markov::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "strict graph {" << endl;
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
||||
out << '"' << nodes_[i]->label() << '"' ;
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||
} else {
|
||||
out << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
||||
EgNode* node =getEgNode (highlightVarIds[i]);
|
||||
if (node) {
|
||||
out << '"' << node->label() << '"' ;
|
||||
out << " [shape=box3d]" << endl;
|
||||
} else {
|
||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
||||
out << '"' << neighs[j]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user