Add support to markov networks
This commit is contained in:
@@ -11,6 +11,104 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
void
|
||||
DAGraph::addNode (DAGraphNode* n)
|
||||
{
|
||||
nodes_.push_back (n);
|
||||
assert (Util::contains (varMap_, n->varId()) == false);
|
||||
varMap_[n->varId()] = n;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
DAGraph::addEdge (VarId vid1, VarId vid2)
|
||||
{
|
||||
unordered_map<VarId, DAGraphNode*>::iterator it1;
|
||||
unordered_map<VarId, DAGraphNode*>::iterator it2;
|
||||
it1 = varMap_.find (vid1);
|
||||
it2 = varMap_.find (vid2);
|
||||
assert (it1 != varMap_.end());
|
||||
assert (it2 != varMap_.end());
|
||||
it1->second->addChild (it2->second);
|
||||
it2->second->addParent (it1->second);
|
||||
}
|
||||
|
||||
|
||||
|
||||
const DAGraphNode*
|
||||
DAGraph::getNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
DAGraphNode*
|
||||
DAGraph::getNode (VarId vid)
|
||||
{
|
||||
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
DAGraph::setIndexes (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->setIndex (i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
DAGraph::clear (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->clear();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
DAGraph::exportToGraphViz (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "DAGraph::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
out << "digraph {" << endl;
|
||||
out << "ranksep=1" << endl;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
out << nodes_[i]->varId() ;
|
||||
out << " [" ;
|
||||
out << "label=\"" << nodes_[i]->label() << "\"" ;
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
out << ",style=filled, fillcolor=yellow" ;
|
||||
}
|
||||
out << "]" << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
||||
out << " [style=bold]" << endl ;
|
||||
}
|
||||
}
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
BayesNet::~BayesNet (void)
|
||||
{
|
||||
@@ -36,8 +134,8 @@ BayesNet::readFromBifFormat (const char* fileName)
|
||||
}
|
||||
States states;
|
||||
string label = var.getChildNode("NAME").getText();
|
||||
unsigned nrStates = var.nChildNode ("OUTCOME");
|
||||
for (unsigned j = 0; j < nrStates; j++) {
|
||||
unsigned range = var.nChildNode ("OUTCOME");
|
||||
for (unsigned j = 0; j < range; j++) {
|
||||
if (var.getChildNode("OUTCOME", j).getText() == 0) {
|
||||
stringstream ss;
|
||||
ss << j + 1;
|
||||
@@ -63,7 +161,7 @@ BayesNet::readFromBifFormat (const char* fileName)
|
||||
abort();
|
||||
}
|
||||
BnNodeSet parents;
|
||||
unsigned nParams = node->nrStates();
|
||||
unsigned nParams = node->range();
|
||||
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
|
||||
string parentLabel = def.getChildNode("GIVEN", j).getText();
|
||||
BayesNode* parentNode = getBayesNode (parentLabel);
|
||||
@@ -71,7 +169,7 @@ BayesNet::readFromBifFormat (const char* fileName)
|
||||
cerr << "error: unknow variable `" << parentLabel << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
nParams *= parentNode->nrStates();
|
||||
nParams *= parentNode->range();
|
||||
parents.push_back (parentNode);
|
||||
}
|
||||
node->setParents (parents);
|
||||
@@ -87,7 +185,7 @@ BayesNet::readFromBifFormat (const char* fileName)
|
||||
cerr << "for variable `" << label << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
params = reorderParameters (params, node->nrStates());
|
||||
params = reorderParameters (params, node->range());
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
@@ -218,130 +316,6 @@ BayesNet::getLeafNodes (void) const
|
||||
|
||||
|
||||
|
||||
BayesNet*
|
||||
BayesNet::getMinimalRequesiteNetwork (VarId vid) const
|
||||
{
|
||||
return getMinimalRequesiteNetwork (VarIds() = {vid});
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNet*
|
||||
BayesNet::getMinimalRequesiteNetwork (const VarIds& queryVarIds) const
|
||||
{
|
||||
BnNodeSet queryVars;
|
||||
Scheduling scheduling;
|
||||
for (unsigned i = 0; i < queryVarIds.size(); i++) {
|
||||
BayesNode* n = getBayesNode (queryVarIds[i]);
|
||||
assert (n);
|
||||
queryVars.push_back (n);
|
||||
scheduling.push (ScheduleInfo (n, false, true));
|
||||
}
|
||||
|
||||
vector<StateInfo*> states (nodes_.size(), 0);
|
||||
|
||||
while (!scheduling.empty()) {
|
||||
ScheduleInfo& sch = scheduling.front();
|
||||
StateInfo* state = states[sch.node->getIndex()];
|
||||
if (!state) {
|
||||
state = new StateInfo();
|
||||
states[sch.node->getIndex()] = state;
|
||||
} else {
|
||||
state->visited = true;
|
||||
}
|
||||
if (!sch.node->hasEvidence() && sch.visitedFromChild) {
|
||||
if (!state->markedOnTop) {
|
||||
state->markedOnTop = true;
|
||||
scheduleParents (sch.node, scheduling);
|
||||
}
|
||||
if (!state->markedOnBottom) {
|
||||
state->markedOnBottom = true;
|
||||
scheduleChilds (sch.node, scheduling);
|
||||
}
|
||||
}
|
||||
if (sch.visitedFromParent) {
|
||||
if (sch.node->hasEvidence() && !state->markedOnTop) {
|
||||
state->markedOnTop = true;
|
||||
scheduleParents (sch.node, scheduling);
|
||||
}
|
||||
if (!sch.node->hasEvidence() && !state->markedOnBottom) {
|
||||
state->markedOnBottom = true;
|
||||
scheduleChilds (sch.node, scheduling);
|
||||
}
|
||||
}
|
||||
scheduling.pop();
|
||||
}
|
||||
/*
|
||||
cout << "\t\ttop\tbottom" << endl;
|
||||
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
|
||||
Util::printDashedLine();
|
||||
cout << endl;
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
cout << nodes_[i]->label() << ":\t\t" ;
|
||||
if (states[i]) {
|
||||
states[i]->markedOnTop ? cout << "yes\t" : cout << "no\t" ;
|
||||
states[i]->markedOnBottom ? cout << "yes\t" : cout << "no\t" ;
|
||||
states[i]->visited ? cout << "yes\t" : cout << "no\t" ;
|
||||
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
|
||||
cout << endl;
|
||||
} else {
|
||||
cout << "no\tno\tno\t" ;
|
||||
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
cout << endl;
|
||||
*/
|
||||
BayesNet* bn = new BayesNet();
|
||||
constructGraph (bn, states);
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
delete states[i];
|
||||
}
|
||||
return bn;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::constructGraph (BayesNet* bn,
|
||||
const vector<StateInfo*>& states) const
|
||||
{
|
||||
BnNodeSet mrnNodes;
|
||||
vector<VarIds> parents;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
bool isRequired = false;
|
||||
if (states[i]) {
|
||||
isRequired = (nodes_[i]->hasEvidence() && states[i]->visited)
|
||||
||
|
||||
states[i]->markedOnTop;
|
||||
}
|
||||
if (isRequired) {
|
||||
parents.push_back (VarIds());
|
||||
if (states[i]->markedOnTop) {
|
||||
const BnNodeSet& ps = nodes_[i]->getParents();
|
||||
for (unsigned j = 0; j < ps.size(); j++) {
|
||||
parents.back().push_back (ps[j]->varId());
|
||||
}
|
||||
}
|
||||
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
|
||||
BayesNode* mrnNode = new BayesNode (nodes_[i]);
|
||||
bn->addNode (mrnNode);
|
||||
mrnNodes.push_back (mrnNode);
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < mrnNodes.size(); i++) {
|
||||
BnNodeSet ps;
|
||||
for (unsigned j = 0; j < parents[i].size(); j++) {
|
||||
assert (bn->getBayesNode (parents[i][j]) != 0);
|
||||
ps.push_back (bn->getBayesNode (parents[i][j]));
|
||||
}
|
||||
mrnNodes[i]->setParents (ps);
|
||||
}
|
||||
bn->setIndexes();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BayesNet::isPolyTree (void) const
|
||||
@@ -458,7 +432,7 @@ BayesNet::exportToBifFormat (const char* fileName) const
|
||||
out << "</GIVEN>" << endl;
|
||||
}
|
||||
Params params = revertParameterReorder (
|
||||
nodes_[i]->params(), nodes_[i]->nrStates());
|
||||
nodes_[i]->params(), nodes_[i]->range());
|
||||
out << "\t<TABLE>" ;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
out << " " << params[j];
|
||||
|
Reference in New Issue
Block a user