Add support to markov networks

This commit is contained in:
Tiago Gomes
2012-04-05 18:38:56 +01:00
parent 6c3add2ebd
commit 0d23591058
32 changed files with 857 additions and 616 deletions

View File

@@ -11,6 +11,104 @@
#include "Util.h"
void
DAGraph::addNode (DAGraphNode* n)
{
nodes_.push_back (n);
assert (Util::contains (varMap_, n->varId()) == false);
varMap_[n->varId()] = n;
}
void
DAGraph::addEdge (VarId vid1, VarId vid2)
{
unordered_map<VarId, DAGraphNode*>::iterator it1;
unordered_map<VarId, DAGraphNode*>::iterator it2;
it1 = varMap_.find (vid1);
it2 = varMap_.find (vid2);
assert (it1 != varMap_.end());
assert (it2 != varMap_.end());
it1->second->addChild (it2->second);
it2->second->addParent (it1->second);
}
const DAGraphNode*
DAGraph::getNode (VarId vid) const
{
unordered_map<VarId, DAGraphNode*>::const_iterator it;
it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
DAGraphNode*
DAGraph::getNode (VarId vid)
{
unordered_map<VarId, DAGraphNode*>::const_iterator it;
it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
void
DAGraph::setIndexes (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i);
}
}
void
DAGraph::clear (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
nodes_[i]->clear();
}
}
void
DAGraph::exportToGraphViz (const char* fileName)
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "DAGraph::exportToDotFile()" << endl;
abort();
}
out << "digraph {" << endl;
out << "ranksep=1" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
out << nodes_[i]->varId() ;
out << " [" ;
out << "label=\"" << nodes_[i]->label() << "\"" ;
if (nodes_[i]->hasEvidence()) {
out << ",style=filled, fillcolor=yellow" ;
}
out << "]" << endl;
}
for (unsigned i = 0; i < nodes_.size(); i++) {
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
for (unsigned j = 0; j < childs.size(); j++) {
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
out << " [style=bold]" << endl ;
}
}
out << "}" << endl;
out.close();
}
BayesNet::~BayesNet (void)
{
@@ -36,8 +134,8 @@ BayesNet::readFromBifFormat (const char* fileName)
}
States states;
string label = var.getChildNode("NAME").getText();
unsigned nrStates = var.nChildNode ("OUTCOME");
for (unsigned j = 0; j < nrStates; j++) {
unsigned range = var.nChildNode ("OUTCOME");
for (unsigned j = 0; j < range; j++) {
if (var.getChildNode("OUTCOME", j).getText() == 0) {
stringstream ss;
ss << j + 1;
@@ -63,7 +161,7 @@ BayesNet::readFromBifFormat (const char* fileName)
abort();
}
BnNodeSet parents;
unsigned nParams = node->nrStates();
unsigned nParams = node->range();
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
string parentLabel = def.getChildNode("GIVEN", j).getText();
BayesNode* parentNode = getBayesNode (parentLabel);
@@ -71,7 +169,7 @@ BayesNet::readFromBifFormat (const char* fileName)
cerr << "error: unknow variable `" << parentLabel << "'" << endl;
abort();
}
nParams *= parentNode->nrStates();
nParams *= parentNode->range();
parents.push_back (parentNode);
}
node->setParents (parents);
@@ -87,7 +185,7 @@ BayesNet::readFromBifFormat (const char* fileName)
cerr << "for variable `" << label << "'" << endl;
abort();
}
params = reorderParameters (params, node->nrStates());
params = reorderParameters (params, node->range());
if (Globals::logDomain) {
Util::toLog (params);
}
@@ -218,130 +316,6 @@ BayesNet::getLeafNodes (void) const
BayesNet*
BayesNet::getMinimalRequesiteNetwork (VarId vid) const
{
return getMinimalRequesiteNetwork (VarIds() = {vid});
}
BayesNet*
BayesNet::getMinimalRequesiteNetwork (const VarIds& queryVarIds) const
{
BnNodeSet queryVars;
Scheduling scheduling;
for (unsigned i = 0; i < queryVarIds.size(); i++) {
BayesNode* n = getBayesNode (queryVarIds[i]);
assert (n);
queryVars.push_back (n);
scheduling.push (ScheduleInfo (n, false, true));
}
vector<StateInfo*> states (nodes_.size(), 0);
while (!scheduling.empty()) {
ScheduleInfo& sch = scheduling.front();
StateInfo* state = states[sch.node->getIndex()];
if (!state) {
state = new StateInfo();
states[sch.node->getIndex()] = state;
} else {
state->visited = true;
}
if (!sch.node->hasEvidence() && sch.visitedFromChild) {
if (!state->markedOnTop) {
state->markedOnTop = true;
scheduleParents (sch.node, scheduling);
}
if (!state->markedOnBottom) {
state->markedOnBottom = true;
scheduleChilds (sch.node, scheduling);
}
}
if (sch.visitedFromParent) {
if (sch.node->hasEvidence() && !state->markedOnTop) {
state->markedOnTop = true;
scheduleParents (sch.node, scheduling);
}
if (!sch.node->hasEvidence() && !state->markedOnBottom) {
state->markedOnBottom = true;
scheduleChilds (sch.node, scheduling);
}
}
scheduling.pop();
}
/*
cout << "\t\ttop\tbottom" << endl;
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
Util::printDashedLine();
cout << endl;
for (unsigned i = 0; i < states.size(); i++) {
cout << nodes_[i]->label() << ":\t\t" ;
if (states[i]) {
states[i]->markedOnTop ? cout << "yes\t" : cout << "no\t" ;
states[i]->markedOnBottom ? cout << "yes\t" : cout << "no\t" ;
states[i]->visited ? cout << "yes\t" : cout << "no\t" ;
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
cout << endl;
} else {
cout << "no\tno\tno\t" ;
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
cout << endl;
}
}
cout << endl;
*/
BayesNet* bn = new BayesNet();
constructGraph (bn, states);
for (unsigned i = 0; i < nodes_.size(); i++) {
delete states[i];
}
return bn;
}
void
BayesNet::constructGraph (BayesNet* bn,
const vector<StateInfo*>& states) const
{
BnNodeSet mrnNodes;
vector<VarIds> parents;
for (unsigned i = 0; i < nodes_.size(); i++) {
bool isRequired = false;
if (states[i]) {
isRequired = (nodes_[i]->hasEvidence() && states[i]->visited)
||
states[i]->markedOnTop;
}
if (isRequired) {
parents.push_back (VarIds());
if (states[i]->markedOnTop) {
const BnNodeSet& ps = nodes_[i]->getParents();
for (unsigned j = 0; j < ps.size(); j++) {
parents.back().push_back (ps[j]->varId());
}
}
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
BayesNode* mrnNode = new BayesNode (nodes_[i]);
bn->addNode (mrnNode);
mrnNodes.push_back (mrnNode);
}
}
for (unsigned i = 0; i < mrnNodes.size(); i++) {
BnNodeSet ps;
for (unsigned j = 0; j < parents[i].size(); j++) {
assert (bn->getBayesNode (parents[i][j]) != 0);
ps.push_back (bn->getBayesNode (parents[i][j]));
}
mrnNodes[i]->setParents (ps);
}
bn->setIndexes();
}
bool
BayesNet::isPolyTree (void) const
@@ -458,7 +432,7 @@ BayesNet::exportToBifFormat (const char* fileName) const
out << "</GIVEN>" << endl;
}
Params params = revertParameterReorder (
nodes_[i]->params(), nodes_[i]->nrStates());
nodes_[i]->params(), nodes_[i]->range());
out << "\t<TABLE>" ;
for (unsigned j = 0; j < params.size(); j++) {
out << " " << params[j];