Add support for bayesian networks defined in an UAI file format

This commit is contained in:
Tiago Gomes 2013-01-07 22:17:05 +00:00
parent 82a4cc508b
commit ba32ebc5f5

View File

@ -55,8 +55,12 @@ FactorGraph::readFromUaiFormat (const char* fileName)
ignoreLines (is); ignoreLines (is);
string line; string line;
getline (is, line); getline (is, line);
if (line != "MARKOV") { if (line == "BAYES") {
cerr << "Error: the network must be a MARKOV network." << endl; bayesFactors_ = true;
} else if (line == "MARKOV") {
bayesFactors_ = false;
} else {
cerr << "Error: the type of network is missing." << endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
// read the number of vars // read the number of vars
@ -73,13 +77,13 @@ FactorGraph::readFromUaiFormat (const char* fileName)
unsigned nrArgs; unsigned nrArgs;
unsigned vid; unsigned vid;
is >> nrFactors; is >> nrFactors;
vector<VarIds> factorVarIds; vector<VarIds> allVarIds;
vector<Ranges> factorRanges; vector<Ranges> allRanges;
for (unsigned i = 0; i < nrFactors; i++) { for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is); ignoreLines (is);
is >> nrArgs; is >> nrArgs;
factorVarIds.push_back ({ }); allVarIds.push_back ({ });
factorRanges.push_back ({ }); allRanges.push_back ({ });
for (unsigned j = 0; j < nrArgs; j++) { for (unsigned j = 0; j < nrArgs; j++) {
is >> vid; is >> vid;
if (vid >= ranges.size()) { if (vid >= ranges.size()) {
@ -88,8 +92,8 @@ FactorGraph::readFromUaiFormat (const char* fileName)
cerr << "." << endl; cerr << "." << endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
factorVarIds.back().push_back (vid); allVarIds.back().push_back (vid);
factorRanges.back().push_back (ranges[vid]); allRanges.back().push_back (ranges[vid]);
} }
} }
// read the parameters // read the parameters
@ -97,9 +101,9 @@ FactorGraph::readFromUaiFormat (const char* fileName)
for (unsigned i = 0; i < nrFactors; i++) { for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is); ignoreLines (is);
is >> nrParams; is >> nrParams;
if (nrParams != Util::sizeExpected (factorRanges[i])) { if (nrParams != Util::sizeExpected (allRanges[i])) {
cerr << "Error: invalid number of parameters for factor nº " << i ; cerr << "Error: invalid number of parameters for factor nº " << i ;
cerr << ", " << Util::sizeExpected (factorRanges[i]); cerr << ", " << Util::sizeExpected (allRanges[i]);
cerr << " expected, " << nrParams << " given." << endl; cerr << " expected, " << nrParams << " given." << endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
@ -110,7 +114,14 @@ FactorGraph::readFromUaiFormat (const char* fileName)
if (Globals::logDomain) { if (Globals::logDomain) {
Util::log (params); Util::log (params);
} }
addFactor (Factor (factorVarIds[i], factorRanges[i], params)); Factor f (allVarIds[i], allRanges[i], params);
if (bayesFactors_ && allVarIds[i].size() > 1) {
// In this format the child is the last variable,
// move it to be the first
std::swap (allVarIds[i].front(), allVarIds[i].back());
f.reorderArguments (allVarIds[i]);
}
addFactor (f);
} }
is.close(); is.close();
} }
@ -318,7 +329,8 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
cerr << "Error: couldn't open file '" << fileName << "'." ; cerr << "Error: couldn't open file '" << fileName << "'." ;
return; return;
} }
out << "MARKOV" << endl; out << (bayesFactors_ ? "BAYES" : "MARKOV") ;
out << endl << endl;
out << varNodes_.size() << endl; out << varNodes_.size() << endl;
VarNodes sortedVns = varNodes_; VarNodes sortedVns = varNodes_;
std::sort (sortedVns.begin(), sortedVns.end(), sortByVarId()); std::sort (sortedVns.begin(), sortedVns.end(), sortByVarId());
@ -328,11 +340,20 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
out << endl << facNodes_.size() << endl; out << endl << facNodes_.size() << endl;
for (size_t i = 0; i < facNodes_.size(); i++) { for (size_t i = 0; i < facNodes_.size(); i++) {
VarIds args = facNodes_[i]->factor().arguments(); VarIds args = facNodes_[i]->factor().arguments();
if (bayesFactors_) {
std::swap (args.front(), args.back());
}
out << args.size() << " " << Util::elementsToString (args) << endl; out << args.size() << " " << Util::elementsToString (args) << endl;
} }
out << endl; out << endl;
for (size_t i = 0; i < facNodes_.size(); i++) { for (size_t i = 0; i < facNodes_.size(); i++) {
Params params = facNodes_[i]->factor().params(); Factor f = facNodes_[i]->factor();
if (bayesFactors_) {
VarIds args = f.arguments();
std::swap (args.front(), args.back());
f.reorderArguments (args);
}
Params params = f.params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::exp (params); Util::exp (params);
} }