Add support to markov networks
This commit is contained in:
parent
6c3add2ebd
commit
0d23591058
@ -40,7 +40,7 @@
|
||||
get_pfl_parameters/2
|
||||
]).
|
||||
|
||||
:- use_module(library(clpbn/horus)).
|
||||
% :- use_module(library(clpbn/horus)).
|
||||
|
||||
:- use_module(library(lists)).
|
||||
|
||||
@ -50,7 +50,7 @@
|
||||
|
||||
|
||||
:- use_module(horus,
|
||||
[create_ground_network/2,
|
||||
[create_ground_network/4,
|
||||
set_bayes_net_params/2,
|
||||
run_ground_solver/3,
|
||||
set_extra_vars_info/2,
|
||||
@ -75,11 +75,19 @@ call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Solutions) :-
|
||||
writeln(factorIds:FactorIds), writeln(''),
|
||||
writeln(evidence:Evidence), writeln(''),
|
||||
writeln(evIds:EvIds),
|
||||
create_ground_network(Type, FactorIds, GroundNetwork).
|
||||
%run_ground_solver(Network, QueryIds, EvIds, Solutions),
|
||||
create_ground_network(Type, FactorIds, EvIds, Network),
|
||||
run_ground_fixme_solver(ground(Network,Hash), QueryKeys, Solutions).
|
||||
%free_graphical_model(Network).
|
||||
|
||||
|
||||
run_ground_fixme_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
||||
%get_dists_parameters(DistIds, DistsParams),
|
||||
%set_bayes_net_params(Network, DistsParams),
|
||||
%vars_to_ids(QueryVars, QueryVarsIds),
|
||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||
run_ground_solver(Network, [QueryIds], Solutions).
|
||||
|
||||
|
||||
get_factors_type([f(bayes, _, _)|_], bayes) :- ! .
|
||||
get_factors_type([f(markov, _, _)|_], markov) :- ! .
|
||||
|
||||
|
80
packages/CLPBN/clpbn/bp/BayesBall.cpp
Normal file
80
packages/CLPBN/clpbn/bp/BayesBall.cpp
Normal file
@ -0,0 +1,80 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "xmlParser/xmlParser.h"
|
||||
|
||||
#include "BayesBall.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||
{
|
||||
assert (fg_.isFromBayesNetwork());
|
||||
|
||||
Scheduling scheduling;
|
||||
for (unsigned i = 0; i < queryIds.size(); i++) {
|
||||
assert (dag_.getNode (queryIds[i]));
|
||||
DAGraphNode* n = dag_.getNode (queryIds[i]);
|
||||
scheduling.push (ScheduleInfo (n, false, true));
|
||||
}
|
||||
|
||||
while (!scheduling.empty()) {
|
||||
ScheduleInfo& sch = scheduling.front();
|
||||
DAGraphNode* n = sch.node;
|
||||
n->setAsVisited();
|
||||
if (n->hasEvidence() == false && sch.visitedFromChild) {
|
||||
if (n->isMarkedOnTop() == false) {
|
||||
n->markOnTop();
|
||||
scheduleParents (n, scheduling);
|
||||
}
|
||||
if (n->isMarkedOnBottom() == false) {
|
||||
n->markOnBottom();
|
||||
scheduleChilds (n, scheduling);
|
||||
}
|
||||
}
|
||||
if (sch.visitedFromParent) {
|
||||
if (n->hasEvidence() && n->isMarkedOnTop() == false) {
|
||||
n->markOnTop();
|
||||
scheduleParents (n, scheduling);
|
||||
}
|
||||
if (n->hasEvidence() == false && n->isMarkedOnBottom() == false) {
|
||||
n->markOnBottom();
|
||||
scheduleChilds (n, scheduling);
|
||||
}
|
||||
}
|
||||
scheduling.pop();
|
||||
}
|
||||
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
constructGraph (fg);
|
||||
return fg;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesBall::constructGraph (FactorGraph* fg) const
|
||||
{
|
||||
const FgFacSet& facNodes = fg_.getFactorNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const DAGraphNode* n = dag_.getNode (
|
||||
facNodes[i]->factor()->argument (0));
|
||||
if (n->isMarkedOnTop()) {
|
||||
fg->addFactor (Factor (*(facNodes[i]->factor())));
|
||||
} else if (n->hasEvidence() && n->isVisited()) {
|
||||
VarIds varIds = { facNodes[i]->factor()->argument (0) };
|
||||
Ranges ranges = { facNodes[i]->factor()->range (0) };
|
||||
Params params (ranges[0], LogAware::noEvidence());
|
||||
params[n->getEvidence()] = LogAware::withEvidence();
|
||||
fg->addFactor (Factor (varIds, ranges, params));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
86
packages/CLPBN/clpbn/bp/BayesBall.h
Normal file
86
packages/CLPBN/clpbn/bp/BayesBall.h
Normal file
@ -0,0 +1,86 @@
|
||||
#ifndef HORUS_BAYESBALL_H
|
||||
#define HORUS_BAYESBALL_H
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <list>
|
||||
#include <map>
|
||||
|
||||
#include "GraphicalModel.h"
|
||||
#include "Horus.h"
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "BayesNet.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
struct ScheduleInfo
|
||||
{
|
||||
ScheduleInfo (DAGraphNode* n, bool vfp, bool vfc) :
|
||||
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||
DAGraphNode* node;
|
||||
bool visitedFromParent;
|
||||
bool visitedFromChild;
|
||||
};
|
||||
|
||||
|
||||
typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
|
||||
|
||||
|
||||
class BayesBall
|
||||
{
|
||||
public:
|
||||
BayesBall (FactorGraph& fg)
|
||||
: fg_(fg) , dag_(fg.getStructure())
|
||||
{
|
||||
dag_.clear();
|
||||
}
|
||||
|
||||
FactorGraph* getMinimalFactorGraph (const VarIds&);
|
||||
|
||||
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
|
||||
{
|
||||
BayesBall bb (fg);
|
||||
return bb.getMinimalFactorGraph (vids);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
void constructGraph (FactorGraph* fg) const;
|
||||
|
||||
void scheduleParents (const DAGraphNode* n, Scheduling& sch) const;
|
||||
|
||||
void scheduleChilds (const DAGraphNode* n, Scheduling& sch) const;
|
||||
|
||||
FactorGraph& fg_;
|
||||
|
||||
DAGraph& dag_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
|
||||
{
|
||||
const vector<DAGraphNode*>& ps = n->parents();
|
||||
for (vector<DAGraphNode*>::const_iterator it = ps.begin();
|
||||
it != ps.end(); it++) {
|
||||
sch.push (ScheduleInfo (*it, false, true));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const
|
||||
{
|
||||
const vector<DAGraphNode*>& cs = n->childs();
|
||||
for (vector<DAGraphNode*>::const_iterator it = cs.begin();
|
||||
it != cs.end(); it++) {
|
||||
sch.push (ScheduleInfo (*it, true, false));
|
||||
}
|
||||
}
|
||||
|
||||
#endif // HORUS_BAYESBALL_H
|
||||
|
@ -11,6 +11,104 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
void
|
||||
DAGraph::addNode (DAGraphNode* n)
|
||||
{
|
||||
nodes_.push_back (n);
|
||||
assert (Util::contains (varMap_, n->varId()) == false);
|
||||
varMap_[n->varId()] = n;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
DAGraph::addEdge (VarId vid1, VarId vid2)
|
||||
{
|
||||
unordered_map<VarId, DAGraphNode*>::iterator it1;
|
||||
unordered_map<VarId, DAGraphNode*>::iterator it2;
|
||||
it1 = varMap_.find (vid1);
|
||||
it2 = varMap_.find (vid2);
|
||||
assert (it1 != varMap_.end());
|
||||
assert (it2 != varMap_.end());
|
||||
it1->second->addChild (it2->second);
|
||||
it2->second->addParent (it1->second);
|
||||
}
|
||||
|
||||
|
||||
|
||||
const DAGraphNode*
|
||||
DAGraph::getNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
DAGraphNode*
|
||||
DAGraph::getNode (VarId vid)
|
||||
{
|
||||
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
DAGraph::setIndexes (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->setIndex (i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
DAGraph::clear (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->clear();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
DAGraph::exportToGraphViz (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "DAGraph::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
out << "digraph {" << endl;
|
||||
out << "ranksep=1" << endl;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
out << nodes_[i]->varId() ;
|
||||
out << " [" ;
|
||||
out << "label=\"" << nodes_[i]->label() << "\"" ;
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
out << ",style=filled, fillcolor=yellow" ;
|
||||
}
|
||||
out << "]" << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
||||
out << " [style=bold]" << endl ;
|
||||
}
|
||||
}
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
BayesNet::~BayesNet (void)
|
||||
{
|
||||
@ -36,8 +134,8 @@ BayesNet::readFromBifFormat (const char* fileName)
|
||||
}
|
||||
States states;
|
||||
string label = var.getChildNode("NAME").getText();
|
||||
unsigned nrStates = var.nChildNode ("OUTCOME");
|
||||
for (unsigned j = 0; j < nrStates; j++) {
|
||||
unsigned range = var.nChildNode ("OUTCOME");
|
||||
for (unsigned j = 0; j < range; j++) {
|
||||
if (var.getChildNode("OUTCOME", j).getText() == 0) {
|
||||
stringstream ss;
|
||||
ss << j + 1;
|
||||
@ -63,7 +161,7 @@ BayesNet::readFromBifFormat (const char* fileName)
|
||||
abort();
|
||||
}
|
||||
BnNodeSet parents;
|
||||
unsigned nParams = node->nrStates();
|
||||
unsigned nParams = node->range();
|
||||
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
|
||||
string parentLabel = def.getChildNode("GIVEN", j).getText();
|
||||
BayesNode* parentNode = getBayesNode (parentLabel);
|
||||
@ -71,7 +169,7 @@ BayesNet::readFromBifFormat (const char* fileName)
|
||||
cerr << "error: unknow variable `" << parentLabel << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
nParams *= parentNode->nrStates();
|
||||
nParams *= parentNode->range();
|
||||
parents.push_back (parentNode);
|
||||
}
|
||||
node->setParents (parents);
|
||||
@ -87,7 +185,7 @@ BayesNet::readFromBifFormat (const char* fileName)
|
||||
cerr << "for variable `" << label << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
params = reorderParameters (params, node->nrStates());
|
||||
params = reorderParameters (params, node->range());
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
@ -218,130 +316,6 @@ BayesNet::getLeafNodes (void) const
|
||||
|
||||
|
||||
|
||||
BayesNet*
|
||||
BayesNet::getMinimalRequesiteNetwork (VarId vid) const
|
||||
{
|
||||
return getMinimalRequesiteNetwork (VarIds() = {vid});
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNet*
|
||||
BayesNet::getMinimalRequesiteNetwork (const VarIds& queryVarIds) const
|
||||
{
|
||||
BnNodeSet queryVars;
|
||||
Scheduling scheduling;
|
||||
for (unsigned i = 0; i < queryVarIds.size(); i++) {
|
||||
BayesNode* n = getBayesNode (queryVarIds[i]);
|
||||
assert (n);
|
||||
queryVars.push_back (n);
|
||||
scheduling.push (ScheduleInfo (n, false, true));
|
||||
}
|
||||
|
||||
vector<StateInfo*> states (nodes_.size(), 0);
|
||||
|
||||
while (!scheduling.empty()) {
|
||||
ScheduleInfo& sch = scheduling.front();
|
||||
StateInfo* state = states[sch.node->getIndex()];
|
||||
if (!state) {
|
||||
state = new StateInfo();
|
||||
states[sch.node->getIndex()] = state;
|
||||
} else {
|
||||
state->visited = true;
|
||||
}
|
||||
if (!sch.node->hasEvidence() && sch.visitedFromChild) {
|
||||
if (!state->markedOnTop) {
|
||||
state->markedOnTop = true;
|
||||
scheduleParents (sch.node, scheduling);
|
||||
}
|
||||
if (!state->markedOnBottom) {
|
||||
state->markedOnBottom = true;
|
||||
scheduleChilds (sch.node, scheduling);
|
||||
}
|
||||
}
|
||||
if (sch.visitedFromParent) {
|
||||
if (sch.node->hasEvidence() && !state->markedOnTop) {
|
||||
state->markedOnTop = true;
|
||||
scheduleParents (sch.node, scheduling);
|
||||
}
|
||||
if (!sch.node->hasEvidence() && !state->markedOnBottom) {
|
||||
state->markedOnBottom = true;
|
||||
scheduleChilds (sch.node, scheduling);
|
||||
}
|
||||
}
|
||||
scheduling.pop();
|
||||
}
|
||||
/*
|
||||
cout << "\t\ttop\tbottom" << endl;
|
||||
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
|
||||
Util::printDashedLine();
|
||||
cout << endl;
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
cout << nodes_[i]->label() << ":\t\t" ;
|
||||
if (states[i]) {
|
||||
states[i]->markedOnTop ? cout << "yes\t" : cout << "no\t" ;
|
||||
states[i]->markedOnBottom ? cout << "yes\t" : cout << "no\t" ;
|
||||
states[i]->visited ? cout << "yes\t" : cout << "no\t" ;
|
||||
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
|
||||
cout << endl;
|
||||
} else {
|
||||
cout << "no\tno\tno\t" ;
|
||||
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
cout << endl;
|
||||
*/
|
||||
BayesNet* bn = new BayesNet();
|
||||
constructGraph (bn, states);
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
delete states[i];
|
||||
}
|
||||
return bn;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::constructGraph (BayesNet* bn,
|
||||
const vector<StateInfo*>& states) const
|
||||
{
|
||||
BnNodeSet mrnNodes;
|
||||
vector<VarIds> parents;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
bool isRequired = false;
|
||||
if (states[i]) {
|
||||
isRequired = (nodes_[i]->hasEvidence() && states[i]->visited)
|
||||
||
|
||||
states[i]->markedOnTop;
|
||||
}
|
||||
if (isRequired) {
|
||||
parents.push_back (VarIds());
|
||||
if (states[i]->markedOnTop) {
|
||||
const BnNodeSet& ps = nodes_[i]->getParents();
|
||||
for (unsigned j = 0; j < ps.size(); j++) {
|
||||
parents.back().push_back (ps[j]->varId());
|
||||
}
|
||||
}
|
||||
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
|
||||
BayesNode* mrnNode = new BayesNode (nodes_[i]);
|
||||
bn->addNode (mrnNode);
|
||||
mrnNodes.push_back (mrnNode);
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < mrnNodes.size(); i++) {
|
||||
BnNodeSet ps;
|
||||
for (unsigned j = 0; j < parents[i].size(); j++) {
|
||||
assert (bn->getBayesNode (parents[i][j]) != 0);
|
||||
ps.push_back (bn->getBayesNode (parents[i][j]));
|
||||
}
|
||||
mrnNodes[i]->setParents (ps);
|
||||
}
|
||||
bn->setIndexes();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BayesNet::isPolyTree (void) const
|
||||
@ -458,7 +432,7 @@ BayesNet::exportToBifFormat (const char* fileName) const
|
||||
out << "</GIVEN>" << endl;
|
||||
}
|
||||
Params params = revertParameterReorder (
|
||||
nodes_[i]->params(), nodes_[i]->nrStates());
|
||||
nodes_[i]->params(), nodes_[i]->range());
|
||||
out << "\t<TABLE>" ;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
out << " " << params[j];
|
||||
|
@ -14,28 +14,79 @@
|
||||
using namespace std;
|
||||
|
||||
|
||||
struct ScheduleInfo
|
||||
|
||||
class VarNode;
|
||||
|
||||
class DAGraphNode : public VarNode
|
||||
{
|
||||
ScheduleInfo (BayesNode* n, bool vfp, bool vfc) :
|
||||
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||
BayesNode* node;
|
||||
bool visitedFromParent;
|
||||
bool visitedFromChild;
|
||||
public:
|
||||
DAGraphNode (VarNode* vn) : VarNode (vn) , visited_(false),
|
||||
markedOnTop_(false), markedOnBottom_(false) { }
|
||||
|
||||
const vector<DAGraphNode*>& childs (void) const { return childs_; }
|
||||
|
||||
vector<DAGraphNode*>& childs (void) { return childs_; }
|
||||
|
||||
const vector<DAGraphNode*>& parents (void) const { return parents_; }
|
||||
|
||||
vector<DAGraphNode*>& parents (void) { return parents_; }
|
||||
|
||||
void addParent (DAGraphNode* p) { parents_.push_back (p); }
|
||||
|
||||
void addChild (DAGraphNode* c) { childs_.push_back (c); }
|
||||
|
||||
bool isVisited (void) const { return visited_; }
|
||||
|
||||
void setAsVisited (void) { visited_ = true; }
|
||||
|
||||
bool isMarkedOnTop (void) const { return markedOnTop_; }
|
||||
|
||||
void markOnTop (void) { markedOnTop_ = true; }
|
||||
|
||||
bool isMarkedOnBottom (void) const { return markedOnBottom_; }
|
||||
|
||||
void markOnBottom (void) { markedOnBottom_ = true; }
|
||||
|
||||
void clear (void) { visited_ = markedOnTop_ = markedOnBottom_ = false; }
|
||||
|
||||
private:
|
||||
bool visited_;
|
||||
bool markedOnTop_;
|
||||
bool markedOnBottom_;
|
||||
|
||||
vector<DAGraphNode*> childs_;
|
||||
vector<DAGraphNode*> parents_;
|
||||
};
|
||||
|
||||
|
||||
struct StateInfo
|
||||
class DAGraph
|
||||
{
|
||||
StateInfo (void) : visited(false), markedOnTop(false),
|
||||
markedOnBottom(false) { }
|
||||
bool visited;
|
||||
bool markedOnTop;
|
||||
bool markedOnBottom;
|
||||
public:
|
||||
DAGraph (void) { }
|
||||
|
||||
void addNode (DAGraphNode* n);
|
||||
|
||||
void addEdge (VarId vid1, VarId vid2);
|
||||
|
||||
const DAGraphNode* getNode (VarId vid) const;
|
||||
|
||||
DAGraphNode* getNode (VarId vid);
|
||||
|
||||
bool empty (void) const { return nodes_.empty(); }
|
||||
|
||||
void setIndexes (void);
|
||||
|
||||
void clear (void);
|
||||
|
||||
void exportToGraphViz (const char*);
|
||||
|
||||
private:
|
||||
vector<DAGraphNode*> nodes_;
|
||||
|
||||
unordered_map<VarId, DAGraphNode*> varMap_;
|
||||
};
|
||||
|
||||
|
||||
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
|
||||
|
||||
|
||||
class BayesNet : public GraphicalModel
|
||||
{
|
||||
@ -66,12 +117,6 @@ class BayesNet : public GraphicalModel
|
||||
|
||||
BnNodeSet getLeafNodes (void) const;
|
||||
|
||||
BayesNet* getMinimalRequesiteNetwork (VarId) const;
|
||||
|
||||
BayesNet* getMinimalRequesiteNetwork (const VarIds&) const;
|
||||
|
||||
void constructGraph (BayesNet*, const vector<StateInfo*>&) const;
|
||||
|
||||
bool isPolyTree (void) const;
|
||||
|
||||
void setIndexes (void);
|
||||
@ -96,37 +141,10 @@ class BayesNet : public GraphicalModel
|
||||
|
||||
Params revertParameterReorder (const Params&, unsigned) const;
|
||||
|
||||
void scheduleParents (const BayesNode*, Scheduling&) const;
|
||||
|
||||
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
||||
|
||||
BnNodeSet nodes_;
|
||||
|
||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
||||
IndexMap varMap_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
|
||||
{
|
||||
const BnNodeSet& ps = n->getParents();
|
||||
for (BnNodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) {
|
||||
sch.push (ScheduleInfo (*it, false, true));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
|
||||
{
|
||||
const BnNodeSet& cs = n->getChilds();
|
||||
for (BnNodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) {
|
||||
sch.push (ScheduleInfo (*it, true, false));
|
||||
}
|
||||
}
|
||||
|
||||
#endif // HORUS_BAYESNET_H
|
||||
|
||||
|
@ -127,7 +127,7 @@ BayesNode::getDomainHeaders (void) const
|
||||
States states = parents_[i]->states();
|
||||
unsigned index = 0;
|
||||
while (index < rowSize) {
|
||||
for (unsigned j = 0; j < parents_[i]->nrStates(); j++) {
|
||||
for (unsigned j = 0; j < parents_[i]->range(); j++) {
|
||||
for (unsigned r = 0; r < nReps; r++) {
|
||||
if (headers[index] != "") {
|
||||
headers[index] = states[j] + "," + headers[index];
|
||||
@ -138,7 +138,7 @@ BayesNode::getDomainHeaders (void) const
|
||||
}
|
||||
}
|
||||
}
|
||||
nReps *= parents_[i]->nrStates();
|
||||
nReps *= parents_[i]->range();
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
@ -16,12 +16,12 @@ class BayesNode : public VarNode
|
||||
BayesNode (const VarNode& v) : VarNode (v) { }
|
||||
|
||||
BayesNode (const BayesNode* n) :
|
||||
VarNode (n->varId(), n->nrStates(), n->getEvidence()),
|
||||
VarNode (n->varId(), n->range(), n->getEvidence()),
|
||||
params_(n->params()), distId_(n->distId()) { }
|
||||
|
||||
BayesNode (VarId vid, unsigned nrStates, int ev,
|
||||
BayesNode (VarId vid, unsigned range, int ev,
|
||||
const Params& ps, unsigned id)
|
||||
: VarNode (vid, nrStates, ev) , params_(ps), distId_(id) { }
|
||||
: VarNode (vid, range, ev) , params_(ps), distId_(id) { }
|
||||
|
||||
const BnNodeSet& getParents (void) const { return parents_; }
|
||||
|
||||
@ -33,7 +33,7 @@ class BayesNode : public VarNode
|
||||
|
||||
unsigned getRowSize (void) const
|
||||
{
|
||||
return params_.size() / nrStates();
|
||||
return params_.size() / range();
|
||||
}
|
||||
|
||||
double getProbability (int row, unsigned col)
|
||||
|
@ -109,7 +109,7 @@ BnBpSolver::initializeSolver (void)
|
||||
for (unsigned i = 0; i < roots.size(); i++) {
|
||||
const Params& params = roots[i]->params();
|
||||
Params& piVals = ninf(roots[i])->getPiValues();
|
||||
for (unsigned ri = 0; ri < roots[i]->nrStates(); ri++) {
|
||||
for (unsigned ri = 0; ri < roots[i]->range(); ri++) {
|
||||
piVals[ri] = params[ri];
|
||||
}
|
||||
}
|
||||
@ -137,7 +137,7 @@ BnBpSolver::initializeSolver (void)
|
||||
if (nodes[i]->hasEvidence()) {
|
||||
Params& piVals = ninf(nodes[i])->getPiValues();
|
||||
Params& ldVals = ninf(nodes[i])->getLambdaValues();
|
||||
for (unsigned xi = 0; xi < nodes[i]->nrStates(); xi++) {
|
||||
for (unsigned xi = 0; xi < nodes[i]->range(); xi++) {
|
||||
piVals[xi] = LogAware::noEvidence();
|
||||
ldVals[xi] = LogAware::noEvidence();
|
||||
}
|
||||
@ -310,7 +310,7 @@ BnBpSolver::updatePiValues (BayesNode* x)
|
||||
const BnNodeSet& ps = x->getParents();
|
||||
Ranges ranges;
|
||||
for (unsigned i = 0; i < ps.size(); i++) {
|
||||
ranges.push_back (ps[i]->nrStates());
|
||||
ranges.push_back (ps[i]->range());
|
||||
}
|
||||
StatesIndexer indexer (ranges, false);
|
||||
stringstream* calcs1 = 0;
|
||||
@ -354,7 +354,7 @@ BnBpSolver::updatePiValues (BayesNode* x)
|
||||
++ indexer;
|
||||
}
|
||||
|
||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||
for (unsigned xi = 0; xi < x->range(); xi++) {
|
||||
double sum = LogAware::addIdenty();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
@ -409,7 +409,7 @@ BnBpSolver::updateLambdaValues (BayesNode* x)
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||
for (unsigned xi = 0; xi < x->range(); xi++) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
@ -461,7 +461,7 @@ BnBpSolver::calculatePiMessage (BpLink* link)
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
const Params& zPiValues = ninf(z)->getPiValues();
|
||||
for (unsigned zi = 0; zi < z->nrStates(); zi++) {
|
||||
for (unsigned zi = 0; zi < z->range(); zi++) {
|
||||
double product = zPiValues[zi];
|
||||
if (Constants::DEBUG >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
@ -526,12 +526,12 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
||||
const BnNodeSet& ps = y->getParents();
|
||||
Ranges ranges;
|
||||
for (unsigned i = 0; i < ps.size(); i++) {
|
||||
ranges.push_back (ps[i]->nrStates());
|
||||
ranges.push_back (ps[i]->range());
|
||||
}
|
||||
StatesIndexer indexer (ranges, false);
|
||||
|
||||
|
||||
unsigned N = indexer.size() / x->nrStates();
|
||||
unsigned N = indexer.size() / x->range();
|
||||
Params messageProducts (N);
|
||||
for (unsigned k = 0; k < N; k++) {
|
||||
while (indexer[parentIndex] != 0) {
|
||||
@ -579,13 +579,13 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||
for (unsigned xi = 0; xi < x->range(); xi++) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double outerSum = LogAware::addIdenty();
|
||||
for (unsigned yi = 0; yi < y->nrStates(); yi++) {
|
||||
for (unsigned yi = 0; yi < y->range(); yi++) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
|
||||
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
|
||||
@ -645,6 +645,7 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
||||
Params
|
||||
BnBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
/*
|
||||
BnNodeSet jointVars;
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
assert (bayesNet_->getBayesNode (jointVarIds[i]));
|
||||
@ -685,7 +686,7 @@ BnBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
|
||||
int count = -1;
|
||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->nrStates() == 0) {
|
||||
if (j % jointVars[i]->range() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
@ -695,6 +696,8 @@ BnBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
delete mrn;
|
||||
}
|
||||
return prevBeliefs;
|
||||
*/
|
||||
return Params();
|
||||
}
|
||||
|
||||
|
||||
@ -714,7 +717,7 @@ BnBpSolver::printPiLambdaValues (const BayesNode* var) const
|
||||
const Params& piVals = ninf(var)->getPiValues();
|
||||
const Params& ldVals = ninf(var)->getLambdaValues();
|
||||
const Params& beliefs = ninf(var)->getBeliefs();
|
||||
for (unsigned xi = 0; xi < var->nrStates(); xi++) {
|
||||
for (unsigned xi = 0; xi < var->range(); xi++) {
|
||||
cout << setw (10) << states[xi];
|
||||
cout << setw (19) << piVals[xi];
|
||||
cout << setw (19) << ldVals[xi];
|
||||
@ -741,8 +744,8 @@ BnBpSolver::printAllMessageStatus (void) const
|
||||
BpNodeInfo::BpNodeInfo (BayesNode* node)
|
||||
{
|
||||
node_ = node;
|
||||
piVals_.resize (node->nrStates(), LogAware::one());
|
||||
ldVals_.resize (node->nrStates(), LogAware::one());
|
||||
piVals_.resize (node->range(), LogAware::one());
|
||||
ldVals_.resize (node->range(), LogAware::one());
|
||||
}
|
||||
|
||||
|
||||
@ -751,20 +754,20 @@ Params
|
||||
BpNodeInfo::getBeliefs (void) const
|
||||
{
|
||||
double sum = 0.0;
|
||||
Params beliefs (node_->nrStates());
|
||||
Params beliefs (node_->range());
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
||||
for (unsigned xi = 0; xi < node_->range(); xi++) {
|
||||
beliefs[xi] = exp (piVals_[xi] + ldVals_[xi]);
|
||||
sum += beliefs[xi];
|
||||
}
|
||||
} else {
|
||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
||||
for (unsigned xi = 0; xi < node_->range(); xi++) {
|
||||
beliefs[xi] = piVals_[xi] * ldVals_[xi];
|
||||
sum += beliefs[xi];
|
||||
}
|
||||
}
|
||||
assert (sum);
|
||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
||||
for (unsigned xi = 0; xi < node_->range(); xi++) {
|
||||
beliefs[xi] /= sum;
|
||||
}
|
||||
return beliefs;
|
||||
@ -779,7 +782,7 @@ BpNodeInfo::receivedBottomInfluence (void) const
|
||||
// this node neither its descendents have evidence,
|
||||
// we can use this to don't send lambda messages his parents
|
||||
bool childInfluenced = false;
|
||||
for (unsigned xi = 1; xi < node_->nrStates(); xi++) {
|
||||
for (unsigned xi = 1; xi < node_->range(); xi++) {
|
||||
if (ldVals_[xi] != ldVals_[0]) {
|
||||
childInfluenced = true;
|
||||
break;
|
||||
|
@ -27,11 +27,11 @@ class BpLink
|
||||
destin_ = d;
|
||||
orientation_ = o;
|
||||
if (orientation_ == LinkOrientation::DOWN) {
|
||||
v1_.resize (s->nrStates(), LogAware::tl (1.0 / s->nrStates()));
|
||||
v2_.resize (s->nrStates(), LogAware::tl (1.0 / s->nrStates()));
|
||||
v1_.resize (s->range(), LogAware::tl (1.0 / s->range()));
|
||||
v2_.resize (s->range(), LogAware::tl (1.0 / s->range()));
|
||||
} else {
|
||||
v1_.resize (d->nrStates(), LogAware::tl (1.0 / d->nrStates()));
|
||||
v2_.resize (d->nrStates(), LogAware::tl (1.0 / d->nrStates()));
|
||||
v1_.resize (d->range(), LogAware::tl (1.0 / d->range()));
|
||||
v2_.resize (d->range(), LogAware::tl (1.0 / d->range()));
|
||||
}
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
|
@ -51,7 +51,7 @@ CFactorGraph::setInitialColors (void)
|
||||
VarColorMap colorMap;
|
||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
unsigned dsize = varNodes[i]->nrStates();
|
||||
unsigned dsize = varNodes[i]->range();
|
||||
VarColorMap::iterator it = colorMap.find (dsize);
|
||||
if (it == colorMap.end()) {
|
||||
it = colorMap.insert (make_pair (
|
||||
|
@ -20,10 +20,10 @@ CbpSolver::getPosterioriOf (VarId vid)
|
||||
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->nrStates(), LogAware::noEvidence());
|
||||
probs.resize (var->range(), LogAware::noEvidence());
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->nrStates(), LogAware::multIdenty());
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
@ -196,7 +196,7 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
const FgFacNode* dst = link->getFactor();
|
||||
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->nrStates(), LogAware::noEvidence());
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
double value = link->getMessage()[src->getEvidence()];
|
||||
msg[src->getEvidence()] = LogAware::pow (value, l->getNumberOfEdges() - 1);
|
||||
} else {
|
||||
|
@ -12,7 +12,7 @@ class CbpSolverLink : public SpLink
|
||||
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn)
|
||||
{
|
||||
edgeCount_ = c;
|
||||
poweredMsg_.resize (vn->nrStates(), LogAware::one());
|
||||
poweredMsg_.resize (vn->range(), LogAware::one());
|
||||
}
|
||||
|
||||
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
||||
@ -53,7 +53,7 @@ class CbpSolver : public FgBpSolver
|
||||
Params getVar2FactorMsg (const SpLink*) const;
|
||||
void printLinkInformation (void) const;
|
||||
|
||||
CFactorGraph* lfg_;
|
||||
CFactorGraph* lfg_;
|
||||
};
|
||||
|
||||
#endif // HORUS_CBP_H
|
||||
|
@ -3,52 +3,37 @@
|
||||
#include <fstream>
|
||||
|
||||
#include "ElimGraph.h"
|
||||
#include "BayesNet.h"
|
||||
|
||||
|
||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
||||
|
||||
|
||||
ElimGraph::ElimGraph (const BayesNet& bayesNet)
|
||||
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||
{
|
||||
const BnNodeSet& bnNodes = bayesNet.getBayesNodes();
|
||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
||||
if (bnNodes[i]->hasEvidence() == false) {
|
||||
addNode (new EgNode (bnNodes[i]));
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
||||
if (bnNodes[i]->hasEvidence() == false) {
|
||||
EgNode* n = getEgNode (bnNodes[i]->varId());
|
||||
const BnNodeSet& childs = bnNodes[i]->getChilds();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
if (childs[j]->hasEvidence() == false) {
|
||||
addEdge (n, getEgNode (childs[j]->varId()));
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
const VarIds& vids = factors[i]->arguments();
|
||||
for (unsigned j = 0; j < vids.size() - 1; j++) {
|
||||
EgNode* n1 = getEgNode (vids[j]);
|
||||
if (n1 == 0) {
|
||||
n1 = new EgNode (vids[j], factors[i]->range (j));
|
||||
addNode (n1);
|
||||
}
|
||||
for (unsigned k = j + 1; k < vids.size(); k++) {
|
||||
EgNode* n2 = getEgNode (vids[k]);
|
||||
if (n2 == 0) {
|
||||
n2 = new EgNode (vids[k], factors[i]->range (k));
|
||||
addNode (n2);
|
||||
}
|
||||
if (neighbors (n1, n2) == false) {
|
||||
addEdge (n1, n2);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (vids.size() == 1) {
|
||||
if (getEgNode (vids[0]) == 0) {
|
||||
addNode (new EgNode (vids[0], factors[i]->range (0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
||||
vector<EgNode*> neighs;
|
||||
const vector<BayesNode*>& parents = bnNodes[i]->getParents();
|
||||
for (unsigned i = 0; i < parents.size(); i++) {
|
||||
if (parents[i]->hasEvidence() == false) {
|
||||
neighs.push_back (getEgNode (parents[i]->varId()));
|
||||
}
|
||||
}
|
||||
if (neighs.size() > 0) {
|
||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||
if (!neighbors (neighs[i], neighs[j])) {
|
||||
addEdge (neighs[i], neighs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
@ -63,40 +48,16 @@ ElimGraph::~ElimGraph (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::addNode (EgNode* n)
|
||||
{
|
||||
nodes_.push_back (n);
|
||||
varMap_.insert (make_pair (n->varId(), n));
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getEgNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId,EgNode*>::const_iterator it =varMap_.find (vid);
|
||||
if (it ==varMap_.end()) {
|
||||
return 0;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarIds
|
||||
ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
||||
{
|
||||
VarIds elimOrder;
|
||||
marked_.resize (nodes_.size(), false);
|
||||
|
||||
for (unsigned i = 0; i < exclude.size(); i++) {
|
||||
assert (getEgNode (exclude[i]));
|
||||
EgNode* node = getEgNode (exclude[i]);
|
||||
assert (node);
|
||||
marked_[*node] = true;
|
||||
}
|
||||
|
||||
unsigned nVarsToEliminate = nodes_.size() - exclude.size();
|
||||
for (unsigned i = 0; i < nVarsToEliminate; i++) {
|
||||
EgNode* node = getLowestCostNode();
|
||||
@ -109,6 +70,99 @@ ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::print (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
cout << " " << neighs[j]->label();
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::exportToGraphViz (
|
||||
const char* fileName,
|
||||
bool showNeighborless,
|
||||
const VarIds& highlightVarIds) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Markov::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "strict graph {" << endl;
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
||||
out << '"' << nodes_[i]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
||||
EgNode* node =getEgNode (highlightVarIds[i]);
|
||||
if (node) {
|
||||
out << '"' << node->label() << '"' ;
|
||||
out << " [shape=box3d]" << endl;
|
||||
} else {
|
||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
||||
out << '"' << neighs[j]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarIds
|
||||
ElimGraph::getEliminationOrder (
|
||||
const vector<Factor*> factors,
|
||||
VarIds excludedVids)
|
||||
{
|
||||
ElimGraph graph (factors);
|
||||
// graph.print();
|
||||
// graph.exportToGraphViz ("_egg.dot");
|
||||
return graph.getEliminatingOrder (excludedVids);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::addNode (EgNode* n)
|
||||
{
|
||||
nodes_.push_back (n);
|
||||
varMap_.insert (make_pair (n->varId(), n));
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getEgNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId, EgNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return (it != varMap_.end()) ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getLowestCostNode (void) const
|
||||
{
|
||||
@ -166,7 +220,7 @@ ElimGraph::getWeightCost (const EgNode* n) const
|
||||
const vector<EgNode*>& neighs = n->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
if (marked_[*neighs[i]] == false) {
|
||||
cost *= neighs[i]->nrStates();
|
||||
cost *= neighs[i]->range();
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
@ -206,7 +260,7 @@ ElimGraph::getWeightedFillCost (const EgNode* n) const
|
||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||
if (marked_[*neighs[j]] == true) continue;
|
||||
if (!neighbors (neighs[i], neighs[j])) {
|
||||
cost += neighs[i]->nrStates() * neighs[j]->nrStates();
|
||||
cost += neighs[i]->range() * neighs[j]->range();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -257,68 +311,3 @@ ElimGraph::setIndexes (void)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::printGraphicalModel (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
cout << " " << neighs[j]->label();
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::exportToGraphViz (const char* fileName,
|
||||
bool showNeighborless,
|
||||
const VarIds& highlightVarIds) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Markov::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "strict graph {" << endl;
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
||||
out << '"' << nodes_[i]->label() << '"' ;
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||
} else {
|
||||
out << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
||||
EgNode* node =getEgNode (highlightVarIds[i]);
|
||||
if (node) {
|
||||
out << '"' << node->label() << '"' ;
|
||||
out << " [shape=box3d]" << endl;
|
||||
} else {
|
||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
||||
out << '"' << neighs[j]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,7 @@ enum ElimHeuristic
|
||||
class EgNode : public VarNode
|
||||
{
|
||||
public:
|
||||
EgNode (VarNode* var) : VarNode (var) { }
|
||||
EgNode (VarId vid, unsigned range) : VarNode (vid, range) { }
|
||||
|
||||
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
|
||||
|
||||
@ -34,10 +34,26 @@ class EgNode : public VarNode
|
||||
class ElimGraph
|
||||
{
|
||||
public:
|
||||
ElimGraph (const BayesNet&);
|
||||
ElimGraph (const vector<Factor*>&); // TODO
|
||||
|
||||
~ElimGraph (void);
|
||||
|
||||
VarIds getEliminatingOrder (const VarIds&);
|
||||
|
||||
void print (void) const;
|
||||
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIds& = VarIds()) const;
|
||||
|
||||
static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
|
||||
|
||||
static void setEliminationHeuristic (ElimHeuristic h)
|
||||
{
|
||||
elimHeuristic_ = h;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
void addEdge (EgNode* n1, EgNode* n2)
|
||||
{
|
||||
assert (n1 != n2);
|
||||
@ -48,22 +64,6 @@ class ElimGraph
|
||||
void addNode (EgNode*);
|
||||
|
||||
EgNode* getEgNode (VarId) const;
|
||||
|
||||
VarIds getEliminatingOrder (const VarIds&);
|
||||
|
||||
void printGraphicalModel (void) const;
|
||||
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIds& = VarIds()) const;
|
||||
|
||||
void setIndexes();
|
||||
|
||||
static void setEliminationHeuristic (ElimHeuristic h)
|
||||
{
|
||||
elimHeuristic_ = h;
|
||||
}
|
||||
|
||||
private:
|
||||
EgNode* getLowestCostNode (void) const;
|
||||
|
||||
unsigned getNeighborsCost (const EgNode*) const;
|
||||
@ -78,9 +78,11 @@ class ElimGraph
|
||||
|
||||
bool neighbors (const EgNode*, const EgNode*) const;
|
||||
|
||||
void setIndexes (void);
|
||||
|
||||
vector<EgNode*> nodes_;
|
||||
vector<bool> marked_;
|
||||
unordered_map<VarId,EgNode*> varMap_;
|
||||
unordered_map<VarId, EgNode*> varMap_;
|
||||
static ElimHeuristic elimHeuristic_;
|
||||
};
|
||||
|
||||
|
@ -18,11 +18,11 @@ Factor::Factor (const Factor& g)
|
||||
|
||||
|
||||
|
||||
Factor::Factor (VarId vid, unsigned nrStates)
|
||||
Factor::Factor (VarId vid, unsigned range)
|
||||
{
|
||||
args_.push_back (vid);
|
||||
ranges_.push_back (nrStates);
|
||||
params_.resize (nrStates, 1.0);
|
||||
ranges_.push_back (range);
|
||||
params_.resize (range, 1.0);
|
||||
distId_ = Util::maxUnsigned();
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
@ -34,8 +34,8 @@ Factor::Factor (const VarNodes& vars)
|
||||
int nrParams = 1;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
args_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
nrParams *= vars[i]->nrStates();
|
||||
ranges_.push_back (vars[i]->range());
|
||||
nrParams *= vars[i]->range();
|
||||
}
|
||||
double val = 1.0 / nrParams;
|
||||
params_.resize (nrParams, val);
|
||||
@ -47,11 +47,11 @@ Factor::Factor (const VarNodes& vars)
|
||||
|
||||
Factor::Factor (
|
||||
VarId vid,
|
||||
unsigned nrStates,
|
||||
unsigned range,
|
||||
const Params& params)
|
||||
{
|
||||
args_.push_back (vid);
|
||||
ranges_.push_back (nrStates);
|
||||
ranges_.push_back (range);
|
||||
params_ = params;
|
||||
distId_ = Util::maxUnsigned();
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
@ -66,7 +66,7 @@ Factor::Factor (
|
||||
{
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
args_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
ranges_.push_back (vars[i]->range());
|
||||
}
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
@ -186,8 +186,8 @@ Factor::sumOut (VarId vid)
|
||||
void
|
||||
Factor::sumOutFirstVariable (void)
|
||||
{
|
||||
unsigned nStates = ranges_.front();
|
||||
unsigned sep = params_.size() / nStates;
|
||||
unsigned range = ranges_.front();
|
||||
unsigned sep = params_.size() / range;
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = sep; i < params_.size(); i++) {
|
||||
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
|
||||
@ -207,14 +207,14 @@ Factor::sumOutFirstVariable (void)
|
||||
void
|
||||
Factor::sumOutLastVariable (void)
|
||||
{
|
||||
unsigned nStates = ranges_.back();
|
||||
unsigned range = ranges_.back();
|
||||
unsigned idx1 = 0;
|
||||
unsigned idx2 = 0;
|
||||
if (Globals::logDomain) {
|
||||
while (idx1 < params_.size()) {
|
||||
params_[idx2] = params_[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < nStates; j++) {
|
||||
for (unsigned j = 1; j < range; j++) {
|
||||
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
|
||||
idx1 ++;
|
||||
}
|
||||
@ -224,7 +224,7 @@ Factor::sumOutLastVariable (void)
|
||||
while (idx1 < params_.size()) {
|
||||
params_[idx2] = params_[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < nStates; j++) {
|
||||
for (unsigned j = 1; j < range; j++) {
|
||||
params_[idx2] += params_[idx1];
|
||||
idx1 ++;
|
||||
}
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "BayesNet.h"
|
||||
#include "BayesBall.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
@ -205,13 +206,13 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
var = new FgVarNode (vids[j], dsize);
|
||||
addVariable (var);
|
||||
} else {
|
||||
if (var->nrStates() != dsize) {
|
||||
if (var->range() != dsize) {
|
||||
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
||||
cerr << "more factors with different domain sizes" << endl;
|
||||
}
|
||||
}
|
||||
neighs.push_back (var);
|
||||
nParams *= var->nrStates();
|
||||
nParams *= var->range();
|
||||
}
|
||||
Params params (nParams, 0);
|
||||
unsigned nNonzeros;
|
||||
@ -274,6 +275,30 @@ FactorGraph::addFactor (FgFacNode* fn)
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (const Factor& factor)
|
||||
{
|
||||
FgFacNode* fn = new FgFacNode (factor);
|
||||
addFactor (fn);
|
||||
const VarIds& vids = factor.arguments();
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
bool found = false;
|
||||
for (unsigned j = 0; j < varNodes_.size(); j++) {
|
||||
if (varNodes_[j]->varId() == vids[i]) {
|
||||
addEdge (varNodes_[j], fn);
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
FgVarNode* vn = new FgVarNode (vids[i], factor.range (i));
|
||||
addVariable (vn);
|
||||
addEdge (vn, fn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
||||
{
|
||||
@ -322,6 +347,26 @@ FactorGraph::isTree (void) const
|
||||
|
||||
|
||||
|
||||
DAGraph&
|
||||
FactorGraph::getStructure (void)
|
||||
{
|
||||
assert (fromBayesNet_);
|
||||
if (structure_.empty()) {
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const VarIds& vids = facNodes_[i]->factor()->arguments();
|
||||
for (unsigned j = 1; j < vids.size(); j++) {
|
||||
structure_.addEdge (vids[j], vids[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return structure_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::setIndexes (void)
|
||||
{
|
||||
@ -339,11 +384,11 @@ void
|
||||
FactorGraph::printGraphicalModel (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
cout << "VarId = " << varNodes_[i]->varId() << endl;
|
||||
cout << "Label = " << varNodes_[i]->label() << endl;
|
||||
cout << "Nr States = " << varNodes_[i]->nrStates() << endl;
|
||||
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||
cout << "Factors = " ;
|
||||
cout << "var id = " << varNodes_[i]->varId() << endl;
|
||||
cout << "label = " << varNodes_[i]->label() << endl;
|
||||
cout << "range = " << varNodes_[i]->range() << endl;
|
||||
cout << "evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||
cout << "factors = " ;
|
||||
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
|
||||
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
|
||||
}
|
||||
@ -351,7 +396,6 @@ FactorGraph::printGraphicalModel (void) const
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
facNodes_[i]->factor()->print();
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -366,22 +410,18 @@ FactorGraph::exportToGraphViz (const char* fileName) const
|
||||
cerr << "FactorGraph::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "graph \"" << fileName << "\" {" << endl;
|
||||
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
if (varNodes_[i]->hasEvidence()) {
|
||||
out << '"' << varNodes_[i]->label() << '"' ;
|
||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||
out << " [label=\"" << facNodes_[i]->getLabel();
|
||||
out << "\"" << ", shape=box]" << endl;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const FgVarSet& myVars = facNodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < myVars.size(); j++) {
|
||||
@ -390,7 +430,6 @@ FactorGraph::exportToGraphViz (const char* fileName) const
|
||||
out << '"' << myVars[j]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
@ -410,7 +449,7 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
out << "MARKOV" << endl;
|
||||
out << varNodes_.size() << endl;
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
out << varNodes_[i]->nrStates() << " " ;
|
||||
out << varNodes_[i]->range() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
|
||||
@ -459,7 +498,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
}
|
||||
out << endl;
|
||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||
out << factorVars[j]->nrStates() << " " ;
|
||||
out << factorVars[j]->range() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
Params params = facNodes_[i]->factor()->params();
|
||||
@ -496,10 +535,11 @@ FactorGraph::containsCycle (void) const
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (const FgVarNode* v,
|
||||
const FgFacNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
FactorGraph::containsCycle (
|
||||
const FgVarNode* v,
|
||||
const FgFacNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedVars[v->getIndex()] = true;
|
||||
const FgFacSet& adjacencies = v->neighbors();
|
||||
@ -520,10 +560,11 @@ FactorGraph::containsCycle (const FgVarNode* v,
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (const FgFacNode* v,
|
||||
const FgVarNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
FactorGraph::containsCycle (
|
||||
const FgFacNode* v,
|
||||
const FgVarNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedFactors[v->getIndex()] = true;
|
||||
const FgVarSet& adjacencies = v->neighbors();
|
||||
|
@ -3,13 +3,14 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "GraphicalModel.h"
|
||||
#include "Factor.h"
|
||||
#include "GraphicalModel.h"
|
||||
#include "BayesNet.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class BayesNet;
|
||||
|
||||
class FgFacNode;
|
||||
|
||||
|
||||
@ -42,6 +43,8 @@ class FgFacNode
|
||||
|
||||
FgFacNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
|
||||
|
||||
FgFacNode (const Factor& f) : factor_(new Factor (f)), index_(-1) { }
|
||||
|
||||
Factor* factor() const { return factor_; }
|
||||
|
||||
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
|
||||
@ -90,7 +93,7 @@ struct CompVarId
|
||||
class FactorGraph : public GraphicalModel
|
||||
{
|
||||
public:
|
||||
FactorGraph (void) { };
|
||||
FactorGraph (void) { }
|
||||
|
||||
FactorGraph (const FactorGraph&);
|
||||
|
||||
@ -102,14 +105,14 @@ class FactorGraph : public GraphicalModel
|
||||
|
||||
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
|
||||
|
||||
void setFromBayesNetwork (void) { fromBayesNet_ = true; }
|
||||
|
||||
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
|
||||
|
||||
FgVarNode* getFgVarNode (VarId vid) const
|
||||
{
|
||||
IndexMap::const_iterator it = varMap_.find (vid);
|
||||
if (it == varMap_.end()) {
|
||||
return 0;
|
||||
} else {
|
||||
return varNodes_[it->second];
|
||||
}
|
||||
return (it != varMap_.end()) ? varNodes_[it->second] : 0;
|
||||
}
|
||||
|
||||
void readFromUaiFormat (const char*);
|
||||
@ -120,6 +123,8 @@ class FactorGraph : public GraphicalModel
|
||||
|
||||
void addFactor (FgFacNode*);
|
||||
|
||||
void addFactor (const Factor& factor);
|
||||
|
||||
void addEdge (FgVarNode*, FgFacNode*);
|
||||
|
||||
void addEdge (FgFacNode*, FgVarNode*);
|
||||
@ -130,6 +135,8 @@ class FactorGraph : public GraphicalModel
|
||||
|
||||
bool isTree (void) const;
|
||||
|
||||
DAGraph& getStructure (void);
|
||||
|
||||
void setIndexes (void);
|
||||
|
||||
void printGraphicalModel (void) const;
|
||||
@ -156,6 +163,9 @@ class FactorGraph : public GraphicalModel
|
||||
FgVarSet varNodes_;
|
||||
FgFacSet facNodes_;
|
||||
|
||||
bool fromBayesNet_;
|
||||
DAGraph structure_;
|
||||
|
||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
||||
IndexMap varMap_;
|
||||
};
|
||||
|
@ -45,7 +45,7 @@ FgBpSolver::runSolver (void)
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
@ -71,10 +71,10 @@ FgBpSolver::getPosterioriOf (VarId vid)
|
||||
FgVarNode* var = factorGraph_->getFgVarNode (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->nrStates(), LogAware::noEvidence());
|
||||
probs.resize (var->range(), LogAware::noEvidence());
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->nrStates(), LogAware::multIdenty());
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
@ -113,7 +113,7 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Factor msg (links[i]->getVariable()->varId(),
|
||||
links[i]->getVariable()->nrStates(),
|
||||
links[i]->getVariable()->range(),
|
||||
getVar2FactorMsg (links[i]));
|
||||
res.multiply (msg);
|
||||
}
|
||||
@ -325,7 +325,7 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
||||
// to factor `src', except from var `dst'
|
||||
unsigned msgSize = 1;
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
msgSize *= links[i]->getVariable()->nrStates();
|
||||
msgSize *= links[i]->getVariable()->range();
|
||||
}
|
||||
unsigned repetitions = 1;
|
||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||
@ -333,9 +333,9 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->nrStates();
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
} else {
|
||||
unsigned ds = links[i]->getVariable()->nrStates();
|
||||
unsigned ds = links[i]->getVariable()->range();
|
||||
Util::add (msgProduct, Params (ds, 1.0), repetitions);
|
||||
repetitions *= ds;
|
||||
}
|
||||
@ -348,9 +348,9 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
||||
cout << ": " << endl;
|
||||
}
|
||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->nrStates();
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
} else {
|
||||
unsigned ds = links[i]->getVariable()->nrStates();
|
||||
unsigned ds = links[i]->getVariable()->range();
|
||||
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
|
||||
repetitions *= ds;
|
||||
}
|
||||
@ -392,13 +392,13 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
const FgFacNode* dst = link->getFactor();
|
||||
Params msg;
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->nrStates(), LogAware::noEvidence());
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
msg[src->getEvidence()] = LogAware::withEvidence();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << msg;
|
||||
}
|
||||
} else {
|
||||
msg.resize (src->nrStates(), LogAware::one());
|
||||
msg.resize (src->range(), LogAware::one());
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << msg;
|
||||
@ -467,7 +467,7 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
|
||||
int count = -1;
|
||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->nrStates() == 0) {
|
||||
if (j % jointVars[i]->range() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
|
@ -20,8 +20,8 @@ class SpLink
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->nrStates(), LogAware::tl (1.0 / vn->nrStates()));
|
||||
v2_.resize (vn->nrStates(), LogAware::tl (1.0 / vn->nrStates()));
|
||||
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
msgSended_ = false;
|
||||
|
@ -50,7 +50,7 @@ extern InfAlgorithms infAlgorithm;
|
||||
namespace Constants {
|
||||
|
||||
// level of debug information
|
||||
const unsigned DEBUG = 2;
|
||||
const unsigned DEBUG = 1;
|
||||
|
||||
const int NO_EVIDENCE = -1;
|
||||
|
||||
|
@ -10,6 +10,8 @@
|
||||
#include "FgBpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
|
||||
#include "ElimGraph.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
void processArguments (BayesNet&, int, const char* []);
|
||||
@ -23,6 +25,31 @@ const string USAGE = "usage: \
|
||||
int
|
||||
main (int argc, const char* argv[])
|
||||
{
|
||||
VarIds vids1 = { 4, 1, 2, 3 } ;
|
||||
VarIds vids2 = { 4, 5 } ;
|
||||
VarIds vids3 = { 4, 6 } ;
|
||||
VarIds vids4 = { 4, 7 } ;
|
||||
// Factor f1 (vids1, {2,2,2,2},{0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.10,0.11,0.12,0.13,0.14,0.15,0.16});
|
||||
// Factor f2 (vids2, {2,2},{0.1,0.2,0.3,0.4});
|
||||
// Factor f3 (vids3, {2,2},{0.1,0.2,0.3,0.4});
|
||||
// Factor f4 (vids4, {2,2},{0.1,0.2,0.3,0.4});
|
||||
|
||||
|
||||
Factor* f1 = new Factor (vids1, {2,2,2,2},{0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.10,0.11,0.12,0.13,0.14,0.15,0.16});
|
||||
Factor* f2 = new Factor (vids2, {2,2},{0.1,0.2,0.3,0.4});
|
||||
Factor* f3 = new Factor (vids3, {2,2},{0.1,0.2,0.3,0.4});
|
||||
Factor* f4 = new Factor (vids4, {2,2},{0.1,0.2,0.3,0.4});
|
||||
Factor* f5 = new Factor (vids4, {2,2},{0.1,0.2,0.3,0.4});
|
||||
|
||||
vector<Factor*> fs = {f1,f2,f3,f4,f5};
|
||||
//FactorGraph fg;
|
||||
//fg.addFactor (f1);
|
||||
//fg.addFactor (f2);
|
||||
//fg.addFactor (f3);
|
||||
//fg.addFactor (f4);
|
||||
ElimGraph eg (fs);
|
||||
eg.exportToGraphViz ("_eg.dot");
|
||||
return 0;
|
||||
if (!argv[1]) {
|
||||
cerr << "error: no graphical model specified" << endl;
|
||||
cerr << USAGE << endl;
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include "FgBpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "BayesBall.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
@ -46,6 +47,7 @@ readUnsignedList (YAP_Term list)
|
||||
}
|
||||
|
||||
|
||||
|
||||
int createLiftedNetwork (void)
|
||||
{
|
||||
Parfactors parfactors;
|
||||
@ -212,20 +214,13 @@ void readLiftedEvidence (
|
||||
int
|
||||
createGroundNetwork (void)
|
||||
{
|
||||
Statistics::incrementPrimaryNetworksCounting();
|
||||
// cout << "creating network number " ;
|
||||
// cout << Statistics::getPrimaryNetworksCounting() << endl;
|
||||
// if (Statistics::getPrimaryNetworksCounting() > 98) {
|
||||
// Statistics::writeStatisticsToFile ("../../compressing.stats");
|
||||
// }
|
||||
BayesNet* bn = new BayesNet();
|
||||
|
||||
FactorGraph* fg = new FactorGraph();;
|
||||
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
cout << "factors type: '" << factorsType << "'" << endl;
|
||||
|
||||
YAP_Term factorList = YAP_ARG2;
|
||||
while (factorList != YAP_TermNil()) {
|
||||
YAP_Term factor = YAP_HeadOfTerm (factorList);
|
||||
YAP_Term factor = YAP_HeadOfTerm (factorList);
|
||||
// read the var ids
|
||||
VarIds varIds = readUnsignedList (YAP_ArgOfTerm (1, factor));
|
||||
// read the ranges
|
||||
@ -234,45 +229,27 @@ createGroundNetwork (void)
|
||||
Params params = readParameters (YAP_ArgOfTerm (3, factor));
|
||||
// read dist id
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (4, factor));
|
||||
fg->addFactor (Factor (varIds, ranges, params, distId));
|
||||
factorList = YAP_TailOfTerm (factorList);
|
||||
Factor f (varIds, ranges, params, distId);
|
||||
f.print();
|
||||
}
|
||||
assert (false);
|
||||
/*
|
||||
vector<VarIds> parents;
|
||||
while (varList != YAP_TermNil()) {
|
||||
YAP_Term var = YAP_HeadOfTerm (varList);
|
||||
VarId vid = (VarId) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
|
||||
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var));
|
||||
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
|
||||
YAP_Term parentL = YAP_ArgOfTerm (4, var);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
|
||||
parents.push_back (VarIds());
|
||||
while (parentL != YAP_TermNil()) {
|
||||
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
|
||||
parents.back().push_back (parentId);
|
||||
parentL = YAP_TailOfTerm (parentL);
|
||||
}
|
||||
assert (bn->getBayesNode (vid) == 0);
|
||||
BayesNode* newNode = new BayesNode (
|
||||
vid, dsize, evidence, Params(), distId);
|
||||
bn->addNode (newNode);
|
||||
varList = YAP_TailOfTerm (varList);
|
||||
if (factorsType == "bayes") {
|
||||
fg->setFromBayesNetwork();
|
||||
}
|
||||
const BnNodeSet& nodes = bn->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
BnNodeSet ps;
|
||||
for (unsigned j = 0; j < parents[i].size(); j++) {
|
||||
assert (bn->getBayesNode (parents[i][j]) != 0);
|
||||
ps.push_back (bn->getBayesNode (parents[i][j]));
|
||||
}
|
||||
nodes[i]->setParents (ps);
|
||||
fg->setIndexes();
|
||||
|
||||
YAP_Term evidenceList = YAP_ARG3;
|
||||
while (evidenceList != YAP_TermNil()) {
|
||||
YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
|
||||
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
|
||||
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
|
||||
cout << vid << " == " << ev << endl;
|
||||
assert (fg->getFgVarNode (vid));
|
||||
fg->getFgVarNode (vid)->setEvidence (ev);
|
||||
evidenceList = YAP_TailOfTerm (evidenceList);
|
||||
}
|
||||
bn->setIndexes();
|
||||
*/
|
||||
YAP_Int p = (YAP_Int) (bn);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
|
||||
|
||||
YAP_Int p = (YAP_Int) (fg);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4);
|
||||
}
|
||||
|
||||
|
||||
@ -354,72 +331,33 @@ runLiftedSolver (void)
|
||||
|
||||
|
||||
|
||||
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
|
||||
|
||||
|
||||
int
|
||||
runGroundSolver (void)
|
||||
{
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
|
||||
vector<VarIds> tasks;
|
||||
std::set<VarId> vids;
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
while (taskList != YAP_TermNil()) {
|
||||
VarIds queryVars;
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
while (jointList != YAP_TermNil()) {
|
||||
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
||||
assert (bn->getBayesNode (vid));
|
||||
queryVars.push_back (vid);
|
||||
vids.insert (vid);
|
||||
jointList = YAP_TailOfTerm (jointList);
|
||||
}
|
||||
tasks.push_back (queryVars);
|
||||
tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList)));
|
||||
taskList = YAP_TailOfTerm (taskList);
|
||||
}
|
||||
|
||||
Solver* bpSolver = 0;
|
||||
GraphicalModel* graphicalModel = 0;
|
||||
CFactorGraph::checkForIdenticalFactors = false;
|
||||
if (Globals::infAlgorithm != InfAlgorithms::VE) {
|
||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (
|
||||
VarIds (vids.begin(), vids.end()));
|
||||
if (Globals::infAlgorithm == InfAlgorithms::BN_BP) {
|
||||
graphicalModel = mrn;
|
||||
bpSolver = new BnBpSolver (*static_cast<BayesNet*> (graphicalModel));
|
||||
} else if (Globals::infAlgorithm == InfAlgorithms::FG_BP) {
|
||||
graphicalModel = new FactorGraph (*mrn);
|
||||
bpSolver = new FgBpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
||||
delete mrn;
|
||||
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
||||
graphicalModel = new FactorGraph (*mrn);
|
||||
bpSolver = new CbpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
||||
delete mrn;
|
||||
}
|
||||
bpSolver->runSolver();
|
||||
}
|
||||
|
||||
fg->printGraphicalModel();
|
||||
vector<Params> results;
|
||||
results.reserve (tasks.size());
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
if (Globals::infAlgorithm == InfAlgorithms::VE) {
|
||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (tasks[i]);
|
||||
VarElimSolver* veSolver = new VarElimSolver (*mrn);
|
||||
if (tasks[i].size() == 1) {
|
||||
results.push_back (veSolver->getPosterioriOf (tasks[i][0]));
|
||||
} else {
|
||||
results.push_back (veSolver->getJointDistributionOf (tasks[i]));
|
||||
}
|
||||
delete mrn;
|
||||
delete veSolver;
|
||||
} else {
|
||||
if (tasks[i].size() == 1) {
|
||||
results.push_back (bpSolver->getPosterioriOf (tasks[i][0]));
|
||||
} else {
|
||||
results.push_back (bpSolver->getJointDistributionOf (tasks[i]));
|
||||
}
|
||||
}
|
||||
if (Globals::infAlgorithm == InfAlgorithms::VE) {
|
||||
runVeSolver (fg, tasks, results);
|
||||
} else {
|
||||
runBpSolver (fg, tasks, results);
|
||||
}
|
||||
delete bpSolver;
|
||||
delete graphicalModel;
|
||||
|
||||
cout << "results: " << results << endl;
|
||||
YAP_Term list = YAP_TermNil();
|
||||
for (int i = results.size() - 1; i >= 0; i--) {
|
||||
const Params& beliefs = results[i];
|
||||
@ -433,12 +371,77 @@ runGroundSolver (void)
|
||||
}
|
||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||
}
|
||||
|
||||
return YAP_Unify (list, YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void runVeSolver (
|
||||
FactorGraph* fg,
|
||||
const vector<VarIds>& tasks,
|
||||
vector<Params>& results)
|
||||
{
|
||||
results.reserve (tasks.size());
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
||||
}
|
||||
VarElimSolver solver (*mfg);
|
||||
if (tasks[i].size() == 1) {
|
||||
results.push_back (solver.getPosterioriOf (tasks[i][0]));
|
||||
} else {
|
||||
results.push_back (solver.getJointDistributionOf (tasks[i]));
|
||||
}
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
delete mfg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void runBpSolver (
|
||||
FactorGraph* fg,
|
||||
const vector<VarIds>& tasks,
|
||||
vector<Params>& results)
|
||||
{
|
||||
std::set<VarId> vids;
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
Util::addToSet (vids, tasks[i]);
|
||||
}
|
||||
Solver* solver = 0;
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
mfg = BayesBall::getMinimalFactorGraph (
|
||||
*fg, VarIds (vids.begin(),vids.end()));
|
||||
}
|
||||
if (Globals::infAlgorithm == InfAlgorithms::FG_BP) {
|
||||
solver = new FgBpSolver (*mfg);
|
||||
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
||||
CFactorGraph::checkForIdenticalFactors = false;
|
||||
solver = new CbpSolver (*mfg);
|
||||
} else {
|
||||
cerr << "error: unknow solver" << endl;
|
||||
abort();
|
||||
}
|
||||
solver->runSolver();
|
||||
results.reserve (tasks.size());
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
if (tasks[i].size() == 1) {
|
||||
results.push_back (solver->getPosterioriOf (tasks[i][0]));
|
||||
} else {
|
||||
results.push_back (solver->getJointDistributionOf (tasks[i]));
|
||||
}
|
||||
}
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
delete mfg;
|
||||
}
|
||||
delete solver;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setParfactorsParams (void)
|
||||
{
|
||||
@ -447,8 +450,8 @@ setParfactorsParams (void)
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
unordered_map<unsigned, Params> paramsMap;
|
||||
while (distList != YAP_TermNil()) {
|
||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
assert (Util::contains (paramsMap, distId) == false);
|
||||
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
@ -621,7 +624,7 @@ extern "C" void
|
||||
init_predicates (void)
|
||||
{
|
||||
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
||||
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 3);
|
||||
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 4);
|
||||
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
||||
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
||||
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
||||
|
@ -37,8 +37,8 @@ class StatesIndexer
|
||||
indices_.resize (vars.size(), 0);
|
||||
ranges_.reserve (vars.size());
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
size_ *= vars[i]->nrStates();
|
||||
ranges_.push_back (vars[i]->range());
|
||||
size_ *= vars[i]->range();
|
||||
}
|
||||
li_ = 0;
|
||||
if (calcOffsets) {
|
||||
|
@ -47,6 +47,7 @@ CWD=$(PWD)
|
||||
HEADERS = \
|
||||
$(srcdir)/GraphicalModel.h \
|
||||
$(srcdir)/BayesNet.h \
|
||||
$(srcdir)/BayesBall.h \
|
||||
$(srcdir)/BayesNode.h \
|
||||
$(srcdir)/ElimGraph.h \
|
||||
$(srcdir)/FactorGraph.h \
|
||||
@ -73,6 +74,7 @@ HEADERS = \
|
||||
|
||||
CPP_SOURCES = \
|
||||
$(srcdir)/BayesNet.cpp \
|
||||
$(srcdir)/BayesBall.cpp \
|
||||
$(srcdir)/BayesNode.cpp \
|
||||
$(srcdir)/ElimGraph.cpp \
|
||||
$(srcdir)/FactorGraph.cpp \
|
||||
@ -98,6 +100,7 @@ CPP_SOURCES = \
|
||||
|
||||
OBJS = \
|
||||
BayesNet.o \
|
||||
BayesBall.o \
|
||||
BayesNode.o \
|
||||
ElimGraph.o \
|
||||
FactorGraph.o \
|
||||
@ -121,6 +124,7 @@ OBJS = \
|
||||
|
||||
HCLI_OBJS = \
|
||||
BayesNet.o \
|
||||
BayesBall.o \
|
||||
BayesNode.o \
|
||||
ElimGraph.o \
|
||||
FactorGraph.o \
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
@ -22,7 +23,9 @@ namespace Util {
|
||||
|
||||
template <typename T> void addToVector (vector<T>&, const vector<T>&);
|
||||
|
||||
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
|
||||
template <typename T> void addToSet (set<T>&, const vector<T>&);
|
||||
|
||||
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
|
||||
|
||||
template <typename T> bool contains (const vector<T>&, const T&);
|
||||
|
||||
@ -83,6 +86,14 @@ Util::addToVector (vector<T>& v, const vector<T>& elements)
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToSet (set<T>& s, const vector<T>& elements)
|
||||
{
|
||||
s.insert (elements.begin(), elements.end());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
||||
{
|
||||
@ -110,8 +121,7 @@ Util::contains (const set<T>& s, const T& e)
|
||||
|
||||
|
||||
template <typename K, typename V> bool
|
||||
Util::contains (
|
||||
const unordered_map<K, V>& m, const K& k)
|
||||
Util::contains (const unordered_map<K, V>& m, const K& k)
|
||||
{
|
||||
return m.find (k) != m.end();
|
||||
}
|
||||
|
@ -6,17 +6,8 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
VarElimSolver::VarElimSolver (const BayesNet& bn) : Solver (&bn)
|
||||
{
|
||||
bayesNet_ = &bn;
|
||||
factorGraph_ = new FactorGraph (bn);
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (&fg)
|
||||
{
|
||||
bayesNet_ = 0;
|
||||
factorGraph_ = &fg;
|
||||
}
|
||||
|
||||
@ -24,9 +15,7 @@ VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (&fg)
|
||||
|
||||
VarElimSolver::~VarElimSolver (void)
|
||||
{
|
||||
if (bayesNet_) {
|
||||
delete factorGraph_;
|
||||
}
|
||||
delete factorList_.back();
|
||||
}
|
||||
|
||||
|
||||
@ -37,7 +26,7 @@ VarElimSolver::getPosterioriOf (VarId vid)
|
||||
assert (factorGraph_->getFgVarNode (vid));
|
||||
FgVarNode* vn = factorGraph_->getFgVarNode (vid);
|
||||
if (vn->hasEvidence()) {
|
||||
Params params (vn->nrStates(), 0.0);
|
||||
Params params (vn->range(), 0.0);
|
||||
params[vn->getEvidence()] = 1.0;
|
||||
return params;
|
||||
}
|
||||
@ -53,14 +42,13 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
||||
varFactors_.clear();
|
||||
elimOrder_.clear();
|
||||
createFactorList();
|
||||
introduceEvidence();
|
||||
chooseEliminationOrder (vids);
|
||||
absorveEvidence();
|
||||
findEliminationOrder (vids);
|
||||
processFactorList (vids);
|
||||
Params params = factorList_.back()->params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
delete factorList_.back();
|
||||
return params;
|
||||
}
|
||||
|
||||
@ -89,7 +77,7 @@ VarElimSolver::createFactorList (void)
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::introduceEvidence (void)
|
||||
VarElimSolver::absorveEvidence (void)
|
||||
{
|
||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
@ -107,26 +95,15 @@ VarElimSolver::introduceEvidence (void)
|
||||
}
|
||||
}
|
||||
}
|
||||
printActiveFactors();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::chooseEliminationOrder (const VarIds& vids)
|
||||
VarElimSolver::findEliminationOrder (const VarIds& vids)
|
||||
{
|
||||
if (bayesNet_) {
|
||||
ElimGraph graph (*bayesNet_);
|
||||
elimOrder_ = graph.getEliminatingOrder (vids);
|
||||
} else {
|
||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
VarId vid = varNodes[i]->varId();
|
||||
if (Util::contains (vids, vid) == false &&
|
||||
varNodes[i]->hasEvidence() == false) {
|
||||
elimOrder_.push_back (vid);
|
||||
}
|
||||
}
|
||||
}
|
||||
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
|
||||
}
|
||||
|
||||
|
||||
@ -165,7 +142,6 @@ void
|
||||
VarElimSolver::eliminate (VarId elimVar)
|
||||
{
|
||||
Factor* result = 0;
|
||||
FgVarNode* vn = factorGraph_->getFgVarNode (elimVar);
|
||||
vector<unsigned>& idxs = varFactors_.find (elimVar)->second;
|
||||
for (unsigned i = 0; i < idxs.size(); i++) {
|
||||
unsigned idx = idxs[i];
|
||||
@ -180,7 +156,7 @@ VarElimSolver::eliminate (VarId elimVar)
|
||||
}
|
||||
}
|
||||
if (result != 0 && result->nrArguments() != 1) {
|
||||
result->sumOut (vn->varId());
|
||||
result->sumOut (elimVar);
|
||||
factorList_.push_back (result);
|
||||
const VarIds& resultVarIds = result->arguments();
|
||||
for (unsigned i = 0; i < resultVarIds.size(); i++) {
|
||||
@ -199,7 +175,6 @@ VarElimSolver::printActiveFactors (void)
|
||||
for (unsigned i = 0; i < factorList_.size(); i++) {
|
||||
if (factorList_[i] != 0) {
|
||||
factorList_[i]->print();
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
#include "Solver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "BayesNet.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
@ -15,8 +14,6 @@ using namespace std;
|
||||
class VarElimSolver : public Solver
|
||||
{
|
||||
public:
|
||||
VarElimSolver (const BayesNet&);
|
||||
|
||||
VarElimSolver (const FactorGraph&);
|
||||
|
||||
~VarElimSolver (void);
|
||||
@ -30,9 +27,9 @@ class VarElimSolver : public Solver
|
||||
private:
|
||||
void createFactorList (void);
|
||||
|
||||
void introduceEvidence (void);
|
||||
void absorveEvidence (void);
|
||||
|
||||
void chooseEliminationOrder (const VarIds&);
|
||||
void findEliminationOrder (const VarIds&);
|
||||
|
||||
void processFactorList (const VarIds&);
|
||||
|
||||
@ -40,7 +37,6 @@ class VarElimSolver : public Solver
|
||||
|
||||
void printActiveFactors (void);
|
||||
|
||||
const BayesNet* bayesNet_;
|
||||
const FactorGraph* factorGraph_;
|
||||
vector<Factor*> factorList_;
|
||||
VarIds elimOrder_;
|
||||
|
@ -9,22 +9,22 @@ using namespace std;
|
||||
|
||||
VarNode::VarNode (const VarNode* v)
|
||||
{
|
||||
varId_ = v->varId();
|
||||
nrStates_ = v->nrStates();
|
||||
evidence_ = v->getEvidence();
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
varId_ = v->varId();
|
||||
range_ = v->range();
|
||||
evidence_ = v->getEvidence();
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarNode::VarNode (VarId varId, unsigned nrStates, int evidence)
|
||||
VarNode::VarNode (VarId varId, unsigned range, int evidence)
|
||||
{
|
||||
assert (nrStates != 0);
|
||||
assert (evidence < (int) nrStates);
|
||||
varId_ = varId;
|
||||
nrStates_ = nrStates;
|
||||
evidence_ = evidence;
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
assert (range != 0);
|
||||
assert (evidence < (int) range);
|
||||
varId_ = varId;
|
||||
range_ = range;
|
||||
evidence_ = evidence;
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ VarNode::VarNode (VarId varId, unsigned nrStates, int evidence)
|
||||
bool
|
||||
VarNode::isValidState (int stateIndex)
|
||||
{
|
||||
return stateIndex >= 0 && stateIndex < (int) nrStates_;
|
||||
return stateIndex >= 0 && stateIndex < (int) range_;
|
||||
}
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ VarNode::isValidState (const string& stateName)
|
||||
void
|
||||
VarNode::setEvidence (int ev)
|
||||
{
|
||||
assert (ev < (int) nrStates_);
|
||||
assert (ev < (int) range_);
|
||||
evidence_ = ev;
|
||||
}
|
||||
|
||||
@ -90,7 +90,7 @@ VarNode::states (void) const
|
||||
return GraphicalModel::getVarInformation (varId_).states;
|
||||
}
|
||||
States states;
|
||||
for (unsigned i = 0; i < nrStates_; i++) {
|
||||
for (unsigned i = 0; i < range_; i++) {
|
||||
stringstream ss;
|
||||
ss << i ;
|
||||
states.push_back (ss.str());
|
||||
|
@ -20,7 +20,7 @@ class VarNode
|
||||
|
||||
unsigned varId (void) const { return varId_; }
|
||||
|
||||
unsigned nrStates (void) const { return nrStates_; }
|
||||
unsigned range (void) const { return range_; }
|
||||
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
|
||||
@ -37,15 +37,13 @@ class VarNode
|
||||
|
||||
bool operator== (const VarNode& var) const
|
||||
{
|
||||
cout << "equal operator called" << endl;
|
||||
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ == var.varId();
|
||||
}
|
||||
|
||||
bool operator!= (const VarNode& var) const
|
||||
{
|
||||
cout << "diff operator called" << endl;
|
||||
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ != var.varId();
|
||||
}
|
||||
|
||||
@ -63,7 +61,7 @@ class VarNode
|
||||
|
||||
private:
|
||||
VarId varId_;
|
||||
unsigned nrStates_;
|
||||
unsigned range_;
|
||||
int evidence_;
|
||||
unsigned index_;
|
||||
|
||||
|
@ -1,29 +1,41 @@
|
||||
|
||||
:- use_module(library(clpbn)).
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
:- set_clpbn_flag(solver, bp).
|
||||
%:- set_pfl_flag(solver,ve).
|
||||
:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
|
||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,fg_bp).
|
||||
%:- set_pfl_flag(solver,fove).
|
||||
|
||||
r(R) :- r_cpt(RCpt),
|
||||
{ R = r with p([r1, r2], RCpt) }.
|
||||
% :- yap_flag(write_strings, off).
|
||||
|
||||
t(T) :- t_cpt(TCpt),
|
||||
{ T = t with p([t1, t2], TCpt) }.
|
||||
|
||||
a(A) :- r(R), t(T), a_cpt(ACpt),
|
||||
{ A = a with p([a1, a2], ACpt, [R, T]) }.
|
||||
bayes burglary::[b1,b3] ; [0.001, 0.999] ; [].
|
||||
|
||||
j(J) :- a(A), j_cpt(JCpt),
|
||||
{ J = j with p([j1, j2], JCpt, [A]) }.
|
||||
bayes earthquake::[e1,e2] ; [0.002, 0.998]; [].
|
||||
|
||||
bayes alarm::[a1,a2] , burglary, earthquake ; [0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ; [].
|
||||
|
||||
bayes john_calls::[j1,j2] , alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
|
||||
|
||||
bayes mary_calls::[m1,m2] , alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
|
||||
|
||||
|
||||
b_cpt([0.001, 0.999]).
|
||||
|
||||
e_cpt([0.002, 0.998]).
|
||||
|
||||
m(M) :- a(A), m_cpt(MCpt),
|
||||
{ M = m with p([m1, m2], MCpt, [A]) }.
|
||||
|
||||
r_cpt([0.001, 0.999]).
|
||||
t_cpt([0.002, 0.998]).
|
||||
a_cpt([0.95, 0.94, 0.29, 0.001,
|
||||
0.05, 0.06, 0.71, 0.999]).
|
||||
j_cpt([0.9, 0.05,
|
||||
|
||||
jc_cpt([0.9, 0.05,
|
||||
0.1, 0.95]).
|
||||
m_cpt([0.7, 0.01,
|
||||
|
||||
mc_cpt([0.7, 0.01,
|
||||
0.3, 0.99]).
|
||||
|
||||
% ?- alarm(A).
|
||||
?- john_calls(J), mary_calls(m1).
|
||||
%?- john_calls(J), mary_calls(m1), alarm(a1).
|
||||
%?- john_calls(J), alarm(a1).
|
||||
|
||||
|
||||
|
@ -2,19 +2,19 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
%:- set_pfl_flag(solver,ve).
|
||||
:- set_pfl_flag(solver,bp), clpbn_bp:set_horus_flag(inf_alg,ve).
|
||||
:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
|
||||
% :- set_pfl_flag(solver,fove).
|
||||
|
||||
:- yap_flag(write_strings, off).
|
||||
|
||||
friendly(P1, P2) :-
|
||||
friends(P1, P2) :-
|
||||
person(P1),
|
||||
person(P2),
|
||||
P1 @> P2.
|
||||
P1 \= P2.
|
||||
|
||||
person(john).
|
||||
person(maggie).
|
||||
person(harry).
|
||||
%person(harry).
|
||||
%person(bill).
|
||||
%person(matt).
|
||||
%person(diana).
|
||||
@ -27,8 +27,13 @@ person(harry).
|
||||
|
||||
markov smokes(P)::[t,f] , cancer(P)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [person(P)].
|
||||
|
||||
markov friend(P1,P2)::[t,f], smokes(P1)::[t,f], smokes(P2)::[t,f] ; [0.5, 0.6, 0.7, 0.8, 0.5, 0.6, 0.7, 0.8] ; [friendly(P1, P2)].
|
||||
markov friend(P1,P2)::[t,f], smokes(P1)::[t,f], smokes(P2)::[t,f] ; [0.5, 0.6, 0.7, 0.8, 0.5, 0.6, 0.7, 0.8] ; [friends(P1, P2)].
|
||||
|
||||
% ?- smokes(person_0, t), smokes(person_1, t), friend(person_0, person_1, F).
|
||||
|
||||
?- smokes(john, t), smokes(maggie, f), friend(john, maggie, X).
|
||||
% ?- smokes(john, t), smokes(maggie, f), friend(john, maggie, X).
|
||||
|
||||
?- smokes(john, t), friend(john, maggie, X).
|
||||
|
||||
% ?- friend(john, maggie, X).
|
||||
|
||||
|
@ -7,7 +7,7 @@
|
||||
|
||||
:- module(clpbn_horus,
|
||||
[create_lifted_network/3,
|
||||
create_ground_network/3,
|
||||
create_ground_network/4,
|
||||
set_parfactors_params/2,
|
||||
set_bayes_net_params/2,
|
||||
run_lifted_solver/3,
|
||||
|
Reference in New Issue
Block a user