2012-05-23 14:56:01 +01:00
|
|
|
#include <cassert>
|
|
|
|
|
|
|
|
#include "BayesBall.h"
|
|
|
|
|
|
|
|
|
2013-02-08 21:12:46 +00:00
|
|
|
namespace Horus {
|
2013-02-07 23:53:13 +00:00
|
|
|
|
2013-02-06 00:24:02 +00:00
|
|
|
BayesBall::BayesBall (FactorGraph& fg)
|
|
|
|
: fg_(fg) , dag_(fg.getStructure())
|
|
|
|
{
|
|
|
|
dag_.clear();
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FactorGraph*
|
|
|
|
BayesBall::getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
|
|
|
|
{
|
|
|
|
BayesBall bb (fg);
|
|
|
|
return bb.getMinimalFactorGraph (vids);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2012-05-23 14:56:01 +01:00
|
|
|
FactorGraph*
|
|
|
|
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
|
|
|
{
|
2012-05-31 22:42:38 +01:00
|
|
|
assert (fg_.bayesianFactors());
|
2012-05-23 14:56:01 +01:00
|
|
|
Scheduling scheduling;
|
2012-05-24 22:55:20 +01:00
|
|
|
for (size_t i = 0; i < queryIds.size(); i++) {
|
2012-05-23 14:56:01 +01:00
|
|
|
assert (dag_.getNode (queryIds[i]));
|
2012-06-19 14:32:12 +01:00
|
|
|
BBNode* n = dag_.getNode (queryIds[i]);
|
2012-05-23 14:56:01 +01:00
|
|
|
scheduling.push (ScheduleInfo (n, false, true));
|
|
|
|
}
|
|
|
|
|
|
|
|
while (!scheduling.empty()) {
|
|
|
|
ScheduleInfo& sch = scheduling.front();
|
2012-06-19 14:32:12 +01:00
|
|
|
BBNode* n = sch.node;
|
2012-05-23 14:56:01 +01:00
|
|
|
n->setAsVisited();
|
|
|
|
if (n->hasEvidence() == false && sch.visitedFromChild) {
|
2013-02-16 16:17:14 +00:00
|
|
|
if (n->isMarkedAbove() == false) {
|
|
|
|
n->markAbove();
|
2012-05-23 14:56:01 +01:00
|
|
|
scheduleParents (n, scheduling);
|
|
|
|
}
|
2013-02-16 16:17:14 +00:00
|
|
|
if (n->isMarkedBelow() == false) {
|
|
|
|
n->markBelow();
|
2012-05-23 14:56:01 +01:00
|
|
|
scheduleChilds (n, scheduling);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (sch.visitedFromParent) {
|
2013-02-16 16:17:14 +00:00
|
|
|
if (n->hasEvidence() && n->isMarkedAbove() == false) {
|
|
|
|
n->markAbove();
|
2012-05-23 14:56:01 +01:00
|
|
|
scheduleParents (n, scheduling);
|
|
|
|
}
|
2013-02-16 16:17:14 +00:00
|
|
|
if (n->hasEvidence() == false && n->isMarkedBelow() == false) {
|
|
|
|
n->markBelow();
|
2012-05-23 14:56:01 +01:00
|
|
|
scheduleChilds (n, scheduling);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
scheduling.pop();
|
|
|
|
}
|
|
|
|
|
|
|
|
FactorGraph* fg = new FactorGraph();
|
|
|
|
constructGraph (fg);
|
|
|
|
return fg;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void
|
|
|
|
BayesBall::constructGraph (FactorGraph* fg) const
|
|
|
|
{
|
|
|
|
const FacNodes& facNodes = fg_.facNodes();
|
2012-05-24 22:55:20 +01:00
|
|
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
2012-06-19 14:32:12 +01:00
|
|
|
const BBNode* n = dag_.getNode (
|
2012-05-23 14:56:01 +01:00
|
|
|
facNodes[i]->factor().argument (0));
|
2013-02-16 16:17:14 +00:00
|
|
|
if (n->isMarkedAbove()) {
|
2012-05-23 14:56:01 +01:00
|
|
|
fg->addFactor (facNodes[i]->factor());
|
|
|
|
} else if (n->hasEvidence() && n->isVisited()) {
|
|
|
|
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
|
|
|
Ranges ranges = { facNodes[i]->factor().range (0) };
|
2013-02-08 21:01:53 +00:00
|
|
|
Params params (ranges[0], LogAware::noEvidence());
|
|
|
|
params[n->getEvidence()] = LogAware::withEvidence();
|
2012-05-23 14:56:01 +01:00
|
|
|
fg->addFactor (Factor (varIds, ranges, params));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
const VarNodes& varNodes = fg_.varNodes();
|
2012-05-24 22:55:20 +01:00
|
|
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
2012-05-23 14:56:01 +01:00
|
|
|
if (varNodes[i]->hasEvidence()) {
|
|
|
|
VarNode* vn = fg->getVarNode (varNodes[i]->varId());
|
|
|
|
if (vn) {
|
|
|
|
vn->setEvidence (varNodes[i]->getEvidence());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2013-02-08 21:12:46 +00:00
|
|
|
} // namespace Horus
|
2013-02-07 23:53:13 +00:00
|
|
|
|