Ouchgit statusgit status! forgot to add these to index
This commit is contained in:
parent
4522850cd6
commit
51fd48cd46
107
packages/CLPBN/horus/GroundSolver.cpp
Normal file
107
packages/CLPBN/horus/GroundSolver.cpp
Normal file
@ -0,0 +1,107 @@
|
||||
#include "GroundSolver.h"
|
||||
#include "Util.h"
|
||||
#include "BeliefProp.h"
|
||||
#include "CountingBp.h"
|
||||
#include "VarElim.h"
|
||||
|
||||
|
||||
void
|
||||
GroundSolver::printAnswer (const VarIds& vids)
|
||||
{
|
||||
Vars unobservedVars;
|
||||
VarIds unobservedVids;
|
||||
for (size_t i = 0; i < vids.size(); i++) {
|
||||
VarNode* vn = fg.getVarNode (vids[i]);
|
||||
if (vn->hasEvidence() == false) {
|
||||
unobservedVars.push_back (vn);
|
||||
unobservedVids.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
if (unobservedVids.empty() == false) {
|
||||
Params res = solveQuery (unobservedVids);
|
||||
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
||||
for (size_t i = 0; i < res.size(); i++) {
|
||||
cout << "P(" << stateLines[i] << ") = " ;
|
||||
cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
GroundSolver::printAllPosterioris (void)
|
||||
{
|
||||
VarNodes vars = fg.varNodes();
|
||||
std::sort (vars.begin(), vars.end(), sortByVarId());
|
||||
for (size_t i = 0; i < vars.size(); i++) {
|
||||
printAnswer ({vars[i]->varId()});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
GroundSolver::getJointByConditioning (
|
||||
GroundSolverType solverType,
|
||||
FactorGraph fg,
|
||||
const VarIds& jointVarIds) const
|
||||
{
|
||||
VarNodes jointVars;
|
||||
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
||||
assert (fg.getVarNode (jointVarIds[i]));
|
||||
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
GroundSolver* solver = 0;
|
||||
switch (solverType) {
|
||||
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
|
||||
for (size_t i = 1; i < jointVarIds.size(); i++) {
|
||||
assert (jointVars[i]->hasEvidence() == false);
|
||||
Params newBeliefs;
|
||||
Vars observedVars;
|
||||
Ranges observedRanges;
|
||||
for (size_t j = 0; j < observedVids.size(); j++) {
|
||||
observedVars.push_back (fg.getVarNode (observedVids[j]));
|
||||
observedRanges.push_back (observedVars.back()->range());
|
||||
}
|
||||
Indexer indexer (observedRanges, false);
|
||||
while (indexer.valid()) {
|
||||
for (size_t j = 0; j < observedVars.size(); j++) {
|
||||
observedVars[j]->setEvidence (indexer[j]);
|
||||
}
|
||||
delete solver;
|
||||
switch (solverType) {
|
||||
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
|
||||
int count = -1;
|
||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->range() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
observedVids.push_back (jointVars[i]->varId());
|
||||
}
|
||||
delete solver;
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
36
packages/CLPBN/horus/GroundSolver.h
Normal file
36
packages/CLPBN/horus/GroundSolver.h
Normal file
@ -0,0 +1,36 @@
|
||||
#ifndef HORUS_GROUNDSOLVER_H
|
||||
#define HORUS_GROUNDSOLVER_H
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "Var.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
class GroundSolver
|
||||
{
|
||||
public:
|
||||
GroundSolver (const FactorGraph& factorGraph) : fg(factorGraph) { }
|
||||
|
||||
virtual ~GroundSolver() { } // ensure that subclass destructor is called
|
||||
|
||||
virtual Params solveQuery (VarIds queryVids) = 0;
|
||||
|
||||
virtual void printSolverFlags (void) const = 0;
|
||||
|
||||
void printAnswer (const VarIds& vids);
|
||||
|
||||
void printAllPosterioris (void);
|
||||
|
||||
Params getJointByConditioning (GroundSolverType,
|
||||
FactorGraph, const VarIds& jointVarIds) const;
|
||||
|
||||
protected:
|
||||
const FactorGraph& fg;
|
||||
};
|
||||
|
||||
#endif // HORUS_GROUNDSOLVER_H
|
||||
|
27
packages/CLPBN/horus/LiftedSolver.h
Normal file
27
packages/CLPBN/horus/LiftedSolver.h
Normal file
@ -0,0 +1,27 @@
|
||||
#ifndef HORUS_LIFTEDSOLVER_H
|
||||
#define HORUS_LIFTEDSOLVER_H
|
||||
|
||||
#include "ParfactorList.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
class LiftedSolver
|
||||
{
|
||||
public:
|
||||
LiftedSolver (const ParfactorList& pfList)
|
||||
: parfactorList(pfList) { }
|
||||
|
||||
virtual ~LiftedSolver() { } // ensure that subclass destructor is called
|
||||
|
||||
virtual Params solveQuery (const Grounds& query) = 0;
|
||||
|
||||
virtual void printSolverFlags (void) const = 0;
|
||||
|
||||
protected:
|
||||
const ParfactorList& parfactorList;
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDSOLVER_H
|
||||
|
Reference in New Issue
Block a user