This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/bp/FactorGraph.cpp

468 lines
11 KiB
C++
Raw Normal View History

#include <set>
2011-12-12 15:29:51 +00:00
#include <vector>
#include <algorithm>
#include <iostream>
#include <fstream>
#include <sstream>
#include "FactorGraph.h"
#include "Factor.h"
#include "BayesNet.h"
2012-04-05 18:38:56 +01:00
#include "BayesBall.h"
2012-03-22 11:33:24 +00:00
#include "Util.h"
bool FactorGraph::orderVariables = false;
2012-03-22 11:33:24 +00:00
FactorGraph::FactorGraph (const FactorGraph& fg)
{
const VarNodes& varNodes = fg.varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
addVarNode (new VarNode (varNodes[i]));
2012-03-22 11:33:24 +00:00
}
2012-04-10 11:51:56 +01:00
const FacNodes& facNodes = fg.facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
2012-04-10 11:51:56 +01:00
FacNode* facNode = new FacNode (facNodes[i]->factor());
addFacNode (facNode);
const VarNodes& neighs = facNodes[i]->neighbors();
2012-03-22 11:33:24 +00:00
for (unsigned j = 0; j < neighs.size(); j++) {
2012-04-10 11:51:56 +01:00
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
2012-03-22 11:33:24 +00:00
}
}
fromBayesNet_ = fg.isFromBayesNetwork();
2012-03-22 11:33:24 +00:00
}
2011-12-12 15:29:51 +00:00
void
FactorGraph::readFromUaiFormat (const char* fileName)
{
2012-04-10 11:51:56 +01:00
std::ifstream is (fileName);
if (!is.is_open()) {
2012-04-10 11:51:56 +01:00
cerr << "error: cannot read from file " << fileName << endl;
abort();
}
ignoreLines (is);
string line;
getline (is, line);
if (line != "MARKOV") {
cerr << "error: the network must be a MARKOV network " << endl;
abort();
}
// read the number of vars
ignoreLines (is);
unsigned nrVars;
is >> nrVars;
// read the range of each var
ignoreLines (is);
Ranges ranges (nrVars);
for (unsigned i = 0; i < nrVars; i++) {
is >> ranges[i];
}
unsigned nrFactors;
unsigned nrArgs;
unsigned vid;
is >> nrFactors;
vector<VarIds> factorVarIds;
vector<Ranges> factorRanges;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrArgs;
factorVarIds.push_back ({ });
factorRanges.push_back ({ });
for (unsigned j = 0; j < nrArgs; j++) {
is >> vid;
if (vid >= ranges.size()) {
cerr << "error: invalid variable identifier `" << vid << "'" << endl;
cerr << "identifiers must be between 0 and " << ranges.size() - 1 ;
cerr << endl;
abort();
}
factorVarIds.back().push_back (vid);
factorRanges.back().push_back (ranges[vid]);
}
}
// read the parameters
unsigned nrParams;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrParams;
if (nrParams != Util::expectedSize (factorRanges[i])) {
cerr << "error: invalid number of parameters for factor nº " << i ;
cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
cerr << ", given: " << nrParams << endl;
2011-12-12 15:29:51 +00:00
abort();
}
Params params (nrParams);
for (unsigned j = 0; j < nrParams; j++) {
is >> params[j];
}
2012-03-22 11:33:24 +00:00
if (Globals::logDomain) {
2011-12-12 15:29:51 +00:00
Util::toLog (params);
}
// TODO order vars is flag on
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
}
is.close();
}
2011-12-12 15:29:51 +00:00
void
FactorGraph::readFromLibDaiFormat (const char* fileName)
{
2012-04-10 11:51:56 +01:00
std::ifstream is (fileName);
2011-12-12 15:29:51 +00:00
if (!is.is_open()) {
2012-04-10 11:51:56 +01:00
cerr << "error: cannot read from file " << fileName << endl;
2011-12-12 15:29:51 +00:00
abort();
}
ignoreLines (is);
unsigned nrFactors;
unsigned nrArgs;
VarId vid;
is >> nrFactors;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
// read the factor arguments
is >> nrArgs;
2012-03-22 11:33:24 +00:00
VarIds vids;
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
2011-12-12 15:29:51 +00:00
is >> vid;
vids.push_back (vid);
}
// read ranges
Ranges ranges (nrArgs);
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> ranges[j];
2012-04-05 23:00:48 +01:00
VarNode* var = getVarNode (vids[j]);
if (var != 0 && ranges[j] != var->range()) {
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
cerr << "more factors with a different range" << endl;
}
2011-12-12 15:29:51 +00:00
}
// read parameters
ignoreLines (is);
2011-12-12 15:29:51 +00:00
unsigned nNonzeros;
is >> nNonzeros;
Params params (Util::expectedSize (ranges), 0);
2011-12-12 15:29:51 +00:00
for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
2011-12-12 15:29:51 +00:00
unsigned index;
is >> index;
ignoreLines (is);
double val;
2011-12-12 15:29:51 +00:00
is >> val;
params[index] = val;
}
reverse (vids.begin(), vids.end());
2012-03-22 11:33:24 +00:00
if (Globals::logDomain) {
2011-12-12 15:29:51 +00:00
Util::toLog (params);
}
// TODO order vars is flag on
addFactor (Factor (vids, ranges, params));
}
2011-12-12 15:29:51 +00:00
is.close();
}
FactorGraph::~FactorGraph (void)
{
for (unsigned i = 0; i < varNodes_.size(); i++) {
delete varNodes_[i];
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < facNodes_.size(); i++) {
delete facNodes_[i];
}
}
2012-04-05 18:38:56 +01:00
void
FactorGraph::addFactor (const Factor& factor)
{
2012-04-10 11:51:56 +01:00
FacNode* fn = new FacNode (factor);
addFacNode (fn);
2012-04-05 18:38:56 +01:00
const VarIds& vids = factor.arguments();
for (unsigned i = 0; i < vids.size(); i++) {
VarMap::const_iterator it = varMap_.find (vids[i]);
if (it != varMap_.end()) {
addEdge (it->second, fn);
} else {
2012-04-05 23:00:48 +01:00
VarNode* vn = new VarNode (vids[i], factor.range (i));
addVarNode (vn);
2012-04-05 18:38:56 +01:00
addEdge (vn, fn);
}
}
}
void
FactorGraph::addVarNode (VarNode* vn)
{
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
2012-04-10 12:53:52 +01:00
varMap_.insert (make_pair (vn->varId(), vn));
}
void
2012-04-10 11:51:56 +01:00
FactorGraph::addFacNode (FacNode* fn)
{
facNodes_.push_back (fn);
fn->setIndex (facNodes_.size() - 1);
}
void
2012-04-10 11:51:56 +01:00
FactorGraph::addEdge (VarNode* vn, FacNode* fn)
{
2011-12-12 15:29:51 +00:00
vn->addNeighbor (fn);
fn->addNeighbor (vn);
}
2011-12-12 15:29:51 +00:00
bool
FactorGraph::isTree (void) const
{
2011-12-12 15:29:51 +00:00
return !containsCycle();
}
2012-04-05 18:38:56 +01:00
DAGraph&
FactorGraph::getStructure (void)
{
assert (fromBayesNet_);
if (structure_.empty()) {
for (unsigned i = 0; i < varNodes_.size(); i++) {
structure_.addNode (new DAGraphNode (varNodes_[i]));
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
const VarIds& vids = facNodes_[i]->factor().arguments();
2012-04-05 18:38:56 +01:00
for (unsigned j = 1; j < vids.size(); j++) {
structure_.addEdge (vids[j], vids[0]);
}
}
}
return structure_;
}
void
2012-04-05 23:00:48 +01:00
FactorGraph::print (void) const
{
for (unsigned i = 0; i < varNodes_.size(); i++) {
2012-04-05 18:38:56 +01:00
cout << "var id = " << varNodes_[i]->varId() << endl;
cout << "label = " << varNodes_[i]->label() << endl;
cout << "range = " << varNodes_[i]->range() << endl;
cout << "evidence = " << varNodes_[i]->getEvidence() << endl;
cout << "factors = " ;
2011-12-12 15:29:51 +00:00
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
}
cout << endl << endl;
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < facNodes_.size(); i++) {
facNodes_[i]->factor().print();
}
}
void
2011-12-12 15:29:51 +00:00
FactorGraph::exportToGraphViz (const char* fileName) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "FactorGraph::exportToDotFile()" << endl;
abort();
}
out << "graph \"" << fileName << "\" {" << endl;
for (unsigned i = 0; i < varNodes_.size(); i++) {
if (varNodes_[i]->hasEvidence()) {
2011-12-12 15:29:51 +00:00
out << '"' << varNodes_[i]->label() << '"' ;
out << " [style=filled, fillcolor=yellow]" << endl;
}
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < facNodes_.size(); i++) {
out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " [label=\"" << facNodes_[i]->getLabel();
out << "\"" << ", shape=box]" << endl;
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < facNodes_.size(); i++) {
2012-04-05 23:00:48 +01:00
const VarNodes& myVars = facNodes_[i]->neighbors();
for (unsigned j = 0; j < myVars.size(); j++) {
2011-12-12 15:29:51 +00:00
out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " -- " ;
2011-12-12 15:29:51 +00:00
out << '"' << myVars[j]->label() << '"' << endl;
}
}
out << "}" << endl;
out.close();
}
void
FactorGraph::exportToUaiFormat (const char* fileName) const
{
ofstream out (fileName);
if (!out.is_open()) {
2012-04-10 11:51:56 +01:00
cerr << "error: cannot open file " << fileName << endl;
abort();
}
out << "MARKOV" << endl;
out << varNodes_.size() << endl;
for (unsigned i = 0; i < varNodes_.size(); i++) {
2012-04-05 18:38:56 +01:00
out << varNodes_[i]->range() << " " ;
}
out << endl;
2011-12-12 15:29:51 +00:00
out << facNodes_.size() << endl;
for (unsigned i = 0; i < facNodes_.size(); i++) {
2012-04-05 23:00:48 +01:00
const VarNodes& factorVars = facNodes_[i]->neighbors();
out << factorVars.size();
for (unsigned j = 0; j < factorVars.size(); j++) {
out << " " << factorVars[j]->getIndex();
}
out << endl;
}
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < facNodes_.size(); i++) {
2012-04-10 12:53:52 +01:00
Params params = facNodes_[i]->factor().params();
2012-03-22 11:33:24 +00:00
if (Globals::logDomain) {
Util::fromLog (params);
}
out << endl << params.size() << endl << " " ;
for (unsigned j = 0; j < params.size(); j++) {
out << params[j] << " " ;
}
out << endl;
}
out.close();
}
2011-12-12 15:29:51 +00:00
void
FactorGraph::exportToLibDaiFormat (const char* fileName) const
{
ofstream out (fileName);
if (!out.is_open()) {
2012-04-10 11:51:56 +01:00
cerr << "error: cannot open file " << fileName << endl;
2011-12-12 15:29:51 +00:00
abort();
}
out << facNodes_.size() << endl << endl;
for (unsigned i = 0; i < facNodes_.size(); i++) {
2012-04-05 23:00:48 +01:00
const VarNodes& factorVars = facNodes_[i]->neighbors();
2011-12-12 15:29:51 +00:00
out << factorVars.size() << endl;
for (int j = factorVars.size() - 1; j >= 0; j--) {
out << factorVars[j]->varId() << " " ;
}
out << endl;
for (unsigned j = 0; j < factorVars.size(); j++) {
2012-04-05 18:38:56 +01:00
out << factorVars[j]->range() << " " ;
2011-12-12 15:29:51 +00:00
}
out << endl;
Params params = facNodes_[i]->factor().params();
2012-03-22 11:33:24 +00:00
if (Globals::logDomain) {
Util::fromLog (params);
}
2011-12-12 15:29:51 +00:00
out << params.size() << endl;
for (unsigned j = 0; j < params.size(); j++) {
out << j << " " << params[j] << endl;
}
out << endl;
}
out.close();
}
void
FactorGraph::ignoreLines (std::ifstream& is) const
{
string ignoreStr;
while (is.peek() == '#' || is.peek() == '\n') {
getline (is, ignoreStr);
}
}
2011-12-12 15:29:51 +00:00
bool
FactorGraph::containsCycle (void) const
{
vector<bool> visitedVars (varNodes_.size(), false);
vector<bool> visitedFactors (facNodes_.size(), false);
for (unsigned i = 0; i < varNodes_.size(); i++) {
int v = varNodes_[i]->getIndex();
if (!visitedVars[v]) {
if (containsCycle (varNodes_[i], 0, visitedVars, visitedFactors)) {
return true;
}
}
}
return false;
}
bool
2012-04-05 18:38:56 +01:00
FactorGraph::containsCycle (
2012-04-05 23:00:48 +01:00
const VarNode* v,
2012-04-10 11:51:56 +01:00
const FacNode* p,
2012-04-05 18:38:56 +01:00
vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
2011-12-12 15:29:51 +00:00
{
visitedVars[v->getIndex()] = true;
2012-04-10 11:51:56 +01:00
const FacNodes& adjacencies = v->neighbors();
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i]->getIndex();
if (!visitedFactors[w]) {
if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) {
return true;
}
}
else if (visitedFactors[w] && adjacencies[i] != p) {
return true;
}
}
return false; // no cycle detected in this component
}
bool
2012-04-05 18:38:56 +01:00
FactorGraph::containsCycle (
2012-04-10 11:51:56 +01:00
const FacNode* v,
2012-04-05 23:00:48 +01:00
const VarNode* p,
2012-04-05 18:38:56 +01:00
vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
2011-12-12 15:29:51 +00:00
{
visitedFactors[v->getIndex()] = true;
2012-04-05 23:00:48 +01:00
const VarNodes& adjacencies = v->neighbors();
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i]->getIndex();
if (!visitedVars[w]) {
if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) {
return true;
}
}
else if (visitedVars[w] && adjacencies[i] != p) {
return true;
}
}
return false; // no cycle detected in this component
}