new version of bp
This commit is contained in:
@@ -9,6 +9,30 @@
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "BayesNet.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
bool FactorGraph::orderFactorVariables = false;
|
||||
|
||||
|
||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||
{
|
||||
const FgVarSet& vars = fg.getVarNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
FgVarNode* varNode = new FgVarNode (vars[i]);
|
||||
addVariable (varNode);
|
||||
}
|
||||
|
||||
const FgFacSet& facs = fg.getFactorNodes();
|
||||
for (unsigned i = 0; i < facs.size(); i++) {
|
||||
FgFacNode* facNode = new FgFacNode (facs[i]);
|
||||
addFactor (facNode);
|
||||
const FgVarSet& neighs = facs[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -30,6 +54,10 @@ FactorGraph::FactorGraph (const BayesNet& bn)
|
||||
}
|
||||
FgFacNode* fn = new FgFacNode (
|
||||
new Factor (neighs, nodes[i]->getDistribution()));
|
||||
if (orderFactorVariables) {
|
||||
sort (neighs.begin(), neighs.end(), CompVarId());
|
||||
fn->factor()->orderVariables();
|
||||
}
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
||||
@@ -110,13 +138,13 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
cerr << ", given: " << nParams << endl;
|
||||
abort();
|
||||
}
|
||||
ParamSet params (nParams);
|
||||
Params params (nParams);
|
||||
for (unsigned j = 0; j < nParams; j++) {
|
||||
double param;
|
||||
is >> param;
|
||||
params[j] = param;
|
||||
}
|
||||
if (NSPACE == NumberSpace::LOGARITHM) {
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
facNodes_[i]->factor()->setParameters (params);
|
||||
@@ -158,7 +186,7 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
|
||||
is >> nVars;
|
||||
VarIdSet vids;
|
||||
VarIds vids;
|
||||
for (unsigned j = 0; j < nVars; j++) {
|
||||
VarId vid;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
@@ -185,15 +213,14 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
neighs.push_back (var);
|
||||
nParams *= var->nrStates();
|
||||
}
|
||||
ParamSet params (nParams, 0);
|
||||
Params params (nParams, 0);
|
||||
unsigned nNonzeros;
|
||||
while ((is.peek()) == '#')
|
||||
getline (is, line);
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> nNonzeros;
|
||||
|
||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||
unsigned index;
|
||||
Param val;
|
||||
double val;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> index;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
@@ -201,7 +228,7 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
params[index] = val;
|
||||
}
|
||||
reverse (neighs.begin(), neighs.end());
|
||||
if (NSPACE == NumberSpace::LOGARITHM) {
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
FgFacNode* fn = new FgFacNode (new Factor (neighs, params));
|
||||
@@ -233,7 +260,7 @@ FactorGraph::addVariable (FgVarNode* vn)
|
||||
{
|
||||
varNodes_.push_back (vn);
|
||||
vn->setIndex (varNodes_.size() - 1);
|
||||
indexMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
|
||||
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
|
||||
}
|
||||
|
||||
|
||||
@@ -246,6 +273,7 @@ FactorGraph::addFactor (FgFacNode* fn)
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
||||
{
|
||||
@@ -326,10 +354,10 @@ void
|
||||
FactorGraph::printGraphicalModel (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
cout << "VarId = " << varNodes_[i]->varId() << endl;
|
||||
cout << "Label = " << varNodes_[i]->label() << endl;
|
||||
cout << "Nr States = " << varNodes_[i]->nrStates() << endl;
|
||||
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||
cout << "VarId = " << varNodes_[i]->varId() << endl;
|
||||
cout << "Label = " << varNodes_[i]->label() << endl;
|
||||
cout << "Nr States = " << varNodes_[i]->nrStates() << endl;
|
||||
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||
cout << "Factors = " ;
|
||||
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
|
||||
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
|
||||
@@ -337,7 +365,7 @@ FactorGraph::printGraphicalModel (void) const
|
||||
cout << endl << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
facNodes_[i]->factor()->printFactor();
|
||||
facNodes_[i]->factor()->print();
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
@@ -412,7 +440,10 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const ParamSet& params = facNodes_[i]->getParameters();
|
||||
Params params = facNodes_[i]->getParameters();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
out << endl << params.size() << endl << " " ;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
out << params[j] << " " ;
|
||||
@@ -446,7 +477,10 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
out << factorVars[j]->nrStates() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
const ParamSet& params = facNodes_[i]->factor()->getParameters();
|
||||
Params params = facNodes_[i]->factor()->getParameters();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
out << params.size() << endl;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
out << j << " " << params[j] << endl;
|
||||
|
Reference in New Issue
Block a user