new version of belief propagation solver.

This commit is contained in:
Vitor Santos Costa
2011-07-22 21:33:30 +01:00
parent a16a7d5b1c
commit 69e5fed10f
41 changed files with 3804 additions and 2238 deletions

View File

@@ -1,17 +1,19 @@
#include <iostream>
#include <cstdlib>
#include <iostream>
#include <sstream>
#include "BayesNet.h"
#include "BPSolver.h"
#include "FactorGraph.h"
#include "SPSolver.h"
#include "BPSolver.h"
#include "CountingBP.h"
using namespace std;
void BayesianNetwork (int, const char* []);
void markovNetwork (int, const char* []);
void runSolver (Solver*, const VarSet&);
const string USAGE = "usage: \
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
@@ -19,14 +21,40 @@ const string USAGE = "usage: \
int
main (int argc, const char* argv[])
{
{
/*
FactorGraph fg;
FgVarNode* varNode1 = new FgVarNode (0, 2);
FgVarNode* varNode2 = new FgVarNode (1, 2);
FgVarNode* varNode3 = new FgVarNode (2, 2);
fg.addVariable (varNode1);
fg.addVariable (varNode2);
fg.addVariable (varNode3);
Distribution* dist = new Distribution (ParamSet() = {1.2, 1.4, 2.0, 0.4});
fg.addFactor (new Factor (FgVarSet() = {varNode1, varNode2}, dist));
fg.addFactor (new Factor (FgVarSet() = {varNode3, varNode2}, dist));
//fg.printGraphicalModel();
//SPSolver sp (fg);
//sp.runSolver();
//sp.printAllPosterioris();
//ParamSet p = sp.getJointDistributionOf (VidSet() = {0, 1, 2});
//cout << Util::parametersToString (p) << endl;
CountingBP cbp (fg);
//cbp.runSolver();
//cbp.printAllPosterioris();
ParamSet p2 = cbp.getJointDistributionOf (VidSet() = {0, 1, 2});
cout << Util::parametersToString (p2) << endl;
fg.freeDistributions();
Statistics::printCompressingStats ("compressing.stats");
return 0;
*/
if (!argv[1]) {
cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl;
exit (0);
}
string fileName = argv[1];
string extension = fileName.substr (fileName.find_last_of ('.') + 1);
const string& fileName = argv[1];
const string& extension = fileName.substr (fileName.find_last_of ('.') + 1);
if (extension == "xml") {
BayesianNetwork (argc, argv);
} else if (extension == "uai") {
@@ -45,13 +73,13 @@ void
BayesianNetwork (int argc, const char* argv[])
{
BayesNet bn (argv[1]);
//bn.printNetwork();
//bn.printGraphicalModel();
NodeSet queryVars;
VarSet queryVars;
for (int i = 2; i < argc; i++) {
string arg = argv[i];
const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) {
BayesNode* queryVar = bn.getNode (arg);
BayesNode* queryVar = bn.getBayesNode (arg);
if (queryVar) {
queryVars.push_back (queryVar);
} else {
@@ -61,9 +89,9 @@ BayesianNetwork (int argc, const char* argv[])
exit (0);
}
} else {
size_t pos = arg.find ('=');
string label = arg.substr (0, pos);
string state = arg.substr (pos + 1);
size_t pos = arg.find ('=');
const string& label = arg.substr (0, pos);
const string& state = arg.substr (pos + 1);
if (label.empty()) {
cerr << "error: missing left argument" << endl;
cerr << USAGE << endl;
@@ -74,7 +102,7 @@ BayesianNetwork (int argc, const char* argv[])
cerr << USAGE << endl;
exit (0);
}
BayesNode* node = bn.getNode (label);
BayesNode* node = bn.getBayesNode (label);
if (node) {
if (node->isValidState (state)) {
node->setEvidence (state);
@@ -94,19 +122,16 @@ BayesianNetwork (int argc, const char* argv[])
}
}
BPSolver solver (bn);
if (queryVars.size() == 0) {
solver.runSolver();
solver.printAllPosterioris();
} else if (queryVars.size() == 1) {
solver.runSolver();
solver.printPosterioriOf (queryVars[0]);
Solver* solver;
if (SolverOptions::convertBn2Fg) {
FactorGraph* fg = new FactorGraph (bn);
fg->printGraphicalModel();
solver = new SPSolver (*fg);
runSolver (solver, queryVars);
delete fg;
} else {
Domain domain = BayesNet::getInstantiations(queryVars);
ParamSet params = solver.getJointDistribution (queryVars);
for (unsigned i = 0; i < params.size(); i++) {
cout << domain[i] << "\t" << params[i] << endl;
}
solver = new BPSolver (bn);
runSolver (solver, queryVars);
}
bn.freeDistributions();
}
@@ -117,11 +142,11 @@ void
markovNetwork (int argc, const char* argv[])
{
FactorGraph fg (argv[1]);
//fg.printFactorGraph();
//fg.printGraphicalModel();
VarSet queryVars;
for (int i = 2; i < argc; i++) {
string arg = argv[i];
const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) {
if (!Util::isInteger (arg)) {
cerr << "error: `" << arg << "' " ;
@@ -129,16 +154,16 @@ markovNetwork (int argc, const char* argv[])
cerr << endl;
exit (0);
}
unsigned varId;
Vid vid;
stringstream ss;
ss << arg;
ss >> varId;
Variable* queryVar = fg.getVariableById (varId);
ss >> vid;
Variable* queryVar = fg.getFgVarNode (vid);
if (queryVar) {
queryVars.push_back (queryVar);
} else {
cerr << "error: there isn't a variable with " ;
cerr << "`" << varId << "' as id" ;
cerr << "`" << vid << "' as id" ;
cerr << endl;
exit (0);
}
@@ -160,11 +185,11 @@ markovNetwork (int argc, const char* argv[])
cerr << endl;
exit (0);
}
unsigned varId;
Vid vid;
stringstream ss;
ss << arg.substr (0, pos);
ss >> varId;
Variable* var = fg.getVariableById (varId);
ss >> vid;
Variable* var = fg.getFgVarNode (vid);
if (var) {
if (!Util::isInteger (arg.substr (pos + 1))) {
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
@@ -176,7 +201,6 @@ markovNetwork (int argc, const char* argv[])
stringstream ss;
ss << arg.substr (pos + 1);
ss >> stateIndex;
cout << "si: " << stateIndex << endl;
if (var->isValidStateIndex (stateIndex)) {
var->setEvidence (stateIndex);
} else {
@@ -188,27 +212,35 @@ markovNetwork (int argc, const char* argv[])
}
} else {
cerr << "error: there isn't a variable with " ;
cerr << "`" << varId << "' as id" ;
cerr << "`" << vid << "' as id" ;
cerr << endl;
exit (0);
}
}
}
SPSolver solver (fg);
if (queryVars.size() == 0) {
solver.runSolver();
solver.printAllPosterioris();
} else if (queryVars.size() == 1) {
solver.runSolver();
solver.printPosterioriOf (queryVars[0]);
} else {
assert (false); //FIXME
//Domain domain = BayesNet::getInstantiations(queryVars);
//ParamSet params = solver.getJointDistribution (queryVars);
//for (unsigned i = 0; i < params.size(); i++) {
// cout << domain[i] << "\t" << params[i] << endl;
//}
}
Solver* solver = new SPSolver (fg);
runSolver (solver, queryVars);
fg.freeDistributions();
}
void
runSolver (Solver* solver, const VarSet& queryVars)
{
VidSet vids;
for (unsigned i = 0; i < queryVars.size(); i++) {
vids.push_back (queryVars[i]->getVarId());
}
if (queryVars.size() == 0) {
solver->runSolver();
solver->printAllPosterioris();
} else if (queryVars.size() == 1) {
solver->runSolver();
solver->printPosterioriOf (vids[0]);
} else {
solver->printJointDistributionOf (vids);
}
delete solver;
}