This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/bp/BayesNet.cpp

622 lines
14 KiB
C++
Raw Normal View History

#include <cstdlib>
#include <cassert>
#include <iostream>
#include <fstream>
#include <sstream>
#include "xmlParser/xmlParser.h"
#include "BayesNet.h"
2012-03-22 11:33:24 +00:00
#include "Util.h"
2011-12-12 15:29:51 +00:00
BayesNet::~BayesNet (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
delete nodes_[i];
}
}
void
BayesNet::readFromBifFormat (const char* fileName)
{
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
// only the first network is parsed, others are ignored
2011-12-12 15:29:51 +00:00
XMLNode xNode = xMainNode.getChildNode ("NETWORK");
unsigned nVars = xNode.nChildNode ("VARIABLE");
for (unsigned i = 0; i < nVars; i++) {
XMLNode var = xNode.getChildNode ("VARIABLE", i);
2011-12-12 15:29:51 +00:00
if (string (var.getAttribute ("TYPE")) != "nature") {
cerr << "error: only \"nature\" variables are supported" << endl;
abort();
}
2011-12-12 15:29:51 +00:00
States states;
string label = var.getChildNode("NAME").getText();
unsigned nrStates = var.nChildNode ("OUTCOME");
for (unsigned j = 0; j < nrStates; j++) {
if (var.getChildNode("OUTCOME", j).getText() == 0) {
stringstream ss;
ss << j + 1;
2011-12-12 15:29:51 +00:00
states.push_back (ss.str());
} else {
2011-12-12 15:29:51 +00:00
states.push_back (var.getChildNode("OUTCOME", j).getText());
}
}
2011-12-12 15:29:51 +00:00
addNode (label, states);
}
unsigned nDefs = xNode.nChildNode ("DEFINITION");
if (nVars != nDefs) {
cerr << "error: different number of variables and definitions" << endl;
abort();
}
for (unsigned i = 0; i < nDefs; i++) {
2011-12-12 15:29:51 +00:00
XMLNode def = xNode.getChildNode ("DEFINITION", i);
string label = def.getChildNode("FOR").getText();
BayesNode* node = getBayesNode (label);
if (!node) {
cerr << "error: unknow variable `" << label << "'" << endl;
abort();
}
BnNodeSet parents;
2011-12-12 15:29:51 +00:00
unsigned nParams = node->nrStates();
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
string parentLabel = def.getChildNode("GIVEN", j).getText();
BayesNode* parentNode = getBayesNode (parentLabel);
2011-12-12 15:29:51 +00:00
if (!parentNode) {
cerr << "error: unknow variable `" << parentLabel << "'" << endl;
abort();
}
2011-12-12 15:29:51 +00:00
nParams *= parentNode->nrStates();
parents.push_back (parentNode);
}
2011-12-12 15:29:51 +00:00
node->setParents (parents);
unsigned count = 0;
2012-03-22 11:33:24 +00:00
Params params (nParams);
2011-12-12 15:29:51 +00:00
stringstream s (def.getChildNode("TABLE").getText());
while (!s.eof() && count < nParams) {
s >> params[count];
count ++;
}
if (count != nParams) {
cerr << "error: invalid number of parameters " ;
cerr << "for variable `" << label << "'" << endl;
abort();
}
2011-12-12 15:29:51 +00:00
params = reorderParameters (params, node->nrStates());
Distribution* dist = new Distribution (params);
node->setDistribution (dist);
addDistribution (dist);
}
2011-12-12 15:29:51 +00:00
setIndexes();
2012-03-22 11:33:24 +00:00
if (Globals::logDomain) {
2011-12-12 15:29:51 +00:00
distributionsToLogs();
}
}
BayesNode*
2011-12-12 15:29:51 +00:00
BayesNet::addNode (string label, const States& states)
{
2011-12-12 15:29:51 +00:00
VarId vid = nodes_.size();
2012-03-22 11:33:24 +00:00
varMap_.insert (make_pair (vid, nodes_.size()));
2011-12-12 15:29:51 +00:00
GraphicalModel::addVariableInformation (vid, label, states);
BayesNode* node = new BayesNode (VarNode (vid, states.size()));
nodes_.push_back (node);
return node;
}
BayesNode*
2012-03-22 11:33:24 +00:00
BayesNet::addNode (VarId vid, unsigned dsize, int evidence, Distribution* dist)
2011-12-12 15:29:51 +00:00
{
2012-03-22 11:33:24 +00:00
varMap_.insert (make_pair (vid, nodes_.size()));
2011-12-12 15:29:51 +00:00
nodes_.push_back (new BayesNode (vid, dsize, evidence, dist));
return nodes_.back();
}
BayesNode*
2011-12-12 15:29:51 +00:00
BayesNet::getBayesNode (VarId vid) const
{
2012-03-22 11:33:24 +00:00
IndexMap::const_iterator it = varMap_.find (vid);
if (it == varMap_.end()) {
return 0;
} else {
return nodes_[it->second];
}
}
BayesNode*
BayesNet::getBayesNode (string label) const
{
BayesNode* node = 0;
for (unsigned i = 0; i < nodes_.size(); i++) {
2011-12-12 15:29:51 +00:00
if (nodes_[i]->label() == label) {
node = nodes_[i];
break;
}
}
return node;
}
2011-12-12 15:29:51 +00:00
VarNode*
BayesNet::getVariableNode (VarId vid) const
{
2011-12-12 15:29:51 +00:00
BayesNode* node = getBayesNode (vid);
assert (node);
return node;
}
VarNodes
BayesNet::getVariableNodes (void) const
{
VarNodes vars;
for (unsigned i = 0; i < nodes_.size(); i++) {
vars.push_back (nodes_[i]);
}
return vars;
}
void
BayesNet::addDistribution (Distribution* dist)
{
dists_.push_back (dist);
}
Distribution*
BayesNet::getDistribution (unsigned distId) const
{
Distribution* dist = 0;
for (unsigned i = 0; i < dists_.size(); i++) {
2012-03-22 11:33:24 +00:00
if (dists_[i]->id == (int) distId) {
dist = dists_[i];
break;
}
}
return dist;
}
const BnNodeSet&
BayesNet::getBayesNodes (void) const
{
return nodes_;
}
unsigned
2011-12-12 15:29:51 +00:00
BayesNet::nrNodes (void) const
{
return nodes_.size();
}
BnNodeSet
BayesNet::getRootNodes (void) const
{
BnNodeSet roots;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (nodes_[i]->isRoot()) {
roots.push_back (nodes_[i]);
}
}
return roots;
}
BnNodeSet
BayesNet::getLeafNodes (void) const
{
BnNodeSet leafs;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (nodes_[i]->isLeaf()) {
leafs.push_back (nodes_[i]);
}
}
return leafs;
}
BayesNet*
2011-12-12 15:29:51 +00:00
BayesNet::getMinimalRequesiteNetwork (VarId vid) const
{
2012-03-22 11:33:24 +00:00
return getMinimalRequesiteNetwork (VarIds() = {vid});
}
BayesNet*
2012-03-22 11:33:24 +00:00
BayesNet::getMinimalRequesiteNetwork (const VarIds& queryVarIds) const
{
BnNodeSet queryVars;
2012-03-22 11:33:24 +00:00
Scheduling scheduling;
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < queryVarIds.size(); i++) {
2012-03-22 11:33:24 +00:00
BayesNode* n = getBayesNode (queryVarIds[i]);
assert (n);
queryVars.push_back (n);
scheduling.push (ScheduleInfo (n, false, true));
}
vector<StateInfo*> states (nodes_.size(), 0);
while (!scheduling.empty()) {
ScheduleInfo& sch = scheduling.front();
StateInfo* state = states[sch.node->getIndex()];
if (!state) {
state = new StateInfo();
states[sch.node->getIndex()] = state;
} else {
state->visited = true;
}
if (!sch.node->hasEvidence() && sch.visitedFromChild) {
if (!state->markedOnTop) {
state->markedOnTop = true;
scheduleParents (sch.node, scheduling);
}
if (!state->markedOnBottom) {
state->markedOnBottom = true;
scheduleChilds (sch.node, scheduling);
}
}
if (sch.visitedFromParent) {
if (sch.node->hasEvidence() && !state->markedOnTop) {
state->markedOnTop = true;
scheduleParents (sch.node, scheduling);
}
if (!sch.node->hasEvidence() && !state->markedOnBottom) {
state->markedOnBottom = true;
scheduleChilds (sch.node, scheduling);
}
}
scheduling.pop();
}
/*
cout << "\t\ttop\tbottom" << endl;
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
cout << "----------------------------------------------------------" ;
cout << endl;
for (unsigned i = 0; i < states.size(); i++) {
2011-12-12 15:29:51 +00:00
cout << nodes_[i]->label() << ":\t\t" ;
if (states[i]) {
states[i]->markedOnTop ? cout << "yes\t" : cout << "no\t" ;
states[i]->markedOnBottom ? cout << "yes\t" : cout << "no\t" ;
states[i]->visited ? cout << "yes\t" : cout << "no\t" ;
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
cout << endl;
} else {
cout << "no\tno\tno\t" ;
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
cout << endl;
}
}
cout << endl;
*/
BayesNet* bn = new BayesNet();
constructGraph (bn, states);
for (unsigned i = 0; i < nodes_.size(); i++) {
delete states[i];
}
return bn;
}
void
BayesNet::constructGraph (BayesNet* bn,
const vector<StateInfo*>& states) const
{
2011-12-12 15:29:51 +00:00
BnNodeSet mrnNodes;
2012-03-22 11:33:24 +00:00
vector<VarIds> parents;
for (unsigned i = 0; i < nodes_.size(); i++) {
bool isRequired = false;
if (states[i]) {
isRequired = (nodes_[i]->hasEvidence() && states[i]->visited)
2011-12-12 15:29:51 +00:00
||
states[i]->markedOnTop;
}
if (isRequired) {
2012-03-22 11:33:24 +00:00
parents.push_back (VarIds());
if (states[i]->markedOnTop) {
const BnNodeSet& ps = nodes_[i]->getParents();
for (unsigned j = 0; j < ps.size(); j++) {
2011-12-12 15:29:51 +00:00
parents.back().push_back (ps[j]->varId());
}
}
2011-12-12 15:29:51 +00:00
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
BayesNode* mrnNode = bn->addNode (nodes_[i]->varId(),
nodes_[i]->nrStates(),
nodes_[i]->getEvidence(),
nodes_[i]->getDistribution());
mrnNodes.push_back (mrnNode);
}
}
for (unsigned i = 0; i < mrnNodes.size(); i++) {
BnNodeSet ps;
for (unsigned j = 0; j < parents[i].size(); j++) {
assert (bn->getBayesNode (parents[i][j]) != 0);
ps.push_back (bn->getBayesNode (parents[i][j]));
}
2011-12-12 15:29:51 +00:00
mrnNodes[i]->setParents (ps);
}
bn->setIndexes();
}
bool
2011-12-12 15:29:51 +00:00
BayesNet::isPolyTree (void) const
{
return !containsUndirectedCycle();
}
void
BayesNet::setIndexes (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i);
}
}
2011-12-12 15:29:51 +00:00
void
BayesNet::distributionsToLogs (void)
{
for (unsigned i = 0; i < dists_.size(); i++) {
Util::toLog (dists_[i]->params);
}
}
void
BayesNet::freeDistributions (void)
{
for (unsigned i = 0; i < dists_.size(); i++) {
delete dists_[i];
}
}
void
BayesNet::printGraphicalModel (void) const
{
for (unsigned i = 0; i < nodes_.size(); i++) {
cout << *nodes_[i];
}
}
void
2011-12-12 15:29:51 +00:00
BayesNet::exportToGraphViz (const char* fileName,
bool showNeighborless,
2012-03-22 11:33:24 +00:00
const VarIds& highlightVarIds) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "BayesNet::exportToDotFile()" << endl;
abort();
}
2011-12-12 15:29:51 +00:00
out << "digraph {" << endl;
out << "ranksep=1" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->hasNeighbors()) {
2011-12-12 15:29:51 +00:00
out << nodes_[i]->varId() ;
if (nodes_[i]->hasEvidence()) {
2011-12-12 15:29:51 +00:00
out << " [" ;
out << "label=\"" << nodes_[i]->label() << "\"," ;
out << "style=filled, fillcolor=yellow" ;
out << "]" ;
} else {
2011-12-12 15:29:51 +00:00
out << " [" ;
out << "label=\"" << nodes_[i]->label() << "\"" ;
out << "]" ;
}
2011-12-12 15:29:51 +00:00
out << endl;
}
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
BayesNode* node = getBayesNode (highlightVarIds[i]);
if (node) {
2011-12-12 15:29:51 +00:00
out << node->varId() ;
out << " [shape=box3d]" << endl;
} else {
2011-12-12 15:29:51 +00:00
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
abort();
}
}
for (unsigned i = 0; i < nodes_.size(); i++) {
const BnNodeSet& childs = nodes_[i]->getChilds();
for (unsigned j = 0; j < childs.size(); j++) {
2011-12-12 15:29:51 +00:00
out << nodes_[i]->varId() << " -> " << childs[j]->varId() << " [style=bold]" << endl ;
}
}
out << "}" << endl;
out.close();
}
void
BayesNet::exportToBifFormat (const char* fileName) const
{
ofstream out (fileName);
if(!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "BayesNet::exportToBifFile()" << endl;
abort();
}
out << "<?xml version=\"1.0\" encoding=\"US-ASCII\"?>" << endl;
out << "<BIF VERSION=\"0.3\">" << endl;
out << "<NETWORK>" << endl;
out << "<NAME>" << fileName << "</NAME>" << endl << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
out << "<VARIABLE TYPE=\"nature\">" << endl;
2011-12-12 15:29:51 +00:00
out << "\t<NAME>" << nodes_[i]->label() << "</NAME>" << endl;
const States& states = nodes_[i]->states();
for (unsigned j = 0; j < states.size(); j++) {
out << "\t<OUTCOME>" << states[j] << "</OUTCOME>" << endl;
}
out << "</VARIABLE>" << endl << endl;
}
for (unsigned i = 0; i < nodes_.size(); i++) {
out << "<DEFINITION>" << endl;
2011-12-12 15:29:51 +00:00
out << "\t<FOR>" << nodes_[i]->label() << "</FOR>" << endl;
const BnNodeSet& parents = nodes_[i]->getParents();
for (unsigned j = 0; j < parents.size(); j++) {
2011-12-12 15:29:51 +00:00
out << "\t<GIVEN>" << parents[j]->label();
out << "</GIVEN>" << endl;
}
2012-03-22 11:33:24 +00:00
Params params = revertParameterReorder (nodes_[i]->getParameters(),
2011-12-12 15:29:51 +00:00
nodes_[i]->nrStates());
out << "\t<TABLE>" ;
for (unsigned j = 0; j < params.size(); j++) {
out << " " << params[j];
}
out << " </TABLE>" << endl;
out << "</DEFINITION>" << endl << endl;
}
out << "</NETWORK>" << endl;
out << "</BIF>" << endl << endl;
out.close();
}
bool
BayesNet::containsUndirectedCycle (void) const
{
vector<bool> visited (nodes_.size(), false);
for (unsigned i = 0; i < nodes_.size(); i++) {
int v = nodes_[i]->getIndex();
if (!visited[v]) {
if (containsUndirectedCycle (v, -1, visited)) {
return true;
}
}
}
return false;
}
bool
2011-12-12 15:29:51 +00:00
BayesNet::containsUndirectedCycle (int v, int p, vector<bool>& visited) const
{
visited[v] = true;
vector<int> adjacencies = getAdjacentNodes (v);
for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i];
if (!visited[w]) {
if (containsUndirectedCycle (w, v, visited)) {
return true;
}
}
else if (visited[w] && w != p) {
return true;
}
}
return false; // no cycle detected in this component
}
vector<int>
BayesNet::getAdjacentNodes (int v) const
{
vector<int> adjacencies;
const BnNodeSet& parents = nodes_[v]->getParents();
const BnNodeSet& childs = nodes_[v]->getChilds();
for (unsigned i = 0; i < parents.size(); i++) {
adjacencies.push_back (parents[i]->getIndex());
}
for (unsigned i = 0; i < childs.size(); i++) {
adjacencies.push_back (childs[i]->getIndex());
}
return adjacencies;
}
2012-03-22 11:33:24 +00:00
Params
BayesNet::reorderParameters (const Params& params, unsigned dsize) const
{
// the interchange format for bayesian networks keeps the probabilities
// in the following order:
// p(a1|b1,c1) p(a2|b1,c1) p(a1|b1,c2) p(a2|b1,c2) p(a1|b2,c1) p(a2|b2,c1)
// p(a1|b2,c2) p(a2|b2,c2).
//
// however, in clpbn we keep the probabilities in this order:
// p(a1|b1,c1) p(a1|b1,c2) p(a1|b2,c1) p(a1|b2,c2) p(a2|b1,c1) p(a2|b1,c2)
// p(a2|b2,c1) p(a2|b2,c2).
unsigned count = 0;
2011-12-12 15:29:51 +00:00
unsigned rowSize = params.size() / dsize;
2012-03-22 11:33:24 +00:00
Params reordered;
while (reordered.size() < params.size()) {
unsigned idx = count;
for (unsigned i = 0; i < rowSize; i++) {
reordered.push_back (params[idx]);
2011-12-12 15:29:51 +00:00
idx += dsize ;
}
count++;
}
return reordered;
}
2012-03-22 11:33:24 +00:00
Params
BayesNet::revertParameterReorder (const Params& params, unsigned dsize) const
{
unsigned count = 0;
2011-12-12 15:29:51 +00:00
unsigned rowSize = params.size() / dsize;
2012-03-22 11:33:24 +00:00
Params reordered;
while (reordered.size() < params.size()) {
unsigned idx = count;
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < dsize; i++) {
reordered.push_back (params[idx]);
idx += rowSize;
}
count ++;
}
return reordered;
}