Use a static method to create a FactorGraph
This commit is contained in:
parent
2ef1651c6a
commit
bc04d28092
@ -18,20 +18,7 @@ bool FactorGraph::printFg_ = false;
|
||||
|
||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||
{
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
addVarNode (new VarNode (varNodes[i]));
|
||||
}
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
FacNode* facNode = new FacNode (facNodes[i]->factor());
|
||||
addFacNode (facNode);
|
||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||
for (size_t j = 0; j < neighs.size(); j++) {
|
||||
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||
}
|
||||
}
|
||||
bayesFactors_ = fg.bayesianFactors();
|
||||
clone (fg);
|
||||
}
|
||||
|
||||
|
||||
@ -48,157 +35,6 @@ FactorGraph::~FactorGraph()
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
ignoreLines (is);
|
||||
std::string line;
|
||||
getline (is, line);
|
||||
if (line == "BAYES") {
|
||||
bayesFactors_ = true;
|
||||
} else if (line == "MARKOV") {
|
||||
bayesFactors_ = false;
|
||||
} else {
|
||||
std::cerr << "Error: the type of network is missing." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
// read the number of vars
|
||||
ignoreLines (is);
|
||||
unsigned nrVars;
|
||||
is >> nrVars;
|
||||
// read the range of each var
|
||||
ignoreLines (is);
|
||||
Ranges ranges (nrVars);
|
||||
for (unsigned i = 0; i < nrVars; i++) {
|
||||
is >> ranges[i];
|
||||
}
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
unsigned vid;
|
||||
is >> nrFactors;
|
||||
std::vector<VarIds> allVarIds;
|
||||
std::vector<Ranges> allRanges;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrArgs;
|
||||
allVarIds.push_back ({ });
|
||||
allRanges.push_back ({ });
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
is >> vid;
|
||||
if (vid >= ranges.size()) {
|
||||
std::cerr << "Error: invalid variable identifier `" << vid << "'" ;
|
||||
std::cerr << ". Identifiers must be between 0 and " ;
|
||||
std::cerr << ranges.size() - 1 << "." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
allVarIds.back().push_back (vid);
|
||||
allRanges.back().push_back (ranges[vid]);
|
||||
}
|
||||
}
|
||||
// read the parameters
|
||||
unsigned nrParams;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrParams;
|
||||
if (nrParams != Util::sizeExpected (allRanges[i])) {
|
||||
std::cerr << "Error: invalid number of parameters for factor nº " ;
|
||||
std::cerr << i << ", " << Util::sizeExpected (allRanges[i]);
|
||||
std::cerr << " expected, " << nrParams << " given." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
Params params (nrParams);
|
||||
for (unsigned j = 0; j < nrParams; j++) {
|
||||
is >> params[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
Factor f (allVarIds[i], allRanges[i], params);
|
||||
if (bayesFactors_ && allVarIds[i].size() > 1) {
|
||||
// In this format the child is the last variable,
|
||||
// move it to be the first
|
||||
std::swap (allVarIds[i].front(), allVarIds[i].back());
|
||||
f.reorderArguments (allVarIds[i]);
|
||||
}
|
||||
addFactor (f);
|
||||
}
|
||||
is.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
ignoreLines (is);
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
VarId vid;
|
||||
is >> nrFactors;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
// read the factor arguments
|
||||
is >> nrArgs;
|
||||
VarIds vids;
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> vid;
|
||||
vids.push_back (vid);
|
||||
}
|
||||
// read ranges
|
||||
Ranges ranges (nrArgs);
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> ranges[j];
|
||||
VarNode* var = getVarNode (vids[j]);
|
||||
if (var && ranges[j] != var->range()) {
|
||||
std::cerr << "Error: variable `" << vids[j] << "' appears" ;
|
||||
std::cerr << " in two or more factors with a different range." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
// read parameters
|
||||
ignoreLines (is);
|
||||
unsigned nNonzeros;
|
||||
is >> nNonzeros;
|
||||
Params params (Util::sizeExpected (ranges), 0);
|
||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||
ignoreLines (is);
|
||||
unsigned index;
|
||||
is >> index;
|
||||
ignoreLines (is);
|
||||
double val;
|
||||
is >> val;
|
||||
params[index] = val;
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
std::reverse (vids.begin(), vids.end());
|
||||
std::reverse (ranges.begin(), ranges.end());
|
||||
Factor f (vids, ranges, params);
|
||||
std::reverse (vids.begin(), vids.end());
|
||||
f.reorderArguments (vids);
|
||||
addFactor (f);
|
||||
}
|
||||
is.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (const Factor& factor)
|
||||
{
|
||||
@ -412,13 +248,198 @@ FactorGraph::exportToGraphViz (const char* fileName) const
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::ignoreLines (std::ifstream& is) const
|
||||
FactorGraph&
|
||||
FactorGraph::operator= (const FactorGraph& fg)
|
||||
{
|
||||
std::string ignoreStr;
|
||||
while (is.peek() == '#' || is.peek() == '\n') {
|
||||
getline (is, ignoreStr);
|
||||
if (this != &fg) {
|
||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||
delete varNodes_[i];
|
||||
}
|
||||
varNodes_.clear();
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
delete facNodes_[i];
|
||||
}
|
||||
facNodes_.clear();
|
||||
varMap_.clear();
|
||||
clone (fg);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph
|
||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
FactorGraph fg;
|
||||
ignoreLines (is);
|
||||
std::string line;
|
||||
getline (is, line);
|
||||
if (line == "BAYES") {
|
||||
fg.bayesFactors_ = true;
|
||||
} else if (line == "MARKOV") {
|
||||
fg.bayesFactors_ = false;
|
||||
} else {
|
||||
std::cerr << "Error: the type of network is missing." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
// read the number of vars
|
||||
ignoreLines (is);
|
||||
unsigned nrVars;
|
||||
is >> nrVars;
|
||||
// read the range of each var
|
||||
ignoreLines (is);
|
||||
Ranges ranges (nrVars);
|
||||
for (unsigned i = 0; i < nrVars; i++) {
|
||||
is >> ranges[i];
|
||||
}
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
unsigned vid;
|
||||
is >> nrFactors;
|
||||
std::vector<VarIds> allVarIds;
|
||||
std::vector<Ranges> allRanges;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrArgs;
|
||||
allVarIds.push_back ({ });
|
||||
allRanges.push_back ({ });
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
is >> vid;
|
||||
if (vid >= ranges.size()) {
|
||||
std::cerr << "Error: invalid variable identifier `" << vid << "'" ;
|
||||
std::cerr << ". Identifiers must be between 0 and " ;
|
||||
std::cerr << ranges.size() - 1 << "." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
allVarIds.back().push_back (vid);
|
||||
allRanges.back().push_back (ranges[vid]);
|
||||
}
|
||||
}
|
||||
// read the parameters
|
||||
unsigned nrParams;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrParams;
|
||||
if (nrParams != Util::sizeExpected (allRanges[i])) {
|
||||
std::cerr << "Error: invalid number of parameters for factor nº " ;
|
||||
std::cerr << i << ", " << Util::sizeExpected (allRanges[i]);
|
||||
std::cerr << " expected, " << nrParams << " given." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
Params params (nrParams);
|
||||
for (unsigned j = 0; j < nrParams; j++) {
|
||||
is >> params[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
Factor f (allVarIds[i], allRanges[i], params);
|
||||
if (fg.bayesFactors_ && allVarIds[i].size() > 1) {
|
||||
// In this format the child is the last variable,
|
||||
// move it to be the first
|
||||
std::swap (allVarIds[i].front(), allVarIds[i].back());
|
||||
f.reorderArguments (allVarIds[i]);
|
||||
}
|
||||
fg.addFactor (f);
|
||||
}
|
||||
is.close();
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph
|
||||
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
FactorGraph fg;
|
||||
ignoreLines (is);
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
VarId vid;
|
||||
is >> nrFactors;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
// read the factor arguments
|
||||
is >> nrArgs;
|
||||
VarIds vids;
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> vid;
|
||||
vids.push_back (vid);
|
||||
}
|
||||
// read ranges
|
||||
Ranges ranges (nrArgs);
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> ranges[j];
|
||||
VarNode* var = fg.getVarNode (vids[j]);
|
||||
if (var && ranges[j] != var->range()) {
|
||||
std::cerr << "Error: variable `" << vids[j] << "' appears" ;
|
||||
std::cerr << " in two or more factors with a different range." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
// read parameters
|
||||
ignoreLines (is);
|
||||
unsigned nNonzeros;
|
||||
is >> nNonzeros;
|
||||
Params params (Util::sizeExpected (ranges), 0);
|
||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||
ignoreLines (is);
|
||||
unsigned index;
|
||||
is >> index;
|
||||
ignoreLines (is);
|
||||
double val;
|
||||
is >> val;
|
||||
params[index] = val;
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
std::reverse (vids.begin(), vids.end());
|
||||
std::reverse (ranges.begin(), ranges.end());
|
||||
Factor f (vids, ranges, params);
|
||||
std::reverse (vids.begin(), vids.end());
|
||||
f.reorderArguments (vids);
|
||||
fg.addFactor (f);
|
||||
}
|
||||
is.close();
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::clone (const FactorGraph& fg)
|
||||
{
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
addVarNode (new VarNode (varNodes[i]));
|
||||
}
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
FacNode* facNode = new FacNode (facNodes[i]->factor());
|
||||
addFacNode (facNode);
|
||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||
for (size_t j = 0; j < neighs.size(); j++) {
|
||||
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||
}
|
||||
}
|
||||
bayesFactors_ = fg.bayesianFactors();
|
||||
}
|
||||
|
||||
|
||||
@ -489,5 +510,16 @@ FactorGraph::containsCycle (
|
||||
return false; // no cycle detected in this component
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::ignoreLines (std::ifstream& is)
|
||||
{
|
||||
std::string ignoreStr;
|
||||
while (is.peek() == '#' || is.peek() == '\n') {
|
||||
getline (is, ignoreStr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -86,10 +86,6 @@ class FactorGraph {
|
||||
|
||||
VarNode* getVarNode (VarId vid) const;
|
||||
|
||||
void readFromUaiFormat (const char*);
|
||||
|
||||
void readFromLibDaiFormat (const char*);
|
||||
|
||||
void addFactor (const Factor& factor);
|
||||
|
||||
void addVarNode (VarNode*);
|
||||
@ -110,6 +106,12 @@ class FactorGraph {
|
||||
|
||||
void exportToGraphViz (const char*) const;
|
||||
|
||||
FactorGraph& operator= (const FactorGraph&);
|
||||
|
||||
static FactorGraph readFromUaiFormat (const char*);
|
||||
|
||||
static FactorGraph readFromLibDaiFormat (const char*);
|
||||
|
||||
static bool exportToLibDai() { return exportLd_; }
|
||||
|
||||
static bool exportToUai() { return exportUai_; }
|
||||
@ -137,7 +139,7 @@ class FactorGraph {
|
||||
private:
|
||||
typedef std::unordered_map<unsigned, VarNode*> VarMap;
|
||||
|
||||
void ignoreLines (std::ifstream&) const;
|
||||
void clone (const FactorGraph& fg);
|
||||
|
||||
bool containsCycle() const;
|
||||
|
||||
@ -147,6 +149,8 @@ class FactorGraph {
|
||||
bool containsCycle (const FacNode*, const VarNode*,
|
||||
std::vector<bool>&, std::vector<bool>&) const;
|
||||
|
||||
static void ignoreLines (std::ifstream&);
|
||||
|
||||
VarNodes varNodes_;
|
||||
FacNodes facNodes_;
|
||||
VarMap varMap_;
|
||||
@ -157,8 +161,6 @@ class FactorGraph {
|
||||
static bool exportUai_;
|
||||
static bool exportGv_;
|
||||
static bool printFg_;
|
||||
|
||||
DISALLOW_ASSIGN (FactorGraph);
|
||||
};
|
||||
|
||||
|
||||
|
@ -100,9 +100,9 @@ readFactorGraph (Horus::FactorGraph& fg, const char* s)
|
||||
std::string fileName (s);
|
||||
std::string extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||
if (extension == "uai") {
|
||||
fg.readFromUaiFormat (fileName.c_str());
|
||||
fg = Horus::FactorGraph::readFromUaiFormat (fileName.c_str());
|
||||
} else if (extension == "fg") {
|
||||
fg.readFromLibDaiFormat (fileName.c_str());
|
||||
fg = Horus::FactorGraph::readFromLibDaiFormat (fileName.c_str());
|
||||
} else {
|
||||
std::cerr << "Error: the probabilistic graphical model must be " ;
|
||||
std::cerr << "defined either in a UAI or libDAI file." << std::endl;
|
||||
|
Reference in New Issue
Block a user