Add support to markov networks
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "BayesNet.h"
|
||||
#include "BayesBall.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
@@ -205,13 +206,13 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
var = new FgVarNode (vids[j], dsize);
|
||||
addVariable (var);
|
||||
} else {
|
||||
if (var->nrStates() != dsize) {
|
||||
if (var->range() != dsize) {
|
||||
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
||||
cerr << "more factors with different domain sizes" << endl;
|
||||
}
|
||||
}
|
||||
neighs.push_back (var);
|
||||
nParams *= var->nrStates();
|
||||
nParams *= var->range();
|
||||
}
|
||||
Params params (nParams, 0);
|
||||
unsigned nNonzeros;
|
||||
@@ -274,6 +275,30 @@ FactorGraph::addFactor (FgFacNode* fn)
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (const Factor& factor)
|
||||
{
|
||||
FgFacNode* fn = new FgFacNode (factor);
|
||||
addFactor (fn);
|
||||
const VarIds& vids = factor.arguments();
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
bool found = false;
|
||||
for (unsigned j = 0; j < varNodes_.size(); j++) {
|
||||
if (varNodes_[j]->varId() == vids[i]) {
|
||||
addEdge (varNodes_[j], fn);
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
FgVarNode* vn = new FgVarNode (vids[i], factor.range (i));
|
||||
addVariable (vn);
|
||||
addEdge (vn, fn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
||||
{
|
||||
@@ -322,6 +347,26 @@ FactorGraph::isTree (void) const
|
||||
|
||||
|
||||
|
||||
DAGraph&
|
||||
FactorGraph::getStructure (void)
|
||||
{
|
||||
assert (fromBayesNet_);
|
||||
if (structure_.empty()) {
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const VarIds& vids = facNodes_[i]->factor()->arguments();
|
||||
for (unsigned j = 1; j < vids.size(); j++) {
|
||||
structure_.addEdge (vids[j], vids[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return structure_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::setIndexes (void)
|
||||
{
|
||||
@@ -339,11 +384,11 @@ void
|
||||
FactorGraph::printGraphicalModel (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
cout << "VarId = " << varNodes_[i]->varId() << endl;
|
||||
cout << "Label = " << varNodes_[i]->label() << endl;
|
||||
cout << "Nr States = " << varNodes_[i]->nrStates() << endl;
|
||||
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||
cout << "Factors = " ;
|
||||
cout << "var id = " << varNodes_[i]->varId() << endl;
|
||||
cout << "label = " << varNodes_[i]->label() << endl;
|
||||
cout << "range = " << varNodes_[i]->range() << endl;
|
||||
cout << "evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||
cout << "factors = " ;
|
||||
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
|
||||
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
|
||||
}
|
||||
@@ -351,7 +396,6 @@ FactorGraph::printGraphicalModel (void) const
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
facNodes_[i]->factor()->print();
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -366,22 +410,18 @@ FactorGraph::exportToGraphViz (const char* fileName) const
|
||||
cerr << "FactorGraph::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "graph \"" << fileName << "\" {" << endl;
|
||||
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
if (varNodes_[i]->hasEvidence()) {
|
||||
out << '"' << varNodes_[i]->label() << '"' ;
|
||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||
out << " [label=\"" << facNodes_[i]->getLabel();
|
||||
out << "\"" << ", shape=box]" << endl;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const FgVarSet& myVars = facNodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < myVars.size(); j++) {
|
||||
@@ -390,7 +430,6 @@ FactorGraph::exportToGraphViz (const char* fileName) const
|
||||
out << '"' << myVars[j]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
@@ -410,7 +449,7 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
out << "MARKOV" << endl;
|
||||
out << varNodes_.size() << endl;
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
out << varNodes_[i]->nrStates() << " " ;
|
||||
out << varNodes_[i]->range() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
|
||||
@@ -459,7 +498,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
}
|
||||
out << endl;
|
||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||
out << factorVars[j]->nrStates() << " " ;
|
||||
out << factorVars[j]->range() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
Params params = facNodes_[i]->factor()->params();
|
||||
@@ -496,10 +535,11 @@ FactorGraph::containsCycle (void) const
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (const FgVarNode* v,
|
||||
const FgFacNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
FactorGraph::containsCycle (
|
||||
const FgVarNode* v,
|
||||
const FgFacNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedVars[v->getIndex()] = true;
|
||||
const FgFacSet& adjacencies = v->neighbors();
|
||||
@@ -520,10 +560,11 @@ FactorGraph::containsCycle (const FgVarNode* v,
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (const FgFacNode* v,
|
||||
const FgVarNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
FactorGraph::containsCycle (
|
||||
const FgFacNode* v,
|
||||
const FgVarNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedFactors[v->getIndex()] = true;
|
||||
const FgVarSet& adjacencies = v->neighbors();
|
||||
|
Reference in New Issue
Block a user