2011-05-17 12:00:33 +01:00
|
|
|
#ifndef BP_SOLVER_H
|
|
|
|
#define BP_SOLVER_H
|
|
|
|
|
|
|
|
#include <iomanip>
|
|
|
|
|
|
|
|
#include "GraphicalModel.h"
|
|
|
|
#include "Variable.h"
|
|
|
|
|
|
|
|
using namespace std;
|
|
|
|
|
|
|
|
class Solver
|
|
|
|
{
|
|
|
|
public:
|
|
|
|
Solver (const GraphicalModel* gm)
|
|
|
|
{
|
|
|
|
gm_ = gm;
|
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
virtual ~Solver() {} // to call subclass destructor
|
2011-05-17 12:00:33 +01:00
|
|
|
virtual void runSolver (void) = 0;
|
2011-07-22 21:33:30 +01:00
|
|
|
virtual ParamSet getPosterioriOf (Vid) const = 0;
|
|
|
|
virtual ParamSet getJointDistributionOf (const VidSet&) = 0;
|
2011-05-17 12:00:33 +01:00
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
void printAllPosterioris (void) const
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
VarSet vars = gm_->getVariables();
|
|
|
|
for (unsigned i = 0; i < vars.size(); i++) {
|
|
|
|
printPosterioriOf (vars[i]->getVarId());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void printPosterioriOf (Vid vid) const
|
|
|
|
{
|
|
|
|
Variable* var = gm_->getVariable (vid);
|
2011-05-17 12:00:33 +01:00
|
|
|
cout << endl;
|
|
|
|
cout << setw (20) << left << var->getLabel() << "posteriori" ;
|
|
|
|
cout << endl;
|
|
|
|
cout << "------------------------------" ;
|
|
|
|
cout << endl;
|
|
|
|
const Domain& domain = var->getDomain();
|
2011-07-22 21:33:30 +01:00
|
|
|
ParamSet results = getPosterioriOf (vid);
|
|
|
|
for (unsigned xi = 0; xi < var->getDomainSize(); xi++) {
|
2011-05-17 12:00:33 +01:00
|
|
|
cout << setw (20) << domain[xi];
|
|
|
|
cout << setprecision (PRECISION) << results[xi];
|
|
|
|
cout << endl;
|
|
|
|
}
|
|
|
|
cout << endl;
|
|
|
|
}
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
void printJointDistributionOf (const VidSet& vids)
|
2011-05-17 12:00:33 +01:00
|
|
|
{
|
2011-07-22 21:33:30 +01:00
|
|
|
const ParamSet& jointDist = getJointDistributionOf (vids);
|
|
|
|
cout << endl;
|
|
|
|
cout << "joint distribution of " ;
|
|
|
|
VarSet vars;
|
|
|
|
for (unsigned i = 0; i < vids.size() - 1; i++) {
|
|
|
|
Variable* var = gm_->getVariable (vids[i]);
|
|
|
|
cout << var->getLabel() << ", " ;
|
|
|
|
vars.push_back (var);
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
2011-07-22 21:33:30 +01:00
|
|
|
Variable* var = gm_->getVariable (vids[vids.size() - 1]);
|
|
|
|
cout << var->getLabel() ;
|
|
|
|
vars.push_back (var);
|
|
|
|
cout << endl;
|
|
|
|
cout << "------------------------------" ;
|
|
|
|
cout << endl;
|
|
|
|
const vector<string>& domainConfs = Util::getInstantiations (vars);
|
|
|
|
for (unsigned i = 0; i < jointDist.size(); i++) {
|
|
|
|
cout << left << setw (20) << domainConfs[i];
|
|
|
|
cout << setprecision (PRECISION) << jointDist[i];
|
|
|
|
cout << endl;
|
|
|
|
}
|
|
|
|
cout << endl;
|
2011-05-17 12:00:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
2011-07-22 21:33:30 +01:00
|
|
|
const GraphicalModel* gm_;
|
2011-05-17 12:00:33 +01:00
|
|
|
};
|
|
|
|
|
2011-07-22 21:33:30 +01:00
|
|
|
#endif //BP_SOLVER_H
|
|
|
|
|