This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/bp/VarElimSolver.cpp

207 lines
4.7 KiB
C++
Raw Normal View History

2011-12-12 15:29:51 +00:00
#include <algorithm>
#include "VarElimSolver.h"
#include "ElimGraph.h"
#include "Factor.h"
2012-03-22 11:33:24 +00:00
#include "Util.h"
2011-12-12 15:29:51 +00:00
VarElimSolver::VarElimSolver (const BayesNet& bn) : Solver (&bn)
{
bayesNet_ = &bn;
factorGraph_ = new FactorGraph (bn);
}
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (&fg)
{
bayesNet_ = 0;
factorGraph_ = &fg;
}
VarElimSolver::~VarElimSolver (void)
{
if (bayesNet_) {
delete factorGraph_;
}
}
2012-03-22 11:33:24 +00:00
Params
2011-12-12 15:29:51 +00:00
VarElimSolver::getPosterioriOf (VarId vid)
{
2012-03-22 11:33:24 +00:00
assert (factorGraph_->getFgVarNode (vid));
2011-12-12 15:29:51 +00:00
FgVarNode* vn = factorGraph_->getFgVarNode (vid);
if (vn->hasEvidence()) {
2012-03-22 11:33:24 +00:00
Params params (vn->nrStates(), 0.0);
2011-12-12 15:29:51 +00:00
params[vn->getEvidence()] = 1.0;
return params;
}
2012-03-22 11:33:24 +00:00
return getJointDistributionOf (VarIds() = {vid});
2011-12-12 15:29:51 +00:00
}
2012-03-22 11:33:24 +00:00
Params
VarElimSolver::getJointDistributionOf (const VarIds& vids)
2011-12-12 15:29:51 +00:00
{
factorList_.clear();
varFactors_.clear();
elimOrder_.clear();
createFactorList();
introduceEvidence();
chooseEliminationOrder (vids);
processFactorList (vids);
2012-03-31 23:27:37 +01:00
Params params = factorList_.back()->params();
2012-03-22 11:33:24 +00:00
if (Globals::logDomain) {
Util::fromLog (params);
}
2011-12-12 15:29:51 +00:00
delete factorList_.back();
return params;
}
void
VarElimSolver::createFactorList (void)
{
const FgFacSet& factorNodes = factorGraph_->getFactorNodes();
factorList_.reserve (factorNodes.size() * 2);
for (unsigned i = 0; i < factorNodes.size(); i++) {
factorList_.push_back (new Factor (*factorNodes[i]->factor()));
const FgVarSet& neighs = factorNodes[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
unordered_map<VarId,vector<unsigned> >::iterator it
= varFactors_.find (neighs[j]->varId());
if (it == varFactors_.end()) {
it = varFactors_.insert (make_pair (
neighs[j]->varId(), vector<unsigned>())).first;
}
it->second.push_back (i);
}
}
}
void
VarElimSolver::introduceEvidence (void)
{
const FgVarSet& varNodes = factorGraph_->getVarNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) {
const vector<unsigned>& idxs =
varFactors_.find (varNodes[i]->varId())->second;
for (unsigned j = 0; j < idxs.size(); j++) {
Factor* factor = factorList_[idxs[j]];
2012-03-31 23:27:37 +01:00
if (factor->nrArguments() == 1) {
2011-12-12 15:29:51 +00:00
factorList_[idxs[j]] = 0;
} else {
2012-03-22 11:33:24 +00:00
factorList_[idxs[j]]->absorveEvidence (
2011-12-12 15:29:51 +00:00
varNodes[i]->varId(), varNodes[i]->getEvidence());
}
}
}
}
}
void
2012-03-22 11:33:24 +00:00
VarElimSolver::chooseEliminationOrder (const VarIds& vids)
2011-12-12 15:29:51 +00:00
{
if (bayesNet_) {
2012-03-22 11:33:24 +00:00
ElimGraph graph (*bayesNet_);
2011-12-12 15:29:51 +00:00
elimOrder_ = graph.getEliminatingOrder (vids);
} else {
const FgVarSet& varNodes = factorGraph_->getVarNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
VarId vid = varNodes[i]->varId();
2012-03-31 23:27:37 +01:00
if (Util::contains (vids, vid) == false &&
varNodes[i]->hasEvidence() == false) {
2011-12-12 15:29:51 +00:00
elimOrder_.push_back (vid);
}
}
}
}
void
2012-03-22 11:33:24 +00:00
VarElimSolver::processFactorList (const VarIds& vids)
2011-12-12 15:29:51 +00:00
{
for (unsigned i = 0; i < elimOrder_.size(); i++) {
eliminate (elimOrder_[i]);
}
2012-03-22 11:33:24 +00:00
Factor* finalFactor = new Factor();
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < factorList_.size(); i++) {
if (factorList_[i]) {
2012-03-22 11:33:24 +00:00
finalFactor->multiply (*factorList_[i]);
2011-12-12 15:29:51 +00:00
delete factorList_[i];
factorList_[i] = 0;
}
}
2012-03-22 11:33:24 +00:00
VarIds unobservedVids;
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < vids.size(); i++) {
if (factorGraph_->getFgVarNode (vids[i])->hasEvidence() == false) {
2012-03-22 11:33:24 +00:00
unobservedVids.push_back (vids[i]);
2011-12-12 15:29:51 +00:00
}
}
2012-03-22 11:33:24 +00:00
2012-03-31 23:27:37 +01:00
finalFactor->reorderArguments (unobservedVids);
2012-03-22 11:33:24 +00:00
finalFactor->normalize();
factorList_.push_back (finalFactor);
2011-12-12 15:29:51 +00:00
}
void
VarElimSolver::eliminate (VarId elimVar)
{
Factor* result = 0;
2012-03-22 11:33:24 +00:00
FgVarNode* vn = factorGraph_->getFgVarNode (elimVar);
2011-12-12 15:29:51 +00:00
vector<unsigned>& idxs = varFactors_.find (elimVar)->second;
for (unsigned i = 0; i < idxs.size(); i++) {
unsigned idx = idxs[i];
if (factorList_[idx]) {
if (result == 0) {
result = new Factor(*factorList_[idx]);
} else {
2012-03-22 11:33:24 +00:00
result->multiply (*factorList_[idx]);
2011-12-12 15:29:51 +00:00
}
delete factorList_[idx];
factorList_[idx] = 0;
}
}
2012-03-31 23:27:37 +01:00
if (result != 0 && result->nrArguments() != 1) {
2012-03-22 11:33:24 +00:00
result->sumOut (vn->varId());
2011-12-12 15:29:51 +00:00
factorList_.push_back (result);
2012-03-31 23:27:37 +01:00
const VarIds& resultVarIds = result->arguments();
2011-12-12 15:29:51 +00:00
for (unsigned i = 0; i < resultVarIds.size(); i++) {
vector<unsigned>& idxs =
varFactors_.find (resultVarIds[i])->second;
idxs.push_back (factorList_.size() - 1);
}
}
}
void
VarElimSolver::printActiveFactors (void)
{
for (unsigned i = 0; i < factorList_.size(); i++) {
if (factorList_[i] != 0) {
2012-03-22 11:33:24 +00:00
factorList_[i]->print();
2011-12-12 15:29:51 +00:00
cout << endl;
}
}
}