This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/BayesBall.cpp

100 lines
2.4 KiB
C++
Raw Normal View History

2012-05-23 14:56:01 +01:00
#include <cassert>
#include "BayesBall.h"
namespace Horus {
2013-02-07 23:53:13 +00:00
BayesBall::BayesBall (FactorGraph& fg)
: fg_(fg) , dag_(fg.getStructure())
{
dag_.clear();
}
FactorGraph*
BayesBall::getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
{
BayesBall bb (fg);
return bb.getMinimalFactorGraph (vids);
}
2012-05-23 14:56:01 +01:00
FactorGraph*
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
{
assert (fg_.bayesianFactors());
2012-05-23 14:56:01 +01:00
Scheduling scheduling;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < queryIds.size(); i++) {
2012-05-23 14:56:01 +01:00
assert (dag_.getNode (queryIds[i]));
2012-06-19 14:32:12 +01:00
BBNode* n = dag_.getNode (queryIds[i]);
2012-05-23 14:56:01 +01:00
scheduling.push (ScheduleInfo (n, false, true));
}
while (!scheduling.empty()) {
ScheduleInfo& sch = scheduling.front();
2012-06-19 14:32:12 +01:00
BBNode* n = sch.node;
2012-05-23 14:56:01 +01:00
n->setAsVisited();
if (n->hasEvidence() == false && sch.visitedFromChild) {
2013-02-16 16:17:14 +00:00
if (n->isMarkedAbove() == false) {
n->markAbove();
2012-05-23 14:56:01 +01:00
scheduleParents (n, scheduling);
}
2013-02-16 16:17:14 +00:00
if (n->isMarkedBelow() == false) {
n->markBelow();
2012-05-23 14:56:01 +01:00
scheduleChilds (n, scheduling);
}
}
if (sch.visitedFromParent) {
2013-02-16 16:17:14 +00:00
if (n->hasEvidence() && n->isMarkedAbove() == false) {
n->markAbove();
2012-05-23 14:56:01 +01:00
scheduleParents (n, scheduling);
}
2013-02-16 16:17:14 +00:00
if (n->hasEvidence() == false && n->isMarkedBelow() == false) {
n->markBelow();
2012-05-23 14:56:01 +01:00
scheduleChilds (n, scheduling);
}
}
scheduling.pop();
}
FactorGraph* fg = new FactorGraph();
constructGraph (fg);
return fg;
}
void
BayesBall::constructGraph (FactorGraph* fg) const
{
const FacNodes& facNodes = fg_.facNodes();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < facNodes.size(); i++) {
2012-06-19 14:32:12 +01:00
const BBNode* n = dag_.getNode (
2012-05-23 14:56:01 +01:00
facNodes[i]->factor().argument (0));
2013-02-16 16:17:14 +00:00
if (n->isMarkedAbove()) {
2012-05-23 14:56:01 +01:00
fg->addFactor (facNodes[i]->factor());
} else if (n->hasEvidence() && n->isVisited()) {
VarIds varIds = { facNodes[i]->factor().argument (0) };
Ranges ranges = { facNodes[i]->factor().range (0) };
Params params (ranges[0], LogAware::noEvidence());
params[n->getEvidence()] = LogAware::withEvidence();
2012-05-23 14:56:01 +01:00
fg->addFactor (Factor (varIds, ranges, params));
}
}
const VarNodes& varNodes = fg_.varNodes();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < varNodes.size(); i++) {
2012-05-23 14:56:01 +01:00
if (varNodes[i]->hasEvidence()) {
VarNode* vn = fg->getVarNode (varNodes[i]->varId());
if (vn) {
vn->setEvidence (varNodes[i]->getEvidence());
}
}
}
}
} // namespace Horus
2013-02-07 23:53:13 +00:00