diff --git a/packages/CLPBN/horus/FactorGraph.cpp b/packages/CLPBN/horus/FactorGraph.cpp index 85925abf6..3d781628a 100644 --- a/packages/CLPBN/horus/FactorGraph.cpp +++ b/packages/CLPBN/horus/FactorGraph.cpp @@ -55,8 +55,12 @@ FactorGraph::readFromUaiFormat (const char* fileName) ignoreLines (is); string line; getline (is, line); - if (line != "MARKOV") { - cerr << "Error: the network must be a MARKOV network." << endl; + if (line == "BAYES") { + bayesFactors_ = true; + } else if (line == "MARKOV") { + bayesFactors_ = false; + } else { + cerr << "Error: the type of network is missing." << endl; exit (EXIT_FAILURE); } // read the number of vars @@ -73,13 +77,13 @@ FactorGraph::readFromUaiFormat (const char* fileName) unsigned nrArgs; unsigned vid; is >> nrFactors; - vector factorVarIds; - vector factorRanges; + vector allVarIds; + vector allRanges; for (unsigned i = 0; i < nrFactors; i++) { ignoreLines (is); is >> nrArgs; - factorVarIds.push_back ({ }); - factorRanges.push_back ({ }); + allVarIds.push_back ({ }); + allRanges.push_back ({ }); for (unsigned j = 0; j < nrArgs; j++) { is >> vid; if (vid >= ranges.size()) { @@ -88,8 +92,8 @@ FactorGraph::readFromUaiFormat (const char* fileName) cerr << "." << endl; exit (EXIT_FAILURE); } - factorVarIds.back().push_back (vid); - factorRanges.back().push_back (ranges[vid]); + allVarIds.back().push_back (vid); + allRanges.back().push_back (ranges[vid]); } } // read the parameters @@ -97,9 +101,9 @@ FactorGraph::readFromUaiFormat (const char* fileName) for (unsigned i = 0; i < nrFactors; i++) { ignoreLines (is); is >> nrParams; - if (nrParams != Util::sizeExpected (factorRanges[i])) { + if (nrParams != Util::sizeExpected (allRanges[i])) { cerr << "Error: invalid number of parameters for factor nÂș " << i ; - cerr << ", " << Util::sizeExpected (factorRanges[i]); + cerr << ", " << Util::sizeExpected (allRanges[i]); cerr << " expected, " << nrParams << " given." << endl; exit (EXIT_FAILURE); } @@ -110,7 +114,14 @@ FactorGraph::readFromUaiFormat (const char* fileName) if (Globals::logDomain) { Util::log (params); } - addFactor (Factor (factorVarIds[i], factorRanges[i], params)); + Factor f (allVarIds[i], allRanges[i], params); + if (bayesFactors_ && allVarIds[i].size() > 1) { + // In this format the child is the last variable, + // move it to be the first + std::swap (allVarIds[i].front(), allVarIds[i].back()); + f.reorderArguments (allVarIds[i]); + } + addFactor (f); } is.close(); } @@ -318,7 +329,8 @@ FactorGraph::exportToUaiFormat (const char* fileName) const cerr << "Error: couldn't open file '" << fileName << "'." ; return; } - out << "MARKOV" << endl; + out << (bayesFactors_ ? "BAYES" : "MARKOV") ; + out << endl << endl; out << varNodes_.size() << endl; VarNodes sortedVns = varNodes_; std::sort (sortedVns.begin(), sortedVns.end(), sortByVarId()); @@ -328,11 +340,20 @@ FactorGraph::exportToUaiFormat (const char* fileName) const out << endl << facNodes_.size() << endl; for (size_t i = 0; i < facNodes_.size(); i++) { VarIds args = facNodes_[i]->factor().arguments(); + if (bayesFactors_) { + std::swap (args.front(), args.back()); + } out << args.size() << " " << Util::elementsToString (args) << endl; } out << endl; for (size_t i = 0; i < facNodes_.size(); i++) { - Params params = facNodes_[i]->factor().params(); + Factor f = facNodes_[i]->factor(); + if (bayesFactors_) { + VarIds args = f.arguments(); + std::swap (args.front(), args.back()); + f.reorderArguments (args); + } + Params params = f.params(); if (Globals::logDomain) { Util::exp (params); }