cleanup
This commit is contained in:
parent
e11ed1a226
commit
3f0f41c8a9
|
@ -14,7 +14,6 @@
|
||||||
|
|
||||||
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
||||||
{
|
{
|
||||||
fg_ = &fg;
|
|
||||||
runned_ = false;
|
runned_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,8 +73,8 @@ BpSolver::getPosterioriOf (VarId vid)
|
||||||
if (runned_ == false) {
|
if (runned_ == false) {
|
||||||
runSolver();
|
runSolver();
|
||||||
}
|
}
|
||||||
assert (fg_->getVarNode (vid));
|
assert (fg.getVarNode (vid));
|
||||||
VarNode* var = fg_->getVarNode (vid);
|
VarNode* var = fg.getVarNode (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->range(), LogAware::noEvidence());
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
|
@ -107,7 +106,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||||
if (runned_ == false) {
|
if (runned_ == false) {
|
||||||
runSolver();
|
runSolver();
|
||||||
}
|
}
|
||||||
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
|
VarNode* vn = fg.getVarNode (jointVarIds[0]);
|
||||||
const FacNodes& facNodes = vn->neighbors();
|
const FacNodes& facNodes = vn->neighbors();
|
||||||
size_t idx = facNodes.size();
|
size_t idx = facNodes.size();
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
@ -190,7 +189,7 @@ BpSolver::runSolver (void)
|
||||||
void
|
void
|
||||||
BpSolver::createLinks (void)
|
BpSolver::createLinks (void)
|
||||||
{
|
{
|
||||||
const FacNodes& facNodes = fg_->facNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
const VarNodes& neighbors = facNodes[i]->neighbors();
|
const VarNodes& neighbors = facNodes[i]->neighbors();
|
||||||
for (size_t j = 0; j < neighbors.size(); j++) {
|
for (size_t j = 0; j < neighbors.size(); j++) {
|
||||||
|
@ -366,12 +365,12 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||||
{
|
{
|
||||||
VarNodes jointVars;
|
VarNodes jointVars;
|
||||||
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
||||||
assert (fg_->getVarNode (jointVarIds[i]));
|
assert (fg.getVarNode (jointVarIds[i]));
|
||||||
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
|
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
FactorGraph* fg = new FactorGraph (*fg_);
|
FactorGraph* tempFg = new FactorGraph (fg);
|
||||||
BpSolver solver (*fg);
|
BpSolver solver (*tempFg);
|
||||||
solver.runSolver();
|
solver.runSolver();
|
||||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||||
|
|
||||||
|
@ -383,7 +382,7 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||||
Vars observedVars;
|
Vars observedVars;
|
||||||
Ranges observedRanges;
|
Ranges observedRanges;
|
||||||
for (size_t j = 0; j < observedVids.size(); j++) {
|
for (size_t j = 0; j < observedVids.size(); j++) {
|
||||||
observedVars.push_back (fg->getVarNode (observedVids[j]));
|
observedVars.push_back (tempFg->getVarNode (observedVids[j]));
|
||||||
observedRanges.push_back (observedVars.back()->range());
|
observedRanges.push_back (observedVars.back()->range());
|
||||||
}
|
}
|
||||||
Indexer indexer (observedRanges, false);
|
Indexer indexer (observedRanges, false);
|
||||||
|
@ -391,7 +390,7 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||||
for (size_t j = 0; j < observedVars.size(); j++) {
|
for (size_t j = 0; j < observedVars.size(); j++) {
|
||||||
observedVars[j]->setEvidence (indexer[j]);
|
observedVars[j]->setEvidence (indexer[j]);
|
||||||
}
|
}
|
||||||
BpSolver solver (*fg);
|
BpSolver solver (*tempFg);
|
||||||
solver.runSolver();
|
solver.runSolver();
|
||||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||||
|
@ -418,12 +417,12 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||||
void
|
void
|
||||||
BpSolver::initializeSolver (void)
|
BpSolver::initializeSolver (void)
|
||||||
{
|
{
|
||||||
const VarNodes& varNodes = fg_->varNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
varsI_.reserve (varNodes.size());
|
varsI_.reserve (varNodes.size());
|
||||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
varsI_.push_back (new SPNodeInfo());
|
varsI_.push_back (new SPNodeInfo());
|
||||||
}
|
}
|
||||||
const FacNodes& facNodes = fg_->facNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
facsI_.reserve (facNodes.size());
|
facsI_.reserve (facNodes.size());
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
facsI_.push_back (new SPNodeInfo());
|
facsI_.push_back (new SPNodeInfo());
|
||||||
|
|
|
@ -165,7 +165,6 @@ class BpSolver : public Solver
|
||||||
vector<SPNodeInfo*> varsI_;
|
vector<SPNodeInfo*> varsI_;
|
||||||
vector<SPNodeInfo*> facsI_;
|
vector<SPNodeInfo*> facsI_;
|
||||||
bool runned_;
|
bool runned_;
|
||||||
const FactorGraph* fg_;
|
|
||||||
|
|
||||||
typedef multiset<BpLink*, CompareResidual> SortedOrder;
|
typedef multiset<BpLink*, CompareResidual> SortedOrder;
|
||||||
SortedOrder sortedOrder_;
|
SortedOrder sortedOrder_;
|
||||||
|
|
Reference in New Issue