Use a static method to create a FactorGraph

This commit is contained in:
Tiago Gomes 2013-03-14 16:57:34 +00:00
parent 2ef1651c6a
commit bc04d28092
3 changed files with 213 additions and 179 deletions

View File

@ -18,20 +18,7 @@ bool FactorGraph::printFg_ = false;
FactorGraph::FactorGraph (const FactorGraph& fg)
{
const VarNodes& varNodes = fg.varNodes();
for (size_t i = 0; i < varNodes.size(); i++) {
addVarNode (new VarNode (varNodes[i]));
}
const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) {
FacNode* facNode = new FacNode (facNodes[i]->factor());
addFacNode (facNode);
const VarNodes& neighs = facNodes[i]->neighbors();
for (size_t j = 0; j < neighs.size(); j++) {
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
}
}
bayesFactors_ = fg.bayesianFactors();
clone (fg);
}
@ -48,157 +35,6 @@ FactorGraph::~FactorGraph()
void
FactorGraph::readFromUaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
ignoreLines (is);
std::string line;
getline (is, line);
if (line == "BAYES") {
bayesFactors_ = true;
} else if (line == "MARKOV") {
bayesFactors_ = false;
} else {
std::cerr << "Error: the type of network is missing." << std::endl;
exit (EXIT_FAILURE);
}
// read the number of vars
ignoreLines (is);
unsigned nrVars;
is >> nrVars;
// read the range of each var
ignoreLines (is);
Ranges ranges (nrVars);
for (unsigned i = 0; i < nrVars; i++) {
is >> ranges[i];
}
unsigned nrFactors;
unsigned nrArgs;
unsigned vid;
is >> nrFactors;
std::vector<VarIds> allVarIds;
std::vector<Ranges> allRanges;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrArgs;
allVarIds.push_back ({ });
allRanges.push_back ({ });
for (unsigned j = 0; j < nrArgs; j++) {
is >> vid;
if (vid >= ranges.size()) {
std::cerr << "Error: invalid variable identifier `" << vid << "'" ;
std::cerr << ". Identifiers must be between 0 and " ;
std::cerr << ranges.size() - 1 << "." << std::endl;
exit (EXIT_FAILURE);
}
allVarIds.back().push_back (vid);
allRanges.back().push_back (ranges[vid]);
}
}
// read the parameters
unsigned nrParams;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrParams;
if (nrParams != Util::sizeExpected (allRanges[i])) {
std::cerr << "Error: invalid number of parameters for factor nº " ;
std::cerr << i << ", " << Util::sizeExpected (allRanges[i]);
std::cerr << " expected, " << nrParams << " given." << std::endl;
exit (EXIT_FAILURE);
}
Params params (nrParams);
for (unsigned j = 0; j < nrParams; j++) {
is >> params[j];
}
if (Globals::logDomain) {
Util::log (params);
}
Factor f (allVarIds[i], allRanges[i], params);
if (bayesFactors_ && allVarIds[i].size() > 1) {
// In this format the child is the last variable,
// move it to be the first
std::swap (allVarIds[i].front(), allVarIds[i].back());
f.reorderArguments (allVarIds[i]);
}
addFactor (f);
}
is.close();
}
void
FactorGraph::readFromLibDaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
ignoreLines (is);
unsigned nrFactors;
unsigned nrArgs;
VarId vid;
is >> nrFactors;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
// read the factor arguments
is >> nrArgs;
VarIds vids;
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> vid;
vids.push_back (vid);
}
// read ranges
Ranges ranges (nrArgs);
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> ranges[j];
VarNode* var = getVarNode (vids[j]);
if (var && ranges[j] != var->range()) {
std::cerr << "Error: variable `" << vids[j] << "' appears" ;
std::cerr << " in two or more factors with a different range." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
}
// read parameters
ignoreLines (is);
unsigned nNonzeros;
is >> nNonzeros;
Params params (Util::sizeExpected (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
unsigned index;
is >> index;
ignoreLines (is);
double val;
is >> val;
params[index] = val;
}
if (Globals::logDomain) {
Util::log (params);
}
std::reverse (vids.begin(), vids.end());
std::reverse (ranges.begin(), ranges.end());
Factor f (vids, ranges, params);
std::reverse (vids.begin(), vids.end());
f.reorderArguments (vids);
addFactor (f);
}
is.close();
}
void
FactorGraph::addFactor (const Factor& factor)
{
@ -412,13 +248,198 @@ FactorGraph::exportToGraphViz (const char* fileName) const
void
FactorGraph::ignoreLines (std::ifstream& is) const
FactorGraph&
FactorGraph::operator= (const FactorGraph& fg)
{
std::string ignoreStr;
while (is.peek() == '#' || is.peek() == '\n') {
getline (is, ignoreStr);
if (this != &fg) {
for (size_t i = 0; i < varNodes_.size(); i++) {
delete varNodes_[i];
}
varNodes_.clear();
for (size_t i = 0; i < facNodes_.size(); i++) {
delete facNodes_[i];
}
facNodes_.clear();
varMap_.clear();
clone (fg);
}
return *this;
}
FactorGraph
FactorGraph::readFromUaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
FactorGraph fg;
ignoreLines (is);
std::string line;
getline (is, line);
if (line == "BAYES") {
fg.bayesFactors_ = true;
} else if (line == "MARKOV") {
fg.bayesFactors_ = false;
} else {
std::cerr << "Error: the type of network is missing." << std::endl;
exit (EXIT_FAILURE);
}
// read the number of vars
ignoreLines (is);
unsigned nrVars;
is >> nrVars;
// read the range of each var
ignoreLines (is);
Ranges ranges (nrVars);
for (unsigned i = 0; i < nrVars; i++) {
is >> ranges[i];
}
unsigned nrFactors;
unsigned nrArgs;
unsigned vid;
is >> nrFactors;
std::vector<VarIds> allVarIds;
std::vector<Ranges> allRanges;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrArgs;
allVarIds.push_back ({ });
allRanges.push_back ({ });
for (unsigned j = 0; j < nrArgs; j++) {
is >> vid;
if (vid >= ranges.size()) {
std::cerr << "Error: invalid variable identifier `" << vid << "'" ;
std::cerr << ". Identifiers must be between 0 and " ;
std::cerr << ranges.size() - 1 << "." << std::endl;
exit (EXIT_FAILURE);
}
allVarIds.back().push_back (vid);
allRanges.back().push_back (ranges[vid]);
}
}
// read the parameters
unsigned nrParams;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrParams;
if (nrParams != Util::sizeExpected (allRanges[i])) {
std::cerr << "Error: invalid number of parameters for factor nº " ;
std::cerr << i << ", " << Util::sizeExpected (allRanges[i]);
std::cerr << " expected, " << nrParams << " given." << std::endl;
exit (EXIT_FAILURE);
}
Params params (nrParams);
for (unsigned j = 0; j < nrParams; j++) {
is >> params[j];
}
if (Globals::logDomain) {
Util::log (params);
}
Factor f (allVarIds[i], allRanges[i], params);
if (fg.bayesFactors_ && allVarIds[i].size() > 1) {
// In this format the child is the last variable,
// move it to be the first
std::swap (allVarIds[i].front(), allVarIds[i].back());
f.reorderArguments (allVarIds[i]);
}
fg.addFactor (f);
}
is.close();
return fg;
}
FactorGraph
FactorGraph::readFromLibDaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
FactorGraph fg;
ignoreLines (is);
unsigned nrFactors;
unsigned nrArgs;
VarId vid;
is >> nrFactors;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
// read the factor arguments
is >> nrArgs;
VarIds vids;
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> vid;
vids.push_back (vid);
}
// read ranges
Ranges ranges (nrArgs);
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> ranges[j];
VarNode* var = fg.getVarNode (vids[j]);
if (var && ranges[j] != var->range()) {
std::cerr << "Error: variable `" << vids[j] << "' appears" ;
std::cerr << " in two or more factors with a different range." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
}
// read parameters
ignoreLines (is);
unsigned nNonzeros;
is >> nNonzeros;
Params params (Util::sizeExpected (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
unsigned index;
is >> index;
ignoreLines (is);
double val;
is >> val;
params[index] = val;
}
if (Globals::logDomain) {
Util::log (params);
}
std::reverse (vids.begin(), vids.end());
std::reverse (ranges.begin(), ranges.end());
Factor f (vids, ranges, params);
std::reverse (vids.begin(), vids.end());
f.reorderArguments (vids);
fg.addFactor (f);
}
is.close();
return fg;
}
void
FactorGraph::clone (const FactorGraph& fg)
{
const VarNodes& varNodes = fg.varNodes();
for (size_t i = 0; i < varNodes.size(); i++) {
addVarNode (new VarNode (varNodes[i]));
}
const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) {
FacNode* facNode = new FacNode (facNodes[i]->factor());
addFacNode (facNode);
const VarNodes& neighs = facNodes[i]->neighbors();
for (size_t j = 0; j < neighs.size(); j++) {
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
}
}
bayesFactors_ = fg.bayesianFactors();
}
@ -489,5 +510,16 @@ FactorGraph::containsCycle (
return false; // no cycle detected in this component
}
void
FactorGraph::ignoreLines (std::ifstream& is)
{
std::string ignoreStr;
while (is.peek() == '#' || is.peek() == '\n') {
getline (is, ignoreStr);
}
}
} // namespace Horus

