Improvements
Factor nodes now contain a factor object instead of a pointer. Refactor the way .fg and .uai formats are readed.
This commit is contained in:
@@ -18,21 +18,20 @@ bool FactorGraph::orderFactorVariables = false;
|
||||
|
||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||
{
|
||||
const VarNodes& vars = fg.varNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
VarNode* varNode = new VarNode (vars[i]);
|
||||
addVariable (varNode);
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
addVarNode (new VarNode (varNodes[i]));
|
||||
}
|
||||
|
||||
const FactorNodes& facs = fg.factorNodes();
|
||||
for (unsigned i = 0; i < facs.size(); i++) {
|
||||
FactorNode* facNode = new FactorNode (facs[i]);
|
||||
addFactor (facNode);
|
||||
const VarNodes& neighs = facs[i]->neighbors();
|
||||
const FactorNodes& facNodes = fg.factorNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
FactorNode* facNode = new FactorNode (facNodes[i]->factor());
|
||||
addFactorNode (facNode);
|
||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
|
||||
}
|
||||
}
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
@@ -40,82 +39,70 @@ FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||
void
|
||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
{
|
||||
ifstream is (fileName);
|
||||
ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
ignoreLines (is);
|
||||
string line;
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
getline (is, line);
|
||||
if (line != "MARKOV") {
|
||||
cerr << "error: the network must be a MARKOV network " << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
unsigned nVars;
|
||||
is >> nVars;
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
vector<int> domainSizes (nVars);
|
||||
for (unsigned i = 0; i < nVars; i++) {
|
||||
unsigned ds;
|
||||
is >> ds;
|
||||
domainSizes[i] = ds;
|
||||
// read the number of vars
|
||||
ignoreLines (is);
|
||||
unsigned nrVars;
|
||||
is >> nrVars;
|
||||
// read the range of each var
|
||||
ignoreLines (is);
|
||||
Ranges ranges (nrVars);
|
||||
for (unsigned i = 0; i < nrVars; i++) {
|
||||
is >> ranges[i];
|
||||
}
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
for (unsigned i = 0; i < nVars; i++) {
|
||||
addVariable (new VarNode (i, domainSizes[i]));
|
||||
}
|
||||
|
||||
unsigned nFactors;
|
||||
is >> nFactors;
|
||||
for (unsigned i = 0; i < nFactors; i++) {
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
unsigned nFactorVars;
|
||||
is >> nFactorVars;
|
||||
Vars neighs;
|
||||
for (unsigned j = 0; j < nFactorVars; j++) {
|
||||
unsigned vid;
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
unsigned vid;
|
||||
is >> nrFactors;
|
||||
vector<VarIds> factorVarIds;
|
||||
vector<Ranges> factorRanges;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrArgs;
|
||||
factorVarIds.push_back ({ });
|
||||
factorRanges.push_back ({ });
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
is >> vid;
|
||||
VarNode* neigh = getVarNode (vid);
|
||||
if (!neigh) {
|
||||
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
|
||||
if (vid >= ranges.size()) {
|
||||
cerr << "error: invalid variable identifier `" << vid << "'" << endl;
|
||||
cerr << "identifiers must be between 0 and " << ranges.size() - 1 ;
|
||||
cerr << endl;
|
||||
abort();
|
||||
}
|
||||
neighs.push_back (neigh);
|
||||
}
|
||||
FactorNode* fn = new FactorNode (new Factor (neighs));
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (fn, static_cast<VarNode*> (neighs[j]));
|
||||
factorVarIds.back().push_back (vid);
|
||||
factorRanges.back().push_back (ranges[vid]);
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nFactors; i++) {
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
unsigned nParams;
|
||||
is >> nParams;
|
||||
if (facNodes_[i]->params().size() != nParams) {
|
||||
cerr << "error: invalid number of parameters for factor " ;
|
||||
cerr << facNodes_[i]->getLabel() ;
|
||||
cerr << ", expected: " << facNodes_[i]->params().size();
|
||||
cerr << ", given: " << nParams << endl;
|
||||
// read the parameters
|
||||
unsigned nrParams;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrParams;
|
||||
if (nrParams != Util::expectedSize (factorRanges[i])) {
|
||||
cerr << "error: invalid number of parameters for factor nº " << i ;
|
||||
cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
|
||||
cerr << ", given: " << nrParams << endl;
|
||||
abort();
|
||||
}
|
||||
Params params (nParams);
|
||||
for (unsigned j = 0; j < nParams; j++) {
|
||||
double param;
|
||||
is >> param;
|
||||
params[j] = param;
|
||||
Params params (nrParams);
|
||||
for (unsigned j = 0; j < nrParams; j++) {
|
||||
is >> params[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
facNodes_[i]->factor()->setParams (params);
|
||||
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
|
||||
}
|
||||
is.close();
|
||||
setIndexes();
|
||||
@@ -131,79 +118,51 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
string line;
|
||||
unsigned nFactors;
|
||||
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> nFactors;
|
||||
|
||||
if (is.fail()) {
|
||||
cerr << "error: cannot read the number of factors" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
getline (is, line);
|
||||
if (is.fail() || line.size() > 0) {
|
||||
cerr << "error: cannot read the number of factors" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nFactors; i++) {
|
||||
unsigned nVars;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
|
||||
is >> nVars;
|
||||
ignoreLines (is);
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
VarId vid;
|
||||
is >> nrFactors;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
// read the factor arguments
|
||||
is >> nrArgs;
|
||||
VarIds vids;
|
||||
for (unsigned j = 0; j < nVars; j++) {
|
||||
VarId vid;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> vid;
|
||||
vids.push_back (vid);
|
||||
}
|
||||
|
||||
Vars neighs;
|
||||
unsigned nParams = 1;
|
||||
for (unsigned j = 0; j < nVars; j++) {
|
||||
unsigned dsize;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> dsize;
|
||||
// read ranges
|
||||
Ranges ranges (nrArgs);
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> ranges[j];
|
||||
VarNode* var = getVarNode (vids[j]);
|
||||
if (var == 0) {
|
||||
var = new VarNode (vids[j], dsize);
|
||||
addVariable (var);
|
||||
} else {
|
||||
if (var->range() != dsize) {
|
||||
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
||||
cerr << "more factors with different domain sizes" << endl;
|
||||
}
|
||||
if (var != 0 && ranges[j] != var->range()) {
|
||||
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
||||
cerr << "more factors with a different range" << endl;
|
||||
}
|
||||
neighs.push_back (var);
|
||||
nParams *= var->range();
|
||||
}
|
||||
Params params (nParams, 0);
|
||||
// read parameters
|
||||
ignoreLines (is);
|
||||
unsigned nNonzeros;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> nNonzeros;
|
||||
|
||||
Params params (Util::expectedSize (ranges), 0);
|
||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||
ignoreLines (is);
|
||||
unsigned index;
|
||||
double val;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> index;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
ignoreLines (is);
|
||||
double val;
|
||||
is >> val;
|
||||
params[index] = val;
|
||||
}
|
||||
reverse (neighs.begin(), neighs.end());
|
||||
reverse (vids.begin(), vids.end());
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
FactorNode* fn = new FactorNode (new Factor (neighs, params));
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (fn, static_cast<VarNode*> (neighs[j]));
|
||||
}
|
||||
addFactor (Factor (vids, ranges, params));
|
||||
}
|
||||
is.close();
|
||||
setIndexes();
|
||||
@@ -223,30 +182,11 @@ FactorGraph::~FactorGraph (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addVariable (VarNode* vn)
|
||||
{
|
||||
varNodes_.push_back (vn);
|
||||
vn->setIndex (varNodes_.size() - 1);
|
||||
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (FactorNode* fn)
|
||||
{
|
||||
facNodes_.push_back (fn);
|
||||
fn->setIndex (facNodes_.size() - 1);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (const Factor& factor)
|
||||
{
|
||||
FactorNode* fn = new FactorNode (factor);
|
||||
addFactor (fn);
|
||||
addFactorNode (fn);
|
||||
const VarIds& vids = factor.arguments();
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
bool found = false;
|
||||
@@ -258,7 +198,7 @@ FactorGraph::addFactor (const Factor& factor)
|
||||
}
|
||||
if (found == false) {
|
||||
VarNode* vn = new VarNode (vids[i], factor.range (i));
|
||||
addVariable (vn);
|
||||
addVarNode (vn);
|
||||
addEdge (vn, fn);
|
||||
}
|
||||
}
|
||||
@@ -266,6 +206,25 @@ FactorGraph::addFactor (const Factor& factor)
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addVarNode (VarNode* vn)
|
||||
{
|
||||
varNodes_.push_back (vn);
|
||||
vn->setIndex (varNodes_.size() - 1);
|
||||
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactorNode (FactorNode* fn)
|
||||
{
|
||||
facNodes_.push_back (fn);
|
||||
fn->setIndex (facNodes_.size() - 1);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addEdge (VarNode* vn, FactorNode* fn)
|
||||
{
|
||||
@@ -301,7 +260,7 @@ FactorGraph::getStructure (void)
|
||||
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const VarIds& vids = facNodes_[i]->factor()->arguments();
|
||||
const VarIds& vids = facNodes_[i]->factor().arguments();
|
||||
for (unsigned j = 1; j < vids.size(); j++) {
|
||||
structure_.addEdge (vids[j], vids[0]);
|
||||
}
|
||||
@@ -340,7 +299,7 @@ FactorGraph::print (void) const
|
||||
cout << endl << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
facNodes_[i]->factor()->print();
|
||||
facNodes_[i]->factor().print();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -446,7 +405,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
out << factorVars[j]->range() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
Params params = facNodes_[i]->factor()->params();
|
||||
Params params = facNodes_[i]->factor().params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
@@ -461,6 +420,17 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::ignoreLines (std::ifstream& is) const
|
||||
{
|
||||
string ignoreStr;
|
||||
while (is.peek() == '#' || is.peek() == '\n') {
|
||||
getline (is, ignoreStr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (void) const
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user