This commit is contained in:
Tiago Gomes 2012-05-31 21:24:40 +01:00
parent e11ed1a226
commit 3f0f41c8a9
2 changed files with 13 additions and 15 deletions

View File

@ -14,7 +14,6 @@
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg) BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
{ {
fg_ = &fg;
runned_ = false; runned_ = false;
} }
@ -74,8 +73,8 @@ BpSolver::getPosterioriOf (VarId vid)
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
} }
assert (fg_->getVarNode (vid)); assert (fg.getVarNode (vid));
VarNode* var = fg_->getVarNode (vid); VarNode* var = fg.getVarNode (vid);
Params probs; Params probs;
if (var->hasEvidence()) { if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence()); probs.resize (var->range(), LogAware::noEvidence());
@ -107,7 +106,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
} }
VarNode* vn = fg_->getVarNode (jointVarIds[0]); VarNode* vn = fg.getVarNode (jointVarIds[0]);
const FacNodes& facNodes = vn->neighbors(); const FacNodes& facNodes = vn->neighbors();
size_t idx = facNodes.size(); size_t idx = facNodes.size();
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
@ -190,7 +189,7 @@ BpSolver::runSolver (void)
void void
BpSolver::createLinks (void) BpSolver::createLinks (void)
{ {
const FacNodes& facNodes = fg_->facNodes(); const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
const VarNodes& neighbors = facNodes[i]->neighbors(); const VarNodes& neighbors = facNodes[i]->neighbors();
for (size_t j = 0; j < neighbors.size(); j++) { for (size_t j = 0; j < neighbors.size(); j++) {
@ -366,12 +365,12 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
{ {
VarNodes jointVars; VarNodes jointVars;
for (size_t i = 0; i < jointVarIds.size(); i++) { for (size_t i = 0; i < jointVarIds.size(); i++) {
assert (fg_->getVarNode (jointVarIds[i])); assert (fg.getVarNode (jointVarIds[i]));
jointVars.push_back (fg_->getVarNode (jointVarIds[i])); jointVars.push_back (fg.getVarNode (jointVarIds[i]));
} }
FactorGraph* fg = new FactorGraph (*fg_); FactorGraph* tempFg = new FactorGraph (fg);
BpSolver solver (*fg); BpSolver solver (*tempFg);
solver.runSolver(); solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]); Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
@ -383,7 +382,7 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
Vars observedVars; Vars observedVars;
Ranges observedRanges; Ranges observedRanges;
for (size_t j = 0; j < observedVids.size(); j++) { for (size_t j = 0; j < observedVids.size(); j++) {
observedVars.push_back (fg->getVarNode (observedVids[j])); observedVars.push_back (tempFg->getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range()); observedRanges.push_back (observedVars.back()->range());
} }
Indexer indexer (observedRanges, false); Indexer indexer (observedRanges, false);
@ -391,7 +390,7 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
for (size_t j = 0; j < observedVars.size(); j++) { for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]); observedVars[j]->setEvidence (indexer[j]);
} }
BpSolver solver (*fg); BpSolver solver (*tempFg);
solver.runSolver(); solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]); Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (size_t k = 0; k < beliefs.size(); k++) { for (size_t k = 0; k < beliefs.size(); k++) {
@ -418,12 +417,12 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
void void
BpSolver::initializeSolver (void) BpSolver::initializeSolver (void)
{ {
const VarNodes& varNodes = fg_->varNodes(); const VarNodes& varNodes = fg.varNodes();
varsI_.reserve (varNodes.size()); varsI_.reserve (varNodes.size());
for (size_t i = 0; i < varNodes.size(); i++) { for (size_t i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo()); varsI_.push_back (new SPNodeInfo());
} }
const FacNodes& facNodes = fg_->facNodes(); const FacNodes& facNodes = fg.facNodes();
facsI_.reserve (facNodes.size()); facsI_.reserve (facNodes.size());
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo()); facsI_.push_back (new SPNodeInfo());

View File

@ -165,7 +165,6 @@ class BpSolver : public Solver
vector<SPNodeInfo*> varsI_; vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_; vector<SPNodeInfo*> facsI_;
bool runned_; bool runned_;
const FactorGraph* fg_;
typedef multiset<BpLink*, CompareResidual> SortedOrder; typedef multiset<BpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_; SortedOrder sortedOrder_;