View File

@ -86,10 +86,6 @@ class FactorGraph {
VarNode* getVarNode (VarId vid) const;
void readFromUaiFormat (const char*);
void readFromLibDaiFormat (const char*);
void addFactor (const Factor& factor);
void addVarNode (VarNode*);
@ -110,6 +106,12 @@ class FactorGraph {
void exportToGraphViz (const char*) const;
FactorGraph& operator= (const FactorGraph&);
static FactorGraph readFromUaiFormat (const char*);
static FactorGraph readFromLibDaiFormat (const char*);
static bool exportToLibDai() { return exportLd_; }
static bool exportToUai() { return exportUai_; }
@ -137,7 +139,7 @@ class FactorGraph {
private:
typedef std::unordered_map<unsigned, VarNode*> VarMap;
void ignoreLines (std::ifstream&) const;
void clone (const FactorGraph& fg);
bool containsCycle() const;
@ -147,6 +149,8 @@ class FactorGraph {
bool containsCycle (const FacNode*, const VarNode*,
std::vector<bool>&, std::vector<bool>&) const;
static void ignoreLines (std::ifstream&);
VarNodes varNodes_;
FacNodes facNodes_;
VarMap varMap_;
@ -157,8 +161,6 @@ class FactorGraph {
static bool exportUai_;
static bool exportGv_;
static bool printFg_;
DISALLOW_ASSIGN (FactorGraph);
};

View File

@ -100,9 +100,9 @@ readFactorGraph (Horus::FactorGraph& fg, const char* s)
std::string fileName (s);
std::string extension = fileName.substr (fileName.find_last_of ('.') + 1);
if (extension == "uai") {
fg.readFromUaiFormat (fileName.c_str());
fg = Horus::FactorGraph::readFromUaiFormat (fileName.c_str());
} else if (extension == "fg") {
fg.readFromLibDaiFormat (fileName.c_str());
fg = Horus::FactorGraph::readFromLibDaiFormat (fileName.c_str());
} else {
std::cerr << "Error: the probabilistic graphical model must be " ;
std::cerr << "defined either in a UAI or libDAI file." << std::endl;