This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/FactorGraph.cpp

492 lines
12 KiB
C++
Raw Normal View History

2013-02-07 20:09:10 +00:00
#include <cassert>
2012-05-23 14:56:01 +01:00
2013-02-07 20:09:10 +00:00
#include <algorithm>
2012-05-23 14:56:01 +01:00
#include <iostream>
#include "FactorGraph.h"
#include "BayesBall.h"
#include "Util.h"
namespace Horus {
2013-02-07 23:53:13 +00:00
bool FactorGraph::exportLd_ = false;
bool FactorGraph::exportUai_ = false;
bool FactorGraph::exportGv_ = false;
bool FactorGraph::printFg_ = false;
2012-05-23 14:56:01 +01:00
FactorGraph::FactorGraph (const FactorGraph& fg)
{
const VarNodes& varNodes = fg.varNodes();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < varNodes.size(); i++) {
2012-05-23 14:56:01 +01:00
addVarNode (new VarNode (varNodes[i]));
}
const FacNodes& facNodes = fg.facNodes();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < facNodes.size(); i++) {
2012-05-23 14:56:01 +01:00
FacNode* facNode = new FacNode (facNodes[i]->factor());
addFacNode (facNode);
const VarNodes& neighs = facNodes[i]->neighbors();
2012-05-24 22:55:20 +01:00
for (size_t j = 0; j < neighs.size(); j++) {
2012-05-23 14:56:01 +01:00
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
}
}
bayesFactors_ = fg.bayesianFactors();
2012-05-23 14:56:01 +01:00
}
FactorGraph::~FactorGraph()
{
for (size_t i = 0; i < varNodes_.size(); i++) {
delete varNodes_[i];
}
for (size_t i = 0; i < facNodes_.size(); i++) {
delete facNodes_[i];
}
}
2012-05-23 14:56:01 +01:00
void
FactorGraph::readFromUaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
ignoreLines (is);
2013-02-07 13:37:15 +00:00
std::string line;
2012-05-23 14:56:01 +01:00
getline (is, line);
if (line == "BAYES") {
bayesFactors_ = true;
} else if (line == "MARKOV") {
bayesFactors_ = false;
} else {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: the type of network is missing." << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
// read the number of vars
ignoreLines (is);
unsigned nrVars;
is >> nrVars;
// read the range of each var
ignoreLines (is);
Ranges ranges (nrVars);
for (unsigned i = 0; i < nrVars; i++) {
is >> ranges[i];
}
unsigned nrFactors;
unsigned nrArgs;
unsigned vid;
is >> nrFactors;
2013-02-07 13:37:15 +00:00
std::vector<VarIds> allVarIds;
std::vector<Ranges> allRanges;
2012-05-23 14:56:01 +01:00
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrArgs;
allVarIds.push_back ({ });
allRanges.push_back ({ });
2012-05-23 14:56:01 +01:00
for (unsigned j = 0; j < nrArgs; j++) {
is >> vid;
if (vid >= ranges.size()) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: invalid variable identifier `" << vid << "'. " ;
std::cerr << "Identifiers must be between 0 and " ;
std::cerr << ranges.size() - 1 ;
std::cerr << "." << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
allVarIds.back().push_back (vid);
allRanges.back().push_back (ranges[vid]);
2012-05-23 14:56:01 +01:00
}
}
// read the parameters
unsigned nrParams;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrParams;
if (nrParams != Util::sizeExpected (allRanges[i])) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: invalid number of parameters for factor nº " << i ;
std::cerr << ", " << Util::sizeExpected (allRanges[i]);
2013-02-07 13:37:15 +00:00
std::cerr << " expected, " << nrParams << " given." << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
Params params (nrParams);
for (unsigned j = 0; j < nrParams; j++) {
is >> params[j];
}
if (Globals::logDomain) {
Util::log (params);
2012-05-23 14:56:01 +01:00
}
Factor f (allVarIds[i], allRanges[i], params);
if (bayesFactors_ && allVarIds[i].size() > 1) {
// In this format the child is the last variable,
// move it to be the first
std::swap (allVarIds[i].front(), allVarIds[i].back());
f.reorderArguments (allVarIds[i]);
}
addFactor (f);
2012-05-23 14:56:01 +01:00
}
is.close();
}
void
FactorGraph::readFromLibDaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
2012-05-23 14:56:01 +01:00
}
ignoreLines (is);
unsigned nrFactors;
unsigned nrArgs;
VarId vid;
is >> nrFactors;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
// read the factor arguments
is >> nrArgs;
VarIds vids;
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> vid;
vids.push_back (vid);
}
// read ranges
Ranges ranges (nrArgs);
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> ranges[j];
VarNode* var = getVarNode (vids[j]);
2012-12-27 12:54:58 +00:00
if (var && ranges[j] != var->range()) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: variable `" << vids[j] << "' appears in two or " ;
std::cerr << "more factors with a different range." << std::endl;
2012-05-23 14:56:01 +01:00
}
}
// read parameters
ignoreLines (is);
unsigned nNonzeros;
is >> nNonzeros;
Params params (Util::sizeExpected (ranges), 0);
2012-05-23 14:56:01 +01:00
for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
unsigned index;
is >> index;
ignoreLines (is);
double val;
is >> val;
params[index] = val;
}
if (Globals::logDomain) {
Util::log (params);
2012-05-23 14:56:01 +01:00
}
2012-05-28 14:42:20 +01:00
std::reverse (vids.begin(), vids.end());
2012-05-23 14:56:01 +01:00
Factor f (vids, ranges, params);
2012-05-28 14:42:20 +01:00
std::reverse (vids.begin(), vids.end());
2012-12-20 23:19:10 +00:00
f.reorderArguments (vids);
2012-05-23 14:56:01 +01:00
addFactor (f);
}
is.close();
}
void
FactorGraph::addFactor (const Factor& factor)
{
FacNode* fn = new FacNode (factor);
addFacNode (fn);
const VarIds& vids = fn->factor().arguments();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < vids.size(); i++) {
2012-05-23 14:56:01 +01:00
VarMap::const_iterator it = varMap_.find (vids[i]);
if (it != varMap_.end()) {
2012-12-20 23:19:10 +00:00
addEdge (it->second, fn);
2012-05-23 14:56:01 +01:00
} else {
VarNode* vn = new VarNode (vids[i], fn->factor().range (i));
addVarNode (vn);
addEdge (vn, fn);
}
}
}
void
FactorGraph::addVarNode (VarNode* vn)
{
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
2013-02-07 13:37:15 +00:00
varMap_.insert (std::make_pair (vn->varId(), vn));
2012-05-23 14:56:01 +01:00
}
void
FactorGraph::addFacNode (FacNode* fn)
{
facNodes_.push_back (fn);
fn->setIndex (facNodes_.size() - 1);
}
void
FactorGraph::addEdge (VarNode* vn, FacNode* fn)
{
vn->addNeighbor (fn);
fn->addNeighbor (vn);
}
bool
FactorGraph::isTree() const
2012-05-23 14:56:01 +01:00
{
return !containsCycle();
}
2012-06-19 14:32:12 +01:00
BayesBallGraph&
FactorGraph::getStructure()
2012-05-23 14:56:01 +01:00
{
assert (bayesFactors_);
2012-05-23 14:56:01 +01:00
if (structure_.empty()) {
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < varNodes_.size(); i++) {
2012-06-19 14:32:12 +01:00
structure_.addNode (new BBNode (varNodes_[i]));
2012-05-23 14:56:01 +01:00
}
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < facNodes_.size(); i++) {
2012-05-23 14:56:01 +01:00
const VarIds& vids = facNodes_[i]->factor().arguments();
2012-05-24 22:55:20 +01:00
for (size_t j = 1; j < vids.size(); j++) {
2012-05-23 14:56:01 +01:00
structure_.addEdge (vids[j], vids[0]);
}
}
}
return structure_;
}
void
FactorGraph::print() const
2012-05-23 14:56:01 +01:00
{
2013-02-07 13:37:15 +00:00
using std::cout;
using std::endl;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < varNodes_.size(); i++) {
2012-05-23 14:56:01 +01:00
cout << "var id = " << varNodes_[i]->varId() << endl;
cout << "label = " << varNodes_[i]->label() << endl;
cout << "range = " << varNodes_[i]->range() << endl;
cout << "evidence = " << varNodes_[i]->getEvidence() << endl;
cout << "factors = " ;
2012-05-24 22:55:20 +01:00
for (size_t j = 0; j < varNodes_[i]->neighbors().size(); j++) {
2012-05-23 14:56:01 +01:00
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
}
cout << endl << endl;
}
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < facNodes_.size(); i++) {
2012-05-23 14:56:01 +01:00
facNodes_[i]->factor().print();
}
}
void
FactorGraph::exportToLibDai (const char* fileName) const
2012-05-23 14:56:01 +01:00
{
2013-02-07 13:37:15 +00:00
std::ofstream out (fileName);
2012-05-23 14:56:01 +01:00
if (!out.is_open()) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return;
2012-05-23 14:56:01 +01:00
}
2013-02-07 13:37:15 +00:00
out << facNodes_.size() << std::endl << std::endl;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < facNodes_.size(); i++) {
Factor f (facNodes_[i]->factor());
2013-02-07 13:37:15 +00:00
out << f.nrArguments() << std::endl;
out << Util::elementsToString (f.arguments()) << std::endl;
out << Util::elementsToString (f.ranges()) << std::endl;
VarIds args = f.arguments();
std::reverse (args.begin(), args.end());
f.reorderArguments (args);
if (Globals::logDomain) {
Util::exp (f.params());
}
2013-02-07 13:37:15 +00:00
out << f.size() << std::endl;
for (size_t j = 0; j < f.size(); j++) {
2013-02-07 13:37:15 +00:00
out << j << " " << f[j] << std::endl;
2012-05-23 14:56:01 +01:00
}
2013-02-07 13:37:15 +00:00
out << std::endl;
2012-05-23 14:56:01 +01:00
}
out.close();
}
void
FactorGraph::exportToUai (const char* fileName) const
2012-05-23 14:56:01 +01:00
{
2013-02-07 13:37:15 +00:00
std::ofstream out (fileName);
2012-05-23 14:56:01 +01:00
if (!out.is_open()) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return;
2012-05-23 14:56:01 +01:00
}
out << (bayesFactors_ ? "BAYES" : "MARKOV") ;
2013-02-07 13:37:15 +00:00
out << std::endl << std::endl;
out << varNodes_.size() << std::endl;
2012-05-28 17:00:46 +01:00
VarNodes sortedVns = varNodes_;
std::sort (sortedVns.begin(), sortedVns.end(), sortByVarId());
for (size_t i = 0; i < sortedVns.size(); i++) {
out << ((i != 0) ? " " : "") << sortedVns[i]->range();
2012-05-23 14:56:01 +01:00
}
2013-02-07 13:37:15 +00:00
out << std::endl << facNodes_.size() << std::endl;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < facNodes_.size(); i++) {
2012-05-28 17:00:46 +01:00
VarIds args = facNodes_[i]->factor().arguments();
if (bayesFactors_) {
std::swap (args.front(), args.back());
}
out << args.size() << " " << Util::elementsToString (args);
2013-02-07 13:37:15 +00:00
out << std::endl;
2012-05-23 14:56:01 +01:00
}
2013-02-07 13:37:15 +00:00
out << std::endl;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < facNodes_.size(); i++) {
Factor f = facNodes_[i]->factor();
if (bayesFactors_) {
VarIds args = f.arguments();
std::swap (args.front(), args.back());
f.reorderArguments (args);
}
Params params = f.params();
if (Globals::logDomain) {
Util::exp (params);
2012-05-23 14:56:01 +01:00
}
2013-02-07 13:37:15 +00:00
out << params.size() << std::endl << " " ;
out << Util::elementsToString (params);
2013-02-07 13:37:15 +00:00
out << std::endl << std::endl;
2012-05-23 14:56:01 +01:00
}
out.close();
}
void
FactorGraph::exportToGraphViz (const char* fileName) const
2012-05-23 14:56:01 +01:00
{
2013-02-07 13:37:15 +00:00
std::ofstream out (fileName);
2012-05-23 14:56:01 +01:00
if (!out.is_open()) {
2013-02-07 13:37:15 +00:00
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return;
2012-05-23 14:56:01 +01:00
}
2013-02-07 13:37:15 +00:00
out << "graph \"" << fileName << "\" {" << std::endl;
for (size_t i = 0; i < varNodes_.size(); i++) {
if (varNodes_[i]->hasEvidence()) {
out << '"' << varNodes_[i]->label() << '"' ;
2013-02-07 13:37:15 +00:00
out << " [style=filled, fillcolor=yellow]" << std::endl;
2012-05-23 14:56:01 +01:00
}
}
for (size_t i = 0; i < facNodes_.size(); i++) {
out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " [label=\"" << facNodes_[i]->getLabel();
2013-02-07 13:37:15 +00:00
out << "\"" << ", shape=box]" << std::endl;
}
for (size_t i = 0; i < facNodes_.size(); i++) {
const VarNodes& myVars = facNodes_[i]->neighbors();
for (size_t j = 0; j < myVars.size(); j++) {
out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " -- " ;
2013-02-07 13:37:15 +00:00
out << '"' << myVars[j]->label() << '"' << std::endl;
2012-05-23 14:56:01 +01:00
}
}
2013-02-07 13:37:15 +00:00
out << "}" << std::endl;
2012-05-23 14:56:01 +01:00
out.close();
}
void
FactorGraph::ignoreLines (std::ifstream& is) const
{
2013-02-07 13:37:15 +00:00
std::string ignoreStr;
2012-05-23 14:56:01 +01:00
while (is.peek() == '#' || is.peek() == '\n') {
getline (is, ignoreStr);
}
}
bool
FactorGraph::containsCycle() const
2012-05-23 14:56:01 +01:00
{
2013-02-07 13:37:15 +00:00
std::vector<bool> visitedVars (varNodes_.size(), false);
std::vector<bool> visitedFactors (facNodes_.size(), false);
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < varNodes_.size(); i++) {
2012-05-23 14:56:01 +01:00
int v = varNodes_[i]->getIndex();
if (!visitedVars[v]) {
if (containsCycle (varNodes_[i], 0, visitedVars, visitedFactors)) {
return true;
}
}
}
return false;
}
bool
FactorGraph::containsCycle (
const VarNode* v,
const FacNode* p,
2013-02-07 13:37:15 +00:00
std::vector<bool>& visitedVars,
std::vector<bool>& visitedFactors) const
2012-05-23 14:56:01 +01:00
{
visitedVars[v->getIndex()] = true;
const FacNodes& adjacencies = v->neighbors();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < adjacencies.size(); i++) {
2012-05-23 14:56:01 +01:00
int w = adjacencies[i]->getIndex();
if (!visitedFactors[w]) {
if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) {
return true;
}
}
else if (visitedFactors[w] && adjacencies[i] != p) {
return true;
}
}
return false; // no cycle detected in this component
}
bool
FactorGraph::containsCycle (
const FacNode* v,
const VarNode* p,
2013-02-07 13:37:15 +00:00
std::vector<bool>& visitedVars,
std::vector<bool>& visitedFactors) const
2012-05-23 14:56:01 +01:00
{
visitedFactors[v->getIndex()] = true;
const VarNodes& adjacencies = v->neighbors();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < adjacencies.size(); i++) {
2012-05-23 14:56:01 +01:00
int w = adjacencies[i]->getIndex();
if (!visitedVars[w]) {
if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) {
return true;
}
}
else if (visitedVars[w] && adjacencies[i] != p) {
return true;
}
}
return false; // no cycle detected in this component
}
} // namespace Horus
2013-02-07 23:53:13 +00:00