Use a static method to create a FactorGraph
This commit is contained in:
parent
2ef1651c6a
commit
bc04d28092
|
@ -18,20 +18,7 @@ bool FactorGraph::printFg_ = false;
|
||||||
|
|
||||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||||
{
|
{
|
||||||
const VarNodes& varNodes = fg.varNodes();
|
clone (fg);
|
||||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
|
||||||
addVarNode (new VarNode (varNodes[i]));
|
|
||||||
}
|
|
||||||
const FacNodes& facNodes = fg.facNodes();
|
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
|
||||||
FacNode* facNode = new FacNode (facNodes[i]->factor());
|
|
||||||
addFacNode (facNode);
|
|
||||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
|
||||||
for (size_t j = 0; j < neighs.size(); j++) {
|
|
||||||
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bayesFactors_ = fg.bayesianFactors();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,157 +35,6 @@ FactorGraph::~FactorGraph()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
|
||||||
{
|
|
||||||
std::ifstream is (fileName);
|
|
||||||
if (!is.is_open()) {
|
|
||||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
|
||||||
std::cerr << std::endl;
|
|
||||||
exit (EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
ignoreLines (is);
|
|
||||||
std::string line;
|
|
||||||
getline (is, line);
|
|
||||||
if (line == "BAYES") {
|
|
||||||
bayesFactors_ = true;
|
|
||||||
} else if (line == "MARKOV") {
|
|
||||||
bayesFactors_ = false;
|
|
||||||
} else {
|
|
||||||
std::cerr << "Error: the type of network is missing." << std::endl;
|
|
||||||
exit (EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
// read the number of vars
|
|
||||||
ignoreLines (is);
|
|
||||||
unsigned nrVars;
|
|
||||||
is >> nrVars;
|
|
||||||
// read the range of each var
|
|
||||||
ignoreLines (is);
|
|
||||||
Ranges ranges (nrVars);
|
|
||||||
for (unsigned i = 0; i < nrVars; i++) {
|
|
||||||
is >> ranges[i];
|
|
||||||
}
|
|
||||||
unsigned nrFactors;
|
|
||||||
unsigned nrArgs;
|
|
||||||
unsigned vid;
|
|
||||||
is >> nrFactors;
|
|
||||||
std::vector<VarIds> allVarIds;
|
|
||||||
std::vector<Ranges> allRanges;
|
|
||||||
for (unsigned i = 0; i < nrFactors; i++) {
|
|
||||||
ignoreLines (is);
|
|
||||||
is >> nrArgs;
|
|
||||||
allVarIds.push_back ({ });
|
|
||||||
allRanges.push_back ({ });
|
|
||||||
for (unsigned j = 0; j < nrArgs; j++) {
|
|
||||||
is >> vid;
|
|
||||||
if (vid >= ranges.size()) {
|
|
||||||
std::cerr << "Error: invalid variable identifier `" << vid << "'" ;
|
|
||||||
std::cerr << ". Identifiers must be between 0 and " ;
|
|
||||||
std::cerr << ranges.size() - 1 << "." << std::endl;
|
|
||||||
exit (EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
allVarIds.back().push_back (vid);
|
|
||||||
allRanges.back().push_back (ranges[vid]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// read the parameters
|
|
||||||
unsigned nrParams;
|
|
||||||
for (unsigned i = 0; i < nrFactors; i++) {
|
|
||||||
ignoreLines (is);
|
|
||||||
is >> nrParams;
|
|
||||||
if (nrParams != Util::sizeExpected (allRanges[i])) {
|
|
||||||
std::cerr << "Error: invalid number of parameters for factor nº " ;
|
|
||||||
std::cerr << i << ", " << Util::sizeExpected (allRanges[i]);
|
|
||||||
std::cerr << " expected, " << nrParams << " given." << std::endl;
|
|
||||||
exit (EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
Params params (nrParams);
|
|
||||||
for (unsigned j = 0; j < nrParams; j++) {
|
|
||||||
is >> params[j];
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
Util::log (params);
|
|
||||||
}
|
|
||||||
Factor f (allVarIds[i], allRanges[i], params);
|
|
||||||
if (bayesFactors_ && allVarIds[i].size() > 1) {
|
|
||||||
// In this format the child is the last variable,
|
|
||||||
// move it to be the first
|
|
||||||
std::swap (allVarIds[i].front(), allVarIds[i].back());
|
|
||||||
f.reorderArguments (allVarIds[i]);
|
|
||||||
}
|
|
||||||
addFactor (f);
|
|
||||||
}
|
|
||||||
is.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
|
||||||
{
|
|
||||||
std::ifstream is (fileName);
|
|
||||||
if (!is.is_open()) {
|
|
||||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
|
||||||
std::cerr << std::endl;
|
|
||||||
exit (EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
ignoreLines (is);
|
|
||||||
unsigned nrFactors;
|
|
||||||
unsigned nrArgs;
|
|
||||||
VarId vid;
|
|
||||||
is >> nrFactors;
|
|
||||||
for (unsigned i = 0; i < nrFactors; i++) {
|
|
||||||
ignoreLines (is);
|
|
||||||
// read the factor arguments
|
|
||||||
is >> nrArgs;
|
|
||||||
VarIds vids;
|
|
||||||
for (unsigned j = 0; j < nrArgs; j++) {
|
|
||||||
ignoreLines (is);
|
|
||||||
is >> vid;
|
|
||||||
vids.push_back (vid);
|
|
||||||
}
|
|
||||||
// read ranges
|
|
||||||
Ranges ranges (nrArgs);
|
|
||||||
for (unsigned j = 0; j < nrArgs; j++) {
|
|
||||||
ignoreLines (is);
|
|
||||||
is >> ranges[j];
|
|
||||||
VarNode* var = getVarNode (vids[j]);
|
|
||||||
if (var && ranges[j] != var->range()) {
|
|
||||||
std::cerr << "Error: variable `" << vids[j] << "' appears" ;
|
|
||||||
std::cerr << " in two or more factors with a different range." ;
|
|
||||||
std::cerr << std::endl;
|
|
||||||
exit (EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// read parameters
|
|
||||||
ignoreLines (is);
|
|
||||||
unsigned nNonzeros;
|
|
||||||
is >> nNonzeros;
|
|
||||||
Params params (Util::sizeExpected (ranges), 0);
|
|
||||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
|
||||||
ignoreLines (is);
|
|
||||||
unsigned index;
|
|
||||||
is >> index;
|
|
||||||
ignoreLines (is);
|
|
||||||
double val;
|
|
||||||
is >> val;
|
|
||||||
params[index] = val;
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
Util::log (params);
|
|
||||||
}
|
|
||||||
std::reverse (vids.begin(), vids.end());
|
|
||||||
std::reverse (ranges.begin(), ranges.end());
|
|
||||||
Factor f (vids, ranges, params);
|
|
||||||
std::reverse (vids.begin(), vids.end());
|
|
||||||
f.reorderArguments (vids);
|
|
||||||
addFactor (f);
|
|
||||||
}
|
|
||||||
is.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addFactor (const Factor& factor)
|
FactorGraph::addFactor (const Factor& factor)
|
||||||
{
|
{
|
||||||
|
@ -412,13 +248,198 @@ FactorGraph::exportToGraphViz (const char* fileName) const
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
FactorGraph&
|
||||||
FactorGraph::ignoreLines (std::ifstream& is) const
|
FactorGraph::operator= (const FactorGraph& fg)
|
||||||
{
|
{
|
||||||
std::string ignoreStr;
|
if (this != &fg) {
|
||||||
while (is.peek() == '#' || is.peek() == '\n') {
|
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||||
getline (is, ignoreStr);
|
delete varNodes_[i];
|
||||||
}
|
}
|
||||||
|
varNodes_.clear();
|
||||||
|
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||||
|
delete facNodes_[i];
|
||||||
|
}
|
||||||
|
facNodes_.clear();
|
||||||
|
varMap_.clear();
|
||||||
|
clone (fg);
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph
|
||||||
|
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||||
|
{
|
||||||
|
std::ifstream is (fileName);
|
||||||
|
if (!is.is_open()) {
|
||||||
|
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||||
|
std::cerr << std::endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
FactorGraph fg;
|
||||||
|
ignoreLines (is);
|
||||||
|
std::string line;
|
||||||
|
getline (is, line);
|
||||||
|
if (line == "BAYES") {
|
||||||
|
fg.bayesFactors_ = true;
|
||||||
|
} else if (line == "MARKOV") {
|
||||||
|
fg.bayesFactors_ = false;
|
||||||
|
} else {
|
||||||
|
std::cerr << "Error: the type of network is missing." << std::endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
// read the number of vars
|
||||||
|
ignoreLines (is);
|
||||||
|
unsigned nrVars;
|
||||||
|
is >> nrVars;
|
||||||
|
// read the range of each var
|
||||||
|
ignoreLines (is);
|
||||||
|
Ranges ranges (nrVars);
|
||||||
|
for (unsigned i = 0; i < nrVars; i++) {
|
||||||
|
is >> ranges[i];
|
||||||
|
}
|
||||||
|
unsigned nrFactors;
|
||||||
|
unsigned nrArgs;
|
||||||
|
unsigned vid;
|
||||||
|
is >> nrFactors;
|
||||||
|
std::vector<VarIds> allVarIds;
|
||||||
|
std::vector<Ranges> allRanges;
|
||||||
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
|
ignoreLines (is);
|
||||||
|
is >> nrArgs;
|
||||||
|
allVarIds.push_back ({ });
|
||||||
|
allRanges.push_back ({ });
|
||||||
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
|
is >> vid;
|
||||||
|
if (vid >= ranges.size()) {
|
||||||
|
std::cerr << "Error: invalid variable identifier `" << vid << "'" ;
|
||||||
|
std::cerr << ". Identifiers must be between 0 and " ;
|
||||||
|
std::cerr << ranges.size() - 1 << "." << std::endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
allVarIds.back().push_back (vid);
|
||||||
|
allRanges.back().push_back (ranges[vid]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// read the parameters
|
||||||
|
unsigned nrParams;
|
||||||
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
|
ignoreLines (is);
|
||||||
|
is >> nrParams;
|
||||||
|
if (nrParams != Util::sizeExpected (allRanges[i])) {
|
||||||
|
std::cerr << "Error: invalid number of parameters for factor nº " ;
|
||||||
|
std::cerr << i << ", " << Util::sizeExpected (allRanges[i]);
|
||||||
|
std::cerr << " expected, " << nrParams << " given." << std::endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
Params params (nrParams);
|
||||||
|
for (unsigned j = 0; j < nrParams; j++) {
|
||||||
|
is >> params[j];
|
||||||
|
}
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::log (params);
|
||||||
|
}
|
||||||
|
Factor f (allVarIds[i], allRanges[i], params);
|
||||||
|
if (fg.bayesFactors_ && allVarIds[i].size() > 1) {
|
||||||
|
// In this format the child is the last variable,
|
||||||
|
// move it to be the first
|
||||||
|
std::swap (allVarIds[i].front(), allVarIds[i].back());
|
||||||
|
f.reorderArguments (allVarIds[i]);
|
||||||
|
}
|
||||||
|
fg.addFactor (f);
|
||||||
|
}
|
||||||
|
is.close();
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph
|
||||||
|
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||||
|
{
|
||||||
|
std::ifstream is (fileName);
|
||||||
|
if (!is.is_open()) {
|
||||||
|
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||||
|
std::cerr << std::endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
FactorGraph fg;
|
||||||
|
ignoreLines (is);
|
||||||
|
unsigned nrFactors;
|
||||||
|
unsigned nrArgs;
|
||||||
|
VarId vid;
|
||||||
|
is >> nrFactors;
|
||||||
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
|
ignoreLines (is);
|
||||||
|
// read the factor arguments
|
||||||
|
is >> nrArgs;
|
||||||
|
VarIds vids;
|
||||||
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
|
ignoreLines (is);
|
||||||
|
is >> vid;
|
||||||
|
vids.push_back (vid);
|
||||||
|
}
|
||||||
|
// read ranges
|
||||||
|
Ranges ranges (nrArgs);
|
||||||
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
|
ignoreLines (is);
|
||||||
|
is >> ranges[j];
|
||||||
|
VarNode* var = fg.getVarNode (vids[j]);
|
||||||
|
if (var && ranges[j] != var->range()) {
|
||||||
|
std::cerr << "Error: variable `" << vids[j] << "' appears" ;
|
||||||
|
std::cerr << " in two or more factors with a different range." ;
|
||||||
|
std::cerr << std::endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// read parameters
|
||||||
|
ignoreLines (is);
|
||||||
|
unsigned nNonzeros;
|
||||||
|
is >> nNonzeros;
|
||||||
|
Params params (Util::sizeExpected (ranges), 0);
|
||||||
|
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||||
|
ignoreLines (is);
|
||||||
|
unsigned index;
|
||||||
|
is >> index;
|
||||||
|
ignoreLines (is);
|
||||||
|
double val;
|
||||||
|
is >> val;
|
||||||
|
params[index] = val;
|
||||||
|
}
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::log (params);
|
||||||
|
}
|
||||||
|
std::reverse (vids.begin(), vids.end());
|
||||||
|
std::reverse (ranges.begin(), ranges.end());
|
||||||
|
Factor f (vids, ranges, params);
|
||||||
|
std::reverse (vids.begin(), vids.end());
|
||||||
|
f.reorderArguments (vids);
|
||||||
|
fg.addFactor (f);
|
||||||
|
}
|
||||||
|
is.close();
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::clone (const FactorGraph& fg)
|
||||||
|
{
|
||||||
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
|
addVarNode (new VarNode (varNodes[i]));
|
||||||
|
}
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
FacNode* facNode = new FacNode (facNodes[i]->factor());
|
||||||
|
addFacNode (facNode);
|
||||||
|
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||||
|
for (size_t j = 0; j < neighs.size(); j++) {
|
||||||
|
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bayesFactors_ = fg.bayesianFactors();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -489,5 +510,16 @@ FactorGraph::containsCycle (
|
||||||
return false; // no cycle detected in this component
|
return false; // no cycle detected in this component
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::ignoreLines (std::ifstream& is)
|
||||||
|
{
|
||||||
|
std::string ignoreStr;
|
||||||
|
while (is.peek() == '#' || is.peek() == '\n') {
|
||||||
|
getline (is, ignoreStr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace Horus
|
} // namespace Horus
|
||||||
|
|
||||||
|
|
|
@ -86,10 +86,6 @@ class FactorGraph {
|
||||||
|
|
||||||
VarNode* getVarNode (VarId vid) const;
|
VarNode* getVarNode (VarId vid) const;
|
||||||
|
|
||||||
void readFromUaiFormat (const char*);
|
|
||||||
|
|
||||||
void readFromLibDaiFormat (const char*);
|
|
||||||
|
|
||||||
void addFactor (const Factor& factor);
|
void addFactor (const Factor& factor);
|
||||||
|
|
||||||
void addVarNode (VarNode*);
|
void addVarNode (VarNode*);
|
||||||
|
@ -110,6 +106,12 @@ class FactorGraph {
|
||||||
|
|
||||||
void exportToGraphViz (const char*) const;
|
void exportToGraphViz (const char*) const;
|
||||||
|
|
||||||
|
FactorGraph& operator= (const FactorGraph&);
|
||||||
|
|
||||||
|
static FactorGraph readFromUaiFormat (const char*);
|
||||||
|
|
||||||
|
static FactorGraph readFromLibDaiFormat (const char*);
|
||||||
|
|
||||||
static bool exportToLibDai() { return exportLd_; }
|
static bool exportToLibDai() { return exportLd_; }
|
||||||
|
|
||||||
static bool exportToUai() { return exportUai_; }
|
static bool exportToUai() { return exportUai_; }
|
||||||
|
@ -137,7 +139,7 @@ class FactorGraph {
|
||||||
private:
|
private:
|
||||||
typedef std::unordered_map<unsigned, VarNode*> VarMap;
|
typedef std::unordered_map<unsigned, VarNode*> VarMap;
|
||||||
|
|
||||||
void ignoreLines (std::ifstream&) const;
|
void clone (const FactorGraph& fg);
|
||||||
|
|
||||||
bool containsCycle() const;
|
bool containsCycle() const;
|
||||||
|
|
||||||
|
@ -147,6 +149,8 @@ class FactorGraph {
|
||||||
bool containsCycle (const FacNode*, const VarNode*,
|
bool containsCycle (const FacNode*, const VarNode*,
|
||||||
std::vector<bool>&, std::vector<bool>&) const;
|
std::vector<bool>&, std::vector<bool>&) const;
|
||||||
|
|
||||||
|
static void ignoreLines (std::ifstream&);
|
||||||
|
|
||||||
VarNodes varNodes_;
|
VarNodes varNodes_;
|
||||||
FacNodes facNodes_;
|
FacNodes facNodes_;
|
||||||
VarMap varMap_;
|
VarMap varMap_;
|
||||||
|
@ -157,8 +161,6 @@ class FactorGraph {
|
||||||
static bool exportUai_;
|
static bool exportUai_;
|
||||||
static bool exportGv_;
|
static bool exportGv_;
|
||||||
static bool printFg_;
|
static bool printFg_;
|
||||||
|
|
||||||
DISALLOW_ASSIGN (FactorGraph);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -100,9 +100,9 @@ readFactorGraph (Horus::FactorGraph& fg, const char* s)
|
||||||
std::string fileName (s);
|
std::string fileName (s);
|
||||||
std::string extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
std::string extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||||
if (extension == "uai") {
|
if (extension == "uai") {
|
||||||
fg.readFromUaiFormat (fileName.c_str());
|
fg = Horus::FactorGraph::readFromUaiFormat (fileName.c_str());
|
||||||
} else if (extension == "fg") {
|
} else if (extension == "fg") {
|
||||||
fg.readFromLibDaiFormat (fileName.c_str());
|
fg = Horus::FactorGraph::readFromLibDaiFormat (fileName.c_str());
|
||||||
} else {
|
} else {
|
||||||
std::cerr << "Error: the probabilistic graphical model must be " ;
|
std::cerr << "Error: the probabilistic graphical model must be " ;
|
||||||
std::cerr << "defined either in a UAI or libDAI file." << std::endl;
|
std::cerr << "defined either in a UAI or libDAI file." << std::endl;
|
||||||
|
|
Reference in New Issue