This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/bp/FactorGraph.cpp

525 lines
12 KiB
C++
Raw Normal View History

#include <set>
2011-12-12 15:29:51 +00:00
#include <vector>
#include <algorithm>
#include <iostream>
#include <fstream>
#include <sstream>
#include "FactorGraph.h"
#include "Factor.h"
#include "BayesNet.h"
2011-12-12 15:29:51 +00:00
FactorGraph::FactorGraph (const BayesNet& bn)
{
2011-12-12 15:29:51 +00:00
const BnNodeSet& nodes = bn.getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
FgVarNode* varNode = new FgVarNode (nodes[i]);
addVariable (varNode);
}
for (unsigned i = 0; i < nodes.size(); i++) {
const BnNodeSet& parents = nodes[i]->getParents();
if (!(nodes[i]->hasEvidence() && parents.size() == 0)) {
VarNodes neighs;
neighs.push_back (varNodes_[nodes[i]->getIndex()]);
for (unsigned j = 0; j < parents.size(); j++) {
neighs.push_back (varNodes_[parents[j]->getIndex()]);
}
FgFacNode* fn = new FgFacNode (
new Factor (neighs, nodes[i]->getDistribution()));
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
}
}
}
setIndexes();
}
void
FactorGraph::readFromUaiFormat (const char* fileName)
{
ifstream is (fileName);
if (!is.is_open()) {
cerr << "error: cannot read from file " + std::string (fileName) << endl;
abort();
}
string line;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
getline (is, line);
if (line != "MARKOV") {
cerr << "error: the network must be a MARKOV network " << endl;
abort();
}
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
2011-12-12 15:29:51 +00:00
unsigned nVars;
is >> nVars;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
vector<int> domainSizes (nVars);
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < nVars; i++) {
unsigned ds;
is >> ds;
domainSizes[i] = ds;
}
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < nVars; i++) {
addVariable (new FgVarNode (i, domainSizes[i]));
}
2011-12-12 15:29:51 +00:00
unsigned nFactors;
is >> nFactors;
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < nFactors; i++) {
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
2011-12-12 15:29:51 +00:00
unsigned nFactorVars;
is >> nFactorVars;
2011-12-12 15:29:51 +00:00
VarNodes neighs;
for (unsigned j = 0; j < nFactorVars; j++) {
unsigned vid;
is >> vid;
2011-12-12 15:29:51 +00:00
FgVarNode* neigh = getFgVarNode (vid);
if (!neigh) {
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
abort();
}
2011-12-12 15:29:51 +00:00
neighs.push_back (neigh);
}
2011-12-12 15:29:51 +00:00
FgFacNode* fn = new FgFacNode (new Factor (neighs));
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
}
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < nFactors; i++) {
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
2011-12-12 15:29:51 +00:00
unsigned nParams;
is >> nParams;
2011-12-12 15:29:51 +00:00
if (facNodes_[i]->getParameters().size() != nParams) {
cerr << "error: invalid number of parameters for factor " ;
cerr << facNodes_[i]->getLabel() ;
cerr << ", expected: " << facNodes_[i]->getParameters().size();
cerr << ", given: " << nParams << endl;
abort();
}
ParamSet params (nParams);
2011-12-12 15:29:51 +00:00
for (unsigned j = 0; j < nParams; j++) {
double param;
is >> param;
params[j] = param;
}
2011-12-12 15:29:51 +00:00
if (NSPACE == NumberSpace::LOGARITHM) {
Util::toLog (params);
}
facNodes_[i]->factor()->setParameters (params);
}
is.close();
2011-12-12 15:29:51 +00:00
setIndexes();
}
2011-12-12 15:29:51 +00:00
void
FactorGraph::readFromLibDaiFormat (const char* fileName)
{
2011-12-12 15:29:51 +00:00
ifstream is (fileName);
if (!is.is_open()) {
cerr << "error: cannot read from file " + std::string (fileName) << endl;
abort();
}
2011-12-12 15:29:51 +00:00
string line;
unsigned nFactors;
while ((is.peek()) == '#') getline (is, line);
is >> nFactors;
if (is.fail()) {
cerr << "error: cannot read the number of factors" << endl;
abort();
}
getline (is, line);
if (is.fail() || line.size() > 0) {
cerr << "error: cannot read the number of factors" << endl;
abort();
}
for (unsigned i = 0; i < nFactors; i++) {
unsigned nVars;
while ((is.peek()) == '#') getline (is, line);
is >> nVars;
VarIdSet vids;
for (unsigned j = 0; j < nVars; j++) {
VarId vid;
while ((is.peek()) == '#') getline (is, line);
is >> vid;
vids.push_back (vid);
}
VarNodes neighs;
unsigned nParams = 1;
for (unsigned j = 0; j < nVars; j++) {
unsigned dsize;
while ((is.peek()) == '#') getline (is, line);
is >> dsize;
FgVarNode* var = getFgVarNode (vids[j]);
if (var == 0) {
var = new FgVarNode (vids[j], dsize);
addVariable (var);
} else {
if (var->nrStates() != dsize) {
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
cerr << "more factors with different domain sizes" << endl;
}
}
2011-12-12 15:29:51 +00:00
neighs.push_back (var);
nParams *= var->nrStates();
}
ParamSet params (nParams, 0);
unsigned nNonzeros;
while ((is.peek()) == '#')
getline (is, line);
is >> nNonzeros;
for (unsigned j = 0; j < nNonzeros; j++) {
unsigned index;
Param val;
while ((is.peek()) == '#') getline (is, line);
is >> index;
while ((is.peek()) == '#') getline (is, line);
is >> val;
params[index] = val;
}
reverse (neighs.begin(), neighs.end());
if (NSPACE == NumberSpace::LOGARITHM) {
Util::toLog (params);
}
FgFacNode* fn = new FgFacNode (new Factor (neighs, params));
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
}
}
2011-12-12 15:29:51 +00:00
is.close();
setIndexes();
}
FactorGraph::~FactorGraph (void)
{
for (unsigned i = 0; i < varNodes_.size(); i++) {
delete varNodes_[i];
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < facNodes_.size(); i++) {
delete facNodes_[i];
}
}
void
2011-12-12 15:29:51 +00:00
FactorGraph::addVariable (FgVarNode* vn)
{
2011-12-12 15:29:51 +00:00
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
indexMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
}
void
2011-12-12 15:29:51 +00:00
FactorGraph::addFactor (FgFacNode* fn)
{
2011-12-12 15:29:51 +00:00
facNodes_.push_back (fn);
fn->setIndex (facNodes_.size() - 1);
}
void
2011-12-12 15:29:51 +00:00
FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
{
2011-12-12 15:29:51 +00:00
vn->addNeighbor (fn);
fn->addNeighbor (vn);
}
void
2011-12-12 15:29:51 +00:00
FactorGraph::addEdge (FgFacNode* fn, FgVarNode* vn)
{
2011-12-12 15:29:51 +00:00
fn->addNeighbor (vn);
vn->addNeighbor (fn);
}
VarNode*
FactorGraph::getVariableNode (VarId vid) const
{
FgVarNode* vn = getFgVarNode (vid);
assert (vn);
return vn;
}
2011-12-12 15:29:51 +00:00
VarNodes
FactorGraph::getVariableNodes (void) const
{
2011-12-12 15:29:51 +00:00
VarNodes vars;
for (unsigned i = 0; i < varNodes_.size(); i++) {
vars.push_back (varNodes_[i]);
}
return vars;
}
2011-12-12 15:29:51 +00:00
bool
FactorGraph::isTree (void) const
{
2011-12-12 15:29:51 +00:00
return !containsCycle();
}
void
FactorGraph::setIndexes (void)
{
for (unsigned i = 0; i < varNodes_.size(); i++) {
varNodes_[i]->setIndex (i);
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < facNodes_.size(); i++) {
facNodes_[i]->setIndex (i);
}
}
void
FactorGraph::freeDistributions (void)
{
set<Distribution*> dists;
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < facNodes_.size(); i++) {
dists.insert (facNodes_[i]->factor()->getDistribution());
}
for (set<Distribution*>::iterator it = dists.begin();
it != dists.end(); it++) {
delete *it;
}
}
void
FactorGraph::printGraphicalModel (void) const
{
for (unsigned i = 0; i < varNodes_.size(); i++) {
2011-12-12 15:29:51 +00:00
cout << "VarId = " << varNodes_[i]->varId() << endl;
cout << "Label = " << varNodes_[i]->label() << endl;
cout << "Nr States = " << varNodes_[i]->nrStates() << endl;
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
cout << "Factors = " ;
2011-12-12 15:29:51 +00:00
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
}
cout << endl << endl;
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < facNodes_.size(); i++) {
facNodes_[i]->factor()->printFactor();
cout << endl;
}
}
void
2011-12-12 15:29:51 +00:00
FactorGraph::exportToGraphViz (const char* fileName) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "FactorGraph::exportToDotFile()" << endl;
abort();
}
out << "graph \"" << fileName << "\" {" << endl;
for (unsigned i = 0; i < varNodes_.size(); i++) {
if (varNodes_[i]->hasEvidence()) {
2011-12-12 15:29:51 +00:00
out << '"' << varNodes_[i]->label() << '"' ;
out << " [style=filled, fillcolor=yellow]" << endl;
}
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < facNodes_.size(); i++) {
out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " [label=\"" << facNodes_[i]->getLabel();
out << "\"" << ", shape=box]" << endl;
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < facNodes_.size(); i++) {
const FgVarSet& myVars = facNodes_[i]->neighbors();
for (unsigned j = 0; j < myVars.size(); j++) {
2011-12-12 15:29:51 +00:00
out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " -- " ;
2011-12-12 15:29:51 +00:00
out << '"' << myVars[j]->label() << '"' << endl;
}
}
out << "}" << endl;
out.close();
}
void
FactorGraph::exportToUaiFormat (const char* fileName) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "FactorGraph::exportToUaiFormat()" << endl;
abort();
}
out << "MARKOV" << endl;
out << varNodes_.size() << endl;
for (unsigned i = 0; i < varNodes_.size(); i++) {
2011-12-12 15:29:51 +00:00
out << varNodes_[i]->nrStates() << " " ;
}
out << endl;
2011-12-12 15:29:51 +00:00
out << facNodes_.size() << endl;
for (unsigned i = 0; i < facNodes_.size(); i++) {
const FgVarSet& factorVars = facNodes_[i]->neighbors();
out << factorVars.size();
for (unsigned j = 0; j < factorVars.size(); j++) {
out << " " << factorVars[j]->getIndex();
}
out << endl;
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < facNodes_.size(); i++) {
const ParamSet& params = facNodes_[i]->getParameters();
out << endl << params.size() << endl << " " ;
for (unsigned j = 0; j < params.size(); j++) {
out << params[j] << " " ;
}
out << endl;
}
out.close();
}
2011-12-12 15:29:51 +00:00
void
FactorGraph::exportToLibDaiFormat (const char* fileName) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "FactorGraph::exportToLibDaiFormat()" << endl;
abort();
}
out << facNodes_.size() << endl << endl;
for (unsigned i = 0; i < facNodes_.size(); i++) {
const FgVarSet& factorVars = facNodes_[i]->neighbors();
out << factorVars.size() << endl;
for (int j = factorVars.size() - 1; j >= 0; j--) {
out << factorVars[j]->varId() << " " ;
}
out << endl;
for (unsigned j = 0; j < factorVars.size(); j++) {
out << factorVars[j]->nrStates() << " " ;
}
out << endl;
const ParamSet& params = facNodes_[i]->factor()->getParameters();
out << params.size() << endl;
for (unsigned j = 0; j < params.size(); j++) {
out << j << " " << params[j] << endl;
}
out << endl;
}
out.close();
}
bool
FactorGraph::containsCycle (void) const
{
vector<bool> visitedVars (varNodes_.size(), false);
vector<bool> visitedFactors (facNodes_.size(), false);
for (unsigned i = 0; i < varNodes_.size(); i++) {
int v = varNodes_[i]->getIndex();
if (!visitedVars[v]) {
if (containsCycle (varNodes_[i], 0, visitedVars, visitedFactors)) {
return true;
}
}
}
return false;
}
bool
FactorGraph::containsCycle (const FgVarNode* v,
const FgFacNode* p,
vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
{
visitedVars[v->getIndex()] = true;
const FgFacSet& adjacencies = v->neighbors();
for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i]->getIndex();
if (!visitedFactors[w]) {
if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) {
return true;
}
}
else if (visitedFactors[w] && adjacencies[i] != p) {
return true;
}
}
return false; // no cycle detected in this component
}
bool
FactorGraph::containsCycle (const FgFacNode* v,
const FgVarNode* p,
vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
{
visitedFactors[v->getIndex()] = true;
const FgVarSet& adjacencies = v->neighbors();
for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i]->getIndex();
if (!visitedVars[w]) {
if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) {
return true;
}
}
else if (visitedVars[w] && adjacencies[i] != p) {
return true;
}
}
return false; // no cycle detected in this component
}