new version of bp
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include "VarElimSolver.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
VarElimSolver::VarElimSolver (const BayesNet& bn) : Solver (&bn)
|
||||
@@ -30,23 +31,23 @@ VarElimSolver::~VarElimSolver (void)
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
Params
|
||||
VarElimSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
assert (factorGraph_->getFgVarNode (vid));
|
||||
FgVarNode* vn = factorGraph_->getFgVarNode (vid);
|
||||
assert (vn);
|
||||
if (vn->hasEvidence()) {
|
||||
ParamSet params (vn->nrStates(), 0.0);
|
||||
Params params (vn->nrStates(), 0.0);
|
||||
params[vn->getEvidence()] = 1.0;
|
||||
return params;
|
||||
}
|
||||
return getJointDistributionOf (VarIdSet() = {vid});
|
||||
return getJointDistributionOf (VarIds() = {vid});
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
VarElimSolver::getJointDistributionOf (const VarIdSet& vids)
|
||||
Params
|
||||
VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
||||
{
|
||||
factorList_.clear();
|
||||
varFactors_.clear();
|
||||
@@ -55,10 +56,11 @@ VarElimSolver::getJointDistributionOf (const VarIdSet& vids)
|
||||
introduceEvidence();
|
||||
chooseEliminationOrder (vids);
|
||||
processFactorList (vids);
|
||||
ParamSet params = factorList_.back()->getParameters();
|
||||
factorList_.back()->freeDistribution();
|
||||
Params params = factorList_.back()->getParameters();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
delete factorList_.back();
|
||||
Util::normalize (params);
|
||||
return params;
|
||||
}
|
||||
|
||||
@@ -99,7 +101,7 @@ VarElimSolver::introduceEvidence (void)
|
||||
if (factor->nrVariables() == 1) {
|
||||
factorList_[idxs[j]] = 0;
|
||||
} else {
|
||||
factorList_[idxs[j]]->removeInconsistentEntries (
|
||||
factorList_[idxs[j]]->absorveEvidence (
|
||||
varNodes[i]->varId(), varNodes[i]->getEvidence());
|
||||
}
|
||||
}
|
||||
@@ -110,10 +112,10 @@ VarElimSolver::introduceEvidence (void)
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::chooseEliminationOrder (const VarIdSet& vids)
|
||||
VarElimSolver::chooseEliminationOrder (const VarIds& vids)
|
||||
{
|
||||
if (bayesNet_) {
|
||||
ElimGraph graph = ElimGraph (*bayesNet_);
|
||||
ElimGraph graph (*bayesNet_);
|
||||
elimOrder_ = graph.getEliminatingOrder (vids);
|
||||
} else {
|
||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
||||
@@ -130,33 +132,31 @@ VarElimSolver::chooseEliminationOrder (const VarIdSet& vids)
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::processFactorList (const VarIdSet& vids)
|
||||
VarElimSolver::processFactorList (const VarIds& vids)
|
||||
{
|
||||
for (unsigned i = 0; i < elimOrder_.size(); i++) {
|
||||
// cout << "-----------------------------------------" << endl;
|
||||
// cout << "Eliminating " << elimOrder_[i];
|
||||
// cout << " in the following factors:" << endl;
|
||||
// printActiveFactors();
|
||||
eliminate (elimOrder_[i]);
|
||||
}
|
||||
Factor* thisIsTheEnd = new Factor();
|
||||
|
||||
Factor* finalFactor = new Factor();
|
||||
for (unsigned i = 0; i < factorList_.size(); i++) {
|
||||
if (factorList_[i]) {
|
||||
thisIsTheEnd->multiplyByFactor (*factorList_[i]);
|
||||
factorList_[i]->freeDistribution();
|
||||
finalFactor->multiply (*factorList_[i]);
|
||||
delete factorList_[i];
|
||||
factorList_[i] = 0;
|
||||
}
|
||||
}
|
||||
VarIdSet vidsWithoutEvidence;
|
||||
|
||||
VarIds unobservedVids;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
if (factorGraph_->getFgVarNode (vids[i])->hasEvidence() == false) {
|
||||
vidsWithoutEvidence.push_back (vids[i]);
|
||||
unobservedVids.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
thisIsTheEnd->orderVariables (vidsWithoutEvidence);
|
||||
factorList_.push_back (thisIsTheEnd);
|
||||
|
||||
finalFactor->reorderVariables (unobservedVids);
|
||||
finalFactor->normalize();
|
||||
factorList_.push_back (finalFactor);
|
||||
}
|
||||
|
||||
|
||||
@@ -164,30 +164,25 @@ VarElimSolver::processFactorList (const VarIdSet& vids)
|
||||
void
|
||||
VarElimSolver::eliminate (VarId elimVar)
|
||||
{
|
||||
FgVarNode* vn = factorGraph_->getFgVarNode (elimVar);
|
||||
Factor* result = 0;
|
||||
FgVarNode* vn = factorGraph_->getFgVarNode (elimVar);
|
||||
vector<unsigned>& idxs = varFactors_.find (elimVar)->second;
|
||||
//cout << "eliminating " << setw (5) << elimVar << ":" ;
|
||||
for (unsigned i = 0; i < idxs.size(); i++) {
|
||||
unsigned idx = idxs[i];
|
||||
if (factorList_[idx]) {
|
||||
if (result == 0) {
|
||||
result = new Factor(*factorList_[idx]);
|
||||
//cout << " " << factorList_[idx]->label();
|
||||
} else {
|
||||
result->multiplyByFactor (*factorList_[idx]);
|
||||
//cout << " x " << factorList_[idx]->label();
|
||||
result->multiply (*factorList_[idx]);
|
||||
}
|
||||
factorList_[idx]->freeDistribution();
|
||||
delete factorList_[idx];
|
||||
factorList_[idx] = 0;
|
||||
}
|
||||
}
|
||||
if (result != 0 && result->nrVariables() != 1) {
|
||||
result->removeVariable (vn->varId());
|
||||
result->sumOut (vn->varId());
|
||||
factorList_.push_back (result);
|
||||
// cout << endl <<" factor size=" << result->size() << endl;
|
||||
const VarIdSet& resultVarIds = result->getVarIds();
|
||||
const VarIds& resultVarIds = result->getVarIds();
|
||||
for (unsigned i = 0; i < resultVarIds.size(); i++) {
|
||||
vector<unsigned>& idxs =
|
||||
varFactors_.find (resultVarIds[i])->second;
|
||||
@@ -203,7 +198,7 @@ VarElimSolver::printActiveFactors (void)
|
||||
{
|
||||
for (unsigned i = 0; i < factorList_.size(); i++) {
|
||||
if (factorList_[i] != 0) {
|
||||
factorList_[i]->printFactor();
|
||||
factorList_[i]->print();
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user