This commit is contained in:
Tiago Gomes 2012-05-31 21:24:40 +01:00
parent e11ed1a226
commit 3f0f41c8a9
2 changed files with 13 additions and 15 deletions

View File

@ -14,7 +14,6 @@
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
{
fg_ = &fg;
runned_ = false;
}
@ -74,8 +73,8 @@ BpSolver::getPosterioriOf (VarId vid)
if (runned_ == false) {
runSolver();
}
assert (fg_->getVarNode (vid));
VarNode* var = fg_->getVarNode (vid);
assert (fg.getVarNode (vid));
VarNode* var = fg.getVarNode (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence());
@ -107,7 +106,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
if (runned_ == false) {
runSolver();
}
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
VarNode* vn = fg.getVarNode (jointVarIds[0]);
const FacNodes& facNodes = vn->neighbors();
size_t idx = facNodes.size();
for (size_t i = 0; i < facNodes.size(); i++) {
@ -190,7 +189,7 @@ BpSolver::runSolver (void)
void
BpSolver::createLinks (void)
{
const FacNodes& facNodes = fg_->facNodes();
const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) {
const VarNodes& neighbors = facNodes[i]->neighbors();
for (size_t j = 0; j < neighbors.size(); j++) {
@ -366,12 +365,12 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
{
VarNodes jointVars;
for (size_t i = 0; i < jointVarIds.size(); i++) {
assert (fg_->getVarNode (jointVarIds[i]));
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
assert (fg.getVarNode (jointVarIds[i]));
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
}
FactorGraph* fg = new FactorGraph (*fg_);
BpSolver solver (*fg);
FactorGraph* tempFg = new FactorGraph (fg);
BpSolver solver (*tempFg);
solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
@ -383,7 +382,7 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
Vars observedVars;
Ranges observedRanges;
for (size_t j = 0; j < observedVids.size(); j++) {
observedVars.push_back (fg->getVarNode (observedVids[j]));
observedVars.push_back (tempFg->getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range());
}
Indexer indexer (observedRanges, false);
@ -391,7 +390,7 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]);
}
BpSolver solver (*fg);
BpSolver solver (*tempFg);
solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (size_t k = 0; k < beliefs.size(); k++) {
@ -418,12 +417,12 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
void
BpSolver::initializeSolver (void)
{
const VarNodes& varNodes = fg_->varNodes();
const VarNodes& varNodes = fg.varNodes();
varsI_.reserve (varNodes.size());
for (size_t i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo());
}
const FacNodes& facNodes = fg_->facNodes();
const FacNodes& facNodes = fg.facNodes();
facsI_.reserve (facNodes.size());
for (size_t i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo());

View File

@ -160,12 +160,11 @@ class BpSolver : public Solver
}
};
BpLinks links_;
BpLinks links_;
unsigned nIters_;
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_;
bool runned_;
const FactorGraph* fg_;
typedef multiset<BpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_;