Merge branch 'master' of https://github.com/tacgomes/yap6.3
This commit is contained in:
commit
eeb53aef71
137
packages/CLPBN/examples/complex.fg
Normal file
137
packages/CLPBN/examples/complex.fg
Normal file
@ -0,0 +1,137 @@
|
||||
10
|
||||
|
||||
2
|
||||
0 1
|
||||
2 2
|
||||
4
|
||||
0 1.02
|
||||
1 0.87
|
||||
2 0.88
|
||||
3 0.45
|
||||
|
||||
4
|
||||
1 2 3 4
|
||||
2 2 3 3
|
||||
36
|
||||
0 0.11
|
||||
1 1.11
|
||||
2 0.41
|
||||
3 0.12
|
||||
4 0.1
|
||||
5 0.17
|
||||
6 1.21
|
||||
7 1.1
|
||||
8 0.11
|
||||
9 0.41
|
||||
10 0.8
|
||||
11 0.71
|
||||
12 0.14
|
||||
13 0.24
|
||||
14 0.54
|
||||
15 1.4
|
||||
16 0.23
|
||||
17 0.24
|
||||
18 0.65
|
||||
19 0.05
|
||||
20 0.32
|
||||
21 0.12
|
||||
22 0.99
|
||||
23 0.69
|
||||
24 0.29
|
||||
25 1.29
|
||||
26 0.15
|
||||
27 1.24
|
||||
28 0.42
|
||||
29 0.124
|
||||
30 0.67
|
||||
31 0.078
|
||||
32 0.14
|
||||
33 0.55
|
||||
34 0.45
|
||||
35 0.1
|
||||
|
||||
3
|
||||
2 5 6
|
||||
2 2 3
|
||||
12
|
||||
0 0.15
|
||||
1 0.55
|
||||
2 2.21
|
||||
3 5.71
|
||||
4 0.44
|
||||
5 0.14
|
||||
6 0.5
|
||||
7 1.75
|
||||
8 1.29
|
||||
9 3.29
|
||||
10 0.36
|
||||
11 1.56
|
||||
|
||||
2
|
||||
7 2
|
||||
4 2
|
||||
8
|
||||
0 0.11
|
||||
1 0.59
|
||||
2 0.15
|
||||
3 0.124
|
||||
4 0.41
|
||||
5 2.11
|
||||
6 1.06
|
||||
7 0.929
|
||||
|
||||
1
|
||||
3
|
||||
3
|
||||
3
|
||||
0 0.1
|
||||
1 0.58
|
||||
2 0.74
|
||||
|
||||
1
|
||||
4
|
||||
3
|
||||
3
|
||||
0 3.2
|
||||
1 0.28
|
||||
2 1.24
|
||||
|
||||
2
|
||||
8 4
|
||||
2 3
|
||||
6
|
||||
0 0.19
|
||||
1 3.1
|
||||
2 0.49
|
||||
3 1.5
|
||||
4 2.1
|
||||
5 2.8
|
||||
|
||||
1
|
||||
5
|
||||
2
|
||||
2
|
||||
0 0.74
|
||||
1 0.14
|
||||
|
||||
1
|
||||
6
|
||||
3
|
||||
3
|
||||
0 0.032
|
||||
1 0.028
|
||||
2 0.24
|
||||
|
||||
2
|
||||
9 7
|
||||
2 4
|
||||
8
|
||||
0 0.61
|
||||
1 0.61
|
||||
2 1.4
|
||||
3 0.24
|
||||
4 0.09
|
||||
5 0.19
|
||||
6 1.4
|
||||
7 0.6
|
||||
|
79
packages/CLPBN/examples/complex.pfl
Normal file
79
packages/CLPBN/examples/complex.pfl
Normal file
@ -0,0 +1,79 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
%:- set_solver(ve).
|
||||
%:- set_solver(hve).
|
||||
%:- set_solver(jt).
|
||||
%:- set_solver(bdd).
|
||||
%:- set_solver(bp).
|
||||
%:- set_solver(cbp).
|
||||
%:- set_solver(gibbs).
|
||||
%:- set_solver(lve).
|
||||
%:- set_solver(lkc).
|
||||
%:- set_solver(lbp).
|
||||
|
||||
/*
|
||||
|
||||
v01 v02
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
v03 v04 v05
|
||||
/ \ | / \
|
||||
/ \ | / \
|
||||
/ \ | / \
|
||||
v06 v07 v08
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
v09 v10
|
||||
|
||||
*/
|
||||
|
||||
markov v01::[a,b] ; table1 ; [].
|
||||
|
||||
markov v02::[a,b,c] ; table2 ; [].
|
||||
|
||||
markov v03::[a,b], v01, v02 ; table3 ; [].
|
||||
|
||||
markov v04::[a,b,c] ; table4 ; [].
|
||||
|
||||
markov v05::[a,b,c] ; table5 ; [].
|
||||
|
||||
markov v06::[a,b,c,d], v03 ; table6 ; [].
|
||||
|
||||
markov v07::[a,b], v03, v04, v05 ; table7 ; [].
|
||||
|
||||
markov v08::[a,b], v05 ; table8 ; [].
|
||||
|
||||
markov v09::[a,b], v06 ; table9 ; [].
|
||||
|
||||
markov v10::[a,b], v07 ; table10 ; [].
|
||||
|
||||
table1([ 0.74, 0.14 ]).
|
||||
|
||||
table2([ 0.032, 0.028, 0.24 ]).
|
||||
|
||||
table3([
|
||||
0.15, 0.44, 1.29, 2.21, 0.5, 0.36,
|
||||
0.55, 0.14, 3.29, 5.71, 1.75, 1.56
|
||||
]).
|
||||
|
||||
table4([ 0.1, 0.58, 0.74 ]).
|
||||
|
||||
table5([ 3.2, 0.28, 1.24 ]).
|
||||
|
||||
table6([ 0.11, 0.41, 0.59, 2.11, 0.15, 1.06, 0.124, 0.929 ]).
|
||||
|
||||
table7([
|
||||
0.11, 0.14, 0.29, 0.1, 0.23, 0.42, 0.11, 0.32, 0.14,
|
||||
0.41, 0.54, 0.15, 1.21, 0.65, 0.67, 0.8, 0.99, 0.45,
|
||||
1.11, 0.24, 1.29, 0.17, 0.24, 0.124, 0.41, 0.12, 0.55,
|
||||
0.12, 1.4, 1.24, 1.1, 0.05, 0.078, 0.71, 0.69, 0.1
|
||||
]).
|
||||
|
||||
table8([ 0.19, 0.49, 2.1, 3.1, 1.5, 2.8 ]).
|
||||
|
||||
table9([ 0.61, 1.4, 0.09, 1.4, 0.61, 0.24, 0.19, 0.6 ]).
|
||||
|
||||
table10([ 1.02, 0.88, 0.87, 0.45 ]).
|
||||
|
@ -3,6 +3,25 @@
|
||||
#include "BayesBall.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
BayesBall::BayesBall (FactorGraph& fg)
|
||||
: fg_(fg) , dag_(fg.getStructure())
|
||||
{
|
||||
dag_.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
BayesBall::getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
|
||||
{
|
||||
BayesBall bb (fg);
|
||||
return bb.getMinimalFactorGraph (vids);
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||
{
|
||||
@ -19,22 +38,22 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||
BBNode* n = sch.node;
|
||||
n->setAsVisited();
|
||||
if (n->hasEvidence() == false && sch.visitedFromChild) {
|
||||
if (n->isMarkedOnTop() == false) {
|
||||
n->markOnTop();
|
||||
if (n->isMarkedAbove() == false) {
|
||||
n->markAbove();
|
||||
scheduleParents (n, scheduling);
|
||||
}
|
||||
if (n->isMarkedOnBottom() == false) {
|
||||
n->markOnBottom();
|
||||
if (n->isMarkedBelow() == false) {
|
||||
n->markBelow();
|
||||
scheduleChilds (n, scheduling);
|
||||
}
|
||||
}
|
||||
if (sch.visitedFromParent) {
|
||||
if (n->hasEvidence() && n->isMarkedOnTop() == false) {
|
||||
n->markOnTop();
|
||||
if (n->hasEvidence() && n->isMarkedAbove() == false) {
|
||||
n->markAbove();
|
||||
scheduleParents (n, scheduling);
|
||||
}
|
||||
if (n->hasEvidence() == false && n->isMarkedOnBottom() == false) {
|
||||
n->markOnBottom();
|
||||
if (n->hasEvidence() == false && n->isMarkedBelow() == false) {
|
||||
n->markBelow();
|
||||
scheduleChilds (n, scheduling);
|
||||
}
|
||||
}
|
||||
@ -55,7 +74,7 @@ BayesBall::constructGraph (FactorGraph* fg) const
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
const BBNode* n = dag_.getNode (
|
||||
facNodes[i]->factor().argument (0));
|
||||
if (n->isMarkedOnTop()) {
|
||||
if (n->isMarkedAbove()) {
|
||||
fg->addFactor (facNodes[i]->factor());
|
||||
} else if (n->hasEvidence() && n->isVisited()) {
|
||||
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
||||
@ -76,3 +95,5 @@ BayesBall::constructGraph (FactorGraph* fg) const
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_BAYESBALL_H
|
||||
#define HORUS_BAYESBALL_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_BAYESBALL_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_BAYESBALL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
@ -9,41 +9,28 @@
|
||||
#include "BayesBallGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
struct ScheduleInfo
|
||||
{
|
||||
ScheduleInfo (BBNode* n, bool vfp, bool vfc)
|
||||
: node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||
|
||||
BBNode* node;
|
||||
bool visitedFromParent;
|
||||
bool visitedFromChild;
|
||||
};
|
||||
|
||||
|
||||
typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
|
||||
|
||||
|
||||
class BayesBall
|
||||
{
|
||||
class BayesBall {
|
||||
public:
|
||||
BayesBall (FactorGraph& fg)
|
||||
: fg_(fg) , dag_(fg.getStructure())
|
||||
{
|
||||
dag_.clear();
|
||||
}
|
||||
BayesBall (FactorGraph& fg);
|
||||
|
||||
FactorGraph* getMinimalFactorGraph (const VarIds&);
|
||||
|
||||
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
|
||||
{
|
||||
BayesBall bb (fg);
|
||||
return bb.getMinimalFactorGraph (vids);
|
||||
}
|
||||
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids);
|
||||
|
||||
private:
|
||||
struct ScheduleInfo {
|
||||
ScheduleInfo (BBNode* n, bool vfp, bool vfc)
|
||||
: node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||
|
||||
BBNode* node;
|
||||
bool visitedFromParent;
|
||||
bool visitedFromChild;
|
||||
};
|
||||
|
||||
typedef std::queue<ScheduleInfo, std::list<ScheduleInfo>> Scheduling;
|
||||
|
||||
void constructGraph (FactorGraph* fg) const;
|
||||
|
||||
@ -51,9 +38,8 @@ class BayesBall
|
||||
|
||||
void scheduleChilds (const BBNode* n, Scheduling& sch) const;
|
||||
|
||||
FactorGraph& fg_;
|
||||
|
||||
BayesBallGraph& dag_;
|
||||
FactorGraph& fg_;
|
||||
BayesBallGraph& dag_;
|
||||
};
|
||||
|
||||
|
||||
@ -61,8 +47,8 @@ class BayesBall
|
||||
inline void
|
||||
BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const
|
||||
{
|
||||
const vector<BBNode*>& ps = n->parents();
|
||||
for (vector<BBNode*>::const_iterator it = ps.begin();
|
||||
const std::vector<BBNode*>& ps = n->parents();
|
||||
for (std::vector<BBNode*>::const_iterator it = ps.begin();
|
||||
it != ps.end(); ++it) {
|
||||
sch.push (ScheduleInfo (*it, false, true));
|
||||
}
|
||||
@ -73,12 +59,14 @@ BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const
|
||||
inline void
|
||||
BayesBall::scheduleChilds (const BBNode* n, Scheduling& sch) const
|
||||
{
|
||||
const vector<BBNode*>& cs = n->childs();
|
||||
for (vector<BBNode*>::const_iterator it = cs.begin();
|
||||
const std::vector<BBNode*>& cs = n->childs();
|
||||
for (std::vector<BBNode*>::const_iterator it = cs.begin();
|
||||
it != cs.end(); ++it) {
|
||||
sch.push (ScheduleInfo (*it, true, false));
|
||||
}
|
||||
}
|
||||
|
||||
#endif // HORUS_BAYESBALL_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_BAYESBALL_H_
|
||||
|
||||
|
@ -1,14 +1,15 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "BayesBallGraph.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
void
|
||||
BayesBallGraph::addNode (BBNode* n)
|
||||
{
|
||||
@ -22,8 +23,8 @@ BayesBallGraph::addNode (BBNode* n)
|
||||
void
|
||||
BayesBallGraph::addEdge (VarId vid1, VarId vid2)
|
||||
{
|
||||
unordered_map<VarId, BBNode*>::iterator it1;
|
||||
unordered_map<VarId, BBNode*>::iterator it2;
|
||||
std::unordered_map<VarId, BBNode*>::iterator it1;
|
||||
std::unordered_map<VarId, BBNode*>::iterator it2;
|
||||
it1 = varMap_.find (vid1);
|
||||
it2 = varMap_.find (vid2);
|
||||
assert (it1 != varMap_.end());
|
||||
@ -37,7 +38,7 @@ BayesBallGraph::addEdge (VarId vid1, VarId vid2)
|
||||
const BBNode*
|
||||
BayesBallGraph::getNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId, BBNode*>::const_iterator it;
|
||||
std::unordered_map<VarId, BBNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
@ -47,7 +48,7 @@ BayesBallGraph::getNode (VarId vid) const
|
||||
BBNode*
|
||||
BayesBallGraph::getNode (VarId vid)
|
||||
{
|
||||
unordered_map<VarId, BBNode*>::const_iterator it;
|
||||
std::unordered_map<VarId, BBNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
@ -55,7 +56,7 @@ BayesBallGraph::getNode (VarId vid)
|
||||
|
||||
|
||||
void
|
||||
BayesBallGraph::setIndexes (void)
|
||||
BayesBallGraph::setIndexes()
|
||||
{
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->setIndex (i);
|
||||
@ -65,7 +66,7 @@ BayesBallGraph::setIndexes (void)
|
||||
|
||||
|
||||
void
|
||||
BayesBallGraph::clear (void)
|
||||
BayesBallGraph::clear()
|
||||
{
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->clear();
|
||||
@ -77,13 +78,14 @@ BayesBallGraph::clear (void)
|
||||
void
|
||||
BayesBallGraph::exportToGraphViz (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << "digraph {" << endl;
|
||||
out << "ranksep=1" << endl;
|
||||
out << "digraph {" << std::endl;
|
||||
out << "ranksep=1" << std::endl;
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
out << nodes_[i]->varId() ;
|
||||
out << " [" ;
|
||||
@ -91,16 +93,18 @@ BayesBallGraph::exportToGraphViz (const char* fileName)
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
out << ",style=filled, fillcolor=yellow" ;
|
||||
}
|
||||
out << "]" << endl;
|
||||
out << "]" << std::endl;
|
||||
}
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
const vector<BBNode*>& childs = nodes_[i]->childs();
|
||||
const std::vector<BBNode*>& childs = nodes_[i]->childs();
|
||||
for (size_t j = 0; j < childs.size(); j++) {
|
||||
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
||||
out << " [style=bold]" << endl ;
|
||||
out << " [style=bold]" << std::endl;
|
||||
}
|
||||
}
|
||||
out << "}" << endl;
|
||||
out << "}" << std::endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_BAYESBALLGRAPH_H
|
||||
#define HORUS_BAYESBALLGRAPH_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_BAYESBALLGRAPH_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_BAYESBALLGRAPH_H_
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
@ -7,54 +7,55 @@
|
||||
#include "Var.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class BBNode : public Var
|
||||
{
|
||||
namespace Horus {
|
||||
|
||||
class BBNode : public Var {
|
||||
public:
|
||||
BBNode (Var* v) : Var (v), visited_(false),
|
||||
markedOnTop_(false), markedOnBottom_(false) { }
|
||||
markedAbove_(false), markedBelow_(false) { }
|
||||
|
||||
const vector<BBNode*>& childs (void) const { return childs_; }
|
||||
const std::vector<BBNode*>& childs() const { return childs_; }
|
||||
|
||||
vector<BBNode*>& childs (void) { return childs_; }
|
||||
std::vector<BBNode*>& childs() { return childs_; }
|
||||
|
||||
const vector<BBNode*>& parents (void) const { return parents_; }
|
||||
const std::vector<BBNode*>& parents() const { return parents_; }
|
||||
|
||||
vector<BBNode*>& parents (void) { return parents_; }
|
||||
std::vector<BBNode*>& parents() { return parents_; }
|
||||
|
||||
void addParent (BBNode* p) { parents_.push_back (p); }
|
||||
|
||||
void addChild (BBNode* c) { childs_.push_back (c); }
|
||||
|
||||
bool isVisited (void) const { return visited_; }
|
||||
bool isVisited() const { return visited_; }
|
||||
|
||||
void setAsVisited (void) { visited_ = true; }
|
||||
void setAsVisited() { visited_ = true; }
|
||||
|
||||
bool isMarkedOnTop (void) const { return markedOnTop_; }
|
||||
bool isMarkedAbove() const { return markedAbove_; }
|
||||
|
||||
void markOnTop (void) { markedOnTop_ = true; }
|
||||
void markAbove() { markedAbove_ = true; }
|
||||
|
||||
bool isMarkedOnBottom (void) const { return markedOnBottom_; }
|
||||
bool isMarkedBelow() const { return markedBelow_; }
|
||||
|
||||
void markOnBottom (void) { markedOnBottom_ = true; }
|
||||
void markBelow() { markedBelow_ = true; }
|
||||
|
||||
void clear (void) { visited_ = markedOnTop_ = markedOnBottom_ = false; }
|
||||
void clear() { visited_ = markedAbove_ = markedBelow_ = false; }
|
||||
|
||||
private:
|
||||
bool visited_;
|
||||
bool markedOnTop_;
|
||||
bool markedOnBottom_;
|
||||
bool markedAbove_;
|
||||
bool markedBelow_;
|
||||
|
||||
vector<BBNode*> childs_;
|
||||
vector<BBNode*> parents_;
|
||||
std::vector<BBNode*> childs_;
|
||||
std::vector<BBNode*> parents_;
|
||||
};
|
||||
|
||||
|
||||
class BayesBallGraph
|
||||
{
|
||||
class BayesBallGraph {
|
||||
public:
|
||||
BayesBallGraph (void) { }
|
||||
BayesBallGraph() { }
|
||||
|
||||
bool empty() const { return nodes_.empty(); }
|
||||
|
||||
void addNode (BBNode* n);
|
||||
|
||||
@ -64,19 +65,18 @@ class BayesBallGraph
|
||||
|
||||
BBNode* getNode (VarId vid);
|
||||
|
||||
bool empty (void) const { return nodes_.empty(); }
|
||||
void setIndexes();
|
||||
|
||||
void setIndexes (void);
|
||||
|
||||
void clear (void);
|
||||
void clear();
|
||||
|
||||
void exportToGraphViz (const char*);
|
||||
|
||||
private:
|
||||
vector<BBNode*> nodes_;
|
||||
|
||||
unordered_map<VarId, BBNode*> varMap_;
|
||||
std::vector<BBNode*> nodes_;
|
||||
std::unordered_map<VarId, BBNode*> varMap_;
|
||||
};
|
||||
|
||||
#endif // HORUS_BAYESBALLGRAPH_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_BAYESBALLGRAPH_H_
|
||||
|
||||
|
@ -1,37 +1,39 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
|
||||
#include "BeliefProp.h"
|
||||
#include "Indexer.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
double BeliefProp::accuracy_ = 0.0001;
|
||||
unsigned BeliefProp::maxIter_ = 1000;
|
||||
MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED;
|
||||
namespace Horus {
|
||||
|
||||
double BeliefProp::accuracy_ = 0.0001;
|
||||
unsigned BeliefProp::maxIter_ = 1000;
|
||||
|
||||
BeliefProp::MsgSchedule BeliefProp::schedule_ =
|
||||
MsgSchedule::seqFixedSch;
|
||||
|
||||
|
||||
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
|
||||
|
||||
BeliefProp::BeliefProp (const FactorGraph& fg)
|
||||
: GroundSolver (fg), nIters_(0), runned_(false)
|
||||
{
|
||||
runned_ = false;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
BeliefProp::~BeliefProp (void)
|
||||
BeliefProp::~BeliefProp()
|
||||
{
|
||||
for (size_t i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
for (size_t i = 0; i < facsI_.size(); i++) {
|
||||
delete facsI_[i];
|
||||
}
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
links_.clear();
|
||||
}
|
||||
|
||||
|
||||
@ -48,22 +50,22 @@ BeliefProp::solveQuery (VarIds queryVids)
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::printSolverFlags (void) const
|
||||
BeliefProp::printSolverFlags() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "belief propagation [" ;
|
||||
ss << "bp_msg_schedule=" ;
|
||||
switch (schedule_) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
|
||||
case MsgSchedule::parallelSch: ss << "parallel"; break;
|
||||
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",bp_max_iter=" << Util::toString (maxIter_);
|
||||
ss << ",bp_accuracy=" << Util::toString (accuracy_);
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << ",bp_max_iter=" << Util::toString (maxIter_);
|
||||
ss << ",bp_accuracy=" << Util::toString (accuracy_);
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
std::cout << ss.str() << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@ -82,7 +84,7 @@ BeliefProp::getPosterioriOf (VarId vid)
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const BpLinks& links = ninf(var)->getLinks();
|
||||
const BpLinks& links = getLinks (var);
|
||||
if (Globals::logDomain) {
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
probs += links[i]->message();
|
||||
@ -133,7 +135,7 @@ BeliefProp::getFactorJoint (
|
||||
runSolver();
|
||||
}
|
||||
Factor res (fn->factor());
|
||||
const BpLinks& links = ninf(fn)->getLinks();
|
||||
const BpLinks& links = getLinks( fn);
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
Factor msg ({links[i]->varNode()->varId()},
|
||||
{links[i]->varNode()->range()},
|
||||
@ -152,26 +154,119 @@ BeliefProp::getFactorJoint (
|
||||
|
||||
|
||||
|
||||
BeliefProp::BpLink::BpLink (FacNode* fn, VarNode* vn)
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::runSolver (void)
|
||||
BeliefProp::BpLink::clearResidual()
|
||||
{
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::BpLink::updateResidual()
|
||||
{
|
||||
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::BpLink::updateMessage()
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::string
|
||||
BeliefProp::BpLink::toString() const
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << fac_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->label();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
|
||||
{
|
||||
if (Globals::verbosity > 2) {
|
||||
std::cout << "calculating & updating " << link->toString();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
calcFactorToVarMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::calculateMessage (BpLink* link, bool calcResidual)
|
||||
{
|
||||
if (Globals::verbosity > 2) {
|
||||
std::cout << "calculating " << link->toString();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
calcFactorToVarMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::updateMessage (BpLink* link)
|
||||
{
|
||||
link->updateMessage();
|
||||
if (Globals::verbosity > 2) {
|
||||
std::cout << "updating " << link->toString();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::runSolver()
|
||||
{
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
while (!converged() && nIters_ < maxIter_) {
|
||||
nIters_ ++;
|
||||
if (Globals::verbosity > 1) {
|
||||
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||
Util::printHeader (std::string ("Iteration ")
|
||||
+ Util::toString (nIters_));
|
||||
}
|
||||
switch (schedule_) {
|
||||
case MsgSchedule::SEQ_RANDOM:
|
||||
case MsgSchedule::seqRandomSch:
|
||||
std::random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
case MsgSchedule::SEQ_FIXED:
|
||||
case MsgSchedule::seqFixedSch:
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
}
|
||||
break;
|
||||
case MsgSchedule::PARALLEL:
|
||||
case MsgSchedule::parallelSch:
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
@ -179,20 +274,21 @@ BeliefProp::runSolver (void)
|
||||
updateMessage(links_[i]);
|
||||
}
|
||||
break;
|
||||
case MsgSchedule::MAX_RESIDUAL:
|
||||
case MsgSchedule::maxResidualSch:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
if (nIters_ < maxIter_) {
|
||||
cout << "Belief propagation converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
std::cout << "Belief propagation converged in " ;
|
||||
std::cout << nIters_ << " iterations" << std::endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
std::cout << "The maximum number of iterations was hit," ;
|
||||
std::cout << " terminating..." ;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
runned_ = true;
|
||||
}
|
||||
@ -200,7 +296,7 @@ BeliefProp::runSolver (void)
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::createLinks (void)
|
||||
BeliefProp::createLinks()
|
||||
{
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
@ -214,7 +310,7 @@ BeliefProp::createLinks (void)
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::maxResidualSchedule (void)
|
||||
BeliefProp::maxResidualSchedule()
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
@ -227,11 +323,13 @@ BeliefProp::maxResidualSchedule (void)
|
||||
|
||||
for (size_t c = 0; c < links_.size(); c++) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "current residuals:" << endl;
|
||||
std::cout << "current residuals:" << std::endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); ++it) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->residual() << endl;
|
||||
std::cout << " " << std::setw (30) << std::left;
|
||||
std::cout << (*it)->toString();
|
||||
std::cout << "residual = " << (*it)->residual();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -249,7 +347,7 @@ BeliefProp::maxResidualSchedule (void)
|
||||
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
||||
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
||||
if (factorNeighbors[i] != link->facNode()) {
|
||||
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
|
||||
const BpLinks& links = getLinks (factorNeighbors[i]);
|
||||
for (size_t j = 0; j < links.size(); j++) {
|
||||
if (links[j]->varNode() != link->varNode()) {
|
||||
calculateMessage (links[j]);
|
||||
@ -273,7 +371,7 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||
{
|
||||
FacNode* src = link->facNode();
|
||||
const VarNode* dst = link->varNode();
|
||||
const BpLinks& links = ninf(src)->getLinks();
|
||||
const BpLinks& links = getLinks (src);
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dst'
|
||||
unsigned reps = 1;
|
||||
@ -282,14 +380,14 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||
if (Globals::logDomain) {
|
||||
for (size_t i = links.size(); i-- > 0; ) {
|
||||
if (links[i]->varNode() != dst) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message from " << links[i]->varNode()->label();
|
||||
cout << ": " ;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message from " << links[i]->varNode()->label();
|
||||
std::cout << ": " ;
|
||||
}
|
||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||
reps, std::plus<double>());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
reps *= links[i]->varNode()->range();
|
||||
@ -297,14 +395,14 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||
} else {
|
||||
for (size_t i = links.size(); i-- > 0; ) {
|
||||
if (links[i]->varNode() != dst) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message from " << links[i]->varNode()->label();
|
||||
cout << ": " ;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message from " << links[i]->varNode()->label();
|
||||
std::cout << ": " ;
|
||||
}
|
||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||
reps, std::multiplies<double>());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
reps *= links[i]->varNode()->range();
|
||||
@ -313,27 +411,28 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||
Factor result (src->factor().arguments(),
|
||||
src->factor().ranges(), msgProduct);
|
||||
result.multiply (src->factor());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->factor().params() << endl;
|
||||
cout << " factor product: " << result.params() << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message product: " << msgProduct << std::endl;
|
||||
std::cout << " original factor: " << src->factor().params();
|
||||
std::cout << std::endl;
|
||||
std::cout << " factor product: " << result.params() << std::endl;
|
||||
}
|
||||
result.sumOutAllExcept (dst->varId());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " marginalized: " << result.params() << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " marginalized: " << result.params() << std::endl;
|
||||
}
|
||||
link->nextMessage() = result.params();
|
||||
LogAware::normalize (link->nextMessage());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " curr msg: " << link->message() << endl;
|
||||
cout << " next msg: " << link->nextMessage() << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " curr msg: " << link->message() << std::endl;
|
||||
std::cout << " next msg: " << link->nextMessage() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||
BeliefProp::getVarToFactorMsg (const BpLink* link)
|
||||
{
|
||||
const VarNode* src = link->varNode();
|
||||
Params msg;
|
||||
@ -343,18 +442,18 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||
} else {
|
||||
msg.resize (src->range(), LogAware::one());
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << msg;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << msg;
|
||||
}
|
||||
BpLinks::const_iterator it;
|
||||
const BpLinks& links = ninf (src)->getLinks();
|
||||
const BpLinks& links = getLinks (src);
|
||||
if (Globals::logDomain) {
|
||||
for (it = links.begin(); it != links.end(); ++it) {
|
||||
if (*it != link) {
|
||||
msg += (*it)->message();
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " x " << (*it)->message();
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " x " << (*it)->message();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -362,13 +461,13 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||
if (*it != link) {
|
||||
msg *= (*it)->message();
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " x " << (*it)->message();
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " x " << (*it)->message();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " = " << msg;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
@ -379,37 +478,37 @@ Params
|
||||
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
return GroundSolver::getJointByConditioning (
|
||||
GroundSolverType::BP, fg, jointVarIds);
|
||||
GroundSolverType::bpSolver, fg, jointVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::initializeSolver (void)
|
||||
BeliefProp::initializeSolver()
|
||||
{
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
varsI_.reserve (varNodes.size());
|
||||
varsLinks_.reserve (varNodes.size());
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
varsI_.push_back (new SPNodeInfo());
|
||||
varsLinks_.push_back (BpLinks());
|
||||
}
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
facsI_.reserve (facNodes.size());
|
||||
facsLinks_.reserve (facNodes.size());
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
facsI_.push_back (new SPNodeInfo());
|
||||
facsLinks_.push_back (BpLinks());
|
||||
}
|
||||
createLinks();
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
FacNode* src = links_[i]->facNode();
|
||||
VarNode* dst = links_[i]->varNode();
|
||||
ninf (dst)->addBpLink (links_[i]);
|
||||
ninf (src)->addBpLink (links_[i]);
|
||||
getLinks (dst).push_back (links_[i]);
|
||||
getLinks (src).push_back (links_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BeliefProp::converged (void)
|
||||
BeliefProp::converged()
|
||||
{
|
||||
if (links_.empty()) {
|
||||
return true;
|
||||
@ -418,16 +517,16 @@ BeliefProp::converged (void)
|
||||
return false;
|
||||
}
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
if (nIters_ == 1) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "no residuals" << endl << endl;
|
||||
std::cout << "no residuals" << std::endl << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (schedule_ == MsgSchedule::MAX_RESIDUAL) {
|
||||
if (schedule_ == MsgSchedule::maxResidualSch) {
|
||||
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
||||
if (maxResidual > accuracy_) {
|
||||
converged = false;
|
||||
@ -438,7 +537,8 @@ BeliefProp::converged (void)
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->residual();
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
std::cout << links_[i]->toString() + " residual = " << residual;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
if (residual > accuracy_) {
|
||||
converged = false;
|
||||
@ -448,7 +548,7 @@ BeliefProp::converged (void)
|
||||
}
|
||||
}
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
@ -457,8 +557,10 @@ BeliefProp::converged (void)
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::printLinkInformation (void) const
|
||||
BeliefProp::printLinkInformation() const
|
||||
{
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
BpLink* l = links_[i];
|
||||
cout << l->toString() << ":" << endl;
|
||||
@ -470,3 +572,5 @@ BeliefProp::printLinkInformation (void) const
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,111 +1,35 @@
|
||||
#ifndef HORUS_BELIEFPROP_H
|
||||
#define HORUS_BELIEFPROP_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include <sstream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "GroundSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
enum MsgSchedule {
|
||||
SEQ_FIXED,
|
||||
SEQ_RANDOM,
|
||||
PARALLEL,
|
||||
MAX_RESIDUAL
|
||||
};
|
||||
|
||||
|
||||
class BpLink
|
||||
{
|
||||
public:
|
||||
BpLink (FacNode* fn, VarNode* vn)
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
virtual ~BpLink (void) { };
|
||||
|
||||
FacNode* facNode (void) const { return fac_; }
|
||||
|
||||
VarNode* varNode (void) const { return var_; }
|
||||
|
||||
const Params& message (void) const { return *currMsg_; }
|
||||
|
||||
Params& nextMessage (void) { return *nextMsg_; }
|
||||
|
||||
double residual (void) const { return residual_; }
|
||||
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
||||
}
|
||||
|
||||
virtual void updateMessage (void)
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << fac_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->label();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
FacNode* fac_;
|
||||
VarNode* var_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
double residual_;
|
||||
namespace Horus {
|
||||
|
||||
class BeliefProp : public GroundSolver {
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BpLink);
|
||||
};
|
||||
class SPNodeInfo;
|
||||
|
||||
typedef vector<BpLink*> BpLinks;
|
||||
|
||||
|
||||
class SPNodeInfo
|
||||
{
|
||||
public:
|
||||
SPNodeInfo (void) { }
|
||||
void addBpLink (BpLink* link) { links_.push_back (link); }
|
||||
const BpLinks& getLinks (void) { return links_; }
|
||||
private:
|
||||
BpLinks links_;
|
||||
DISALLOW_COPY_AND_ASSIGN (SPNodeInfo);
|
||||
};
|
||||
enum class MsgSchedule {
|
||||
seqFixedSch,
|
||||
seqRandomSch,
|
||||
parallelSch,
|
||||
maxResidualSch
|
||||
};
|
||||
|
||||
|
||||
class BeliefProp : public GroundSolver
|
||||
{
|
||||
public:
|
||||
BeliefProp (const FactorGraph&);
|
||||
|
||||
virtual ~BeliefProp (void);
|
||||
virtual ~BeliefProp();
|
||||
|
||||
Params solveQuery (VarIds);
|
||||
|
||||
virtual void printSolverFlags (void) const;
|
||||
virtual void printSolverFlags() const;
|
||||
|
||||
virtual Params getPosterioriOf (VarId);
|
||||
|
||||
@ -113,105 +37,128 @@ class BeliefProp : public GroundSolver
|
||||
|
||||
Params getFactorJoint (FacNode* fn, const VarIds&);
|
||||
|
||||
static double accuracy (void) { return accuracy_; }
|
||||
static double accuracy() { return accuracy_; }
|
||||
|
||||
static void setAccuracy (double acc) { accuracy_ = acc; }
|
||||
|
||||
static unsigned maxIterations (void) { return maxIter_; }
|
||||
static unsigned maxIterations() { return maxIter_; }
|
||||
|
||||
static void setMaxIterations (unsigned mi) { maxIter_ = mi; }
|
||||
|
||||
static MsgSchedule msgSchedule (void) { return schedule_; }
|
||||
static MsgSchedule msgSchedule() { return schedule_; }
|
||||
|
||||
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
|
||||
|
||||
protected:
|
||||
SPNodeInfo* ninf (const VarNode* var) const
|
||||
{
|
||||
return varsI_[var->getIndex()];
|
||||
}
|
||||
class BpLink {
|
||||
public:
|
||||
BpLink (FacNode* fn, VarNode* vn);
|
||||
|
||||
SPNodeInfo* ninf (const FacNode* fac) const
|
||||
{
|
||||
return facsI_[fac->getIndex()];
|
||||
}
|
||||
virtual ~BpLink() { };
|
||||
|
||||
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
calcFactorToVarMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
FacNode* facNode() const { return fac_; }
|
||||
|
||||
void calculateMessage (BpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
calcFactorToVarMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
VarNode* varNode() const { return var_; }
|
||||
|
||||
void updateMessage (BpLink* link)
|
||||
{
|
||||
link->updateMessage();
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
}
|
||||
const Params& message() const { return *currMsg_; }
|
||||
|
||||
struct CompareResidual
|
||||
{
|
||||
inline bool operator() (const BpLink* link1, const BpLink* link2)
|
||||
{
|
||||
return link1->residual() > link2->residual();
|
||||
}
|
||||
Params& nextMessage() { return *nextMsg_; }
|
||||
|
||||
double residual() const { return residual_; }
|
||||
|
||||
void clearResidual();
|
||||
|
||||
void updateResidual();
|
||||
|
||||
virtual void updateMessage();
|
||||
|
||||
std::string toString() const;
|
||||
|
||||
protected:
|
||||
FacNode* fac_;
|
||||
VarNode* var_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
double residual_;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BpLink);
|
||||
};
|
||||
|
||||
void runSolver (void);
|
||||
struct CmpResidual {
|
||||
bool operator() (const BpLink* l1, const BpLink* l2) {
|
||||
return l1->residual() > l2->residual();
|
||||
}};
|
||||
|
||||
virtual void createLinks (void);
|
||||
typedef std::vector<BeliefProp::BpLink*> BpLinks;
|
||||
typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
|
||||
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||
|
||||
virtual void maxResidualSchedule (void);
|
||||
BpLinks& getLinks (const VarNode* var);
|
||||
|
||||
BpLinks& getLinks (const FacNode* fac);
|
||||
|
||||
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true);
|
||||
|
||||
void calculateMessage (BpLink* link, bool calcResidual = true);
|
||||
|
||||
void updateMessage (BpLink* link);
|
||||
|
||||
void runSolver();
|
||||
|
||||
virtual void createLinks();
|
||||
|
||||
virtual void maxResidualSchedule();
|
||||
|
||||
virtual void calcFactorToVarMsg (BpLink*);
|
||||
|
||||
virtual Params getVarToFactorMsg (const BpLink*) const;
|
||||
virtual Params getVarToFactorMsg (const BpLink*);
|
||||
|
||||
virtual Params getJointByConditioning (const VarIds&) const;
|
||||
|
||||
BpLinks links_;
|
||||
unsigned nIters_;
|
||||
vector<SPNodeInfo*> varsI_;
|
||||
vector<SPNodeInfo*> facsI_;
|
||||
bool runned_;
|
||||
BpLinks links_;
|
||||
unsigned nIters_;
|
||||
bool runned_;
|
||||
SortedOrder sortedOrder_;
|
||||
BpLinkMap linkMap_;
|
||||
|
||||
typedef multiset<BpLink*, CompareResidual> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||
BpLinkMap linkMap_;
|
||||
|
||||
static double accuracy_;
|
||||
static unsigned maxIter_;
|
||||
static MsgSchedule schedule_;
|
||||
static double accuracy_;
|
||||
|
||||
private:
|
||||
void initializeSolver (void);
|
||||
void initializeSolver();
|
||||
|
||||
bool converged (void);
|
||||
bool converged();
|
||||
|
||||
virtual void printLinkInformation (void) const;
|
||||
virtual void printLinkInformation() const;
|
||||
|
||||
std::vector<BpLinks> varsLinks_;
|
||||
std::vector<BpLinks> facsLinks_;
|
||||
|
||||
static unsigned maxIter_;
|
||||
static MsgSchedule schedule_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (BeliefProp);
|
||||
};
|
||||
|
||||
#endif // HORUS_BELIEFPROP_H
|
||||
|
||||
|
||||
inline BeliefProp::BpLinks&
|
||||
BeliefProp::getLinks (const VarNode* var)
|
||||
{
|
||||
return varsLinks_[var->getIndex()];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline BeliefProp::BpLinks&
|
||||
BeliefProp::getLinks (const FacNode* fac)
|
||||
{
|
||||
return facsLinks_[fac->getIndex()];
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_
|
||||
|
||||
|
@ -1,11 +1,88 @@
|
||||
#include <queue>
|
||||
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <fstream>
|
||||
|
||||
#include "ConstraintTree.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class CTNode {
|
||||
public:
|
||||
CTNode (const CTNode& n, const CTChilds& chs = CTChilds())
|
||||
: symbol_(n.symbol()), childs_(chs), level_(n.level()) { }
|
||||
|
||||
CTNode (Symbol s, unsigned l, const CTChilds& chs = CTChilds())
|
||||
: symbol_(s), childs_(chs), level_(l) { }
|
||||
|
||||
unsigned level() const { return level_; }
|
||||
|
||||
void setLevel (unsigned level) { level_ = level; }
|
||||
|
||||
Symbol symbol() const { return symbol_; }
|
||||
|
||||
void setSymbol (Symbol s) { symbol_ = s; }
|
||||
|
||||
CTChilds& childs() { return childs_; }
|
||||
|
||||
const CTChilds& childs() const { return childs_; }
|
||||
|
||||
size_t nrChilds() const { return childs_.size(); }
|
||||
|
||||
bool isRoot() const { return level_ == 0; }
|
||||
|
||||
bool isLeaf() const { return childs_.empty(); }
|
||||
|
||||
CTChilds::iterator findSymbol (Symbol symb);
|
||||
|
||||
void mergeSubtree (CTNode*, bool = true);
|
||||
|
||||
void removeChild (CTNode*);
|
||||
|
||||
void removeChilds();
|
||||
|
||||
void removeAndDeleteChild (CTNode*);
|
||||
|
||||
void removeAndDeleteAllChilds();
|
||||
|
||||
SymbolSet childSymbols() const;
|
||||
|
||||
static CTNode* copySubtree (const CTNode*);
|
||||
|
||||
static void deleteSubtree (CTNode*);
|
||||
|
||||
private:
|
||||
void updateChildLevels (CTNode*, unsigned);
|
||||
|
||||
Symbol symbol_;
|
||||
CTChilds childs_;
|
||||
unsigned level_;
|
||||
|
||||
DISALLOW_ASSIGN (CTNode);
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline CTChilds::iterator
|
||||
CTNode::findSymbol (Symbol symb)
|
||||
{
|
||||
CTNode tmp (symb, 0);
|
||||
return childs_.find (&tmp);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
CmpSymbol::operator() (const CTNode* n1, const CTNode* n2) const
|
||||
{
|
||||
return n1->symbol() < n2->symbol();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CTNode::mergeSubtree (CTNode* n, bool updateLevels)
|
||||
{
|
||||
@ -38,7 +115,7 @@ CTNode::removeChild (CTNode* child)
|
||||
|
||||
|
||||
void
|
||||
CTNode::removeChilds (void)
|
||||
CTNode::removeChilds()
|
||||
{
|
||||
childs_.clear();
|
||||
}
|
||||
@ -55,7 +132,7 @@ CTNode::removeAndDeleteChild (CTNode* child)
|
||||
|
||||
|
||||
void
|
||||
CTNode::removeAndDeleteAllChilds (void)
|
||||
CTNode::removeAndDeleteAllChilds()
|
||||
{
|
||||
for (CTChilds::const_iterator chIt = childs_.begin();
|
||||
chIt != childs_.end(); ++ chIt) {
|
||||
@ -67,7 +144,7 @@ CTNode::removeAndDeleteAllChilds (void)
|
||||
|
||||
|
||||
SymbolSet
|
||||
CTNode::childSymbols (void) const
|
||||
CTNode::childSymbols() const
|
||||
{
|
||||
SymbolSet symbols;
|
||||
for (CTChilds::const_iterator chIt = childs_.begin();
|
||||
@ -106,14 +183,14 @@ CTNode::copySubtree (const CTNode* root1)
|
||||
return new CTNode (*root1);
|
||||
}
|
||||
CTNode* root2 = new CTNode (*root1);
|
||||
typedef pair<const CTNode*, CTNode*> StackPair;
|
||||
vector<StackPair> stack = { StackPair (root1, root2) };
|
||||
typedef std::pair<const CTNode*, CTNode*> StackPair;
|
||||
std::vector<StackPair> stack = { StackPair (root1, root2) };
|
||||
while (stack.empty() == false) {
|
||||
const CTNode* n1 = stack.back().first;
|
||||
CTNode* n2 = stack.back().second;
|
||||
stack.pop_back();
|
||||
// cout << "n2 childs: " << n2->childs();
|
||||
// cout << "n1 childs: " << n1->childs();
|
||||
// std::cout << "n2 childs: " << n2->childs();
|
||||
// std::cout << "n1 childs: " << n1->childs();
|
||||
n2->childs().reserve (n1->nrChilds());
|
||||
stack.reserve (n1->nrChilds());
|
||||
for (CTChilds::const_iterator chIt = n1->childs().begin();
|
||||
@ -144,7 +221,8 @@ CTNode::deleteSubtree (CTNode* n)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &out, const CTNode& n)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& out, const CTNode& n)
|
||||
{
|
||||
out << "(" << n.level() << ") " ;
|
||||
out << n.symbol();
|
||||
@ -187,7 +265,8 @@ ConstraintTree::ConstraintTree (
|
||||
|
||||
|
||||
|
||||
ConstraintTree::ConstraintTree (vector<vector<string>> names)
|
||||
ConstraintTree::ConstraintTree (
|
||||
std::vector<std::vector<std::string>> names)
|
||||
{
|
||||
assert (names.empty() == false);
|
||||
assert (names.front().empty() == false);
|
||||
@ -216,13 +295,33 @@ ConstraintTree::ConstraintTree (const ConstraintTree& ct)
|
||||
|
||||
|
||||
|
||||
ConstraintTree::~ConstraintTree (void)
|
||||
ConstraintTree::ConstraintTree (
|
||||
const CTChilds& rootChilds,
|
||||
const LogVars& logVars)
|
||||
: root_(new CTNode (Symbol (0), unsigned (0), rootChilds)),
|
||||
logVars_(logVars),
|
||||
logVarSet_(logVars)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
ConstraintTree::~ConstraintTree()
|
||||
{
|
||||
CTNode::deleteSubtree (root_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ConstraintTree::empty() const
|
||||
{
|
||||
return root_->childs().empty();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ConstraintTree::addTuple (const Tuple& tuple)
|
||||
{
|
||||
@ -448,7 +547,7 @@ ConstraintTree::ConstraintTree::isSingleton (LogVar X)
|
||||
|
||||
|
||||
LogVarSet
|
||||
ConstraintTree::singletons (void)
|
||||
ConstraintTree::singletons()
|
||||
{
|
||||
LogVarSet singletons;
|
||||
for (size_t i = 0; i < logVars_.size(); i++) {
|
||||
@ -491,7 +590,7 @@ ConstraintTree::tupleSet (const LogVars& originalLvs)
|
||||
getTuples (root_, Tuples(), stopLevel, tuples, CTNodes() = {});
|
||||
|
||||
if (originalLvs.size() != uniqueLvs.size()) {
|
||||
vector<size_t> indexes;
|
||||
std::vector<size_t> indexes;
|
||||
indexes.reserve (originalLvs.size());
|
||||
for (size_t i = 0; i < originalLvs.size(); i++) {
|
||||
indexes.push_back (Util::indexOf (uniqueLvs, originalLvs[i]));
|
||||
@ -519,21 +618,22 @@ ConstraintTree::exportToGraphViz (
|
||||
const char* fileName,
|
||||
bool showLogVars) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << "digraph {" << endl;
|
||||
out << "digraph {" << std::endl;
|
||||
ConstraintTree copy (*this);
|
||||
copy.moveToTop (copy.logVarSet_.elements());
|
||||
CTNodes nodes = getNodesBelow (copy.root_);
|
||||
out << "\"" << copy.root_ << "\"" << " [label=\"R\"]" << endl;
|
||||
out << "\"" << copy.root_ << "\"" << " [label=\"R\"]" << std::endl;
|
||||
for (CTNodes::const_iterator it = ++ nodes.begin();
|
||||
it != nodes.end(); ++ it) {
|
||||
out << "\"" << *it << "\"";
|
||||
out << " [label=\"" << **it << "\"]" ;
|
||||
out << endl;
|
||||
out << std::endl;
|
||||
}
|
||||
for (CTNodes::const_iterator it = nodes.begin();
|
||||
it != nodes.end(); ++ it) {
|
||||
@ -542,24 +642,24 @@ ConstraintTree::exportToGraphViz (
|
||||
chIt != childs.end(); ++ chIt) {
|
||||
out << "\"" << *it << "\"" ;
|
||||
out << " -> " ;
|
||||
out << "\"" << *chIt << "\"" << endl ;
|
||||
out << "\"" << *chIt << "\"" << std::endl ;
|
||||
}
|
||||
}
|
||||
if (showLogVars) {
|
||||
out << "Root [label=\"\", shape=plaintext]" << endl;
|
||||
out << "Root [label=\"\", shape=plaintext]" << std::endl;
|
||||
for (size_t i = 0; i < copy.logVars_.size(); i++) {
|
||||
out << copy.logVars_[i] << " [label=" ;
|
||||
out << copy.logVars_[i] << ", " ;
|
||||
out << "shape=plaintext, fontsize=14]" << endl;
|
||||
out << "shape=plaintext, fontsize=14]" << std::endl;
|
||||
}
|
||||
out << "Root -> " << copy.logVars_[0];
|
||||
out << " [style=invis]" << endl;
|
||||
out << " [style=invis]" << std::endl;
|
||||
for (size_t i = 0; i < copy.logVars_.size() - 1; i++) {
|
||||
out << copy.logVars_[i] << " -> " << copy.logVars_[i + 1];
|
||||
out << " [style=invis]" << endl;
|
||||
out << " [style=invis]" << std::endl;
|
||||
}
|
||||
}
|
||||
out << "}" << endl;
|
||||
out << "}" <<std::endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
@ -690,9 +790,9 @@ ConstraintTree::split (
|
||||
split (root_, ct->root(), commChilds, exclChilds, stopLevel);
|
||||
ConstraintTree* commCt = new ConstraintTree (commChilds, logVars_);
|
||||
ConstraintTree* exclCt = new ConstraintTree (exclChilds, logVars_);
|
||||
// cout << commCt->tupleSet() << " + " ;
|
||||
// cout << exclCt->tupleSet() << " = " ;
|
||||
// cout << tupleSet() << endl;
|
||||
// std::cout << commCt->tupleSet() << " + " ;
|
||||
// std::cout << exclCt->tupleSet() << " = " ;
|
||||
// std::cout << tupleSet() << std::endl;
|
||||
assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
|
||||
assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
|
||||
return {commCt, exclCt};
|
||||
@ -710,20 +810,20 @@ ConstraintTree::countNormalize (const LogVarSet& Ys)
|
||||
}
|
||||
moveToTop (Zs.elements());
|
||||
ConstraintTrees cts;
|
||||
unordered_map<unsigned, ConstraintTree*> countMap;
|
||||
std::unordered_map<unsigned, ConstraintTree*> countMap;
|
||||
unsigned stopLevel = getLevel (Zs.back());
|
||||
const CTChilds& childs = root_->childs();
|
||||
|
||||
for (CTChilds::const_iterator chIt = childs.begin();
|
||||
chIt != childs.end(); ++ chIt) {
|
||||
const vector<pair<CTNode*, unsigned>>& res =
|
||||
const std::vector<std::pair<CTNode*, unsigned>>& res =
|
||||
countNormalize (*chIt, stopLevel);
|
||||
for (size_t j = 0; j < res.size(); j++) {
|
||||
unordered_map<unsigned, ConstraintTree*>::iterator it
|
||||
std::unordered_map<unsigned, ConstraintTree*>::iterator it
|
||||
= countMap.find (res[j].second);
|
||||
if (it == countMap.end()) {
|
||||
ConstraintTree* newCt = new ConstraintTree (logVars_);
|
||||
it = countMap.insert (make_pair (res[j].second, newCt)).first;
|
||||
it = countMap.insert (std::make_pair (res[j].second, newCt)).first;
|
||||
cts.push_back (newCt);
|
||||
}
|
||||
it->second->root_->mergeSubtree (res[j].first);
|
||||
@ -743,31 +843,31 @@ ConstraintTree::jointCountNormalize (
|
||||
LogVar X_new2)
|
||||
{
|
||||
unsigned N = getConditionalCount (X);
|
||||
// cout << "My tuples: " << tupleSet() << endl;
|
||||
// cout << "CommCt tuples: " << commCt->tupleSet() << endl;
|
||||
// cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
|
||||
// cout << "Counted Lv: " << X << endl;
|
||||
// cout << "X_new1: " << X_new1 << endl;
|
||||
// cout << "X_new2: " << X_new2 << endl;
|
||||
// cout << "Original N: " << N << endl;
|
||||
// cout << endl;
|
||||
// std::cout << "My tuples: " << tupleSet() << std::endl;
|
||||
// std::cout << "CommCt tuples: " << commCt->tupleSet() << std::endl;
|
||||
// std::cout << "ExclCt tuples: " << exclCt->tupleSet() << std::endl;
|
||||
// std::cout << "Counted Lv: " << X << std::endl;
|
||||
// std::cout << "X_new1: " << X_new1 << std::endl;
|
||||
// std::cout << "X_new2: " << X_new2 << std::endl;
|
||||
// std::cout << "Original N: " << N << std::endl;
|
||||
// std::cout << endl;
|
||||
|
||||
ConstraintTrees normCts1 = commCt->countNormalize (X);
|
||||
vector<unsigned> counts1 (normCts1.size());
|
||||
std::vector<unsigned> counts1 (normCts1.size());
|
||||
for (size_t i = 0; i < normCts1.size(); i++) {
|
||||
counts1[i] = normCts1[i]->getConditionalCount (X);
|
||||
// cout << "normCts1[" << i << "] #" << counts1[i] ;
|
||||
// cout << " " << normCts1[i]->tupleSet() << endl;
|
||||
// std::cout << "normCts1[" << i << "] #" << counts1[i] ;
|
||||
// std::cout << " " << normCts1[i]->tupleSet() << std::endl;
|
||||
}
|
||||
|
||||
ConstraintTrees normCts2 = exclCt->countNormalize (X);
|
||||
vector<unsigned> counts2 (normCts2.size());
|
||||
std::vector<unsigned> counts2 (normCts2.size());
|
||||
for (size_t i = 0; i < normCts2.size(); i++) {
|
||||
counts2[i] = normCts2[i]->getConditionalCount (X);
|
||||
// cout << "normCts2[" << i << "] #" << counts2[i] ;
|
||||
// cout << " " << normCts2[i]->tupleSet() << endl;
|
||||
// std::cout << "normCts2[" << i << "] #" << counts2[i] ;
|
||||
// std::cout << " " << normCts2[i]->tupleSet() << std::endl;
|
||||
}
|
||||
// cout << endl;
|
||||
// std::cout << std::endl;
|
||||
|
||||
ConstraintTree* excl1 = 0;
|
||||
for (size_t i = 0; i < normCts1.size(); i++) {
|
||||
@ -775,7 +875,7 @@ ConstraintTree::jointCountNormalize (
|
||||
excl1 = normCts1[i];
|
||||
normCts1.erase (normCts1.begin() + i);
|
||||
counts1.erase (counts1.begin() + i);
|
||||
// cout << "joint-count(" << N << ",0)" << endl;
|
||||
// std::cout << "joint-count(" << N << ",0)" << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -786,7 +886,7 @@ ConstraintTree::jointCountNormalize (
|
||||
excl2 = normCts2[i];
|
||||
normCts2.erase (normCts2.begin() + i);
|
||||
counts2.erase (counts2.begin() + i);
|
||||
// cout << "joint-count(0," << N << ")" << endl;
|
||||
// std::cout << "joint-count(0," << N << ")" << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -794,8 +894,8 @@ ConstraintTree::jointCountNormalize (
|
||||
for (size_t i = 0; i < normCts1.size(); i++) {
|
||||
unsigned j;
|
||||
for (j = 0; counts1[i] + counts2[j] != N; j++) ;
|
||||
// cout << "joint-count(" << counts1[i] ;
|
||||
// cout << "," << counts2[j] << ")" << endl;
|
||||
// std::cout << "joint-count(" << counts1[i] ;
|
||||
// std::cout << "," << counts2[j] << ")" << std::endl;
|
||||
const CTChilds& childs = normCts2[j]->root_->childs();
|
||||
for (CTChilds::const_iterator chIt = childs.begin();
|
||||
chIt != childs.end(); ++ chIt) {
|
||||
@ -930,7 +1030,7 @@ CTNodes
|
||||
ConstraintTree::getNodesBelow (CTNode* fromHere) const
|
||||
{
|
||||
CTNodes nodes;
|
||||
queue<CTNode*> queue;
|
||||
std::queue<CTNode*> queue;
|
||||
queue.push (fromHere);
|
||||
while (queue.empty() == false) {
|
||||
CTNode* node = queue.front();
|
||||
@ -1016,7 +1116,7 @@ ConstraintTree::swapLogVar (LogVar X)
|
||||
{
|
||||
size_t pos = Util::indexOf (logVars_, X);
|
||||
assert (pos != logVars_.size());
|
||||
const CTNodes& nodes = getNodesAtLevel (pos);
|
||||
CTNodes nodes = getNodesAtLevel (pos);
|
||||
for (CTNodes::const_iterator nodeIt = nodes.begin();
|
||||
nodeIt != nodes.end(); ++ nodeIt) {
|
||||
CTChilds childsCopy = (*nodeIt)->childs();
|
||||
@ -1098,7 +1198,7 @@ ConstraintTree::getTuples (
|
||||
|
||||
|
||||
unsigned
|
||||
ConstraintTree::size (void) const
|
||||
ConstraintTree::size() const
|
||||
{
|
||||
return countTuples (root_);
|
||||
}
|
||||
@ -1114,26 +1214,26 @@ ConstraintTree::nrSymbols (LogVar X)
|
||||
|
||||
|
||||
|
||||
vector<pair<CTNode*, unsigned>>
|
||||
std::vector<std::pair<CTNode*, unsigned>>
|
||||
ConstraintTree::countNormalize (
|
||||
const CTNode* n,
|
||||
unsigned stopLevel)
|
||||
{
|
||||
if (n->level() == stopLevel) {
|
||||
return vector<pair<CTNode*, unsigned>>() = {
|
||||
make_pair (CTNode::copySubtree (n), countTuples (n))
|
||||
return std::vector<std::pair<CTNode*, unsigned>>() = {
|
||||
std::make_pair (CTNode::copySubtree (n), countTuples (n))
|
||||
};
|
||||
}
|
||||
vector<pair<CTNode*, unsigned>> res;
|
||||
std::vector<std::pair<CTNode*, unsigned>> res;
|
||||
const CTChilds& childs = n->childs();
|
||||
for (CTChilds::const_iterator chIt = childs.begin();
|
||||
chIt != childs.end(); ++ chIt) {
|
||||
const vector<pair<CTNode*, unsigned>>& lowerRes =
|
||||
const std::vector<std::pair<CTNode*, unsigned>>& lowerRes =
|
||||
countNormalize (*chIt, stopLevel);
|
||||
for (size_t j = 0; j < lowerRes.size(); j++) {
|
||||
CTNode* newNode = new CTNode (*n);
|
||||
newNode->mergeSubtree (lowerRes[j].first);
|
||||
res.push_back (make_pair (newNode, lowerRes[j].second));
|
||||
res.push_back (std::make_pair (newNode, lowerRes[j].second));
|
||||
}
|
||||
}
|
||||
return res;
|
||||
@ -1172,3 +1272,5 @@ ConstraintTree::split (
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,104 +1,35 @@
|
||||
#ifndef HORUS_CONSTRAINTTREE_H
|
||||
#define HORUS_CONSTRAINTTREE_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_CONSTRAINTTREE_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_CONSTRAINTTREE_H_
|
||||
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
#include "TinySet.h"
|
||||
#include "LiftedUtils.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class CTNode;
|
||||
typedef vector<CTNode*> CTNodes;
|
||||
|
||||
class ConstraintTree;
|
||||
typedef vector<ConstraintTree*> ConstraintTrees;
|
||||
|
||||
|
||||
class CTNode
|
||||
{
|
||||
public:
|
||||
struct CompareSymbol
|
||||
{
|
||||
bool operator() (const CTNode* n1, const CTNode* n2) const
|
||||
{
|
||||
return n1->symbol() < n2->symbol();
|
||||
}
|
||||
};
|
||||
typedef std::vector<CTNode*> CTNodes;
|
||||
typedef std::vector<ConstraintTree*> ConstraintTrees;
|
||||
|
||||
private:
|
||||
typedef TinySet<CTNode*, CompareSymbol> CTChilds_;
|
||||
|
||||
public:
|
||||
CTNode (const CTNode& n, const CTChilds_& chs = CTChilds_())
|
||||
: symbol_(n.symbol()), childs_(chs), level_(n.level()) { }
|
||||
|
||||
CTNode (Symbol s, unsigned l, const CTChilds_& chs = CTChilds_())
|
||||
: symbol_(s), childs_(chs), level_(l) { }
|
||||
|
||||
unsigned level (void) const { return level_; }
|
||||
|
||||
void setLevel (unsigned level) { level_ = level; }
|
||||
|
||||
Symbol symbol (void) const { return symbol_; }
|
||||
|
||||
void setSymbol (const Symbol s) { symbol_ = s; }
|
||||
|
||||
CTChilds_& childs (void) { return childs_; }
|
||||
|
||||
const CTChilds_& childs (void) const { return childs_; }
|
||||
|
||||
size_t nrChilds (void) const { return childs_.size(); }
|
||||
|
||||
bool isRoot (void) const { return level_ == 0; }
|
||||
|
||||
bool isLeaf (void) const { return childs_.empty(); }
|
||||
|
||||
CTChilds_::iterator findSymbol (Symbol symb)
|
||||
{
|
||||
CTNode tmp (symb, 0);
|
||||
return childs_.find (&tmp);
|
||||
}
|
||||
|
||||
void mergeSubtree (CTNode*, bool = true);
|
||||
|
||||
void removeChild (CTNode*);
|
||||
|
||||
void removeChilds (void);
|
||||
|
||||
void removeAndDeleteChild (CTNode*);
|
||||
|
||||
void removeAndDeleteAllChilds (void);
|
||||
|
||||
SymbolSet childSymbols (void) const;
|
||||
|
||||
static CTNode* copySubtree (const CTNode*);
|
||||
|
||||
static void deleteSubtree (CTNode*);
|
||||
|
||||
private:
|
||||
void updateChildLevels (CTNode*, unsigned);
|
||||
|
||||
Symbol symbol_;
|
||||
CTChilds_ childs_;
|
||||
unsigned level_;
|
||||
|
||||
DISALLOW_ASSIGN (CTNode);
|
||||
struct CmpSymbol {
|
||||
bool operator() (const CTNode* n1, const CTNode* n2) const;
|
||||
};
|
||||
|
||||
ostream& operator<< (ostream &out, const CTNode&);
|
||||
|
||||
typedef TinySet<CTNode*, CmpSymbol> CTChilds;
|
||||
|
||||
|
||||
typedef TinySet<CTNode*, CTNode::CompareSymbol> CTChilds;
|
||||
|
||||
|
||||
class ConstraintTree
|
||||
{
|
||||
class ConstraintTree {
|
||||
public:
|
||||
ConstraintTree (unsigned);
|
||||
|
||||
@ -106,38 +37,23 @@ class ConstraintTree
|
||||
|
||||
ConstraintTree (const LogVars&, const Tuples&);
|
||||
|
||||
ConstraintTree (vector<vector<string>> names);
|
||||
ConstraintTree (std::vector<std::vector<std::string>> names);
|
||||
|
||||
ConstraintTree (const ConstraintTree&);
|
||||
|
||||
ConstraintTree (const CTChilds& rootChilds, const LogVars& logVars)
|
||||
: root_(new CTNode (0, 0, rootChilds)),
|
||||
logVars_(logVars),
|
||||
logVarSet_(logVars) { }
|
||||
ConstraintTree (const CTChilds& rootChilds, const LogVars& logVars);
|
||||
|
||||
~ConstraintTree (void);
|
||||
~ConstraintTree();
|
||||
|
||||
CTNode* root (void) const { return root_; }
|
||||
CTNode* root() const { return root_; }
|
||||
|
||||
bool empty (void) const { return root_->childs().empty(); }
|
||||
bool empty() const;
|
||||
|
||||
const LogVars& logVars (void) const
|
||||
{
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
return logVars_;
|
||||
}
|
||||
const LogVars& logVars() const;
|
||||
|
||||
const LogVarSet& logVarSet (void) const
|
||||
{
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
return logVarSet_;
|
||||
}
|
||||
const LogVarSet& logVarSet() const;
|
||||
|
||||
size_t nrLogVars (void) const
|
||||
{
|
||||
return logVars_.size();
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
}
|
||||
size_t nrLogVars() const;
|
||||
|
||||
void addTuple (const Tuple&);
|
||||
|
||||
@ -163,13 +79,13 @@ class ConstraintTree
|
||||
|
||||
bool isSingleton (LogVar);
|
||||
|
||||
LogVarSet singletons (void);
|
||||
LogVarSet singletons();
|
||||
|
||||
TupleSet tupleSet (unsigned = 0) const;
|
||||
|
||||
TupleSet tupleSet (const LogVars&);
|
||||
|
||||
unsigned size (void) const;
|
||||
unsigned size() const;
|
||||
|
||||
unsigned nrSymbols (LogVar);
|
||||
|
||||
@ -218,11 +134,10 @@ class ConstraintTree
|
||||
|
||||
void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const;
|
||||
|
||||
vector<std::pair<CTNode*, unsigned>> countNormalize (
|
||||
std::vector<std::pair<CTNode*, unsigned>> countNormalize (
|
||||
const CTNode*, unsigned);
|
||||
|
||||
static void split (
|
||||
CTNode*, CTNode*, CTChilds&, CTChilds&, unsigned);
|
||||
static void split (CTNode*, CTNode*, CTChilds&, CTChilds&, unsigned);
|
||||
|
||||
CTNode* root_;
|
||||
LogVars logVars_;
|
||||
@ -230,5 +145,33 @@ class ConstraintTree
|
||||
};
|
||||
|
||||
|
||||
#endif // HORUS_CONSTRAINTTREE_H
|
||||
|
||||
inline const LogVars&
|
||||
ConstraintTree::logVars() const
|
||||
{
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
return logVars_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline const LogVarSet&
|
||||
ConstraintTree::logVarSet() const
|
||||
{
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
return logVarSet_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline size_t
|
||||
ConstraintTree::nrLogVars() const
|
||||
{
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
return logVars_.size();
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_CONSTRAINTTREE_H_
|
||||
|
||||
|
@ -1,7 +1,62 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "CountingBp.h"
|
||||
#include "WeightedBp.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class VarCluster {
|
||||
public:
|
||||
VarCluster (const VarNodes& vs) : members_(vs) { }
|
||||
|
||||
const VarNode* first() const { return members_.front(); }
|
||||
|
||||
const VarNodes& members() const { return members_; }
|
||||
|
||||
VarNode* representative() const { return repr_; }
|
||||
|
||||
void setRepresentative (VarNode* vn) { repr_ = vn; }
|
||||
|
||||
private:
|
||||
VarNodes members_;
|
||||
VarNode* repr_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (VarCluster);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class FacCluster {
|
||||
private:
|
||||
typedef std::vector<VarCluster*> VarClusters;
|
||||
|
||||
public:
|
||||
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
|
||||
: members_(fcs), varClusters_(vcs) { }
|
||||
|
||||
const FacNode* first() const { return members_.front(); }
|
||||
|
||||
const FacNodes& members() const { return members_; }
|
||||
|
||||
FacNode* representative() const { return repr_; }
|
||||
|
||||
void setRepresentative (FacNode* fn) { repr_ = fn; }
|
||||
|
||||
VarClusters& varClusters() { return varClusters_; }
|
||||
|
||||
FacNodes members_;
|
||||
FacNode* repr_;
|
||||
VarClusters varClusters_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (FacCluster);
|
||||
};
|
||||
|
||||
|
||||
|
||||
bool CountingBp::fif_ = true;
|
||||
|
||||
|
||||
@ -17,7 +72,7 @@ CountingBp::CountingBp (const FactorGraph& fg)
|
||||
|
||||
|
||||
|
||||
CountingBp::~CountingBp (void)
|
||||
CountingBp::~CountingBp()
|
||||
{
|
||||
delete solver_;
|
||||
delete compressedFg_;
|
||||
@ -32,23 +87,24 @@ CountingBp::~CountingBp (void)
|
||||
|
||||
|
||||
void
|
||||
CountingBp::printSolverFlags (void) const
|
||||
CountingBp::printSolverFlags() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "counting bp [" ;
|
||||
ss << "bp_msg_schedule=" ;
|
||||
typedef WeightedBp::MsgSchedule MsgSchedule;
|
||||
switch (WeightedBp::msgSchedule()) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
|
||||
case MsgSchedule::parallelSch: ss << "parallel"; break;
|
||||
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
||||
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << ",fif=" << Util::toString (CountingBp::fif_);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
std::cout << ss.str() << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@ -69,11 +125,10 @@ CountingBp::solveQuery (VarIds queryVids)
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
if (idx == facNodes.size()) {
|
||||
res = GroundSolver::getJointByConditioning (
|
||||
GroundSolverType::CBP, fg, queryVids);
|
||||
GroundSolverType::CbpSolver, fg, queryVids);
|
||||
} else {
|
||||
VarIds reprArgs;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
@ -124,7 +179,7 @@ CountingBp::findIdenticalFactors()
|
||||
|
||||
|
||||
void
|
||||
CountingBp::setInitialColors (void)
|
||||
CountingBp::setInitialColors()
|
||||
{
|
||||
varColors_.resize (fg.nrVarNodes());
|
||||
facColors_.resize (fg.nrFacNodes());
|
||||
@ -135,7 +190,7 @@ CountingBp::setInitialColors (void)
|
||||
unsigned range = varNodes[i]->range();
|
||||
VarColorMap::iterator it = colorMap.find (range);
|
||||
if (it == colorMap.end()) {
|
||||
it = colorMap.insert (make_pair (
|
||||
it = colorMap.insert (std::make_pair (
|
||||
range, Colors (range + 1, -1))).first;
|
||||
}
|
||||
unsigned idx = varNodes[i]->hasEvidence()
|
||||
@ -154,7 +209,8 @@ CountingBp::setInitialColors (void)
|
||||
unsigned distId = facNodes[i]->factor().distId();
|
||||
DistColorMap::iterator it = distColors.find (distId);
|
||||
if (it == distColors.end()) {
|
||||
it = distColors.insert (make_pair (distId, getNewColor())).first;
|
||||
it = distColors.insert (std::make_pair (
|
||||
distId, getNewColor())).first;
|
||||
}
|
||||
setColor (facNodes[i], it->second);
|
||||
}
|
||||
@ -163,7 +219,7 @@ CountingBp::setInitialColors (void)
|
||||
|
||||
|
||||
void
|
||||
CountingBp::createGroups (void)
|
||||
CountingBp::createGroups()
|
||||
{
|
||||
VarSignMap varGroups;
|
||||
FacSignMap facGroups;
|
||||
@ -179,10 +235,11 @@ CountingBp::createGroups (void)
|
||||
size_t prevVarGroupsSize = varGroups.size();
|
||||
varGroups.clear();
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
const VarSignature& signature = getSignature (varNodes[i]);
|
||||
VarSignature signature = getSignature (varNodes[i]);
|
||||
VarSignMap::iterator it = varGroups.find (signature);
|
||||
if (it == varGroups.end()) {
|
||||
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||
it = varGroups.insert (std::make_pair (
|
||||
signature, VarNodes())).first;
|
||||
}
|
||||
it->second.push_back (varNodes[i]);
|
||||
}
|
||||
@ -199,10 +256,11 @@ CountingBp::createGroups (void)
|
||||
facGroups.clear();
|
||||
// set a new color to the factors with the same signature
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
const FacSignature& signature = getSignature (facNodes[i]);
|
||||
FacSignature signature = getSignature (facNodes[i]);
|
||||
FacSignMap::iterator it = facGroups.find (signature);
|
||||
if (it == facGroups.end()) {
|
||||
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||
it = facGroups.insert (std::make_pair (
|
||||
signature, FacNodes())).first;
|
||||
}
|
||||
it->second.push_back (facNodes[i]);
|
||||
}
|
||||
@ -235,7 +293,8 @@ CountingBp::createClusters (
|
||||
const VarNodes& groupVars = it->second;
|
||||
VarCluster* vc = new VarCluster (groupVars);
|
||||
for (size_t i = 0; i < groupVars.size(); i++) {
|
||||
varClusterMap_.insert (make_pair (groupVars[i]->varId(), vc));
|
||||
varClusterMap_.insert (std::make_pair (
|
||||
groupVars[i]->varId(), vc));
|
||||
}
|
||||
varClusters_.push_back (vc);
|
||||
}
|
||||
@ -257,29 +316,29 @@ CountingBp::createClusters (
|
||||
|
||||
|
||||
|
||||
VarSignature
|
||||
CountingBp::VarSignature
|
||||
CountingBp::getSignature (const VarNode* varNode)
|
||||
{
|
||||
const FacNodes& neighs = varNode->neighbors();
|
||||
VarSignature sign;
|
||||
const FacNodes& neighs = varNode->neighbors();
|
||||
sign.reserve (neighs.size() + 1);
|
||||
for (size_t i = 0; i < neighs.size(); i++) {
|
||||
sign.push_back (make_pair (
|
||||
sign.push_back (std::make_pair (
|
||||
getColor (neighs[i]),
|
||||
neighs[i]->factor().indexOf (varNode->varId())));
|
||||
}
|
||||
std::sort (sign.begin(), sign.end());
|
||||
sign.push_back (make_pair (getColor (varNode), 0));
|
||||
sign.push_back (std::make_pair (getColor (varNode), 0));
|
||||
return sign;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FacSignature
|
||||
CountingBp::FacSignature
|
||||
CountingBp::getSignature (const FacNode* facNode)
|
||||
{
|
||||
const VarNodes& neighs = facNode->neighbors();
|
||||
FacSignature sign;
|
||||
const VarNodes& neighs = facNode->neighbors();
|
||||
sign.reserve (neighs.size() + 1);
|
||||
for (size_t i = 0; i < neighs.size(); i++) {
|
||||
sign.push_back (getColor (neighs[i]));
|
||||
@ -314,7 +373,7 @@ CountingBp::getRepresentative (FacNode* fn)
|
||||
|
||||
|
||||
FactorGraph*
|
||||
CountingBp::getCompressedFactorGraph (void)
|
||||
CountingBp::getCompressedFactorGraph()
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
for (size_t i = 0; i < varClusters_.size(); i++) {
|
||||
@ -342,10 +401,10 @@ CountingBp::getCompressedFactorGraph (void)
|
||||
|
||||
|
||||
|
||||
vector<vector<unsigned>>
|
||||
CountingBp::getWeights (void) const
|
||||
std::vector<std::vector<unsigned>>
|
||||
CountingBp::getWeights() const
|
||||
{
|
||||
vector<vector<unsigned>> weights;
|
||||
std::vector<std::vector<unsigned>> weights;
|
||||
weights.reserve (facClusters_.size());
|
||||
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||
const VarClusters& neighs = facClusters_[i]->varClusters();
|
||||
@ -390,32 +449,34 @@ CountingBp::printGroups (
|
||||
const FacSignMap& facGroups) const
|
||||
{
|
||||
unsigned count = 1;
|
||||
cout << "variable groups:" << endl;
|
||||
std::cout << "variable groups:" << std::endl;
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); ++it) {
|
||||
const VarNodes& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << count << ": " ;
|
||||
std::cout << count << ": " ;
|
||||
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->label() << " " ;
|
||||
std::cout << groupMembers[i]->label() << " " ;
|
||||
}
|
||||
count ++;
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
count = 1;
|
||||
cout << endl << "factor groups:" << endl;
|
||||
std::cout << std::endl << "factor groups:" << std::endl;
|
||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||
it != facGroups.end(); ++it) {
|
||||
const FacNodes& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << ++count << ": " ;
|
||||
std::cout << ++count << ": " ;
|
||||
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->getLabel() << " " ;
|
||||
std::cout << groupMembers[i]->getLabel() << " " ;
|
||||
}
|
||||
count ++;
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,155 +1,99 @@
|
||||
#ifndef HORUS_COUNTINGBP_H
|
||||
#define HORUS_COUNTINGBP_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "GroundSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class VarCluster;
|
||||
class FacCluster;
|
||||
class WeightedBp;
|
||||
|
||||
typedef long Color;
|
||||
typedef vector<Color> Colors;
|
||||
typedef vector<std::pair<Color,unsigned>> VarSignature;
|
||||
typedef vector<Color> FacSignature;
|
||||
|
||||
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||
typedef unordered_map<unsigned, Colors> VarColorMap;
|
||||
|
||||
typedef unordered_map<VarSignature, VarNodes> VarSignMap;
|
||||
typedef unordered_map<FacSignature, FacNodes> FacSignMap;
|
||||
|
||||
typedef unordered_map<VarId, VarCluster*> VarClusterMap;
|
||||
|
||||
typedef vector<VarCluster*> VarClusters;
|
||||
typedef vector<FacCluster*> FacClusters;
|
||||
|
||||
template <class T>
|
||||
inline size_t hash_combine (size_t seed, const T& v)
|
||||
template <class T> inline size_t
|
||||
hash_combine (size_t seed, const T& v)
|
||||
{
|
||||
return seed ^ (hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
|
||||
return seed ^ (std::hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
||||
namespace std {
|
||||
template <typename T1, typename T2> struct hash<std::pair<T1,T2>>
|
||||
{
|
||||
size_t operator() (const std::pair<T1,T2>& p) const
|
||||
{
|
||||
return hash_combine (std::hash<T1>()(p.first), p.second);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct hash<std::vector<T>>
|
||||
{
|
||||
size_t operator() (const std::vector<T>& vec) const
|
||||
{
|
||||
size_t h = 0;
|
||||
typename vector<T>::const_iterator first = vec.begin();
|
||||
typename vector<T>::const_iterator last = vec.end();
|
||||
for (; first != last; ++first) {
|
||||
h = hash_combine (h, *first);
|
||||
}
|
||||
return h;
|
||||
}
|
||||
};
|
||||
}
|
||||
template <typename T1, typename T2> struct hash<std::pair<T1,T2>> {
|
||||
size_t operator() (const std::pair<T1,T2>& p) const {
|
||||
return Horus::hash_combine (std::hash<T1>()(p.first), p.second);
|
||||
}};
|
||||
|
||||
|
||||
class VarCluster
|
||||
template <typename T> struct hash<std::vector<T>>
|
||||
{
|
||||
public:
|
||||
VarCluster (const VarNodes& vs) : members_(vs) { }
|
||||
|
||||
const VarNode* first (void) const { return members_.front(); }
|
||||
|
||||
const VarNodes& members (void) const { return members_; }
|
||||
|
||||
VarNode* representative (void) const { return repr_; }
|
||||
|
||||
void setRepresentative (VarNode* vn) { repr_ = vn; }
|
||||
|
||||
private:
|
||||
VarNodes members_;
|
||||
VarNode* repr_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (VarCluster);
|
||||
size_t operator() (const std::vector<T>& vec) const
|
||||
{
|
||||
size_t h = 0;
|
||||
typename std::vector<T>::const_iterator first = vec.begin();
|
||||
typename std::vector<T>::const_iterator last = vec.end();
|
||||
for (; first != last; ++first) {
|
||||
h = Horus::hash_combine (h, *first);
|
||||
}
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class FacCluster
|
||||
{
|
||||
public:
|
||||
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
|
||||
: members_(fcs), varClusters_(vcs) { }
|
||||
|
||||
const FacNode* first (void) const { return members_.front(); }
|
||||
|
||||
const FacNodes& members (void) const { return members_; }
|
||||
|
||||
FacNode* representative (void) const { return repr_; }
|
||||
|
||||
void setRepresentative (FacNode* fn) { repr_ = fn; }
|
||||
|
||||
VarClusters& varClusters (void) { return varClusters_; }
|
||||
|
||||
private:
|
||||
FacNodes members_;
|
||||
FacNode* repr_;
|
||||
VarClusters varClusters_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (FacCluster);
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
|
||||
class CountingBp : public GroundSolver
|
||||
{
|
||||
namespace Horus {
|
||||
|
||||
class CountingBp : public GroundSolver {
|
||||
public:
|
||||
CountingBp (const FactorGraph& fg);
|
||||
|
||||
~CountingBp (void);
|
||||
~CountingBp();
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
void printSolverFlags() const;
|
||||
|
||||
Params solveQuery (VarIds);
|
||||
|
||||
static void setFindIdenticalFactorsFlag (bool fif) { fif_ = fif; }
|
||||
|
||||
private:
|
||||
Color getNewColor (void)
|
||||
{
|
||||
++ freeColor_;
|
||||
return freeColor_ - 1;
|
||||
}
|
||||
typedef long Color;
|
||||
typedef std::vector<Color> Colors;
|
||||
|
||||
Color getColor (const VarNode* vn) const
|
||||
{
|
||||
return varColors_[vn->getIndex()];
|
||||
}
|
||||
typedef std::vector<std::pair<Color,unsigned>> VarSignature;
|
||||
typedef std::vector<Color> FacSignature;
|
||||
|
||||
Color getColor (const FacNode* fn) const
|
||||
{
|
||||
return facColors_[fn->getIndex()];
|
||||
}
|
||||
typedef std::vector<VarCluster*> VarClusters;
|
||||
typedef std::vector<FacCluster*> FacClusters;
|
||||
|
||||
void setColor (const VarNode* vn, Color c)
|
||||
{
|
||||
varColors_[vn->getIndex()] = c;
|
||||
}
|
||||
typedef std::unordered_map<unsigned, Color> DistColorMap;
|
||||
typedef std::unordered_map<unsigned, Colors> VarColorMap;
|
||||
typedef std::unordered_map<VarSignature, VarNodes> VarSignMap;
|
||||
typedef std::unordered_map<FacSignature, FacNodes> FacSignMap;
|
||||
typedef std::unordered_map<VarId, VarCluster*> VarClusterMap;
|
||||
|
||||
void setColor (const FacNode* fn, Color c)
|
||||
{
|
||||
facColors_[fn->getIndex()] = c;
|
||||
}
|
||||
Color getNewColor();
|
||||
|
||||
void findIdenticalFactors (void);
|
||||
Color getColor (const VarNode* vn) const;
|
||||
|
||||
void setInitialColors (void);
|
||||
Color getColor (const FacNode* fn) const;
|
||||
|
||||
void createGroups (void);
|
||||
void setColor (const VarNode* vn, Color c);
|
||||
|
||||
void setColor (const FacNode* fn, Color c);
|
||||
|
||||
void findIdenticalFactors();
|
||||
|
||||
void setInitialColors();
|
||||
|
||||
void createGroups();
|
||||
|
||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||
|
||||
@ -163,12 +107,12 @@ class CountingBp : public GroundSolver
|
||||
|
||||
FacNode* getRepresentative (FacNode*);
|
||||
|
||||
FactorGraph* getCompressedFactorGraph (void);
|
||||
FactorGraph* getCompressedFactorGraph();
|
||||
|
||||
vector<vector<unsigned>> getWeights (void) const;
|
||||
std::vector<std::vector<unsigned>> getWeights() const;
|
||||
|
||||
unsigned getWeight (const FacCluster*,
|
||||
const VarCluster*, size_t index) const;
|
||||
unsigned getWeight (const FacCluster*, const VarCluster*,
|
||||
size_t index) const;
|
||||
|
||||
Color freeColor_;
|
||||
Colors varColors_;
|
||||
@ -184,5 +128,48 @@ class CountingBp : public GroundSolver
|
||||
DISALLOW_COPY_AND_ASSIGN (CountingBp);
|
||||
};
|
||||
|
||||
#endif // HORUS_COUNTINGBP_H
|
||||
|
||||
|
||||
inline CountingBp::Color
|
||||
CountingBp::getNewColor()
|
||||
{
|
||||
++ freeColor_;
|
||||
return freeColor_ - 1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline CountingBp::Color
|
||||
CountingBp::getColor (const VarNode* vn) const
|
||||
{
|
||||
return varColors_[vn->getIndex()];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline CountingBp::Color
|
||||
CountingBp::getColor (const FacNode* fn) const
|
||||
{
|
||||
return facColors_[fn->getIndex()];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
CountingBp::setColor (const VarNode* vn, CountingBp::Color c)
|
||||
{
|
||||
varColors_[vn->getIndex()] = c;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
CountingBp::setColor (const FacNode* fn, Color c)
|
||||
{
|
||||
facColors_[fn->getIndex()] = c;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_
|
||||
|
||||
|
@ -1,25 +1,30 @@
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
#include "ElimGraph.h"
|
||||
|
||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
ElimGraph::ElimHeuristic ElimGraph::elimHeuristic_ =
|
||||
ElimHeuristic::minNeighborsEh;
|
||||
|
||||
|
||||
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||
ElimGraph::ElimGraph (const std::vector<Factor*>& factors)
|
||||
{
|
||||
for (size_t i = 0; i < factors.size(); i++) {
|
||||
if (factors[i]) {
|
||||
const VarIds& args = factors[i]->arguments();
|
||||
for (size_t j = 0; j < args.size() - 1; j++) {
|
||||
EgNode* n1 = getEgNode (args[j]);
|
||||
EGNode* n1 = getEGNode (args[j]);
|
||||
if (!n1) {
|
||||
n1 = new EgNode (args[j], factors[i]->range (j));
|
||||
n1 = new EGNode (args[j], factors[i]->range (j));
|
||||
addNode (n1);
|
||||
}
|
||||
for (size_t k = j + 1; k < args.size(); k++) {
|
||||
EgNode* n2 = getEgNode (args[k]);
|
||||
EGNode* n2 = getEGNode (args[k]);
|
||||
if (!n2) {
|
||||
n2 = new EgNode (args[k], factors[i]->range (k));
|
||||
n2 = new EGNode (args[k], factors[i]->range (k));
|
||||
addNode (n2);
|
||||
}
|
||||
if (!neighbors (n1, n2)) {
|
||||
@ -27,8 +32,8 @@ ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (args.size() == 1 && !getEgNode (args[0])) {
|
||||
addNode (new EgNode (args[0], factors[i]->range (0)));
|
||||
if (args.size() == 1 && !getEGNode (args[0])) {
|
||||
addNode (new EGNode (args[0], factors[i]->range (0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -36,7 +41,7 @@ ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||
|
||||
|
||||
|
||||
ElimGraph::~ElimGraph (void)
|
||||
ElimGraph::~ElimGraph()
|
||||
{
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
delete nodes_[i];
|
||||
@ -57,7 +62,7 @@ ElimGraph::getEliminatingOrder (const VarIds& excludedVids)
|
||||
}
|
||||
size_t nrVarsToEliminate = nodes_.size() - excludedVids.size();
|
||||
for (size_t i = 0; i < nrVarsToEliminate; i++) {
|
||||
EgNode* node = getLowestCostNode();
|
||||
EGNode* node = getLowestCostNode();
|
||||
unmarked_.remove (node);
|
||||
const EGNeighs& neighs = node->neighbors();
|
||||
for (size_t j = 0; j < neighs.size(); j++) {
|
||||
@ -72,15 +77,15 @@ ElimGraph::getEliminatingOrder (const VarIds& excludedVids)
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::print (void) const
|
||||
ElimGraph::print() const
|
||||
{
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||
std::cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||
EGNeighs neighs = nodes_[i]->neighbors();
|
||||
for (size_t j = 0; j < neighs.size(); j++) {
|
||||
cout << " " << neighs[j]->label();
|
||||
std::cout << " " << neighs[j]->label();
|
||||
}
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -92,25 +97,27 @@ ElimGraph::exportToGraphViz (
|
||||
bool showNeighborless,
|
||||
const VarIds& highlightVarIds) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << "strict graph {" << endl;
|
||||
out << "strict graph {" << std::endl;
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
if (showNeighborless || nodes_[i]->neighbors().empty() == false) {
|
||||
out << '"' << nodes_[i]->label() << '"' << endl;
|
||||
out << '"' << nodes_[i]->label() << '"' << std::endl;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < highlightVarIds.size(); i++) {
|
||||
EgNode* node =getEgNode (highlightVarIds[i]);
|
||||
EGNode* node =getEGNode (highlightVarIds[i]);
|
||||
if (node) {
|
||||
out << '"' << node->label() << '"' ;
|
||||
out << " [shape=box3d]" << endl;
|
||||
out << " [shape=box3d]" << std::endl;
|
||||
} else {
|
||||
cerr << "Error: invalid variable id: " << highlightVarIds[i] << "." ;
|
||||
cerr << endl;
|
||||
std::cerr << "Error: invalid variable id: " ;
|
||||
std::cerr << highlightVarIds[i] << "." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
@ -118,10 +125,10 @@ ElimGraph::exportToGraphViz (
|
||||
EGNeighs neighs = nodes_[i]->neighbors();
|
||||
for (size_t j = 0; j < neighs.size(); j++) {
|
||||
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
||||
out << '"' << neighs[j]->label() << '"' << endl;
|
||||
out << '"' << neighs[j]->label() << '"' << std::endl;
|
||||
}
|
||||
}
|
||||
out << "}" << endl;
|
||||
out << "}" << std::endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
@ -132,7 +139,7 @@ ElimGraph::getEliminationOrder (
|
||||
const Factors& factors,
|
||||
VarIds excludedVids)
|
||||
{
|
||||
if (elimHeuristic_ == ElimHeuristic::SEQUENTIAL) {
|
||||
if (elimHeuristic_ == ElimHeuristic::sequentialEh) {
|
||||
VarIds allVids;
|
||||
Factors::const_iterator first = factors.begin();
|
||||
Factors::const_iterator end = factors.end();
|
||||
@ -150,33 +157,33 @@ ElimGraph::getEliminationOrder (
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::addNode (EgNode* n)
|
||||
ElimGraph::addNode (EGNode* n)
|
||||
{
|
||||
nodes_.push_back (n);
|
||||
n->setIndex (nodes_.size() - 1);
|
||||
varMap_.insert (make_pair (n->varId(), n));
|
||||
varMap_.insert (std::make_pair (n->varId(), n));
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getEgNode (VarId vid) const
|
||||
ElimGraph::EGNode*
|
||||
ElimGraph::getEGNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId, EgNode*>::const_iterator it;
|
||||
std::unordered_map<VarId, EGNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return (it != varMap_.end()) ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getLowestCostNode (void) const
|
||||
ElimGraph::EGNode*
|
||||
ElimGraph::getLowestCostNode() const
|
||||
{
|
||||
EgNode* bestNode = 0;
|
||||
EGNode* bestNode = 0;
|
||||
unsigned minCost = Util::maxUnsigned();
|
||||
EGNeighs::const_iterator it;
|
||||
switch (elimHeuristic_) {
|
||||
case MIN_NEIGHBORS: {
|
||||
case ElimHeuristic::minNeighborsEh: {
|
||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||
unsigned cost = getNeighborsCost (*it);
|
||||
if (cost < minCost) {
|
||||
@ -185,7 +192,7 @@ ElimGraph::getLowestCostNode (void) const
|
||||
}
|
||||
}}
|
||||
break;
|
||||
case MIN_WEIGHT: {
|
||||
case ElimHeuristic::minWeightEh: {
|
||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||
unsigned cost = getWeightCost (*it);
|
||||
if (cost < minCost) {
|
||||
@ -194,7 +201,7 @@ ElimGraph::getLowestCostNode (void) const
|
||||
}
|
||||
}}
|
||||
break;
|
||||
case MIN_FILL: {
|
||||
case ElimHeuristic::minFillEh: {
|
||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||
unsigned cost = getFillCost (*it);
|
||||
if (cost < minCost) {
|
||||
@ -203,7 +210,7 @@ ElimGraph::getLowestCostNode (void) const
|
||||
}
|
||||
}}
|
||||
break;
|
||||
case WEIGHTED_MIN_FILL: {
|
||||
case ElimHeuristic::weightedMinFillEh: {
|
||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||
unsigned cost = getWeightedFillCost (*it);
|
||||
if (cost < minCost) {
|
||||
@ -222,7 +229,7 @@ ElimGraph::getLowestCostNode (void) const
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::connectAllNeighbors (const EgNode* n)
|
||||
ElimGraph::connectAllNeighbors (const EGNode* n)
|
||||
{
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
@ -236,3 +243,5 @@ ElimGraph::connectAllNeighbors (const EgNode* n)
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,143 +1,177 @@
|
||||
#ifndef HORUS_ELIMGRAPH_H
|
||||
#define HORUS_ELIMGRAPH_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_ELIMGRAPH_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_ELIMGRAPH_H_
|
||||
|
||||
#include "unordered_map"
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "TinySet.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
enum ElimHeuristic
|
||||
{
|
||||
SEQUENTIAL,
|
||||
MIN_NEIGHBORS,
|
||||
MIN_WEIGHT,
|
||||
MIN_FILL,
|
||||
WEIGHTED_MIN_FILL
|
||||
};
|
||||
namespace Horus {
|
||||
|
||||
|
||||
class EgNode;
|
||||
|
||||
typedef TinySet<EgNode*> EGNeighs;
|
||||
|
||||
|
||||
class EgNode : public Var
|
||||
{
|
||||
class ElimGraph {
|
||||
public:
|
||||
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
|
||||
enum class ElimHeuristic {
|
||||
sequentialEh,
|
||||
minNeighborsEh,
|
||||
minWeightEh,
|
||||
minFillEh,
|
||||
weightedMinFillEh
|
||||
};
|
||||
|
||||
void addNeighbor (EgNode* n) { neighs_.insert (n); }
|
||||
|
||||
void removeNeighbor (EgNode* n) { neighs_.remove (n); }
|
||||
|
||||
bool isNeighbor (EgNode* n) const { return neighs_.contains (n); }
|
||||
|
||||
const EGNeighs& neighbors (void) const { return neighs_; }
|
||||
|
||||
private:
|
||||
EGNeighs neighs_;
|
||||
};
|
||||
|
||||
|
||||
class ElimGraph
|
||||
{
|
||||
public:
|
||||
ElimGraph (const Factors&);
|
||||
|
||||
~ElimGraph (void);
|
||||
~ElimGraph();
|
||||
|
||||
VarIds getEliminatingOrder (const VarIds&);
|
||||
|
||||
void print (void) const;
|
||||
void print() const;
|
||||
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIds& = VarIds()) const;
|
||||
|
||||
static VarIds getEliminationOrder (const Factors&, VarIds);
|
||||
|
||||
static ElimHeuristic elimHeuristic (void) { return elimHeuristic_; }
|
||||
static ElimHeuristic elimHeuristic() { return elimHeuristic_; }
|
||||
|
||||
static void setElimHeuristic (ElimHeuristic eh) { elimHeuristic_ = eh; }
|
||||
|
||||
private:
|
||||
void addEdge (EgNode* n1, EgNode* n2)
|
||||
{
|
||||
assert (n1 != n2);
|
||||
n1->addNeighbor (n2);
|
||||
n2->addNeighbor (n1);
|
||||
}
|
||||
class EGNode;
|
||||
|
||||
unsigned getNeighborsCost (const EgNode* n) const
|
||||
{
|
||||
return n->neighbors().size();
|
||||
}
|
||||
typedef TinySet<EGNode*> EGNeighs;
|
||||
|
||||
unsigned getWeightCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 1;
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
for (size_t i = 0; i < neighs.size(); i++) {
|
||||
cost *= neighs[i]->range();
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
class EGNode : public Var {
|
||||
public:
|
||||
EGNode (VarId vid, unsigned range) : Var (vid, range) { }
|
||||
|
||||
unsigned getFillCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (size_t i = 0; i < neighs.size() - 1; i++) {
|
||||
for (size_t j = i + 1; j < neighs.size(); j++) {
|
||||
if ( ! neighbors (neighs[i], neighs[j])) {
|
||||
cost ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
void addNeighbor (EGNode* n) { neighs_.insert (n); }
|
||||
|
||||
unsigned getWeightedFillCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (size_t i = 0; i < neighs.size() - 1; i++) {
|
||||
for (size_t j = i + 1; j < neighs.size(); j++) {
|
||||
if ( ! neighbors (neighs[i], neighs[j])) {
|
||||
cost += neighs[i]->range() * neighs[j]->range();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
void removeNeighbor (EGNode* n) { neighs_.remove (n); }
|
||||
|
||||
bool neighbors (EgNode* n1, EgNode* n2) const
|
||||
{
|
||||
return n1->isNeighbor (n2);
|
||||
}
|
||||
bool isNeighbor (EGNode* n) const { return neighs_.contains (n); }
|
||||
|
||||
void addNode (EgNode*);
|
||||
const EGNeighs& neighbors() const { return neighs_; }
|
||||
|
||||
EgNode* getEgNode (VarId) const;
|
||||
private:
|
||||
EGNeighs neighs_;
|
||||
};
|
||||
|
||||
EgNode* getLowestCostNode (void) const;
|
||||
void addEdge (EGNode* n1, EGNode* n2);
|
||||
|
||||
void connectAllNeighbors (const EgNode*);
|
||||
unsigned getNeighborsCost (const EGNode* n) const;
|
||||
|
||||
vector<EgNode*> nodes_;
|
||||
TinySet<EgNode*> unmarked_;
|
||||
unordered_map<VarId, EgNode*> varMap_;
|
||||
unsigned getWeightCost (const EGNode* n) const;
|
||||
|
||||
unsigned getFillCost (const EGNode* n) const;
|
||||
|
||||
unsigned getWeightedFillCost (const EGNode* n) const;
|
||||
|
||||
bool neighbors (EGNode* n1, EGNode* n2) const;
|
||||
|
||||
void addNode (EGNode*);
|
||||
|
||||
EGNode* getEGNode (VarId) const;
|
||||
|
||||
EGNode* getLowestCostNode() const;
|
||||
|
||||
void connectAllNeighbors (const EGNode*);
|
||||
|
||||
std::vector<EGNode*> nodes_;
|
||||
EGNeighs unmarked_;
|
||||
std::unordered_map<VarId, EGNode*> varMap_;
|
||||
|
||||
static ElimHeuristic elimHeuristic_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (ElimGraph);
|
||||
};
|
||||
|
||||
#endif // HORUS_ELIMGRAPH_H
|
||||
|
||||
|
||||
/* Profiling shows that we should inline the following functions */
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
ElimGraph::addEdge (EGNode* n1, EGNode* n2)
|
||||
{
|
||||
assert (n1 != n2);
|
||||
n1->addNeighbor (n2);
|
||||
n2->addNeighbor (n1);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
ElimGraph::getNeighborsCost (const EGNode* n) const
|
||||
{
|
||||
return n->neighbors().size();
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
ElimGraph::getWeightCost (const EGNode* n) const
|
||||
{
|
||||
unsigned cost = 1;
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
for (size_t i = 0; i < neighs.size(); i++) {
|
||||
cost *= neighs[i]->range();
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
ElimGraph::getFillCost (const EGNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (size_t i = 0; i < neighs.size() - 1; i++) {
|
||||
for (size_t j = i + 1; j < neighs.size(); j++) {
|
||||
if ( ! neighbors (neighs[i], neighs[j])) {
|
||||
cost ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
ElimGraph::getWeightedFillCost (const EGNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (size_t i = 0; i < neighs.size() - 1; i++) {
|
||||
for (size_t j = i + 1; j < neighs.size(); j++) {
|
||||
if ( ! neighbors (neighs[i], neighs[j])) {
|
||||
cost += neighs[i]->range() * neighs[j]->range();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
ElimGraph::neighbors (EGNode* n1, EGNode* n2) const
|
||||
{
|
||||
return n1->isNeighbor (n2);
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_ELIMGRAPH_H_
|
||||
|
||||
|
@ -1,21 +1,15 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "Factor.h"
|
||||
#include "Indexer.h"
|
||||
#include "Var.h"
|
||||
|
||||
|
||||
Factor::Factor (const Factor& g)
|
||||
{
|
||||
clone (g);
|
||||
}
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
Factor::Factor (
|
||||
const VarIds& vids,
|
||||
@ -77,7 +71,7 @@ Factor::sumOutAllExcept (VarId vid)
|
||||
void
|
||||
Factor::sumOutAllExcept (const VarIds& vids)
|
||||
{
|
||||
vector<bool> mask (args_.size(), false);
|
||||
std::vector<bool> mask (args_.size(), false);
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
assert (indexOf (vids[i]) != args_.size());
|
||||
mask[indexOf (vids[i])] = true;
|
||||
@ -91,28 +85,30 @@ void
|
||||
Factor::sumOutAllExceptIndex (size_t idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
vector<bool> mask (args_.size(), false);
|
||||
std::vector<bool> mask (args_.size(), false);
|
||||
mask[idx] = true;
|
||||
sumOutArgs (mask);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
Factor::multiply (Factor& g)
|
||||
|
||||
Factor&
|
||||
Factor::multiply (const Factor& g)
|
||||
{
|
||||
if (args_.empty()) {
|
||||
clone (g);
|
||||
operator= (g);
|
||||
} else {
|
||||
TFactor<VarId>::multiply (g);
|
||||
GenericFactor<VarId>::multiply (g);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
Factor::getLabel (void) const
|
||||
std::string
|
||||
Factor::getLabel() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "f(" ;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
@ -125,19 +121,19 @@ Factor::getLabel (void) const
|
||||
|
||||
|
||||
void
|
||||
Factor::print (void) const
|
||||
Factor::print() const
|
||||
{
|
||||
Vars vars;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
vars.push_back (new Var (args_[i], ranges_[i]));
|
||||
}
|
||||
vector<string> jointStrings = Util::getStateLines (vars);
|
||||
std::vector<std::string> jointStrings = Util::getStateLines (vars);
|
||||
for (size_t i = 0; i < params_.size(); i++) {
|
||||
// cout << "[" << distId_ << "] " ;
|
||||
cout << "f(" << jointStrings[i] << ")" ;
|
||||
cout << " = " << params_[i] << endl;
|
||||
std::cout << "f(" << jointStrings[i] << ")" ;
|
||||
std::cout << " = " << params_[i] << std::endl;
|
||||
}
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
for (size_t i = 0; i < vars.size(); i++) {
|
||||
delete vars[i];
|
||||
}
|
||||
@ -146,8 +142,9 @@ Factor::print (void) const
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutFirstVariable (void)
|
||||
Factor::sumOutFirstVariable()
|
||||
{
|
||||
assert (ranges_.front() == 2);
|
||||
size_t sep = params_.size() / 2;
|
||||
if (Globals::logDomain) {
|
||||
std::transform (
|
||||
@ -169,19 +166,21 @@ Factor::sumOutFirstVariable (void)
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutLastVariable (void)
|
||||
Factor::sumOutLastVariable()
|
||||
{
|
||||
assert (ranges_.back() == 2);
|
||||
Params::iterator first1 = params_.begin();
|
||||
Params::iterator first2 = params_.begin();
|
||||
Params::iterator last = params_.end();
|
||||
if (Globals::logDomain) {
|
||||
while (first2 != last) {
|
||||
// the arguments can be swaped, but that is ok
|
||||
*first1++ = Util::logSum (*first2++, *first2++);
|
||||
double tmp = *first2++;
|
||||
*first1++ = Util::logSum (tmp, *first2++);
|
||||
}
|
||||
} else {
|
||||
while (first2 != last) {
|
||||
*first1++ = (*first2++) + (*first2++);
|
||||
*first1 = *first2++;
|
||||
*first1++ += *first2++;
|
||||
}
|
||||
}
|
||||
params_.resize (params_.size() / 2);
|
||||
@ -192,7 +191,7 @@ Factor::sumOutLastVariable (void)
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutArgs (const vector<bool>& mask)
|
||||
Factor::sumOutArgs (const std::vector<bool>& mask)
|
||||
{
|
||||
assert (mask.size() == args_.size());
|
||||
size_t new_size = 1;
|
||||
@ -224,14 +223,5 @@ Factor::sumOutArgs (const vector<bool>& mask)
|
||||
params_ = newps;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::clone (const Factor& g)
|
||||
{
|
||||
args_ = g.arguments();
|
||||
ranges_ = g.ranges();
|
||||
params_ = g.params();
|
||||
distId_ = g.distId();
|
||||
}
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,262 +1,20 @@
|
||||
#ifndef HORUS_FACTOR_H
|
||||
#define HORUS_FACTOR_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_FACTOR_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_FACTOR_H_
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "Indexer.h"
|
||||
#include "GenericFactor.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace Horus {
|
||||
|
||||
|
||||
template <typename T>
|
||||
class TFactor
|
||||
{
|
||||
class Factor : public GenericFactor<VarId> {
|
||||
public:
|
||||
const vector<T>& arguments (void) const { return args_; }
|
||||
|
||||
vector<T>& arguments (void) { return args_; }
|
||||
|
||||
const Ranges& ranges (void) const { return ranges_; }
|
||||
|
||||
const Params& params (void) const { return params_; }
|
||||
|
||||
Params& params (void) { return params_; }
|
||||
|
||||
size_t nrArguments (void) const { return args_.size(); }
|
||||
|
||||
size_t size (void) const { return params_.size(); }
|
||||
|
||||
unsigned distId (void) const { return distId_; }
|
||||
|
||||
void setDistId (unsigned id) { distId_ = id; }
|
||||
|
||||
void normalize (void) { LogAware::normalize (params_); }
|
||||
|
||||
void randomize (void)
|
||||
{
|
||||
for (size_t i = 0; i < params_.size(); ++i) {
|
||||
params_[i] = (double) std::rand() / RAND_MAX;
|
||||
}
|
||||
}
|
||||
|
||||
void setParams (const Params& newParams)
|
||||
{
|
||||
params_ = newParams;
|
||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||
}
|
||||
|
||||
size_t indexOf (const T& t) const
|
||||
{
|
||||
return Util::indexOf (args_, t);
|
||||
}
|
||||
|
||||
const T& argument (size_t idx) const
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
T& argument (size_t idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
unsigned range (size_t idx) const
|
||||
{
|
||||
assert (idx < ranges_.size());
|
||||
return ranges_[idx];
|
||||
}
|
||||
|
||||
void multiply (TFactor<T>& g)
|
||||
{
|
||||
if (args_ == g.arguments()) {
|
||||
// optimization
|
||||
Globals::logDomain
|
||||
? params_ += g.params()
|
||||
: params_ *= g.params();
|
||||
return;
|
||||
}
|
||||
unsigned range_prod = 1;
|
||||
bool share_arguments = false;
|
||||
const vector<T>& g_args = g.arguments();
|
||||
const Ranges& g_ranges = g.ranges();
|
||||
const Params& g_params = g.params();
|
||||
for (size_t i = 0; i < g_args.size(); i++) {
|
||||
size_t idx = indexOf (g_args[i]);
|
||||
if (idx == args_.size()) {
|
||||
range_prod *= g_ranges[i];
|
||||
args_.push_back (g_args[i]);
|
||||
ranges_.push_back (g_ranges[i]);
|
||||
} else {
|
||||
share_arguments = true;
|
||||
}
|
||||
}
|
||||
if (share_arguments == false) {
|
||||
// optimization
|
||||
cartesianProduct (g_params.begin(), g_params.end());
|
||||
} else {
|
||||
extend (range_prod);
|
||||
Params::iterator it = params_.begin();
|
||||
MapIndexer indexer (args_, ranges_, g_args, g_ranges);
|
||||
if (Globals::logDomain) {
|
||||
for (; indexer.valid(); ++it, ++indexer) {
|
||||
*it += g_params[indexer];
|
||||
}
|
||||
} else {
|
||||
for (; indexer.valid(); ++it, ++indexer) {
|
||||
*it *= g_params[indexer];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void sumOutIndex (size_t idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
assert (args_.size() > 1);
|
||||
size_t new_size = params_.size() / ranges_[idx];
|
||||
Params newps (new_size, LogAware::addIdenty());
|
||||
Params::const_iterator first = params_.begin();
|
||||
Params::const_iterator last = params_.end();
|
||||
MapIndexer indexer (ranges_, idx);
|
||||
if (Globals::logDomain) {
|
||||
for (; first != last; ++indexer) {
|
||||
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
||||
}
|
||||
} else {
|
||||
for (; first != last; ++indexer) {
|
||||
newps[indexer] += *first++;
|
||||
}
|
||||
}
|
||||
params_ = newps;
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
void absorveEvidence (const T& arg, unsigned obsIdx)
|
||||
{
|
||||
size_t idx = indexOf (arg);
|
||||
assert (idx != args_.size());
|
||||
assert (obsIdx < ranges_[idx]);
|
||||
Params newps;
|
||||
newps.reserve (params_.size() / ranges_[idx]);
|
||||
Indexer indexer (ranges_);
|
||||
for (unsigned i = 0; i < obsIdx; ++i) {
|
||||
indexer.incrementDimension (idx);
|
||||
}
|
||||
while (indexer.valid()) {
|
||||
newps.push_back (params_[indexer]);
|
||||
indexer.incrementExceptDimension (idx);
|
||||
}
|
||||
params_ = newps;
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
void reorderArguments (const vector<T> new_args)
|
||||
{
|
||||
assert (new_args.size() == args_.size());
|
||||
if (new_args == args_) {
|
||||
return; // already on the desired order
|
||||
}
|
||||
Ranges new_ranges;
|
||||
for (size_t i = 0; i < new_args.size(); i++) {
|
||||
size_t idx = indexOf (new_args[i]);
|
||||
assert (idx != args_.size());
|
||||
new_ranges.push_back (ranges_[idx]);
|
||||
}
|
||||
Params newps;
|
||||
newps.reserve (params_.size());
|
||||
MapIndexer indexer (new_args, new_ranges, args_, ranges_);
|
||||
for (; indexer.valid(); ++indexer) {
|
||||
newps.push_back (params_[indexer]);
|
||||
}
|
||||
params_ = newps;
|
||||
args_ = new_args;
|
||||
ranges_ = new_ranges;
|
||||
}
|
||||
|
||||
bool contains (const T& arg) const
|
||||
{
|
||||
return Util::contains (args_, arg);
|
||||
}
|
||||
|
||||
bool contains (const vector<T>& args) const
|
||||
{
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (contains (args[i]) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
double& operator[] (size_t idx)
|
||||
{
|
||||
assert (idx < params_.size());
|
||||
return params_[idx];
|
||||
}
|
||||
|
||||
|
||||
protected:
|
||||
vector<T> args_;
|
||||
Ranges ranges_;
|
||||
Params params_;
|
||||
unsigned distId_;
|
||||
|
||||
private:
|
||||
void extend (unsigned range_prod)
|
||||
{
|
||||
Params backup = params_;
|
||||
params_.clear();
|
||||
params_.reserve (backup.size() * range_prod);
|
||||
Params::const_iterator first = backup.begin();
|
||||
Params::const_iterator last = backup.end();
|
||||
for (; first != last; ++first) {
|
||||
for (unsigned reps = 0; reps < range_prod; ++reps) {
|
||||
params_.push_back (*first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cartesianProduct (
|
||||
Params::const_iterator first2,
|
||||
Params::const_iterator last2)
|
||||
{
|
||||
Params backup = params_;
|
||||
params_.clear();
|
||||
params_.reserve (params_.size() * (last2 - first2));
|
||||
Params::const_iterator first1 = backup.begin();
|
||||
Params::const_iterator last1 = backup.end();
|
||||
Params::const_iterator tmp;
|
||||
if (Globals::logDomain) {
|
||||
for (; first1 != last1; ++first1) {
|
||||
for (tmp = first2; tmp != last2; ++tmp) {
|
||||
params_.push_back ((*first1) + (*tmp));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (; first1 != last1; ++first1) {
|
||||
for (tmp = first2; tmp != last2; ++tmp) {
|
||||
params_.push_back ((*first1) * (*tmp));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
class Factor : public TFactor<VarId>
|
||||
{
|
||||
public:
|
||||
Factor (void) { }
|
||||
|
||||
Factor (const Factor&);
|
||||
Factor() { }
|
||||
|
||||
Factor (const VarIds&, const Ranges&, const Params&,
|
||||
unsigned = Util::maxUnsigned());
|
||||
@ -272,23 +30,21 @@ class Factor : public TFactor<VarId>
|
||||
|
||||
void sumOutAllExceptIndex (size_t idx);
|
||||
|
||||
void multiply (Factor&);
|
||||
Factor& multiply (const Factor&);
|
||||
|
||||
string getLabel (void) const;
|
||||
std::string getLabel() const;
|
||||
|
||||
void print (void) const;
|
||||
void print() const;
|
||||
|
||||
private:
|
||||
void sumOutFirstVariable (void);
|
||||
void sumOutFirstVariable();
|
||||
|
||||
void sumOutLastVariable (void);
|
||||
void sumOutLastVariable();
|
||||
|
||||
void sumOutArgs (const vector<bool>& mask);
|
||||
|
||||
void clone (const Factor& f);
|
||||
|
||||
DISALLOW_ASSIGN (Factor);
|
||||
void sumOutArgs (const std::vector<bool>& mask);
|
||||
};
|
||||
|
||||
#endif // HORUS_FACTOR_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_FACTOR_H_
|
||||
|
||||
|
@ -1,17 +1,15 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "BayesBall.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
bool FactorGraph::exportLd_ = false;
|
||||
bool FactorGraph::exportUai_ = false;
|
||||
bool FactorGraph::exportGv_ = false;
|
||||
@ -20,25 +18,12 @@ bool FactorGraph::printFg_ = false;
|
||||
|
||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||
{
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
addVarNode (new VarNode (varNodes[i]));
|
||||
}
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
FacNode* facNode = new FacNode (facNodes[i]->factor());
|
||||
addFacNode (facNode);
|
||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||
for (size_t j = 0; j < neighs.size(); j++) {
|
||||
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||
}
|
||||
}
|
||||
bayesFactors_ = fg.bayesianFactors();
|
||||
clone (fg);
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph::~FactorGraph (void)
|
||||
FactorGraph::~FactorGraph()
|
||||
{
|
||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||
delete varNodes_[i];
|
||||
@ -50,152 +35,6 @@ FactorGraph::~FactorGraph (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
ignoreLines (is);
|
||||
string line;
|
||||
getline (is, line);
|
||||
if (line == "BAYES") {
|
||||
bayesFactors_ = true;
|
||||
} else if (line == "MARKOV") {
|
||||
bayesFactors_ = false;
|
||||
} else {
|
||||
cerr << "Error: the type of network is missing." << endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
// read the number of vars
|
||||
ignoreLines (is);
|
||||
unsigned nrVars;
|
||||
is >> nrVars;
|
||||
// read the range of each var
|
||||
ignoreLines (is);
|
||||
Ranges ranges (nrVars);
|
||||
for (unsigned i = 0; i < nrVars; i++) {
|
||||
is >> ranges[i];
|
||||
}
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
unsigned vid;
|
||||
is >> nrFactors;
|
||||
vector<VarIds> allVarIds;
|
||||
vector<Ranges> allRanges;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrArgs;
|
||||
allVarIds.push_back ({ });
|
||||
allRanges.push_back ({ });
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
is >> vid;
|
||||
if (vid >= ranges.size()) {
|
||||
cerr << "Error: invalid variable identifier `" << vid << "'. " ;
|
||||
cerr << "Identifiers must be between 0 and " << ranges.size() - 1 ;
|
||||
cerr << "." << endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
allVarIds.back().push_back (vid);
|
||||
allRanges.back().push_back (ranges[vid]);
|
||||
}
|
||||
}
|
||||
// read the parameters
|
||||
unsigned nrParams;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrParams;
|
||||
if (nrParams != Util::sizeExpected (allRanges[i])) {
|
||||
cerr << "Error: invalid number of parameters for factor nº " << i ;
|
||||
cerr << ", " << Util::sizeExpected (allRanges[i]);
|
||||
cerr << " expected, " << nrParams << " given." << endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
Params params (nrParams);
|
||||
for (unsigned j = 0; j < nrParams; j++) {
|
||||
is >> params[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
Factor f (allVarIds[i], allRanges[i], params);
|
||||
if (bayesFactors_ && allVarIds[i].size() > 1) {
|
||||
// In this format the child is the last variable,
|
||||
// move it to be the first
|
||||
std::swap (allVarIds[i].front(), allVarIds[i].back());
|
||||
f.reorderArguments (allVarIds[i]);
|
||||
}
|
||||
addFactor (f);
|
||||
}
|
||||
is.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
ignoreLines (is);
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
VarId vid;
|
||||
is >> nrFactors;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
// read the factor arguments
|
||||
is >> nrArgs;
|
||||
VarIds vids;
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> vid;
|
||||
vids.push_back (vid);
|
||||
}
|
||||
// read ranges
|
||||
Ranges ranges (nrArgs);
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> ranges[j];
|
||||
VarNode* var = getVarNode (vids[j]);
|
||||
if (var && ranges[j] != var->range()) {
|
||||
cerr << "Error: variable `" << vids[j] << "' appears in two or " ;
|
||||
cerr << "more factors with a different range." << endl;
|
||||
}
|
||||
}
|
||||
// read parameters
|
||||
ignoreLines (is);
|
||||
unsigned nNonzeros;
|
||||
is >> nNonzeros;
|
||||
Params params (Util::sizeExpected (ranges), 0);
|
||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||
ignoreLines (is);
|
||||
unsigned index;
|
||||
is >> index;
|
||||
ignoreLines (is);
|
||||
double val;
|
||||
is >> val;
|
||||
params[index] = val;
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
std::reverse (vids.begin(), vids.end());
|
||||
Factor f (vids, ranges, params);
|
||||
std::reverse (vids.begin(), vids.end());
|
||||
f.reorderArguments (vids);
|
||||
addFactor (f);
|
||||
}
|
||||
is.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (const Factor& factor)
|
||||
{
|
||||
@ -221,7 +60,7 @@ FactorGraph::addVarNode (VarNode* vn)
|
||||
{
|
||||
varNodes_.push_back (vn);
|
||||
vn->setIndex (varNodes_.size() - 1);
|
||||
varMap_.insert (make_pair (vn->varId(), vn));
|
||||
varMap_.insert (std::make_pair (vn->varId(), vn));
|
||||
}
|
||||
|
||||
|
||||
@ -245,7 +84,7 @@ FactorGraph::addEdge (VarNode* vn, FacNode* fn)
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::isTree (void) const
|
||||
FactorGraph::isTree() const
|
||||
{
|
||||
return !containsCycle();
|
||||
}
|
||||
@ -253,7 +92,7 @@ FactorGraph::isTree (void) const
|
||||
|
||||
|
||||
BayesBallGraph&
|
||||
FactorGraph::getStructure (void)
|
||||
FactorGraph::getStructure()
|
||||
{
|
||||
assert (bayesFactors_);
|
||||
if (structure_.empty()) {
|
||||
@ -273,8 +112,10 @@ FactorGraph::getStructure (void)
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::print (void) const
|
||||
FactorGraph::print() const
|
||||
{
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||
cout << "var id = " << varNodes_[i]->varId() << endl;
|
||||
cout << "label = " << varNodes_[i]->label() << endl;
|
||||
@ -296,28 +137,29 @@ FactorGraph::print (void) const
|
||||
void
|
||||
FactorGraph::exportToLibDai (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << facNodes_.size() << endl << endl;
|
||||
out << facNodes_.size() << std::endl << std::endl;
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
Factor f (facNodes_[i]->factor());
|
||||
out << f.nrArguments() << endl;
|
||||
out << Util::elementsToString (f.arguments()) << endl;
|
||||
out << Util::elementsToString (f.ranges()) << endl;
|
||||
out << f.nrArguments() << std::endl;
|
||||
out << Util::elementsToString (f.arguments()) << std::endl;
|
||||
out << Util::elementsToString (f.ranges()) << std::endl;
|
||||
VarIds args = f.arguments();
|
||||
std::reverse (args.begin(), args.end());
|
||||
f.reorderArguments (args);
|
||||
if (Globals::logDomain) {
|
||||
Util::exp (f.params());
|
||||
}
|
||||
out << f.size() << endl;
|
||||
out << f.size() << std::endl;
|
||||
for (size_t j = 0; j < f.size(); j++) {
|
||||
out << j << " " << f[j] << endl;
|
||||
out << j << " " << f[j] << std::endl;
|
||||
}
|
||||
out << endl;
|
||||
out << std::endl;
|
||||
}
|
||||
out.close();
|
||||
}
|
||||
@ -327,28 +169,30 @@ FactorGraph::exportToLibDai (const char* fileName) const
|
||||
void
|
||||
FactorGraph::exportToUai (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << (bayesFactors_ ? "BAYES" : "MARKOV") ;
|
||||
out << endl << endl;
|
||||
out << varNodes_.size() << endl;
|
||||
out << std::endl << std::endl;
|
||||
out << varNodes_.size() << std::endl;
|
||||
VarNodes sortedVns = varNodes_;
|
||||
std::sort (sortedVns.begin(), sortedVns.end(), sortByVarId());
|
||||
for (size_t i = 0; i < sortedVns.size(); i++) {
|
||||
out << ((i != 0) ? " " : "") << sortedVns[i]->range();
|
||||
}
|
||||
out << endl << facNodes_.size() << endl;
|
||||
out << std::endl << facNodes_.size() << std::endl;
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
VarIds args = facNodes_[i]->factor().arguments();
|
||||
if (bayesFactors_) {
|
||||
std::swap (args.front(), args.back());
|
||||
}
|
||||
out << args.size() << " " << Util::elementsToString (args) << endl;
|
||||
out << args.size() << " " << Util::elementsToString (args);
|
||||
out << std::endl;
|
||||
}
|
||||
out << endl;
|
||||
out << std::endl;
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
Factor f = facNodes_[i]->factor();
|
||||
if (bayesFactors_) {
|
||||
@ -360,8 +204,9 @@ FactorGraph::exportToUai (const char* fileName) const
|
||||
if (Globals::logDomain) {
|
||||
Util::exp (params);
|
||||
}
|
||||
out << params.size() << endl << " " ;
|
||||
out << Util::elementsToString (params) << endl << endl;
|
||||
out << params.size() << std::endl << " " ;
|
||||
out << Util::elementsToString (params);
|
||||
out << std::endl << std::endl;
|
||||
}
|
||||
out.close();
|
||||
}
|
||||
@ -371,53 +216,239 @@ FactorGraph::exportToUai (const char* fileName) const
|
||||
void
|
||||
FactorGraph::exportToGraphViz (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << "graph \"" << fileName << "\" {" << endl;
|
||||
out << "graph \"" << fileName << "\" {" << std::endl;
|
||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||
if (varNodes_[i]->hasEvidence()) {
|
||||
out << '"' << varNodes_[i]->label() << '"' ;
|
||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||
out << " [style=filled, fillcolor=yellow]" << std::endl;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||
out << " [label=\"" << facNodes_[i]->getLabel();
|
||||
out << "\"" << ", shape=box]" << endl;
|
||||
out << "\"" << ", shape=box]" << std::endl;
|
||||
}
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
const VarNodes& myVars = facNodes_[i]->neighbors();
|
||||
for (size_t j = 0; j < myVars.size(); j++) {
|
||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||
out << " -- " ;
|
||||
out << '"' << myVars[j]->label() << '"' << endl;
|
||||
out << '"' << myVars[j]->label() << '"' << std::endl;
|
||||
}
|
||||
}
|
||||
out << "}" << endl;
|
||||
out << "}" << std::endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::ignoreLines (std::ifstream& is) const
|
||||
FactorGraph&
|
||||
FactorGraph::operator= (const FactorGraph& fg)
|
||||
{
|
||||
string ignoreStr;
|
||||
while (is.peek() == '#' || is.peek() == '\n') {
|
||||
getline (is, ignoreStr);
|
||||
if (this != &fg) {
|
||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||
delete varNodes_[i];
|
||||
}
|
||||
varNodes_.clear();
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
delete facNodes_[i];
|
||||
}
|
||||
facNodes_.clear();
|
||||
varMap_.clear();
|
||||
clone (fg);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph
|
||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
FactorGraph fg;
|
||||
ignoreLines (is);
|
||||
std::string line;
|
||||
getline (is, line);
|
||||
if (line == "BAYES") {
|
||||
fg.bayesFactors_ = true;
|
||||
} else if (line == "MARKOV") {
|
||||
fg.bayesFactors_ = false;
|
||||
} else {
|
||||
std::cerr << "Error: the type of network is missing." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
// read the number of vars
|
||||
ignoreLines (is);
|
||||
unsigned nrVars;
|
||||
is >> nrVars;
|
||||
// read the range of each var
|
||||
ignoreLines (is);
|
||||
Ranges ranges (nrVars);
|
||||
for (unsigned i = 0; i < nrVars; i++) {
|
||||
is >> ranges[i];
|
||||
}
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
unsigned vid;
|
||||
is >> nrFactors;
|
||||
std::vector<VarIds> allVarIds;
|
||||
std::vector<Ranges> allRanges;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrArgs;
|
||||
allVarIds.push_back ({ });
|
||||
allRanges.push_back ({ });
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
is >> vid;
|
||||
if (vid >= ranges.size()) {
|
||||
std::cerr << "Error: invalid variable identifier `" << vid << "'" ;
|
||||
std::cerr << ". Identifiers must be between 0 and " ;
|
||||
std::cerr << ranges.size() - 1 << "." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
allVarIds.back().push_back (vid);
|
||||
allRanges.back().push_back (ranges[vid]);
|
||||
}
|
||||
}
|
||||
// read the parameters
|
||||
unsigned nrParams;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrParams;
|
||||
if (nrParams != Util::sizeExpected (allRanges[i])) {
|
||||
std::cerr << "Error: invalid number of parameters for factor nº " ;
|
||||
std::cerr << i << ", " << Util::sizeExpected (allRanges[i]);
|
||||
std::cerr << " expected, " << nrParams << " given." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
Params params (nrParams);
|
||||
for (unsigned j = 0; j < nrParams; j++) {
|
||||
is >> params[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
Factor f (allVarIds[i], allRanges[i], params);
|
||||
if (fg.bayesFactors_ && allVarIds[i].size() > 1) {
|
||||
// In this format the child is the last variable,
|
||||
// move it to be the first
|
||||
std::swap (allVarIds[i].front(), allVarIds[i].back());
|
||||
f.reorderArguments (allVarIds[i]);
|
||||
}
|
||||
fg.addFactor (f);
|
||||
}
|
||||
is.close();
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph
|
||||
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
FactorGraph fg;
|
||||
ignoreLines (is);
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
VarId vid;
|
||||
is >> nrFactors;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
// read the factor arguments
|
||||
is >> nrArgs;
|
||||
VarIds vids;
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> vid;
|
||||
vids.push_back (vid);
|
||||
}
|
||||
// read ranges
|
||||
Ranges ranges (nrArgs);
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> ranges[j];
|
||||
VarNode* var = fg.getVarNode (vids[j]);
|
||||
if (var && ranges[j] != var->range()) {
|
||||
std::cerr << "Error: variable `" << vids[j] << "' appears" ;
|
||||
std::cerr << " in two or more factors with a different range." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
// read parameters
|
||||
ignoreLines (is);
|
||||
unsigned nNonzeros;
|
||||
is >> nNonzeros;
|
||||
Params params (Util::sizeExpected (ranges), 0);
|
||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||
ignoreLines (is);
|
||||
unsigned index;
|
||||
is >> index;
|
||||
ignoreLines (is);
|
||||
double val;
|
||||
is >> val;
|
||||
params[index] = val;
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
std::reverse (vids.begin(), vids.end());
|
||||
std::reverse (ranges.begin(), ranges.end());
|
||||
Factor f (vids, ranges, params);
|
||||
std::reverse (vids.begin(), vids.end());
|
||||
f.reorderArguments (vids);
|
||||
fg.addFactor (f);
|
||||
}
|
||||
is.close();
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::clone (const FactorGraph& fg)
|
||||
{
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
addVarNode (new VarNode (varNodes[i]));
|
||||
}
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
FacNode* facNode = new FacNode (facNodes[i]->factor());
|
||||
addFacNode (facNode);
|
||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||
for (size_t j = 0; j < neighs.size(); j++) {
|
||||
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||
}
|
||||
}
|
||||
bayesFactors_ = fg.bayesianFactors();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (void) const
|
||||
FactorGraph::containsCycle() const
|
||||
{
|
||||
vector<bool> visitedVars (varNodes_.size(), false);
|
||||
vector<bool> visitedFactors (facNodes_.size(), false);
|
||||
std::vector<bool> visitedVars (varNodes_.size(), false);
|
||||
std::vector<bool> visitedFactors (facNodes_.size(), false);
|
||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||
int v = varNodes_[i]->getIndex();
|
||||
if (!visitedVars[v]) {
|
||||
@ -435,8 +466,8 @@ bool
|
||||
FactorGraph::containsCycle (
|
||||
const VarNode* v,
|
||||
const FacNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
std::vector<bool>& visitedVars,
|
||||
std::vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedVars[v->getIndex()] = true;
|
||||
const FacNodes& adjacencies = v->neighbors();
|
||||
@ -460,8 +491,8 @@ bool
|
||||
FactorGraph::containsCycle (
|
||||
const FacNode* v,
|
||||
const VarNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
std::vector<bool>& visitedVars,
|
||||
std::vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedFactors[v->getIndex()] = true;
|
||||
const VarNodes& adjacencies = v->neighbors();
|
||||
@ -479,3 +510,16 @@ FactorGraph::containsCycle (
|
||||
return false; // no cycle detected in this component
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::ignoreLines (std::ifstream& is)
|
||||
{
|
||||
std::string ignoreStr;
|
||||
while (is.peek() == '#' || is.peek() == '\n') {
|
||||
getline (is, ignoreStr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,28 +1,32 @@
|
||||
#ifndef HORUS_FACTORGRAPH_H
|
||||
#define HORUS_FACTORGRAPH_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
|
||||
#include "Factor.h"
|
||||
#include "BayesBallGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class FacNode;
|
||||
|
||||
class VarNode : public Var
|
||||
{
|
||||
|
||||
class VarNode : public Var {
|
||||
public:
|
||||
VarNode (VarId varId, unsigned nrStates,
|
||||
int evidence = Constants::NO_EVIDENCE)
|
||||
int evidence = Constants::unobserved)
|
||||
: Var (varId, nrStates, evidence) { }
|
||||
|
||||
VarNode (const Var* v) : Var (v) { }
|
||||
|
||||
void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
|
||||
|
||||
const FacNodes& neighbors (void) const { return neighs_; }
|
||||
const FacNodes& neighbors() const { return neighs_; }
|
||||
|
||||
private:
|
||||
FacNodes neighs_;
|
||||
@ -32,24 +36,23 @@ class VarNode : public Var
|
||||
|
||||
|
||||
|
||||
class FacNode
|
||||
{
|
||||
class FacNode {
|
||||
public:
|
||||
FacNode (const Factor& f) : factor_(f), index_(-1) { }
|
||||
|
||||
const Factor& factor (void) const { return factor_; }
|
||||
const Factor& factor() const { return factor_; }
|
||||
|
||||
Factor& factor (void) { return factor_; }
|
||||
Factor& factor() { return factor_; }
|
||||
|
||||
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
|
||||
|
||||
const VarNodes& neighbors (void) const { return neighs_; }
|
||||
const VarNodes& neighbors() const { return neighs_; }
|
||||
|
||||
size_t getIndex (void) const { return index_; }
|
||||
size_t getIndex() const { return index_; }
|
||||
|
||||
void setIndex (size_t index) { index_ = index; }
|
||||
|
||||
string getLabel (void) { return factor_.getLabel(); }
|
||||
std::string getLabel() { return factor_.getLabel(); }
|
||||
|
||||
private:
|
||||
VarNodes neighs_;
|
||||
@ -61,36 +64,27 @@ class FacNode
|
||||
|
||||
|
||||
|
||||
class FactorGraph
|
||||
{
|
||||
class FactorGraph {
|
||||
public:
|
||||
FactorGraph (void) : bayesFactors_(false) { }
|
||||
FactorGraph() : bayesFactors_(false) { }
|
||||
|
||||
FactorGraph (const FactorGraph&);
|
||||
|
||||
~FactorGraph (void);
|
||||
~FactorGraph();
|
||||
|
||||
const VarNodes& varNodes (void) const { return varNodes_; }
|
||||
const VarNodes& varNodes() const { return varNodes_; }
|
||||
|
||||
const FacNodes& facNodes (void) const { return facNodes_; }
|
||||
const FacNodes& facNodes() const { return facNodes_; }
|
||||
|
||||
void setFactorsAsBayesian (void) { bayesFactors_ = true; }
|
||||
void setFactorsAsBayesian() { bayesFactors_ = true; }
|
||||
|
||||
bool bayesianFactors (void) const { return bayesFactors_; }
|
||||
bool bayesianFactors() const { return bayesFactors_; }
|
||||
|
||||
size_t nrVarNodes (void) const { return varNodes_.size(); }
|
||||
size_t nrVarNodes() const { return varNodes_.size(); }
|
||||
|
||||
size_t nrFacNodes (void) const { return facNodes_.size(); }
|
||||
size_t nrFacNodes() const { return facNodes_.size(); }
|
||||
|
||||
VarNode* getVarNode (VarId vid) const
|
||||
{
|
||||
VarMap::const_iterator it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
void readFromUaiFormat (const char*);
|
||||
|
||||
void readFromLibDaiFormat (const char*);
|
||||
VarNode* getVarNode (VarId vid) const;
|
||||
|
||||
void addFactor (const Factor& factor);
|
||||
|
||||
@ -100,11 +94,11 @@ class FactorGraph
|
||||
|
||||
void addEdge (VarNode*, FacNode*);
|
||||
|
||||
bool isTree (void) const;
|
||||
bool isTree() const;
|
||||
|
||||
BayesBallGraph& getStructure (void);
|
||||
BayesBallGraph& getStructure();
|
||||
|
||||
void print (void) const;
|
||||
void print() const;
|
||||
|
||||
void exportToLibDai (const char*) const;
|
||||
|
||||
@ -112,67 +106,80 @@ class FactorGraph
|
||||
|
||||
void exportToGraphViz (const char*) const;
|
||||
|
||||
static bool exportToLibDai (void) { return exportLd_; }
|
||||
FactorGraph& operator= (const FactorGraph&);
|
||||
|
||||
static bool exportToUai (void) { return exportUai_; }
|
||||
static FactorGraph readFromUaiFormat (const char*);
|
||||
|
||||
static bool exportGraphViz (void) { return exportGv_; }
|
||||
static FactorGraph readFromLibDaiFormat (const char*);
|
||||
|
||||
static bool printFactorGraph (void) { return printFg_; }
|
||||
static bool exportToLibDai() { return exportLd_; }
|
||||
|
||||
static void enableExportToLibDai (void) { exportLd_ = true; }
|
||||
static bool exportToUai() { return exportUai_; }
|
||||
|
||||
static void disableExportToLibDai (void) { exportLd_ = false; }
|
||||
static bool exportGraphViz() { return exportGv_; }
|
||||
|
||||
static void enableExportToUai (void) { exportUai_ = true; }
|
||||
static bool printFactorGraph() { return printFg_; }
|
||||
|
||||
static void disableExportToUai (void) { exportUai_ = false; }
|
||||
static void enableExportToLibDai() { exportLd_ = true; }
|
||||
|
||||
static void enableExportToGraphViz (void) { exportGv_ = true; }
|
||||
static void disableExportToLibDai() { exportLd_ = false; }
|
||||
|
||||
static void disableExportToGraphViz (void) { exportGv_ = false; }
|
||||
static void enableExportToUai() { exportUai_ = true; }
|
||||
|
||||
static void enablePrintFactorGraph (void) { printFg_ = true; }
|
||||
static void disableExportToUai() { exportUai_ = false; }
|
||||
|
||||
static void disablePrintFactorGraph (void) { printFg_ = false; }
|
||||
static void enableExportToGraphViz() { exportGv_ = true; }
|
||||
|
||||
static void disableExportToGraphViz() { exportGv_ = false; }
|
||||
|
||||
static void enablePrintFactorGraph() { printFg_ = true; }
|
||||
|
||||
static void disablePrintFactorGraph() { printFg_ = false; }
|
||||
|
||||
private:
|
||||
void ignoreLines (std::ifstream&) const;
|
||||
typedef std::unordered_map<unsigned, VarNode*> VarMap;
|
||||
|
||||
bool containsCycle (void) const;
|
||||
void clone (const FactorGraph& fg);
|
||||
|
||||
bool containsCycle() const;
|
||||
|
||||
bool containsCycle (const VarNode*, const FacNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
std::vector<bool>&, std::vector<bool>&) const;
|
||||
|
||||
bool containsCycle (const FacNode*, const VarNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
std::vector<bool>&, std::vector<bool>&) const;
|
||||
|
||||
VarNodes varNodes_;
|
||||
FacNodes facNodes_;
|
||||
static void ignoreLines (std::ifstream&);
|
||||
|
||||
VarNodes varNodes_;
|
||||
FacNodes facNodes_;
|
||||
VarMap varMap_;
|
||||
BayesBallGraph structure_;
|
||||
bool bayesFactors_;
|
||||
|
||||
typedef unordered_map<unsigned, VarNode*> VarMap;
|
||||
VarMap varMap_;
|
||||
|
||||
static bool exportLd_;
|
||||
static bool exportUai_;
|
||||
static bool exportGv_;
|
||||
static bool printFg_;
|
||||
|
||||
DISALLOW_ASSIGN (FactorGraph);
|
||||
static bool exportLd_;
|
||||
static bool exportUai_;
|
||||
static bool exportGv_;
|
||||
static bool printFg_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
struct sortByVarId
|
||||
inline VarNode*
|
||||
FactorGraph::getVarNode (VarId vid) const
|
||||
{
|
||||
VarMap::const_iterator it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
struct sortByVarId {
|
||||
bool operator()(VarNode* vn1, VarNode* vn2) {
|
||||
return vn1->varId() < vn2->varId();
|
||||
}
|
||||
};
|
||||
}};
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // HORUS_FACTORGRAPH_H
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_
|
||||
|
||||
|
256
packages/CLPBN/horus/GenericFactor.cpp
Normal file
256
packages/CLPBN/horus/GenericFactor.cpp
Normal file
@ -0,0 +1,256 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "GenericFactor.h"
|
||||
#include "ProbFormula.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
template <typename T> const T&
|
||||
GenericFactor<T>::argument (size_t idx) const
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> T&
|
||||
GenericFactor<T>::argument (size_t idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> unsigned
|
||||
GenericFactor<T>::range (size_t idx) const
|
||||
{
|
||||
assert (idx < ranges_.size());
|
||||
return ranges_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
GenericFactor<T>::contains (const T& arg) const
|
||||
{
|
||||
return Util::contains (args_, arg);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
GenericFactor<T>::contains (const std::vector<T>& args) const
|
||||
{
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (contains (args[i]) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
GenericFactor<T>::setParams (const Params& newParams)
|
||||
{
|
||||
params_ = newParams;
|
||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> double
|
||||
GenericFactor<T>::operator[] (size_t idx) const
|
||||
{
|
||||
assert (idx < params_.size());
|
||||
return params_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> double&
|
||||
GenericFactor<T>::operator[] (size_t idx)
|
||||
{
|
||||
assert (idx < params_.size());
|
||||
return params_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> GenericFactor<T>&
|
||||
GenericFactor<T>::multiply (const GenericFactor<T>& g)
|
||||
{
|
||||
if (args_ == g.arguments()) {
|
||||
// optimization
|
||||
Globals::logDomain
|
||||
? params_ += g.params()
|
||||
: params_ *= g.params();
|
||||
return *this;
|
||||
}
|
||||
unsigned range_prod = 1;
|
||||
bool share_arguments = false;
|
||||
const std::vector<T>& g_args = g.arguments();
|
||||
const Ranges& g_ranges = g.ranges();
|
||||
const Params& g_params = g.params();
|
||||
for (size_t i = 0; i < g_args.size(); i++) {
|
||||
size_t idx = indexOf (g_args[i]);
|
||||
if (idx == args_.size()) {
|
||||
range_prod *= g_ranges[i];
|
||||
args_.push_back (g_args[i]);
|
||||
ranges_.push_back (g_ranges[i]);
|
||||
} else {
|
||||
share_arguments = true;
|
||||
}
|
||||
}
|
||||
if (share_arguments == false) {
|
||||
// optimization
|
||||
cartesianProduct (g_params.begin(), g_params.end());
|
||||
} else {
|
||||
extend (range_prod);
|
||||
Params::iterator it = params_.begin();
|
||||
MapIndexer indexer (args_, ranges_, g_args, g_ranges);
|
||||
if (Globals::logDomain) {
|
||||
for (; indexer.valid(); ++it, ++indexer) {
|
||||
*it += g_params[indexer];
|
||||
}
|
||||
} else {
|
||||
for (; indexer.valid(); ++it, ++indexer) {
|
||||
*it *= g_params[indexer];
|
||||
}
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
GenericFactor<T>::sumOutIndex (size_t idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
assert (args_.size() > 1);
|
||||
size_t new_size = params_.size() / ranges_[idx];
|
||||
Params newps (new_size, LogAware::addIdenty());
|
||||
Params::const_iterator first = params_.begin();
|
||||
Params::const_iterator last = params_.end();
|
||||
MapIndexer indexer (ranges_, idx);
|
||||
if (Globals::logDomain) {
|
||||
for (; first != last; ++indexer) {
|
||||
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
||||
}
|
||||
} else {
|
||||
for (; first != last; ++indexer) {
|
||||
newps[indexer] += *first++;
|
||||
}
|
||||
}
|
||||
params_ = newps;
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
GenericFactor<T>::absorveEvidence (const T& arg, unsigned obsIdx)
|
||||
{
|
||||
size_t idx = indexOf (arg);
|
||||
assert (idx != args_.size());
|
||||
assert (obsIdx < ranges_[idx]);
|
||||
Params newps;
|
||||
newps.reserve (params_.size() / ranges_[idx]);
|
||||
Indexer indexer (ranges_);
|
||||
for (unsigned i = 0; i < obsIdx; ++i) {
|
||||
indexer.incrementDimension (idx);
|
||||
}
|
||||
while (indexer.valid()) {
|
||||
newps.push_back (params_[indexer]);
|
||||
indexer.incrementExceptDimension (idx);
|
||||
}
|
||||
params_ = newps;
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
GenericFactor<T>::reorderArguments (const std::vector<T>& new_args)
|
||||
{
|
||||
assert (new_args.size() == args_.size());
|
||||
if (new_args == args_) {
|
||||
return; // already on the desired order
|
||||
}
|
||||
Ranges new_ranges;
|
||||
for (size_t i = 0; i < new_args.size(); i++) {
|
||||
size_t idx = indexOf (new_args[i]);
|
||||
assert (idx != args_.size());
|
||||
new_ranges.push_back (ranges_[idx]);
|
||||
}
|
||||
Params newps;
|
||||
newps.reserve (params_.size());
|
||||
MapIndexer indexer (new_args, new_ranges, args_, ranges_);
|
||||
for (; indexer.valid(); ++indexer) {
|
||||
newps.push_back (params_[indexer]);
|
||||
}
|
||||
params_ = newps;
|
||||
args_ = new_args;
|
||||
ranges_ = new_ranges;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
GenericFactor<T>::extend (unsigned range_prod)
|
||||
{
|
||||
Params backup = params_;
|
||||
params_.clear();
|
||||
params_.reserve (backup.size() * range_prod);
|
||||
Params::const_iterator first = backup.begin();
|
||||
Params::const_iterator last = backup.end();
|
||||
for (; first != last; ++first) {
|
||||
for (unsigned reps = 0; reps < range_prod; ++reps) {
|
||||
params_.push_back (*first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
GenericFactor<T>::cartesianProduct (
|
||||
Params::const_iterator first2,
|
||||
Params::const_iterator last2)
|
||||
{
|
||||
Params backup = params_;
|
||||
params_.clear();
|
||||
params_.reserve (params_.size() * (last2 - first2));
|
||||
Params::const_iterator first1 = backup.begin();
|
||||
Params::const_iterator last1 = backup.end();
|
||||
Params::const_iterator tmp;
|
||||
if (Globals::logDomain) {
|
||||
for (; first1 != last1; ++first1) {
|
||||
for (tmp = first2; tmp != last2; ++tmp) {
|
||||
params_.push_back ((*first1) + (*tmp));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (; first1 != last1; ++first1) {
|
||||
for (tmp = first2; tmp != last2; ++tmp) {
|
||||
params_.push_back ((*first1) * (*tmp));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template class GenericFactor<VarId>;
|
||||
template class GenericFactor<ProbFormula>;
|
||||
|
||||
} // namespace Horus
|
||||
|
76
packages/CLPBN/horus/GenericFactor.h
Normal file
76
packages/CLPBN/horus/GenericFactor.h
Normal file
@ -0,0 +1,76 @@
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_GENERICFACTOR_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_GENERICFACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
template <typename T>
|
||||
class GenericFactor {
|
||||
public:
|
||||
const std::vector<T>& arguments() const { return args_; }
|
||||
|
||||
std::vector<T>& arguments() { return args_; }
|
||||
|
||||
const Ranges& ranges() const { return ranges_; }
|
||||
|
||||
const Params& params() const { return params_; }
|
||||
|
||||
Params& params() { return params_; }
|
||||
|
||||
size_t nrArguments() const { return args_.size(); }
|
||||
|
||||
size_t size() const { return params_.size(); }
|
||||
|
||||
unsigned distId() const { return distId_; }
|
||||
|
||||
void setDistId (unsigned id) { distId_ = id; }
|
||||
|
||||
void normalize() { LogAware::normalize (params_); }
|
||||
|
||||
size_t indexOf (const T& t) const { return Util::indexOf (args_, t); }
|
||||
|
||||
const T& argument (size_t idx) const;
|
||||
|
||||
T& argument (size_t idx);
|
||||
|
||||
unsigned range (size_t idx) const;
|
||||
|
||||
bool contains (const T& arg) const;
|
||||
|
||||
bool contains (const std::vector<T>& args) const;
|
||||
|
||||
void setParams (const Params& newParams);
|
||||
|
||||
double operator[] (size_t idx) const;
|
||||
|
||||
double& operator[] (size_t idx);
|
||||
|
||||
GenericFactor<T>& multiply (const GenericFactor<T>& g);
|
||||
|
||||
void sumOutIndex (size_t idx);
|
||||
|
||||
void absorveEvidence (const T& arg, unsigned obsIdx);
|
||||
|
||||
void reorderArguments (const std::vector<T>& new_args);
|
||||
|
||||
protected:
|
||||
std::vector<T> args_;
|
||||
Ranges ranges_;
|
||||
Params params_;
|
||||
unsigned distId_;
|
||||
|
||||
private:
|
||||
void extend (unsigned range_prod);
|
||||
|
||||
void cartesianProduct (
|
||||
Params::const_iterator first2, Params::const_iterator last2);
|
||||
};
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_GENERICFACTOR_H_
|
||||
|
@ -1,10 +1,20 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "GroundSolver.h"
|
||||
#include "VarElim.h"
|
||||
#include "BeliefProp.h"
|
||||
#include "CountingBp.h"
|
||||
#include "Indexer.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
void
|
||||
GroundSolver::printAnswer (const VarIds& vids)
|
||||
{
|
||||
@ -19,20 +29,21 @@ GroundSolver::printAnswer (const VarIds& vids)
|
||||
}
|
||||
if (unobservedVids.empty() == false) {
|
||||
Params res = solveQuery (unobservedVids);
|
||||
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
||||
std::vector<std::string> stateLines =
|
||||
Util::getStateLines (unobservedVars);
|
||||
for (size_t i = 0; i < res.size(); i++) {
|
||||
cout << "P(" << stateLines[i] << ") = " ;
|
||||
cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||
cout << endl;
|
||||
std::cout << "P(" << stateLines[i] << ") = " ;
|
||||
std::cout << std::setprecision (Constants::precision) << res[i];
|
||||
std::cout << std::endl;
|
||||
}
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
GroundSolver::printAllPosterioris (void)
|
||||
GroundSolver::printAllPosterioris()
|
||||
{
|
||||
VarNodes vars = fg.varNodes();
|
||||
std::sort (vars.begin(), vars.end(), sortByVarId());
|
||||
@ -57,9 +68,9 @@ GroundSolver::getJointByConditioning (
|
||||
|
||||
GroundSolver* solver = 0;
|
||||
switch (solverType) {
|
||||
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
||||
case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::veSolver: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
@ -80,9 +91,9 @@ GroundSolver::getJointByConditioning (
|
||||
}
|
||||
delete solver;
|
||||
switch (solverType) {
|
||||
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
||||
case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::veSolver: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
@ -105,3 +116,5 @@ GroundSolver::getJointByConditioning (
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,16 +1,13 @@
|
||||
#ifndef HORUS_GROUNDSOLVER_H
|
||||
#define HORUS_GROUNDSOLVER_H
|
||||
|
||||
#include <iomanip>
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_GROUNDSOLVER_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_GROUNDSOLVER_H_
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace Horus {
|
||||
|
||||
class GroundSolver
|
||||
{
|
||||
class GroundSolver {
|
||||
public:
|
||||
GroundSolver (const FactorGraph& factorGraph) : fg(factorGraph) { }
|
||||
|
||||
@ -18,11 +15,11 @@ class GroundSolver
|
||||
|
||||
virtual Params solveQuery (VarIds queryVids) = 0;
|
||||
|
||||
virtual void printSolverFlags (void) const = 0;
|
||||
virtual void printSolverFlags() const = 0;
|
||||
|
||||
void printAnswer (const VarIds& vids);
|
||||
|
||||
void printAllPosterioris (void);
|
||||
void printAllPosterioris();
|
||||
|
||||
static Params getJointByConditioning (GroundSolverType,
|
||||
FactorGraph, const VarIds& jointVarIds);
|
||||
@ -30,8 +27,11 @@ class GroundSolver
|
||||
protected:
|
||||
const FactorGraph& fg;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (GroundSolver);
|
||||
};
|
||||
|
||||
#endif // HORUS_GROUNDSOLVER_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_GROUNDSOLVER_H_
|
||||
|
||||
|
@ -7,6 +7,8 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
HistogramSet::HistogramSet (unsigned size, unsigned range)
|
||||
{
|
||||
size_ = size;
|
||||
@ -17,7 +19,7 @@ HistogramSet::HistogramSet (unsigned size, unsigned range)
|
||||
|
||||
|
||||
void
|
||||
HistogramSet::nextHistogram (void)
|
||||
HistogramSet::nextHistogram()
|
||||
{
|
||||
for (size_t i = hist_.size() - 1; i-- > 0; ) {
|
||||
if (hist_[i] > 0) {
|
||||
@ -43,7 +45,7 @@ HistogramSet::operator[] (size_t idx) const
|
||||
|
||||
|
||||
unsigned
|
||||
HistogramSet::nrHistograms (void) const
|
||||
HistogramSet::nrHistograms() const
|
||||
{
|
||||
return HistogramSet::nrHistograms (size_, hist_.size());
|
||||
}
|
||||
@ -51,7 +53,7 @@ HistogramSet::nrHistograms (void) const
|
||||
|
||||
|
||||
void
|
||||
HistogramSet::reset (void)
|
||||
HistogramSet::reset()
|
||||
{
|
||||
std::fill (hist_.begin() + 1, hist_.end(), 0);
|
||||
hist_[0] = size_;
|
||||
@ -59,12 +61,12 @@ HistogramSet::reset (void)
|
||||
|
||||
|
||||
|
||||
vector<Histogram>
|
||||
std::vector<Histogram>
|
||||
HistogramSet::getHistograms (unsigned N, unsigned R)
|
||||
{
|
||||
HistogramSet hs (N, R);
|
||||
unsigned H = hs.nrHistograms();
|
||||
vector<Histogram> histograms;
|
||||
std::vector<Histogram> histograms;
|
||||
histograms.reserve (H);
|
||||
for (unsigned i = 0; i < H; i++) {
|
||||
histograms.push_back (hs.hist_);
|
||||
@ -86,9 +88,9 @@ HistogramSet::nrHistograms (unsigned N, unsigned R)
|
||||
size_t
|
||||
HistogramSet::findIndex (
|
||||
const Histogram& h,
|
||||
const vector<Histogram>& hists)
|
||||
const std::vector<Histogram>& hists)
|
||||
{
|
||||
vector<Histogram>::const_iterator it = std::lower_bound (
|
||||
std::vector<Histogram>::const_iterator it = std::lower_bound (
|
||||
hists.begin(), hists.end(), h, std::greater<Histogram>());
|
||||
assert (it != hists.end() && *it == h);
|
||||
return std::distance (hists.begin(), it);
|
||||
@ -96,13 +98,13 @@ HistogramSet::findIndex (
|
||||
|
||||
|
||||
|
||||
vector<double>
|
||||
std::vector<double>
|
||||
HistogramSet::getNumAssigns (unsigned N, unsigned R)
|
||||
{
|
||||
HistogramSet hs (N, R);
|
||||
double N_fac = Util::logFactorial (N);
|
||||
unsigned H = hs.nrHistograms();
|
||||
vector<double> numAssigns;
|
||||
std::vector<double> numAssigns;
|
||||
numAssigns.reserve (H);
|
||||
for (unsigned h = 0; h < H; h++) {
|
||||
double prod = 0.0;
|
||||
@ -118,14 +120,6 @@ HistogramSet::getNumAssigns (unsigned N, unsigned R)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const HistogramSet& hs)
|
||||
{
|
||||
os << "#" << hs.hist_;
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
HistogramSet::maxCount (size_t idx) const
|
||||
{
|
||||
@ -144,3 +138,14 @@ HistogramSet::clearAfter (size_t idx)
|
||||
std::fill (hist_.begin() + idx + 1, hist_.end(), 0);
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const HistogramSet& hs)
|
||||
{
|
||||
os << "#" << hs.hist_;
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,50 +1,51 @@
|
||||
#ifndef HORUS_HISTOGRAM_H
|
||||
#define HORUS_HISTOGRAM_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_HISTOGRAM_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_HISTOGRAM_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <ostream>
|
||||
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
typedef std::vector<unsigned> Histogram;
|
||||
|
||||
typedef vector<unsigned> Histogram;
|
||||
|
||||
class HistogramSet
|
||||
{
|
||||
namespace Horus {
|
||||
|
||||
class HistogramSet {
|
||||
public:
|
||||
HistogramSet (unsigned, unsigned);
|
||||
|
||||
void nextHistogram (void);
|
||||
void nextHistogram();
|
||||
|
||||
unsigned operator[] (size_t idx) const;
|
||||
|
||||
unsigned nrHistograms (void) const;
|
||||
unsigned nrHistograms() const;
|
||||
|
||||
void reset (void);
|
||||
void reset();
|
||||
|
||||
static vector<Histogram> getHistograms (unsigned, unsigned);
|
||||
static std::vector<Histogram> getHistograms (unsigned, unsigned);
|
||||
|
||||
static unsigned nrHistograms (unsigned, unsigned);
|
||||
|
||||
static size_t findIndex (
|
||||
const Histogram&, const vector<Histogram>&);
|
||||
const Histogram&, const std::vector<Histogram>&);
|
||||
|
||||
static vector<double> getNumAssigns (unsigned, unsigned);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
|
||||
static std::vector<double> getNumAssigns (unsigned, unsigned);
|
||||
|
||||
private:
|
||||
unsigned maxCount (size_t) const;
|
||||
|
||||
void clearAfter (size_t);
|
||||
|
||||
friend std::ostream& operator<< (std::ostream&, const HistogramSet&);
|
||||
|
||||
unsigned size_;
|
||||
Histogram hist_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (HistogramSet);
|
||||
};
|
||||
|
||||
#endif // HORUS_HISTOGRAM_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_HISTOGRAM_H_
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_HORUS_H
|
||||
#define HORUS_HORUS_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_HORUS_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_HORUS_H_
|
||||
|
||||
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||
TypeName(const TypeName&); \
|
||||
@ -14,6 +14,9 @@
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class Var;
|
||||
class Factor;
|
||||
class VarNode;
|
||||
@ -31,19 +34,17 @@ typedef std::vector<unsigned> Ranges;
|
||||
typedef unsigned long long ullong;
|
||||
|
||||
|
||||
enum LiftedSolverType
|
||||
{
|
||||
LVE, // generalized counting first-order variable elimination (GC-FOVE)
|
||||
LBP, // lifted first-order belief propagation
|
||||
LKC // lifted first-order knowledge compilation
|
||||
enum class LiftedSolverType {
|
||||
lveSolver, // generalized counting first-order variable elimination
|
||||
lbpSolver, // lifted first-order belief propagation
|
||||
lkcSolver // lifted first-order knowledge compilation
|
||||
};
|
||||
|
||||
|
||||
enum GroundSolverType
|
||||
{
|
||||
VE, // variable elimination
|
||||
BP, // belief propagation
|
||||
CBP // counting belief propagation
|
||||
enum class GroundSolverType {
|
||||
veSolver, // variable elimination
|
||||
bpSolver, // belief propagation
|
||||
CbpSolver // counting belief propagation
|
||||
};
|
||||
|
||||
|
||||
@ -57,20 +58,22 @@ extern unsigned verbosity;
|
||||
extern LiftedSolverType liftedSolver;
|
||||
extern GroundSolverType groundSolver;
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
namespace Constants {
|
||||
|
||||
// show message calculation for belief propagation
|
||||
const bool SHOW_BP_CALCS = false;
|
||||
const bool showBpCalcs = false;
|
||||
|
||||
const int NO_EVIDENCE = -1;
|
||||
const int unobserved = -1;
|
||||
|
||||
// number of digits to show when printing a parameter
|
||||
const unsigned PRECISION = 6;
|
||||
const unsigned precision = 8;
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
#endif // HORUS_HORUS_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_HORUS_H_
|
||||
|
||||
|
@ -1,53 +1,61 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "VarElim.h"
|
||||
#include "BeliefProp.h"
|
||||
#include "CountingBp.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace {
|
||||
|
||||
int readHorusFlags (int, const char* []);
|
||||
void readFactorGraph (FactorGraph&, const char*);
|
||||
VarIds readQueryAndEvidence (FactorGraph&, int, const char* [], int);
|
||||
|
||||
void runSolver (const FactorGraph&, const VarIds&);
|
||||
void readFactorGraph (Horus::FactorGraph&, const char*);
|
||||
|
||||
const string USAGE = "usage: ./hcli [solver=hve|bp|cbp] \
|
||||
Horus::VarIds readQueryAndEvidence (
|
||||
Horus::FactorGraph&, int, const char* [], int);
|
||||
|
||||
void runSolver (const Horus::FactorGraph&, const Horus::VarIds&);
|
||||
|
||||
const std::string usage = "usage: ./hcli [solver=hve|bp|cbp] \
|
||||
[<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
main (int argc, const char* argv[])
|
||||
{
|
||||
if (argc <= 1) {
|
||||
cerr << "Error: no probabilistic graphical model was given." << endl;
|
||||
cerr << USAGE << endl;
|
||||
std::cerr << "Error: no probabilistic graphical model was given." ;
|
||||
std::cerr << std::endl << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
int idx = readHorusFlags (argc, argv);
|
||||
FactorGraph fg;
|
||||
Horus::FactorGraph fg;
|
||||
readFactorGraph (fg, argv[idx]);
|
||||
VarIds queryIds = readQueryAndEvidence (fg, argc, argv, idx + 1);
|
||||
if (FactorGraph::exportToLibDai()) {
|
||||
Horus::VarIds queryIds
|
||||
= readQueryAndEvidence (fg, argc, argv, idx + 1);
|
||||
if (Horus::FactorGraph::exportToLibDai()) {
|
||||
fg.exportToLibDai ("model.fg");
|
||||
}
|
||||
if (FactorGraph::exportToUai()) {
|
||||
if (Horus::FactorGraph::exportToUai()) {
|
||||
fg.exportToUai ("model.uai");
|
||||
}
|
||||
if (FactorGraph::exportGraphViz()) {
|
||||
if (Horus::FactorGraph::exportGraphViz()) {
|
||||
fg.exportToGraphViz ("model.dot");
|
||||
}
|
||||
if (FactorGraph::printFactorGraph()) {
|
||||
if (Horus::FactorGraph::printFactorGraph()) {
|
||||
fg.print();
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "factor graph contains " ;
|
||||
cout << fg.nrVarNodes() << " variables and " ;
|
||||
cout << fg.nrFacNodes() << " factors " << endl;
|
||||
if (Horus::Globals::verbosity > 0) {
|
||||
std::cout << "factor graph contains " ;
|
||||
std::cout << fg.nrVarNodes() << " variables and " ;
|
||||
std::cout << fg.nrFacNodes() << " factors " << std::endl;
|
||||
}
|
||||
runSolver (fg, queryIds);
|
||||
return 0;
|
||||
@ -55,29 +63,31 @@ main (int argc, const char* argv[])
|
||||
|
||||
|
||||
|
||||
namespace {
|
||||
|
||||
int
|
||||
readHorusFlags (int argc, const char* argv[])
|
||||
{
|
||||
int i = 1;
|
||||
for (; i < argc; i++) {
|
||||
const string& arg = argv[i];
|
||||
const std::string& arg = argv[i];
|
||||
size_t pos = arg.find ('=');
|
||||
if (pos == std::string::npos) {
|
||||
return i;
|
||||
}
|
||||
string leftArg = arg.substr (0, pos);
|
||||
string rightArg = arg.substr (pos + 1);
|
||||
std::string leftArg = arg.substr (0, pos);
|
||||
std::string rightArg = arg.substr (pos + 1);
|
||||
if (leftArg.empty()) {
|
||||
cerr << "Error: missing left argument." << endl;
|
||||
cerr << USAGE << endl;
|
||||
std::cerr << "Error: missing left argument." << std::endl;
|
||||
std::cerr << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
if (rightArg.empty()) {
|
||||
cerr << "Error: missing right argument." << endl;
|
||||
cerr << USAGE << endl;
|
||||
std::cerr << "Error: missing right argument." << std::endl;
|
||||
std::cerr << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
Util::setHorusFlag (leftArg, rightArg);
|
||||
Horus::Util::setHorusFlag (leftArg, rightArg);
|
||||
}
|
||||
return i + 1;
|
||||
}
|
||||
@ -85,84 +95,84 @@ readHorusFlags (int argc, const char* argv[])
|
||||
|
||||
|
||||
void
|
||||
readFactorGraph (FactorGraph& fg, const char* s)
|
||||
readFactorGraph (Horus::FactorGraph& fg, const char* s)
|
||||
{
|
||||
string fileName (s);
|
||||
string extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||
std::string fileName (s);
|
||||
std::string extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||
if (extension == "uai") {
|
||||
fg.readFromUaiFormat (fileName.c_str());
|
||||
fg = Horus::FactorGraph::readFromUaiFormat (fileName.c_str());
|
||||
} else if (extension == "fg") {
|
||||
fg.readFromLibDaiFormat (fileName.c_str());
|
||||
fg = Horus::FactorGraph::readFromLibDaiFormat (fileName.c_str());
|
||||
} else {
|
||||
cerr << "Error: the probabilistic graphical model must be " ;
|
||||
cerr << "defined either in a UAI or libDAI file." << endl;
|
||||
std::cerr << "Error: the probabilistic graphical model must be " ;
|
||||
std::cerr << "defined either in a UAI or libDAI file." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarIds
|
||||
Horus::VarIds
|
||||
readQueryAndEvidence (
|
||||
FactorGraph& fg,
|
||||
Horus::FactorGraph& fg,
|
||||
int argc,
|
||||
const char* argv[],
|
||||
int start)
|
||||
{
|
||||
VarIds queryIds;
|
||||
Horus::VarIds queryIds;
|
||||
for (int i = start; i < argc; i++) {
|
||||
const string& arg = argv[i];
|
||||
const std::string& arg = argv[i];
|
||||
if (arg.find ('=') == std::string::npos) {
|
||||
if (Util::isInteger (arg) == false) {
|
||||
cerr << "Error: `" << arg << "' " ;
|
||||
cerr << "is not a variable id." ;
|
||||
cerr << endl;
|
||||
if (Horus::Util::isInteger (arg) == false) {
|
||||
std::cerr << "Error: `" << arg << "' " ;
|
||||
std::cerr << "is not a variable id." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
VarId vid = Util::stringToUnsigned (arg);
|
||||
VarNode* queryVar = fg.getVarNode (vid);
|
||||
Horus::VarId vid = Horus::Util::stringToUnsigned (arg);
|
||||
Horus::VarNode* queryVar = fg.getVarNode (vid);
|
||||
if (queryVar == false) {
|
||||
cerr << "Error: unknow variable with id " ;
|
||||
cerr << "`" << vid << "'." << endl;
|
||||
std::cerr << "Error: unknow variable with id " ;
|
||||
std::cerr << "`" << vid << "'." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
queryIds.push_back (vid);
|
||||
} else {
|
||||
size_t pos = arg.find ('=');
|
||||
string leftArg = arg.substr (0, pos);
|
||||
string rightArg = arg.substr (pos + 1);
|
||||
std::string leftArg = arg.substr (0, pos);
|
||||
std::string rightArg = arg.substr (pos + 1);
|
||||
if (leftArg.empty()) {
|
||||
cerr << "Error: missing left argument." << endl;
|
||||
cerr << USAGE << endl;
|
||||
std::cerr << "Error: missing left argument." << std::endl;
|
||||
std::cerr << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
if (Util::isInteger (leftArg) == false) {
|
||||
cerr << "Error: `" << leftArg << "' " ;
|
||||
cerr << "is not a variable id." << endl ;
|
||||
if (Horus::Util::isInteger (leftArg) == false) {
|
||||
std::cerr << "Error: `" << leftArg << "' " ;
|
||||
std::cerr << "is not a variable id." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
VarId vid = Util::stringToUnsigned (leftArg);
|
||||
VarNode* observedVar = fg.getVarNode (vid);
|
||||
Horus::VarId vid = Horus::Util::stringToUnsigned (leftArg);
|
||||
Horus::VarNode* observedVar = fg.getVarNode (vid);
|
||||
if (observedVar == false) {
|
||||
cerr << "Error: unknow variable with id " ;
|
||||
cerr << "`" << vid << "'." << endl;
|
||||
std::cerr << "Error: unknow variable with id " ;
|
||||
std::cerr << "`" << vid << "'." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
if (rightArg.empty()) {
|
||||
cerr << "Error: missing right argument." << endl;
|
||||
cerr << USAGE << endl;
|
||||
std::cerr << "Error: missing right argument." << std::endl;
|
||||
std::cerr << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
if (Util::isInteger (rightArg) == false) {
|
||||
cerr << "Error: `" << rightArg << "' " ;
|
||||
cerr << "is not a state index." << endl ;
|
||||
if (Horus::Util::isInteger (rightArg) == false) {
|
||||
std::cerr << "Error: `" << rightArg << "' " ;
|
||||
std::cerr << "is not a state index." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
unsigned stateIdx = Util::stringToUnsigned (rightArg);
|
||||
unsigned stateIdx = Horus::Util::stringToUnsigned (rightArg);
|
||||
if (observedVar->isValidState (stateIdx) == false) {
|
||||
cerr << "Error: `" << stateIdx << "' " ;
|
||||
cerr << "is not a valid state index for variable with id " ;
|
||||
cerr << "`" << vid << "'." << endl;
|
||||
std::cerr << "Error: `" << stateIdx << "' " ;
|
||||
std::cerr << "is not a valid state index for variable with id " ;
|
||||
std::cerr << "`" << vid << "'." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
observedVar->setEvidence (stateIdx);
|
||||
@ -174,25 +184,27 @@ readQueryAndEvidence (
|
||||
|
||||
|
||||
void
|
||||
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||
runSolver (
|
||||
const Horus::FactorGraph& fg,
|
||||
const Horus::VarIds& queryIds)
|
||||
{
|
||||
GroundSolver* solver = 0;
|
||||
switch (Globals::groundSolver) {
|
||||
case GroundSolverType::VE:
|
||||
solver = new VarElim (fg);
|
||||
Horus::GroundSolver* solver = 0;
|
||||
switch (Horus::Globals::groundSolver) {
|
||||
case Horus::GroundSolverType::veSolver:
|
||||
solver = new Horus::VarElim (fg);
|
||||
break;
|
||||
case GroundSolverType::BP:
|
||||
solver = new BeliefProp (fg);
|
||||
case Horus::GroundSolverType::bpSolver:
|
||||
solver = new Horus::BeliefProp (fg);
|
||||
break;
|
||||
case GroundSolverType::CBP:
|
||||
solver = new CountingBp (fg);
|
||||
case Horus::GroundSolverType::CbpSolver:
|
||||
solver = new Horus::CountingBp (fg);
|
||||
break;
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
if (Horus::Globals::verbosity > 0) {
|
||||
solver->printSolverFlags();
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
if (queryIds.empty()) {
|
||||
solver->printAllPosterioris();
|
||||
@ -202,3 +214,5 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||
delete solver;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
@ -20,31 +21,35 @@
|
||||
#include "BayesBall.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace Horus {
|
||||
|
||||
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||
namespace {
|
||||
|
||||
Parfactor* readParfactor (YAP_Term);
|
||||
|
||||
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
||||
ObservedFormulas* readLiftedEvidence (YAP_Term);
|
||||
|
||||
vector<unsigned> readUnsignedList (YAP_Term list);
|
||||
std::vector<unsigned> readUnsignedList (YAP_Term);
|
||||
|
||||
Params readParameters (YAP_Term);
|
||||
|
||||
YAP_Term fillAnswersPrologList (vector<Params>& results);
|
||||
YAP_Term fillSolutionList (const std::vector<Params>&);
|
||||
|
||||
}
|
||||
|
||||
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||
|
||||
|
||||
|
||||
int
|
||||
createLiftedNetwork (void)
|
||||
createLiftedNetwork()
|
||||
{
|
||||
Parfactors parfactors;
|
||||
YAP_Term parfactorList = YAP_ARG1;
|
||||
while (parfactorList != YAP_TermNil()) {
|
||||
YAP_Term pfTerm = YAP_HeadOfTerm (parfactorList);
|
||||
parfactors.push_back (readParfactor (pfTerm));
|
||||
parfactorList = YAP_TailOfTerm (parfactorList);
|
||||
parfactorList = YAP_TailOfTerm (parfactorList);
|
||||
}
|
||||
|
||||
// LiftedUtils::printSymbolDictionary();
|
||||
@ -52,7 +57,7 @@ createLiftedNetwork (void)
|
||||
Util::printHeader ("INITIAL PARFACTORS");
|
||||
for (size_t i = 0; i < parfactors.size(); i++) {
|
||||
parfactors[i]->print();
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,21 +69,20 @@ createLiftedNetwork (void)
|
||||
}
|
||||
|
||||
// read evidence
|
||||
ObservedFormulas* obsFormulas = new ObservedFormulas();
|
||||
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
|
||||
ObservedFormulas* obsFormulas = readLiftedEvidence (YAP_ARG2);
|
||||
|
||||
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
|
||||
LiftedNetwork* network = new LiftedNetwork (pfList, obsFormulas);
|
||||
|
||||
YAP_Int p = (YAP_Int) (net);
|
||||
YAP_Int p = (YAP_Int) (network);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
createGroundNetwork (void)
|
||||
createGroundNetwork()
|
||||
{
|
||||
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
std::string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
if (factorsType == "bayes") {
|
||||
fg->setFactorsAsBayesian();
|
||||
@ -121,9 +125,9 @@ createGroundNetwork (void)
|
||||
fg->print();
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "factor graph contains " ;
|
||||
cout << fg->nrVarNodes() << " variables and " ;
|
||||
cout << fg->nrFacNodes() << " factors " << endl;
|
||||
std::cout << "factor graph contains " ;
|
||||
std::cout << fg->nrVarNodes() << " variables and " ;
|
||||
std::cout << fg->nrFacNodes() << " factors " << std::endl;
|
||||
}
|
||||
YAP_Int p = (YAP_Int) (fg);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4);
|
||||
@ -132,45 +136,46 @@ createGroundNetwork (void)
|
||||
|
||||
|
||||
int
|
||||
runLiftedSolver (void)
|
||||
runLiftedSolver()
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
ParfactorList pfListCopy (*network->first);
|
||||
LiftedOperations::absorveEvidence (pfListCopy, *network->second);
|
||||
ParfactorList copy (*network->first);
|
||||
LiftedOperations::absorveEvidence (copy, *network->second);
|
||||
|
||||
LiftedSolver* solver = 0;
|
||||
switch (Globals::liftedSolver) {
|
||||
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break;
|
||||
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break;
|
||||
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break;
|
||||
case LiftedSolverType::lveSolver: solver = new LiftedVe (copy); break;
|
||||
case LiftedSolverType::lbpSolver: solver = new LiftedBp (copy); break;
|
||||
case LiftedSolverType::lkcSolver: solver = new LiftedKc (copy); break;
|
||||
}
|
||||
|
||||
if (Globals::verbosity > 0) {
|
||||
solver->printSolverFlags();
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
vector<Params> results;
|
||||
std::vector<Params> results;
|
||||
while (taskList != YAP_TermNil()) {
|
||||
Grounds queryVars;
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
while (jointList != YAP_TermNil()) {
|
||||
YAP_Term ground = YAP_HeadOfTerm (jointList);
|
||||
if (YAP_IsAtomTerm (ground)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||
std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||
queryVars.push_back (Ground (LiftedUtils::getSymbol (name)));
|
||||
} else {
|
||||
assert (YAP_IsApplTerm (ground));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
|
||||
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
|
||||
std::string name ((char*) (YAP_AtomName (
|
||||
YAP_NameOfFunctor (yapFunctor))));
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
Symbols args;
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, ground);
|
||||
assert (YAP_IsAtomTerm (ti));
|
||||
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
std::string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
args.push_back (LiftedUtils::getSymbol (arg));
|
||||
}
|
||||
queryVars.push_back (Ground (functor, args));
|
||||
@ -183,17 +188,17 @@ runLiftedSolver (void)
|
||||
|
||||
delete solver;
|
||||
|
||||
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
|
||||
return YAP_Unify (fillSolutionList (results), YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
runGroundSolver (void)
|
||||
runGroundSolver()
|
||||
{
|
||||
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
|
||||
vector<VarIds> tasks;
|
||||
std::vector<VarIds> tasks;
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
while (taskList != YAP_TermNil()) {
|
||||
tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList)));
|
||||
@ -213,17 +218,17 @@ runGroundSolver (void)
|
||||
GroundSolver* solver = 0;
|
||||
CountingBp::setFindIdenticalFactorsFlag (false);
|
||||
switch (Globals::groundSolver) {
|
||||
case GroundSolverType::VE: solver = new VarElim (*mfg); break;
|
||||
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break;
|
||||
case GroundSolverType::veSolver: solver = new VarElim (*mfg); break;
|
||||
case GroundSolverType::bpSolver: solver = new BeliefProp (*mfg); break;
|
||||
case GroundSolverType::CbpSolver: solver = new CountingBp (*mfg); break;
|
||||
}
|
||||
|
||||
if (Globals::verbosity > 0) {
|
||||
solver->printSolverFlags();
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
vector<Params> results;
|
||||
std::vector<Params> results;
|
||||
results.reserve (tasks.size());
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
results.push_back (solver->solveQuery (tasks[i]));
|
||||
@ -234,19 +239,19 @@ runGroundSolver (void)
|
||||
delete mfg;
|
||||
}
|
||||
|
||||
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
|
||||
return YAP_Unify (fillSolutionList (results), YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setParfactorsParams (void)
|
||||
setParfactorsParams()
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
ParfactorList* pfList = network->first;
|
||||
YAP_Term distIdsList = YAP_ARG2;
|
||||
YAP_Term paramsList = YAP_ARG3;
|
||||
unordered_map<unsigned, Params> paramsMap;
|
||||
std::unordered_map<unsigned, Params> paramsMap;
|
||||
while (distIdsList != YAP_TermNil()) {
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (
|
||||
YAP_HeadOfTerm (distIdsList));
|
||||
@ -267,12 +272,12 @@ setParfactorsParams (void)
|
||||
|
||||
|
||||
int
|
||||
setFactorsParams (void)
|
||||
setFactorsParams()
|
||||
{
|
||||
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term distIdsList = YAP_ARG2;
|
||||
YAP_Term paramsList = YAP_ARG3;
|
||||
unordered_map<unsigned, Params> paramsMap;
|
||||
std::unordered_map<unsigned, Params> paramsMap;
|
||||
while (distIdsList != YAP_TermNil()) {
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (
|
||||
YAP_HeadOfTerm (distIdsList));
|
||||
@ -293,10 +298,10 @@ setFactorsParams (void)
|
||||
|
||||
|
||||
int
|
||||
setVarsInformation (void)
|
||||
setVarsInformation()
|
||||
{
|
||||
Var::clearVarsInfo();
|
||||
vector<string> labels;
|
||||
std::vector<std::string> labels;
|
||||
YAP_Term labelsL = YAP_ARG1;
|
||||
while (labelsL != YAP_TermNil()) {
|
||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
|
||||
@ -323,20 +328,20 @@ setVarsInformation (void)
|
||||
|
||||
|
||||
int
|
||||
setHorusFlag (void)
|
||||
setHorusFlag()
|
||||
{
|
||||
string option ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
string value;
|
||||
std::string option ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
std::string value;
|
||||
if (option == "verbosity") {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << (int) YAP_IntOfTerm (YAP_ARG2);
|
||||
ss >> value;
|
||||
} else if (option == "bp_accuracy") {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << (float) YAP_FloatOfTerm (YAP_ARG2);
|
||||
ss >> value;
|
||||
} else if (option == "bp_max_iter") {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << (int) YAP_IntOfTerm (YAP_ARG2);
|
||||
ss >> value;
|
||||
} else {
|
||||
@ -348,7 +353,7 @@ setHorusFlag (void)
|
||||
|
||||
|
||||
int
|
||||
freeGroundNetwork (void)
|
||||
freeGroundNetwork()
|
||||
{
|
||||
delete (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
return TRUE;
|
||||
@ -357,7 +362,7 @@ freeGroundNetwork (void)
|
||||
|
||||
|
||||
int
|
||||
freeLiftedNetwork (void)
|
||||
freeLiftedNetwork()
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
delete network->first;
|
||||
@ -368,6 +373,8 @@ freeLiftedNetwork (void)
|
||||
|
||||
|
||||
|
||||
namespace {
|
||||
|
||||
Parfactor*
|
||||
readParfactor (YAP_Term pfTerm)
|
||||
{
|
||||
@ -386,23 +393,24 @@ readParfactor (YAP_Term pfTerm)
|
||||
// read parametric random vars
|
||||
ProbFormulas formulas;
|
||||
unsigned count = 0;
|
||||
unordered_map<YAP_Term, LogVar> lvMap;
|
||||
std::unordered_map<YAP_Term, LogVar> lvMap;
|
||||
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
|
||||
while (pvList != YAP_TermNil()) {
|
||||
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
||||
if (YAP_IsAtomTerm (formulaTerm)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
||||
std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
formulas.push_back (ProbFormula (functor, ranges[count]));
|
||||
} else {
|
||||
LogVars logVars;
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
|
||||
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
|
||||
std::string name ((char*) YAP_AtomName (
|
||||
YAP_NameOfFunctor (yapFunctor)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
|
||||
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
||||
std::unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
||||
if (it != lvMap.end()) {
|
||||
logVars.push_back (it->second);
|
||||
} else {
|
||||
@ -418,7 +426,7 @@ readParfactor (YAP_Term pfTerm)
|
||||
}
|
||||
|
||||
// read the parameters
|
||||
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
|
||||
Params params = readParameters (YAP_ArgOfTerm (4, pfTerm));
|
||||
|
||||
// read the constraint
|
||||
Tuples tuples;
|
||||
@ -434,10 +442,11 @@ readParfactor (YAP_Term pfTerm)
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
||||
if (YAP_IsAtomTerm (ti) == false) {
|
||||
cerr << "Error: the constraint contains free variables." << endl;
|
||||
std::cerr << "Error: the constraint contains free variables." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
tuple[i - 1] = LiftedUtils::getSymbol (name);
|
||||
}
|
||||
tuples.push_back (tuple);
|
||||
@ -449,55 +458,56 @@ readParfactor (YAP_Term pfTerm)
|
||||
|
||||
|
||||
|
||||
void
|
||||
readLiftedEvidence (
|
||||
YAP_Term observedList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
ObservedFormulas*
|
||||
readLiftedEvidence (YAP_Term observedList)
|
||||
{
|
||||
ObservedFormulas* obsFormulas = new ObservedFormulas();
|
||||
while (observedList != YAP_TermNil()) {
|
||||
YAP_Term pair = YAP_HeadOfTerm (observedList);
|
||||
YAP_Term ground = YAP_ArgOfTerm (1, pair);
|
||||
Symbol functor;
|
||||
Symbols args;
|
||||
if (YAP_IsAtomTerm (ground)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||
std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||
functor = LiftedUtils::getSymbol (name);
|
||||
} else {
|
||||
assert (YAP_IsApplTerm (ground));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
|
||||
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
|
||||
std::string name ((char*) (YAP_AtomName (
|
||||
YAP_NameOfFunctor (yapFunctor))));
|
||||
functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, ground);
|
||||
assert (YAP_IsAtomTerm (ti));
|
||||
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
std::string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
args.push_back (LiftedUtils::getSymbol (arg));
|
||||
}
|
||||
}
|
||||
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
|
||||
bool found = false;
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
if (obsFormulas[i].functor() == functor &&
|
||||
obsFormulas[i].arity() == args.size() &&
|
||||
obsFormulas[i].evidence() == evidence) {
|
||||
obsFormulas[i].addTuple (args);
|
||||
for (size_t i = 0; i < obsFormulas->size(); i++) {
|
||||
if ((*obsFormulas)[i].functor() == functor &&
|
||||
(*obsFormulas)[i].arity() == args.size() &&
|
||||
(*obsFormulas)[i].evidence() == evidence) {
|
||||
(*obsFormulas)[i].addTuple (args);
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
|
||||
obsFormulas->push_back (ObservedFormula (functor, evidence, args));
|
||||
}
|
||||
observedList = YAP_TailOfTerm (observedList);
|
||||
}
|
||||
return obsFormulas;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<unsigned>
|
||||
std::vector<unsigned>
|
||||
readUnsignedList (YAP_Term list)
|
||||
{
|
||||
vector<unsigned> vec;
|
||||
std::vector<unsigned> vec;
|
||||
while (list != YAP_TermNil()) {
|
||||
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
|
||||
list = YAP_TailOfTerm (list);
|
||||
@ -514,10 +524,11 @@ readParameters (YAP_Term paramL)
|
||||
assert (YAP_IsPairTerm (paramL));
|
||||
while (paramL != YAP_TermNil()) {
|
||||
YAP_Term hd = YAP_HeadOfTerm (paramL);
|
||||
if (YAP_IsFloatTerm(hd))
|
||||
if (YAP_IsFloatTerm (hd)) {
|
||||
params.push_back ((double) YAP_FloatOfTerm (hd));
|
||||
else
|
||||
} else {
|
||||
params.push_back ((double) YAP_IntOfTerm (hd));
|
||||
}
|
||||
paramL = YAP_TailOfTerm (paramL);
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
@ -529,17 +540,17 @@ readParameters (YAP_Term paramL)
|
||||
|
||||
|
||||
YAP_Term
|
||||
fillAnswersPrologList (vector<Params>& results)
|
||||
fillSolutionList (const std::vector<Params>& results)
|
||||
{
|
||||
YAP_Term list = YAP_TermNil();
|
||||
for (size_t i = results.size(); i-- > 0; ) {
|
||||
const Params& beliefs = results[i];
|
||||
YAP_Term queryBeliefsL = YAP_TermNil();
|
||||
for (size_t j = beliefs.size(); j-- > 0; ) {
|
||||
YAP_Int sl1 = YAP_InitSlot (list);
|
||||
YAP_Int sl = YAP_InitSlot (list);
|
||||
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
||||
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
||||
list = YAP_GetFromSlot (sl1);
|
||||
list = YAP_GetFromSlot (sl);
|
||||
YAP_RecoverSlots (1);
|
||||
}
|
||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||
@ -547,10 +558,12 @@ fillAnswersPrologList (vector<Params>& results)
|
||||
return list;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
extern "C" void
|
||||
init_predicates (void)
|
||||
init_predicates()
|
||||
{
|
||||
YAP_UserCPredicate ("cpp_create_lifted_network",
|
||||
createLiftedNetwork, 3);
|
||||
@ -583,3 +596,5 @@ init_predicates (void)
|
||||
freeGroundNetwork, 1);
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
32
packages/CLPBN/horus/Indexer.cpp
Normal file
32
packages/CLPBN/horus/Indexer.cpp
Normal file
@ -0,0 +1,32 @@
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const Indexer& indexer)
|
||||
{
|
||||
os << "(" ;
|
||||
os << std::setw (2) << std::setfill('0') << indexer.index_;
|
||||
os << ") " ;
|
||||
os << indexer.indices_;
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::ostream&
|
||||
operator<< (std::ostream &os, const MapIndexer& indexer)
|
||||
{
|
||||
os << "(" ;
|
||||
os << std::setw (2) << std::setfill('0') << indexer.index_;
|
||||
os << ") " ;
|
||||
os << indexer.indices_;
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
@ -1,262 +1,343 @@
|
||||
#ifndef HORUS_INDEXER_H
|
||||
#define HORUS_INDEXER_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_INDEXER_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_INDEXER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
class Indexer
|
||||
{
|
||||
namespace Horus {
|
||||
|
||||
class Indexer {
|
||||
public:
|
||||
Indexer (const Ranges& ranges, bool calcOffsets = true)
|
||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||
size_(Util::sizeExpected (ranges))
|
||||
{
|
||||
if (calcOffsets) {
|
||||
calculateOffsets();
|
||||
}
|
||||
}
|
||||
Indexer (const Ranges& ranges, bool calcOffsets = true);
|
||||
|
||||
void increment (void)
|
||||
{
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
indices_[i] ++;
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
break;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
}
|
||||
}
|
||||
index_ ++;
|
||||
}
|
||||
void increment();
|
||||
|
||||
void incrementDimension (size_t dim)
|
||||
{
|
||||
assert (dim < ranges_.size());
|
||||
assert (ranges_.size() == offsets_.size());
|
||||
assert (indices_[dim] < ranges_[dim]);
|
||||
indices_[dim] ++;
|
||||
index_ += offsets_[dim];
|
||||
}
|
||||
void incrementDimension (size_t dim);
|
||||
|
||||
void incrementExceptDimension (size_t dim)
|
||||
{
|
||||
assert (ranges_.size() == offsets_.size());
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
if (i != dim) {
|
||||
indices_[i] ++;
|
||||
index_ += offsets_[i];
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
return;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
index_ -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
index_ = size_;
|
||||
}
|
||||
void incrementExceptDimension (size_t dim);
|
||||
|
||||
Indexer& operator++ (void)
|
||||
{
|
||||
increment();
|
||||
return *this;
|
||||
}
|
||||
Indexer& operator++();
|
||||
|
||||
operator size_t (void) const
|
||||
{
|
||||
return index_;
|
||||
}
|
||||
operator size_t() const;
|
||||
|
||||
unsigned operator[] (size_t dim) const
|
||||
{
|
||||
assert (valid());
|
||||
assert (dim < ranges_.size());
|
||||
return indices_[dim];
|
||||
}
|
||||
unsigned operator[] (size_t dim) const;
|
||||
|
||||
bool valid (void) const
|
||||
{
|
||||
return index_ < size_;
|
||||
}
|
||||
bool valid() const;
|
||||
|
||||
void reset (void)
|
||||
{
|
||||
std::fill (indices_.begin(), indices_.end(), 0);
|
||||
index_ = 0;
|
||||
}
|
||||
void reset();
|
||||
|
||||
void resetDimension (size_t dim)
|
||||
{
|
||||
indices_[dim] = 0;
|
||||
index_ -= offsets_[dim] * ranges_[dim];
|
||||
}
|
||||
void resetDimension (size_t dim);
|
||||
|
||||
size_t size (void) const
|
||||
{
|
||||
return size_ ;
|
||||
}
|
||||
size_t size() const;
|
||||
|
||||
private:
|
||||
void calculateOffsets();
|
||||
|
||||
friend std::ostream& operator<< (std::ostream&, const Indexer&);
|
||||
|
||||
private:
|
||||
void calculateOffsets (void)
|
||||
{
|
||||
size_t prod = 1;
|
||||
offsets_.resize (ranges_.size());
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
}
|
||||
}
|
||||
|
||||
size_t index_;
|
||||
Ranges indices_;
|
||||
const Ranges& ranges_;
|
||||
size_t size_;
|
||||
vector<size_t> offsets_;
|
||||
size_t index_;
|
||||
Ranges indices_;
|
||||
const Ranges& ranges_;
|
||||
size_t size_;
|
||||
std::vector<size_t> offsets_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (Indexer);
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline std::ostream&
|
||||
operator<< (std::ostream& os, const Indexer& indexer)
|
||||
inline
|
||||
Indexer::Indexer (const Ranges& ranges, bool calcOffsets)
|
||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||
size_(Util::sizeExpected (ranges))
|
||||
{
|
||||
os << "(" ;
|
||||
os << std::setw (2) << std::setfill('0') << indexer.index_;
|
||||
os << ") " ;
|
||||
os << indexer.indices_;
|
||||
return os;
|
||||
if (calcOffsets) {
|
||||
calculateOffsets();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
class MapIndexer
|
||||
inline void
|
||||
Indexer::increment()
|
||||
{
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
indices_[i] ++;
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
break;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
}
|
||||
}
|
||||
index_ ++;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Indexer::incrementDimension (size_t dim)
|
||||
{
|
||||
assert (dim < ranges_.size());
|
||||
assert (ranges_.size() == offsets_.size());
|
||||
assert (indices_[dim] < ranges_[dim]);
|
||||
indices_[dim] ++;
|
||||
index_ += offsets_[dim];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Indexer::incrementExceptDimension (size_t dim)
|
||||
{
|
||||
assert (ranges_.size() == offsets_.size());
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
if (i != dim) {
|
||||
indices_[i] ++;
|
||||
index_ += offsets_[i];
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
return;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
index_ -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
index_ = size_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline Indexer&
|
||||
Indexer::operator++()
|
||||
{
|
||||
increment();
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline
|
||||
Indexer::operator size_t() const
|
||||
{
|
||||
return index_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
Indexer::operator[] (size_t dim) const
|
||||
{
|
||||
assert (valid());
|
||||
assert (dim < ranges_.size());
|
||||
return indices_[dim];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
Indexer::valid() const
|
||||
{
|
||||
return index_ < size_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Indexer::reset()
|
||||
{
|
||||
index_ = 0;
|
||||
std::fill (indices_.begin(), indices_.end(), 0);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Indexer::resetDimension (size_t dim)
|
||||
{
|
||||
indices_[dim] = 0;
|
||||
index_ -= offsets_[dim] * ranges_[dim];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline size_t
|
||||
Indexer::size() const
|
||||
{
|
||||
return size_ ;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Indexer::calculateOffsets()
|
||||
{
|
||||
size_t prod = 1;
|
||||
offsets_.resize (ranges_.size());
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
class MapIndexer {
|
||||
public:
|
||||
MapIndexer (const Ranges& ranges, const vector<bool>& mask)
|
||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||
valid_(true)
|
||||
{
|
||||
size_t prod = 1;
|
||||
offsets_.resize (ranges.size(), 0);
|
||||
for (size_t i = ranges.size(); i-- > 0; ) {
|
||||
if (mask[i]) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges[i];
|
||||
}
|
||||
}
|
||||
assert (ranges.size() == mask.size());
|
||||
}
|
||||
MapIndexer (const Ranges& ranges, const std::vector<bool>& mask);
|
||||
|
||||
MapIndexer (const Ranges& ranges, size_t dim)
|
||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||
valid_(true)
|
||||
{
|
||||
size_t prod = 1;
|
||||
offsets_.resize (ranges.size(), 0);
|
||||
for (size_t i = ranges.size(); i-- > 0; ) {
|
||||
if (i != dim) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
MapIndexer (const Ranges& ranges, size_t dim);
|
||||
|
||||
template <typename T>
|
||||
MapIndexer (
|
||||
const vector<T>& allArgs,
|
||||
const Ranges& allRanges,
|
||||
const vector<T>& wantedArgs,
|
||||
const Ranges& wantedRanges)
|
||||
: index_(0), indices_(allArgs.size(), 0), ranges_(allRanges),
|
||||
valid_(true)
|
||||
{
|
||||
size_t prod = 1;
|
||||
vector<size_t> offsets (wantedRanges.size());
|
||||
for (size_t i = wantedRanges.size(); i-- > 0; ) {
|
||||
offsets[i] = prod;
|
||||
prod *= wantedRanges[i];
|
||||
}
|
||||
offsets_.reserve (allArgs.size());
|
||||
for (size_t i = 0; i < allArgs.size(); i++) {
|
||||
size_t idx = Util::indexOf (wantedArgs, allArgs[i]);
|
||||
offsets_.push_back (idx != wantedArgs.size() ? offsets[idx] : 0);
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
MapIndexer (
|
||||
const std::vector<T>& allArgs,
|
||||
const Ranges& allRanges,
|
||||
const std::vector<T>& wantedArgs,
|
||||
const Ranges& wantedRanges);
|
||||
|
||||
MapIndexer& operator++ (void)
|
||||
{
|
||||
assert (valid_);
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
indices_[i] ++;
|
||||
index_ += offsets_[i];
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
return *this;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
index_ -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
}
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
MapIndexer& operator++();
|
||||
|
||||
operator size_t (void) const
|
||||
{
|
||||
assert (valid());
|
||||
return index_;
|
||||
}
|
||||
operator size_t() const;
|
||||
|
||||
unsigned operator[] (size_t dim) const
|
||||
{
|
||||
assert (valid());
|
||||
assert (dim < ranges_.size());
|
||||
return indices_[dim];
|
||||
}
|
||||
unsigned operator[] (size_t dim) const;
|
||||
|
||||
bool valid (void) const
|
||||
{
|
||||
return valid_;
|
||||
}
|
||||
bool valid() const;
|
||||
|
||||
void reset (void)
|
||||
{
|
||||
std::fill (indices_.begin(), indices_.end(), 0);
|
||||
index_ = 0;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<< (std::ostream&, const MapIndexer&);
|
||||
void reset();
|
||||
|
||||
private:
|
||||
size_t index_;
|
||||
Ranges indices_;
|
||||
const Ranges& ranges_;
|
||||
bool valid_;
|
||||
vector<size_t> offsets_;
|
||||
friend std::ostream& operator<< (std::ostream&, const MapIndexer&);
|
||||
|
||||
size_t index_;
|
||||
Ranges indices_;
|
||||
const Ranges& ranges_;
|
||||
bool valid_;
|
||||
std::vector<size_t> offsets_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (MapIndexer);
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline std::ostream&
|
||||
operator<< (std::ostream &os, const MapIndexer& indexer)
|
||||
inline
|
||||
MapIndexer::MapIndexer (
|
||||
const Ranges& ranges,
|
||||
const std::vector<bool>& mask)
|
||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||
valid_(true)
|
||||
{
|
||||
os << "(" ;
|
||||
os << std::setw (2) << std::setfill('0') << indexer.index_;
|
||||
os << ") " ;
|
||||
os << indexer.indices_;
|
||||
return os;
|
||||
size_t prod = 1;
|
||||
offsets_.resize (ranges.size(), 0);
|
||||
for (size_t i = ranges.size(); i-- > 0; ) {
|
||||
if (mask[i]) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges[i];
|
||||
}
|
||||
}
|
||||
assert (ranges.size() == mask.size());
|
||||
}
|
||||
|
||||
|
||||
#endif // HORUS_INDEXER_H
|
||||
|
||||
inline
|
||||
MapIndexer::MapIndexer (const Ranges& ranges, size_t dim)
|
||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||
valid_(true)
|
||||
{
|
||||
size_t prod = 1;
|
||||
offsets_.resize (ranges.size(), 0);
|
||||
for (size_t i = ranges.size(); i-- > 0; ) {
|
||||
if (i != dim) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> inline
|
||||
MapIndexer::MapIndexer (
|
||||
const std::vector<T>& allArgs,
|
||||
const Ranges& allRanges,
|
||||
const std::vector<T>& wantedArgs,
|
||||
const Ranges& wantedRanges)
|
||||
: index_(0), indices_(allArgs.size(), 0), ranges_(allRanges),
|
||||
valid_(true)
|
||||
{
|
||||
size_t prod = 1;
|
||||
std::vector<size_t> offsets (wantedRanges.size());
|
||||
for (size_t i = wantedRanges.size(); i-- > 0; ) {
|
||||
offsets[i] = prod;
|
||||
prod *= wantedRanges[i];
|
||||
}
|
||||
offsets_.reserve (allArgs.size());
|
||||
for (size_t i = 0; i < allArgs.size(); i++) {
|
||||
size_t idx = Util::indexOf (wantedArgs, allArgs[i]);
|
||||
offsets_.push_back (idx != wantedArgs.size() ? offsets[idx] : 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline MapIndexer&
|
||||
MapIndexer::operator++()
|
||||
{
|
||||
assert (valid_);
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
indices_[i] ++;
|
||||
index_ += offsets_[i];
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
return *this;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
index_ -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
}
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline
|
||||
MapIndexer::operator size_t() const
|
||||
{
|
||||
assert (valid());
|
||||
return index_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
MapIndexer::operator[] (size_t dim) const
|
||||
{
|
||||
assert (valid());
|
||||
assert (dim < ranges_.size());
|
||||
return indices_[dim];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
MapIndexer::valid() const
|
||||
{
|
||||
return valid_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
MapIndexer::reset()
|
||||
{
|
||||
index_ = 0;
|
||||
std::fill (indices_.begin(), indices_.end(), 0);
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_INDEXER_H_
|
||||
|
||||
|
@ -1,9 +1,15 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "LiftedBp.h"
|
||||
#include "LiftedOperations.h"
|
||||
#include "WeightedBp.h"
|
||||
#include "FactorGraph.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
LiftedBp::LiftedBp (const ParfactorList& parfactorList)
|
||||
: LiftedSolver (parfactorList)
|
||||
{
|
||||
@ -14,7 +20,7 @@ LiftedBp::LiftedBp (const ParfactorList& parfactorList)
|
||||
|
||||
|
||||
|
||||
LiftedBp::~LiftedBp (void)
|
||||
LiftedBp::~LiftedBp()
|
||||
{
|
||||
delete solver_;
|
||||
delete fg_;
|
||||
@ -27,7 +33,7 @@ LiftedBp::solveQuery (const Grounds& query)
|
||||
{
|
||||
assert (query.empty() == false);
|
||||
Params res;
|
||||
vector<PrvGroup> groups = getQueryGroups (query);
|
||||
std::vector<PrvGroup> groups = getQueryGroups (query);
|
||||
if (query.size() == 1) {
|
||||
res = solver_->getPosterioriOf (groups[0]);
|
||||
} else {
|
||||
@ -58,28 +64,29 @@ LiftedBp::solveQuery (const Grounds& query)
|
||||
|
||||
|
||||
void
|
||||
LiftedBp::printSolverFlags (void) const
|
||||
LiftedBp::printSolverFlags() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "lifted bp [" ;
|
||||
ss << "bp_msg_schedule=" ;
|
||||
typedef WeightedBp::MsgSchedule MsgSchedule;
|
||||
switch (WeightedBp::msgSchedule()) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
|
||||
case MsgSchedule::parallelSch: ss << "parallel"; break;
|
||||
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
||||
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
std::cout << ss.str() << std::endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedBp::refineParfactors (void)
|
||||
LiftedBp::refineParfactors()
|
||||
{
|
||||
pfList_ = parfactorList;
|
||||
while (iterate() == false);
|
||||
@ -93,7 +100,7 @@ LiftedBp::refineParfactors (void)
|
||||
|
||||
|
||||
bool
|
||||
LiftedBp::iterate (void)
|
||||
LiftedBp::iterate()
|
||||
{
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
@ -114,10 +121,10 @@ LiftedBp::iterate (void)
|
||||
|
||||
|
||||
|
||||
vector<PrvGroup>
|
||||
std::vector<PrvGroup>
|
||||
LiftedBp::getQueryGroups (const Grounds& query)
|
||||
{
|
||||
vector<PrvGroup> queryGroups;
|
||||
std::vector<PrvGroup> queryGroups;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
for (; it != pfList_.end(); ++it) {
|
||||
@ -134,12 +141,12 @@ LiftedBp::getQueryGroups (const Grounds& query)
|
||||
|
||||
|
||||
void
|
||||
LiftedBp::createFactorGraph (void)
|
||||
LiftedBp::createFactorGraph()
|
||||
{
|
||||
fg_ = new FactorGraph();
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
for (; it != pfList_.end(); ++it) {
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
std::vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
VarIds varIds;
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
varIds.push_back (groups[i]);
|
||||
@ -150,10 +157,10 @@ LiftedBp::createFactorGraph (void)
|
||||
|
||||
|
||||
|
||||
vector<vector<unsigned>>
|
||||
LiftedBp::getWeights (void) const
|
||||
std::vector<std::vector<unsigned>>
|
||||
LiftedBp::getWeights() const
|
||||
{
|
||||
vector<vector<unsigned>> weights;
|
||||
std::vector<std::vector<unsigned>> weights;
|
||||
weights.reserve (pfList_.size());
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
for (; it != pfList_.end(); ++it) {
|
||||
@ -196,7 +203,7 @@ LiftedBp::getJointByConditioning (
|
||||
Grounds obsGrounds = {query[0]};
|
||||
for (size_t i = 1; i < query.size(); i++) {
|
||||
Params newBeliefs;
|
||||
vector<ObservedFormula> obsFs;
|
||||
std::vector<ObservedFormula> obsFs;
|
||||
Ranges obsRanges;
|
||||
for (size_t j = 0; j < obsGrounds.size(); j++) {
|
||||
obsFs.push_back (ObservedFormula (
|
||||
@ -231,3 +238,5 @@ LiftedBp::getJointByConditioning (
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,33 +1,38 @@
|
||||
#ifndef HORUS_LIFTEDBP_H
|
||||
#define HORUS_LIFTEDBP_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDBP_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDBP_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "LiftedSolver.h"
|
||||
#include "ParfactorList.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class FactorGraph;
|
||||
class WeightedBp;
|
||||
|
||||
class LiftedBp : public LiftedSolver
|
||||
{
|
||||
class LiftedBp : public LiftedSolver{
|
||||
public:
|
||||
LiftedBp (const ParfactorList& pfList);
|
||||
LiftedBp (const ParfactorList& pfList);
|
||||
|
||||
~LiftedBp (void);
|
||||
~LiftedBp();
|
||||
|
||||
Params solveQuery (const Grounds&);
|
||||
Params solveQuery (const Grounds&);
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
void printSolverFlags() const;
|
||||
|
||||
private:
|
||||
void refineParfactors (void);
|
||||
void refineParfactors();
|
||||
|
||||
bool iterate (void);
|
||||
bool iterate();
|
||||
|
||||
vector<PrvGroup> getQueryGroups (const Grounds&);
|
||||
std::vector<PrvGroup> getQueryGroups (const Grounds&);
|
||||
|
||||
void createFactorGraph (void);
|
||||
void createFactorGraph();
|
||||
|
||||
vector<vector<unsigned>> getWeights (void) const;
|
||||
std::vector<std::vector<unsigned>> getWeights() const;
|
||||
|
||||
unsigned rangeOfGround (const Ground&);
|
||||
|
||||
@ -38,8 +43,9 @@ class LiftedBp : public LiftedSolver
|
||||
FactorGraph* fg_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedBp);
|
||||
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDBP_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDBP_H_
|
||||
|
||||
|
@ -1,11 +1,283 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "LiftedKc.h"
|
||||
#include "LiftedWCNF.h"
|
||||
#include "LiftedOperations.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
OrNode::~OrNode (void)
|
||||
namespace Horus {
|
||||
|
||||
enum class CircuitNodeType {
|
||||
orCnt,
|
||||
andCnt,
|
||||
setOrCnt,
|
||||
setAndCnt,
|
||||
incExcCnt,
|
||||
leafCnt,
|
||||
smoothCnt,
|
||||
trueCnt,
|
||||
compilationFailedCnt
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CircuitNode {
|
||||
public:
|
||||
CircuitNode() { }
|
||||
|
||||
virtual ~CircuitNode() { }
|
||||
|
||||
virtual double weight() const = 0;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class OrNode : public CircuitNode {
|
||||
public:
|
||||
OrNode() : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
|
||||
|
||||
~OrNode();
|
||||
|
||||
CircuitNode** leftBranch () { return &leftBranch_; }
|
||||
CircuitNode** rightBranch() { return &rightBranch_; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
CircuitNode* leftBranch_;
|
||||
CircuitNode* rightBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class AndNode : public CircuitNode {
|
||||
public:
|
||||
AndNode() : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
|
||||
|
||||
AndNode (CircuitNode* leftBranch, CircuitNode* rightBranch)
|
||||
: CircuitNode(), leftBranch_(leftBranch),
|
||||
rightBranch_(rightBranch) { }
|
||||
|
||||
~AndNode();
|
||||
|
||||
CircuitNode** leftBranch () { return &leftBranch_; }
|
||||
CircuitNode** rightBranch() { return &rightBranch_; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
CircuitNode* leftBranch_;
|
||||
CircuitNode* rightBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SetOrNode : public CircuitNode {
|
||||
public:
|
||||
SetOrNode (unsigned nrGroundings)
|
||||
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
|
||||
|
||||
~SetOrNode();
|
||||
|
||||
CircuitNode** follow() { return &follow_; }
|
||||
|
||||
static unsigned nrPositives() { return nrPos_; }
|
||||
|
||||
static unsigned nrNegatives() { return nrNeg_; }
|
||||
|
||||
static bool isSet() { return nrPos_ >= 0; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
CircuitNode* follow_;
|
||||
unsigned nrGroundings_;
|
||||
static int nrPos_;
|
||||
static int nrNeg_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SetAndNode : public CircuitNode {
|
||||
public:
|
||||
SetAndNode (unsigned nrGroundings)
|
||||
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
|
||||
|
||||
~SetAndNode();
|
||||
|
||||
CircuitNode** follow() { return &follow_; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
CircuitNode* follow_;
|
||||
unsigned nrGroundings_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class IncExcNode : public CircuitNode {
|
||||
public:
|
||||
IncExcNode()
|
||||
: CircuitNode(), plus1Branch_(0), plus2Branch_(0), minusBranch_(0) { }
|
||||
|
||||
~IncExcNode();
|
||||
|
||||
CircuitNode** plus1Branch() { return &plus1Branch_; }
|
||||
CircuitNode** plus2Branch() { return &plus2Branch_; }
|
||||
CircuitNode** minusBranch() { return &minusBranch_; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
CircuitNode* plus1Branch_;
|
||||
CircuitNode* plus2Branch_;
|
||||
CircuitNode* minusBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LeafNode : public CircuitNode {
|
||||
public:
|
||||
LeafNode (Clause* clause, const LiftedWCNF& lwcnf)
|
||||
: CircuitNode(), clause_(clause), lwcnf_(lwcnf) { }
|
||||
|
||||
~LeafNode();
|
||||
|
||||
const Clause* clause() const { return clause_; }
|
||||
|
||||
Clause* clause() { return clause_; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
Clause* clause_;
|
||||
const LiftedWCNF& lwcnf_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SmoothNode : public CircuitNode {
|
||||
public:
|
||||
SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf)
|
||||
: CircuitNode(), clauses_(clauses), lwcnf_(lwcnf) { }
|
||||
|
||||
~SmoothNode();
|
||||
|
||||
const Clauses& clauses() const { return clauses_; }
|
||||
|
||||
Clauses clauses() { return clauses_; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
Clauses clauses_;
|
||||
const LiftedWCNF& lwcnf_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class TrueNode : public CircuitNode {
|
||||
public:
|
||||
TrueNode() : CircuitNode() { }
|
||||
|
||||
double weight() const;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CompilationFailedNode : public CircuitNode {
|
||||
public:
|
||||
CompilationFailedNode() : CircuitNode() { }
|
||||
|
||||
double weight() const;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LiftedCircuit {
|
||||
public:
|
||||
LiftedCircuit (const LiftedWCNF* lwcnf);
|
||||
|
||||
~LiftedCircuit();
|
||||
|
||||
bool isCompilationSucceeded() const;
|
||||
|
||||
double getWeightedModelCount() const;
|
||||
|
||||
void exportToGraphViz (const char*);
|
||||
|
||||
private:
|
||||
void compile (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
|
||||
LogVars& rootLogVars);
|
||||
|
||||
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
void shatterCountedLogVars (Clauses& clauses);
|
||||
|
||||
bool shatterCountedLogVarsAux (Clauses& clauses);
|
||||
|
||||
bool shatterCountedLogVarsAux (Clauses& clauses,
|
||||
size_t idx1, size_t idx2);
|
||||
|
||||
bool independentClause (Clause& clause, Clauses& otherClauses) const;
|
||||
|
||||
bool independentLiteral (const Literal& lit,
|
||||
const Literals& otherLits) const;
|
||||
|
||||
LitLvTypesSet smoothCircuit (CircuitNode* node);
|
||||
|
||||
void createSmoothNode (const LitLvTypesSet& lids,
|
||||
CircuitNode** prev);
|
||||
|
||||
std::vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
|
||||
|
||||
bool containsTypes (const LogVarTypes& typesA,
|
||||
const LogVarTypes& typesB) const;
|
||||
|
||||
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
|
||||
|
||||
void exportToGraphViz (CircuitNode* node, std::ofstream&);
|
||||
|
||||
void printClauses (CircuitNode* node, std::ofstream&,
|
||||
std::string extraOptions = "");
|
||||
|
||||
std::string escapeNode (const CircuitNode* node) const;
|
||||
|
||||
std::string getExplanationString (CircuitNode* node);
|
||||
|
||||
CircuitNode* root_;
|
||||
const LiftedWCNF* lwcnf_;
|
||||
bool compilationSucceeded_;
|
||||
Clauses backupClauses_;
|
||||
std::unordered_map<CircuitNode*, Clauses> originClausesMap_;
|
||||
std::unordered_map<CircuitNode*, std::string> explanationMap_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedCircuit);
|
||||
};
|
||||
|
||||
|
||||
|
||||
OrNode::~OrNode()
|
||||
{
|
||||
delete leftBranch_;
|
||||
delete rightBranch_;
|
||||
@ -14,7 +286,7 @@ OrNode::~OrNode (void)
|
||||
|
||||
|
||||
double
|
||||
OrNode::weight (void) const
|
||||
OrNode::weight() const
|
||||
{
|
||||
double lw = leftBranch_->weight();
|
||||
double rw = rightBranch_->weight();
|
||||
@ -23,7 +295,7 @@ OrNode::weight (void) const
|
||||
|
||||
|
||||
|
||||
AndNode::~AndNode (void)
|
||||
AndNode::~AndNode()
|
||||
{
|
||||
delete leftBranch_;
|
||||
delete rightBranch_;
|
||||
@ -32,7 +304,7 @@ AndNode::~AndNode (void)
|
||||
|
||||
|
||||
double
|
||||
AndNode::weight (void) const
|
||||
AndNode::weight() const
|
||||
{
|
||||
double lw = leftBranch_->weight();
|
||||
double rw = rightBranch_->weight();
|
||||
@ -46,7 +318,7 @@ int SetOrNode::nrNeg_ = -1;
|
||||
|
||||
|
||||
|
||||
SetOrNode::~SetOrNode (void)
|
||||
SetOrNode::~SetOrNode()
|
||||
{
|
||||
delete follow_;
|
||||
}
|
||||
@ -54,7 +326,7 @@ SetOrNode::~SetOrNode (void)
|
||||
|
||||
|
||||
double
|
||||
SetOrNode::weight (void) const
|
||||
SetOrNode::weight() const
|
||||
{
|
||||
double weightSum = LogAware::addIdenty();
|
||||
for (unsigned i = 0; i < nrGroundings_ + 1; i++) {
|
||||
@ -76,7 +348,7 @@ SetOrNode::weight (void) const
|
||||
|
||||
|
||||
|
||||
SetAndNode::~SetAndNode (void)
|
||||
SetAndNode::~SetAndNode()
|
||||
{
|
||||
delete follow_;
|
||||
}
|
||||
@ -84,14 +356,14 @@ SetAndNode::~SetAndNode (void)
|
||||
|
||||
|
||||
double
|
||||
SetAndNode::weight (void) const
|
||||
SetAndNode::weight() const
|
||||
{
|
||||
return LogAware::pow (follow_->weight(), nrGroundings_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
IncExcNode::~IncExcNode (void)
|
||||
IncExcNode::~IncExcNode()
|
||||
{
|
||||
delete plus1Branch_;
|
||||
delete plus2Branch_;
|
||||
@ -101,7 +373,7 @@ IncExcNode::~IncExcNode (void)
|
||||
|
||||
|
||||
double
|
||||
IncExcNode::weight (void) const
|
||||
IncExcNode::weight() const
|
||||
{
|
||||
double w = 0.0;
|
||||
if (Globals::logDomain) {
|
||||
@ -116,7 +388,7 @@ IncExcNode::weight (void) const
|
||||
|
||||
|
||||
|
||||
LeafNode::~LeafNode (void)
|
||||
LeafNode::~LeafNode()
|
||||
{
|
||||
delete clause_;
|
||||
}
|
||||
@ -124,7 +396,7 @@ LeafNode::~LeafNode (void)
|
||||
|
||||
|
||||
double
|
||||
LeafNode::weight (void) const
|
||||
LeafNode::weight() const
|
||||
{
|
||||
assert (clause_->isUnit());
|
||||
if (clause_->posCountedLogVars().empty() == false
|
||||
@ -161,7 +433,7 @@ LeafNode::weight (void) const
|
||||
|
||||
|
||||
|
||||
SmoothNode::~SmoothNode (void)
|
||||
SmoothNode::~SmoothNode()
|
||||
{
|
||||
Clause::deleteClauses (clauses_);
|
||||
}
|
||||
@ -169,7 +441,7 @@ SmoothNode::~SmoothNode (void)
|
||||
|
||||
|
||||
double
|
||||
SmoothNode::weight (void) const
|
||||
SmoothNode::weight() const
|
||||
{
|
||||
Clauses cs = clauses();
|
||||
double totalWeight = LogAware::multIdenty();
|
||||
@ -204,7 +476,7 @@ SmoothNode::weight (void) const
|
||||
|
||||
|
||||
double
|
||||
TrueNode::weight (void) const
|
||||
TrueNode::weight() const
|
||||
{
|
||||
return LogAware::multIdenty();
|
||||
}
|
||||
@ -212,7 +484,7 @@ TrueNode::weight (void) const
|
||||
|
||||
|
||||
double
|
||||
CompilationFailedNode::weight (void) const
|
||||
CompilationFailedNode::weight() const
|
||||
{
|
||||
// weighted model counting in compilation
|
||||
// failed nodes should give NaN
|
||||
@ -234,21 +506,22 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
||||
if (Globals::verbosity > 1) {
|
||||
if (compilationSucceeded_) {
|
||||
double wmc = LogAware::exp (getWeightedModelCount());
|
||||
cout << "Weighted model count = " << wmc << endl << endl;
|
||||
std::cout << "Weighted model count = " << wmc;
|
||||
std::cout << std::endl << std::endl;
|
||||
}
|
||||
cout << "Exporting circuit to graphviz (circuit.dot)..." ;
|
||||
cout << endl << endl;
|
||||
std::cout << "Exporting circuit to graphviz (circuit.dot)..." ;
|
||||
std::cout << std::endl << std::endl;
|
||||
exportToGraphViz ("circuit.dot");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedCircuit::~LiftedCircuit (void)
|
||||
LiftedCircuit::~LiftedCircuit()
|
||||
{
|
||||
delete root_;
|
||||
unordered_map<CircuitNode*, Clauses>::iterator it;
|
||||
it = originClausesMap_.begin();
|
||||
std::unordered_map<CircuitNode*, Clauses>::iterator it
|
||||
= originClausesMap_.begin();
|
||||
while (it != originClausesMap_.end()) {
|
||||
Clause::deleteClauses (it->second);
|
||||
++ it;
|
||||
@ -258,7 +531,7 @@ LiftedCircuit::~LiftedCircuit (void)
|
||||
|
||||
|
||||
bool
|
||||
LiftedCircuit::isCompilationSucceeded (void) const
|
||||
LiftedCircuit::isCompilationSucceeded() const
|
||||
{
|
||||
return compilationSucceeded_;
|
||||
}
|
||||
@ -266,7 +539,7 @@ LiftedCircuit::isCompilationSucceeded (void) const
|
||||
|
||||
|
||||
double
|
||||
LiftedCircuit::getWeightedModelCount (void) const
|
||||
LiftedCircuit::getWeightedModelCount() const
|
||||
{
|
||||
assert (compilationSucceeded_);
|
||||
return root_->weight();
|
||||
@ -277,15 +550,16 @@ LiftedCircuit::getWeightedModelCount (void) const
|
||||
void
|
||||
LiftedCircuit::exportToGraphViz (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << "digraph {" << endl;
|
||||
out << "ranksep=1" << endl;
|
||||
out << "digraph {" << std::endl;
|
||||
out << "ranksep=1" << std::endl;
|
||||
exportToGraphViz (root_, out);
|
||||
out << "}" << endl;
|
||||
out << "}" << std::endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
@ -389,7 +663,7 @@ LiftedCircuit::tryUnitPropagation (
|
||||
AndNode* andNode = new AndNode();
|
||||
if (Globals::verbosity > 1) {
|
||||
originClausesMap_[andNode] = backupClauses_;
|
||||
stringstream explanation;
|
||||
std::stringstream explanation;
|
||||
explanation << " UP on " << clauses[i]->literals()[0];
|
||||
explanationMap_[andNode] = explanation.str();
|
||||
}
|
||||
@ -478,7 +752,7 @@ LiftedCircuit::tryShannonDecomp (
|
||||
OrNode* orNode = new OrNode();
|
||||
if (Globals::verbosity > 1) {
|
||||
originClausesMap_[orNode] = backupClauses_;
|
||||
stringstream explanation;
|
||||
std::stringstream explanation;
|
||||
explanation << " SD on " << literals[j];
|
||||
explanationMap_[orNode] = explanation.str();
|
||||
}
|
||||
@ -558,7 +832,7 @@ LiftedCircuit::tryInclusionExclusion (
|
||||
IncExcNode* ieNode = new IncExcNode();
|
||||
if (Globals::verbosity > 1) {
|
||||
originClausesMap_[ieNode] = backupClauses_;
|
||||
stringstream explanation;
|
||||
std::stringstream explanation;
|
||||
explanation << " IncExc on clause nº " << i + 1;
|
||||
explanationMap_[ieNode] = explanation.str();
|
||||
}
|
||||
@ -635,13 +909,13 @@ LiftedCircuit::tryIndepPartialGroundingAux (
|
||||
}
|
||||
}
|
||||
// verifies if the IPG logical vars appear in the same positions
|
||||
unordered_map<LiteralId, size_t> positions;
|
||||
std::unordered_map<LiteralId, size_t> positions;
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
const Literals& literals = clauses[i]->literals();
|
||||
for (size_t j = 0; j < literals.size(); j++) {
|
||||
size_t idx = literals[j].indexOfLogVar (rootLogVars[i]);
|
||||
assert (idx != literals[j].nrLogVars());
|
||||
unordered_map<LiteralId, size_t>::iterator it;
|
||||
std::unordered_map<LiteralId, size_t>::iterator it;
|
||||
it = positions.find (literals[j].lid());
|
||||
if (it != positions.end()) {
|
||||
if (it->second != idx) {
|
||||
@ -810,7 +1084,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
|
||||
switch (getCircuitNodeType (node)) {
|
||||
|
||||
case CircuitNodeType::OR_NODE: {
|
||||
case CircuitNodeType::orCnt: {
|
||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
|
||||
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
|
||||
@ -823,7 +1097,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::AND_NODE: {
|
||||
case CircuitNodeType::andCnt: {
|
||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
|
||||
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
|
||||
@ -832,17 +1106,18 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::SET_OR_NODE: {
|
||||
case CircuitNodeType::setOrCnt: {
|
||||
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
|
||||
propagLits = smoothCircuit (*casted->follow());
|
||||
TinySet<pair<LiteralId,unsigned>> litSet;
|
||||
TinySet<std::pair<LiteralId,unsigned>> litSet;
|
||||
for (size_t i = 0; i < propagLits.size(); i++) {
|
||||
litSet.insert (make_pair (propagLits[i].lid(),
|
||||
litSet.insert (std::make_pair (propagLits[i].lid(),
|
||||
propagLits[i].logVarTypes().size()));
|
||||
}
|
||||
LitLvTypesSet missingLids;
|
||||
for (size_t i = 0; i < litSet.size(); i++) {
|
||||
vector<LogVarTypes> allTypes = getAllPossibleTypes (litSet[i].second);
|
||||
std::vector<LogVarTypes> allTypes
|
||||
= getAllPossibleTypes (litSet[i].second);
|
||||
for (size_t j = 0; j < allTypes.size(); j++) {
|
||||
bool typeFound = false;
|
||||
for (size_t k = 0; k < propagLits.size(); k++) {
|
||||
@ -869,13 +1144,13 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::SET_AND_NODE: {
|
||||
case CircuitNodeType::setAndCnt: {
|
||||
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
|
||||
propagLits = smoothCircuit (*casted->follow());
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::INC_EXC_NODE: {
|
||||
case CircuitNodeType::incExcCnt: {
|
||||
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
|
||||
LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch());
|
||||
LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch());
|
||||
@ -888,7 +1163,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::LEAF_NODE: {
|
||||
case CircuitNodeType::leafCnt: {
|
||||
LeafNode* casted = dynamic_cast<LeafNode*>(node);
|
||||
propagLits.insert (LitLvTypes (
|
||||
casted->clause()->literals()[0].lid(),
|
||||
@ -911,8 +1186,8 @@ LiftedCircuit::createSmoothNode (
|
||||
{
|
||||
if (missingLits.empty() == false) {
|
||||
if (Globals::verbosity > 1) {
|
||||
unordered_map<CircuitNode*, Clauses>::iterator it;
|
||||
it = originClausesMap_.find (*prev);
|
||||
std::unordered_map<CircuitNode*, Clauses>::iterator it
|
||||
= originClausesMap_.find (*prev);
|
||||
if (it != originClausesMap_.end()) {
|
||||
backupClauses_ = it->second;
|
||||
} else {
|
||||
@ -927,9 +1202,9 @@ LiftedCircuit::createSmoothNode (
|
||||
Clause* c = lwcnf_->createClause (lid);
|
||||
for (size_t j = 0; j < types.size(); j++) {
|
||||
LogVar X = c->literals().front().logVars()[j];
|
||||
if (types[j] == LogVarType::POS_LV) {
|
||||
if (types[j] == LogVarType::posLvt) {
|
||||
c->addPosCountedLogVar (X);
|
||||
} else if (types[j] == LogVarType::NEG_LV) {
|
||||
} else if (types[j] == LogVarType::negLvt) {
|
||||
c->addNegCountedLogVar (X);
|
||||
}
|
||||
}
|
||||
@ -947,15 +1222,15 @@ LiftedCircuit::createSmoothNode (
|
||||
|
||||
|
||||
|
||||
vector<LogVarTypes>
|
||||
std::vector<LogVarTypes>
|
||||
LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
|
||||
{
|
||||
vector<LogVarTypes> res;
|
||||
std::vector<LogVarTypes> res;
|
||||
if (nrLogVars == 0) {
|
||||
// do nothing
|
||||
} else if (nrLogVars == 1) {
|
||||
res.push_back ({ LogVarType::POS_LV });
|
||||
res.push_back ({ LogVarType::NEG_LV });
|
||||
res.push_back ({ LogVarType::posLvt });
|
||||
res.push_back ({ LogVarType::negLvt });
|
||||
} else {
|
||||
Ranges ranges (nrLogVars, 2);
|
||||
Indexer indexer (ranges);
|
||||
@ -963,9 +1238,9 @@ LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
|
||||
LogVarTypes types;
|
||||
for (size_t i = 0; i < nrLogVars; i++) {
|
||||
if (indexer[i] == 0) {
|
||||
types.push_back (LogVarType::POS_LV);
|
||||
types.push_back (LogVarType::posLvt);
|
||||
} else {
|
||||
types.push_back (LogVarType::NEG_LV);
|
||||
types.push_back (LogVarType::negLvt);
|
||||
}
|
||||
}
|
||||
res.push_back (types);
|
||||
@ -983,13 +1258,13 @@ LiftedCircuit::containsTypes (
|
||||
const LogVarTypes& typesB) const
|
||||
{
|
||||
for (size_t i = 0; i < typesA.size(); i++) {
|
||||
if (typesA[i] == LogVarType::FULL_LV) {
|
||||
if (typesA[i] == LogVarType::fullLvt) {
|
||||
|
||||
} else if (typesA[i] == LogVarType::POS_LV
|
||||
&& typesB[i] == LogVarType::POS_LV) {
|
||||
} else if (typesA[i] == LogVarType::posLvt
|
||||
&& typesB[i] == LogVarType::posLvt) {
|
||||
|
||||
} else if (typesA[i] == LogVarType::NEG_LV
|
||||
&& typesB[i] == LogVarType::NEG_LV) {
|
||||
} else if (typesA[i] == LogVarType::negLvt
|
||||
&& typesB[i] == LogVarType::negLvt) {
|
||||
|
||||
} else {
|
||||
return false;
|
||||
@ -1003,25 +1278,25 @@ LiftedCircuit::containsTypes (
|
||||
CircuitNodeType
|
||||
LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const
|
||||
{
|
||||
CircuitNodeType type = CircuitNodeType::OR_NODE;
|
||||
CircuitNodeType type = CircuitNodeType::orCnt;
|
||||
if (dynamic_cast<const OrNode*>(node)) {
|
||||
type = CircuitNodeType::OR_NODE;
|
||||
type = CircuitNodeType::orCnt;
|
||||
} else if (dynamic_cast<const AndNode*>(node)) {
|
||||
type = CircuitNodeType::AND_NODE;
|
||||
type = CircuitNodeType::andCnt;
|
||||
} else if (dynamic_cast<const SetOrNode*>(node)) {
|
||||
type = CircuitNodeType::SET_OR_NODE;
|
||||
type = CircuitNodeType::setOrCnt;
|
||||
} else if (dynamic_cast<const SetAndNode*>(node)) {
|
||||
type = CircuitNodeType::SET_AND_NODE;
|
||||
type = CircuitNodeType::setAndCnt;
|
||||
} else if (dynamic_cast<const IncExcNode*>(node)) {
|
||||
type = CircuitNodeType::INC_EXC_NODE;
|
||||
type = CircuitNodeType::incExcCnt;
|
||||
} else if (dynamic_cast<const LeafNode*>(node)) {
|
||||
type = CircuitNodeType::LEAF_NODE;
|
||||
type = CircuitNodeType::leafCnt;
|
||||
} else if (dynamic_cast<const SmoothNode*>(node)) {
|
||||
type = CircuitNodeType::SMOOTH_NODE;
|
||||
type = CircuitNodeType::smoothCnt;
|
||||
} else if (dynamic_cast<const TrueNode*>(node)) {
|
||||
type = CircuitNodeType::TRUE_NODE;
|
||||
type = CircuitNodeType::trueCnt;
|
||||
} else if (dynamic_cast<const CompilationFailedNode*>(node)) {
|
||||
type = CircuitNodeType::COMPILATION_FAILED_NODE;
|
||||
type = CircuitNodeType::compilationFailedCnt;
|
||||
} else {
|
||||
assert (false);
|
||||
}
|
||||
@ -1031,127 +1306,131 @@ LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const
|
||||
|
||||
|
||||
void
|
||||
LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
||||
LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
||||
{
|
||||
assert (node);
|
||||
|
||||
static unsigned nrAuxNodes = 0;
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "n" << nrAuxNodes;
|
||||
string auxNode = ss.str();
|
||||
std::string auxNode = ss.str();
|
||||
nrAuxNodes ++;
|
||||
string opStyle = "shape=circle,width=0.7,margin=\"0.0,0.0\"," ;
|
||||
std::string opStyle = "shape=circle,width=0.7,margin=\"0.0,0.0\"," ;
|
||||
|
||||
switch (getCircuitNodeType (node)) {
|
||||
|
||||
case OR_NODE: {
|
||||
case CircuitNodeType::orCnt: {
|
||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
os << auxNode << " [" << opStyle << "label=\"∨\"]" << endl;
|
||||
os << auxNode << " [" << opStyle << "label=\"∨\"]" ;
|
||||
os << std::endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << getExplanationString (node) << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->leftBranch());
|
||||
os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->rightBranch());
|
||||
os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
exportToGraphViz (*casted->leftBranch(), os);
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
break;
|
||||
}
|
||||
|
||||
case AND_NODE: {
|
||||
case CircuitNodeType::andCnt: {
|
||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
os << auxNode << " [" << opStyle << "label=\"∧\"]" << endl;
|
||||
os << auxNode << " [" << opStyle << "label=\"∧\"]" ;
|
||||
os << std::endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << getExplanationString (node) << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->leftBranch());
|
||||
os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->rightBranch()) << endl;
|
||||
os << escapeNode (*casted->rightBranch());
|
||||
os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
exportToGraphViz (*casted->leftBranch(), os);
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
break;
|
||||
}
|
||||
|
||||
case SET_OR_NODE: {
|
||||
case CircuitNodeType::setOrCnt: {
|
||||
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
os << auxNode << " [" << opStyle << "label=\"∨(X)\"]" << endl;
|
||||
os << auxNode << " [" << opStyle << "label=\"∨(X)\"]" ;
|
||||
os << std::endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << getExplanationString (node) << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->follow());
|
||||
os << " [label=\" " << (*casted->follow())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
exportToGraphViz (*casted->follow(), os);
|
||||
break;
|
||||
}
|
||||
|
||||
case SET_AND_NODE: {
|
||||
case CircuitNodeType::setAndCnt: {
|
||||
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
os << auxNode << " [" << opStyle << "label=\"∧(X)\"]" << endl;
|
||||
os << auxNode << " [" << opStyle << "label=\"∧(X)\"]" ;
|
||||
os << std::endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << getExplanationString (node) << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->follow());
|
||||
os << " [label=\" " << (*casted->follow())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
exportToGraphViz (*casted->follow(), os);
|
||||
break;
|
||||
}
|
||||
|
||||
case INC_EXC_NODE: {
|
||||
case CircuitNodeType::incExcCnt: {
|
||||
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
os << auxNode << " [" << opStyle << "label=\"+ - +\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << getExplanationString (node) << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->plus1Branch());
|
||||
os << " [label=\" " << (*casted->plus1Branch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->minusBranch()) << endl;
|
||||
os << escapeNode (*casted->minusBranch()) << std::endl;
|
||||
os << " [label=\" " << (*casted->minusBranch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->plus2Branch());
|
||||
os << " [label=\" " << (*casted->plus2Branch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
exportToGraphViz (*casted->plus1Branch(), os);
|
||||
exportToGraphViz (*casted->plus2Branch(), os);
|
||||
@ -1159,24 +1438,24 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
||||
break;
|
||||
}
|
||||
|
||||
case LEAF_NODE: {
|
||||
case CircuitNodeType::leafCnt: {
|
||||
printClauses (node, os, "style=filled,fillcolor=palegreen,");
|
||||
break;
|
||||
}
|
||||
|
||||
case SMOOTH_NODE: {
|
||||
case CircuitNodeType::smoothCnt: {
|
||||
printClauses (node, os, "style=filled,fillcolor=lightblue,");
|
||||
break;
|
||||
}
|
||||
|
||||
case TRUE_NODE: {
|
||||
case CircuitNodeType::trueCnt: {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,label=\"⊤\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
break;
|
||||
}
|
||||
|
||||
case COMPILATION_FAILED_NODE: {
|
||||
case CircuitNodeType::compilationFailedCnt: {
|
||||
printClauses (node, os, "style=filled,fillcolor=salmon,");
|
||||
break;
|
||||
}
|
||||
@ -1188,17 +1467,17 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
||||
|
||||
|
||||
|
||||
string
|
||||
std::string
|
||||
LiftedCircuit::escapeNode (const CircuitNode* node) const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "\"" << node << "\"" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
std::string
|
||||
LiftedCircuit::getExplanationString (CircuitNode* node)
|
||||
{
|
||||
return Util::contains (explanationMap_, node)
|
||||
@ -1211,15 +1490,15 @@ LiftedCircuit::getExplanationString (CircuitNode* node)
|
||||
void
|
||||
LiftedCircuit::printClauses (
|
||||
CircuitNode* node,
|
||||
ofstream& os,
|
||||
string extraOptions)
|
||||
std::ofstream& os,
|
||||
std::string extraOptions)
|
||||
{
|
||||
Clauses clauses;
|
||||
if (Util::contains (originClausesMap_, node)) {
|
||||
clauses = originClausesMap_[node];
|
||||
} else if (getCircuitNodeType (node) == CircuitNodeType::LEAF_NODE) {
|
||||
} else if (getCircuitNodeType (node) == CircuitNodeType::leafCnt) {
|
||||
clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ;
|
||||
} else if (getCircuitNodeType (node) == CircuitNodeType::SMOOTH_NODE) {
|
||||
} else if (getCircuitNodeType (node) == CircuitNodeType::smoothCnt) {
|
||||
clauses = (dynamic_cast<SmoothNode*>(node))->clauses();
|
||||
}
|
||||
assert (clauses.empty() == false);
|
||||
@ -1230,15 +1509,7 @@ LiftedCircuit::printClauses (
|
||||
os << *clauses[i];
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedKc::~LiftedKc (void)
|
||||
{
|
||||
delete lwcnf_;
|
||||
delete circuit_;
|
||||
os << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@ -1246,20 +1517,21 @@ LiftedKc::~LiftedKc (void)
|
||||
Params
|
||||
LiftedKc::solveQuery (const Grounds& query)
|
||||
{
|
||||
pfList_ = parfactorList;
|
||||
LiftedOperations::shatterAgainstQuery (pfList_, query);
|
||||
LiftedOperations::runWeakBayesBall (pfList_, query);
|
||||
lwcnf_ = new LiftedWCNF (pfList_);
|
||||
circuit_ = new LiftedCircuit (lwcnf_);
|
||||
if (circuit_->isCompilationSucceeded() == false) {
|
||||
cerr << "Error: the circuit compilation has failed." << endl;
|
||||
ParfactorList pfList (parfactorList);
|
||||
LiftedOperations::shatterAgainstQuery (pfList, query);
|
||||
LiftedOperations::runWeakBayesBall (pfList, query);
|
||||
LiftedWCNF lwcnf (pfList);
|
||||
LiftedCircuit circuit (&lwcnf);
|
||||
if (circuit.isCompilationSucceeded() == false) {
|
||||
std::cerr << "Error: the circuit compilation has failed." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
vector<PrvGroup> groups;
|
||||
std::vector<PrvGroup> groups;
|
||||
Ranges ranges;
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
size_t idx = (*it)->indexOfGround (query[i]);
|
||||
if (idx != (*it)->nrArguments()) {
|
||||
groups.push_back ((*it)->argument (idx).group());
|
||||
@ -1274,18 +1546,18 @@ LiftedKc::solveQuery (const Grounds& query)
|
||||
Indexer indexer (ranges);
|
||||
while (indexer.valid()) {
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
|
||||
std::vector<LiteralId> litIds = lwcnf.prvGroupLiterals (groups[i]);
|
||||
for (size_t j = 0; j < litIds.size(); j++) {
|
||||
if (indexer[i] == j) {
|
||||
lwcnf_->addWeight (litIds[j], LogAware::one(),
|
||||
lwcnf.addWeight (litIds[j], LogAware::one(),
|
||||
LogAware::one());
|
||||
} else {
|
||||
lwcnf_->addWeight (litIds[j], LogAware::zero(),
|
||||
lwcnf.addWeight (litIds[j], LogAware::zero(),
|
||||
LogAware::one());
|
||||
}
|
||||
}
|
||||
}
|
||||
params.push_back (circuit_->getWeightedModelCount());
|
||||
params.push_back (circuit.getWeightedModelCount());
|
||||
++ indexer;
|
||||
}
|
||||
LogAware::normalize (params);
|
||||
@ -1298,12 +1570,14 @@ LiftedKc::solveQuery (const Grounds& query)
|
||||
|
||||
|
||||
void
|
||||
LiftedKc::printSolverFlags (void) const
|
||||
LiftedKc::printSolverFlags() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "lifted kc [" ;
|
||||
ss << "log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
std::cout << ss.str() << std::endl;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,302 +1,26 @@
|
||||
#ifndef HORUS_LIFTEDKC_H
|
||||
#define HORUS_LIFTEDKC_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDKC_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDKC_H_
|
||||
|
||||
#include "LiftedSolver.h"
|
||||
#include "LiftedWCNF.h"
|
||||
#include "ParfactorList.h"
|
||||
|
||||
|
||||
enum CircuitNodeType {
|
||||
OR_NODE,
|
||||
AND_NODE,
|
||||
SET_OR_NODE,
|
||||
SET_AND_NODE,
|
||||
INC_EXC_NODE,
|
||||
LEAF_NODE,
|
||||
SMOOTH_NODE,
|
||||
TRUE_NODE,
|
||||
COMPILATION_FAILED_NODE
|
||||
};
|
||||
namespace Horus {
|
||||
|
||||
|
||||
|
||||
class CircuitNode
|
||||
{
|
||||
public:
|
||||
CircuitNode (void) { }
|
||||
|
||||
virtual ~CircuitNode (void) { }
|
||||
|
||||
virtual double weight (void) const = 0;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class OrNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
OrNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
|
||||
|
||||
~OrNode (void);
|
||||
|
||||
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* leftBranch_;
|
||||
CircuitNode* rightBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class AndNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
AndNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
|
||||
|
||||
AndNode (CircuitNode* leftBranch, CircuitNode* rightBranch)
|
||||
: CircuitNode(), leftBranch_(leftBranch), rightBranch_(rightBranch) { }
|
||||
|
||||
~AndNode (void);
|
||||
|
||||
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* leftBranch_;
|
||||
CircuitNode* rightBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SetOrNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
SetOrNode (unsigned nrGroundings)
|
||||
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
|
||||
|
||||
~SetOrNode (void);
|
||||
|
||||
CircuitNode** follow (void) { return &follow_; }
|
||||
|
||||
static unsigned nrPositives (void) { return nrPos_; }
|
||||
|
||||
static unsigned nrNegatives (void) { return nrNeg_; }
|
||||
|
||||
static bool isSet (void) { return nrPos_ >= 0; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* follow_;
|
||||
unsigned nrGroundings_;
|
||||
static int nrPos_;
|
||||
static int nrNeg_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SetAndNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
SetAndNode (unsigned nrGroundings)
|
||||
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
|
||||
|
||||
~SetAndNode (void);
|
||||
|
||||
CircuitNode** follow (void) { return &follow_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* follow_;
|
||||
unsigned nrGroundings_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class IncExcNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
IncExcNode (void)
|
||||
: CircuitNode(), plus1Branch_(0), plus2Branch_(0), minusBranch_(0) { }
|
||||
|
||||
~IncExcNode (void);
|
||||
|
||||
CircuitNode** plus1Branch (void) { return &plus1Branch_; }
|
||||
CircuitNode** plus2Branch (void) { return &plus2Branch_; }
|
||||
CircuitNode** minusBranch (void) { return &minusBranch_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* plus1Branch_;
|
||||
CircuitNode* plus2Branch_;
|
||||
CircuitNode* minusBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LeafNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
LeafNode (Clause* clause, const LiftedWCNF& lwcnf)
|
||||
: CircuitNode(), clause_(clause), lwcnf_(lwcnf) { }
|
||||
|
||||
~LeafNode (void);
|
||||
|
||||
const Clause* clause (void) const { return clause_; }
|
||||
|
||||
Clause* clause (void) { return clause_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
Clause* clause_;
|
||||
const LiftedWCNF& lwcnf_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SmoothNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf)
|
||||
: CircuitNode(), clauses_(clauses), lwcnf_(lwcnf) { }
|
||||
|
||||
~SmoothNode (void);
|
||||
|
||||
const Clauses& clauses (void) const { return clauses_; }
|
||||
|
||||
Clauses clauses (void) { return clauses_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
Clauses clauses_;
|
||||
const LiftedWCNF& lwcnf_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class TrueNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
TrueNode (void) : CircuitNode() { }
|
||||
|
||||
double weight (void) const;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CompilationFailedNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
CompilationFailedNode (void) : CircuitNode() { }
|
||||
|
||||
double weight (void) const;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LiftedCircuit
|
||||
{
|
||||
public:
|
||||
LiftedCircuit (const LiftedWCNF* lwcnf);
|
||||
|
||||
~LiftedCircuit (void);
|
||||
|
||||
bool isCompilationSucceeded (void) const;
|
||||
|
||||
double getWeightedModelCount (void) const;
|
||||
|
||||
void exportToGraphViz (const char*);
|
||||
|
||||
private:
|
||||
void compile (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
|
||||
LogVars& rootLogVars);
|
||||
|
||||
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
void shatterCountedLogVars (Clauses& clauses);
|
||||
|
||||
bool shatterCountedLogVarsAux (Clauses& clauses);
|
||||
|
||||
bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2);
|
||||
|
||||
bool independentClause (Clause& clause, Clauses& otherClauses) const;
|
||||
|
||||
bool independentLiteral (const Literal& lit,
|
||||
const Literals& otherLits) const;
|
||||
|
||||
LitLvTypesSet smoothCircuit (CircuitNode* node);
|
||||
|
||||
void createSmoothNode (const LitLvTypesSet& lids,
|
||||
CircuitNode** prev);
|
||||
|
||||
vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
|
||||
|
||||
bool containsTypes (const LogVarTypes& typesA,
|
||||
const LogVarTypes& typesB) const;
|
||||
|
||||
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
|
||||
|
||||
void exportToGraphViz (CircuitNode* node, ofstream&);
|
||||
|
||||
void printClauses (CircuitNode* node, ofstream&,
|
||||
string extraOptions = "");
|
||||
|
||||
string escapeNode (const CircuitNode* node) const;
|
||||
|
||||
string getExplanationString (CircuitNode* node);
|
||||
|
||||
CircuitNode* root_;
|
||||
const LiftedWCNF* lwcnf_;
|
||||
bool compilationSucceeded_;
|
||||
Clauses backupClauses_;
|
||||
unordered_map<CircuitNode*, Clauses> originClausesMap_;
|
||||
unordered_map<CircuitNode*, string> explanationMap_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedCircuit);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LiftedKc : public LiftedSolver
|
||||
{
|
||||
class LiftedKc : public LiftedSolver {
|
||||
public:
|
||||
LiftedKc (const ParfactorList& pfList)
|
||||
: LiftedSolver(pfList) { }
|
||||
|
||||
~LiftedKc (void);
|
||||
|
||||
Params solveQuery (const Grounds&);
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
void printSolverFlags() const;
|
||||
|
||||
private:
|
||||
LiftedWCNF* lwcnf_;
|
||||
LiftedCircuit* circuit_;
|
||||
ParfactorList pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedKc);
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDKC_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDKC_H_
|
||||
|
||||
|
@ -1,10 +1,22 @@
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <iostream>
|
||||
|
||||
#include "LiftedOperations.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace LiftedOperations {
|
||||
|
||||
namespace {
|
||||
|
||||
Parfactors absorve (ObservedFormula& obsFormula, Parfactor* g);
|
||||
|
||||
}
|
||||
|
||||
void
|
||||
LiftedOperations::shatterAgainstQuery (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
shatterAgainstQuery (ParfactorList& pfList, const Grounds& query)
|
||||
{
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
if (query[i].isAtom()) {
|
||||
@ -35,17 +47,17 @@ LiftedOperations::shatterAgainstQuery (
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
cerr << "Error: could not find a parfactor with ground " ;
|
||||
cerr << "`" << query[i] << "'." << endl;
|
||||
std::cerr << "Error: could not find a parfactor with ground " ;
|
||||
std::cerr << "`" << query[i] << "'." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
pfList.add (newPfs);
|
||||
}
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||
std::cout << "SHATTERED AGAINST THE QUERY" << std::endl;
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
cout << " -> " << query[i] << endl;
|
||||
std::cout << " -> " << query[i] << std::endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList.print();
|
||||
@ -55,12 +67,10 @@ LiftedOperations::shatterAgainstQuery (
|
||||
|
||||
|
||||
void
|
||||
LiftedOperations::runWeakBayesBall (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
runWeakBayesBall (ParfactorList& pfList, const Grounds& query)
|
||||
{
|
||||
queue<PrvGroup> todo; // groups to process
|
||||
set<PrvGroup> done; // processed or in queue
|
||||
std::queue<PrvGroup> todo; // groups to process
|
||||
std::set<PrvGroup> done; // processed or in queue
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
@ -74,14 +84,14 @@ LiftedOperations::runWeakBayesBall (
|
||||
}
|
||||
}
|
||||
|
||||
set<Parfactor*> requiredPfs;
|
||||
std::set<Parfactor*> requiredPfs;
|
||||
while (todo.empty() == false) {
|
||||
PrvGroup group = todo.front();
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false &&
|
||||
(*it)->containsGroup (group)) {
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
std::vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
if (Util::contains (done, groups[i]) == false) {
|
||||
todo.push (groups[i]);
|
||||
@ -116,9 +126,7 @@ LiftedOperations::runWeakBayesBall (
|
||||
|
||||
|
||||
void
|
||||
LiftedOperations::absorveEvidence (
|
||||
ParfactorList& pfList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
absorveEvidence (ParfactorList& pfList, ObservedFormulas& obsFormulas)
|
||||
{
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
Parfactors newPfs;
|
||||
@ -143,9 +151,9 @@ LiftedOperations::absorveEvidence (
|
||||
}
|
||||
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||
std::cout << "AFTER EVIDENCE ABSORVED" << std::endl;
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
cout << " -> " << obsFormulas[i] << endl;
|
||||
std::cout << " -> " << obsFormulas[i] << std::endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList.print();
|
||||
@ -155,9 +163,7 @@ LiftedOperations::absorveEvidence (
|
||||
|
||||
|
||||
Parfactors
|
||||
LiftedOperations::countNormalize (
|
||||
Parfactor* g,
|
||||
const LogVarSet& set)
|
||||
countNormalize (Parfactor* g, const LogVarSet& set)
|
||||
{
|
||||
Parfactors normPfs;
|
||||
if (set.empty()) {
|
||||
@ -174,7 +180,7 @@ LiftedOperations::countNormalize (
|
||||
|
||||
|
||||
Parfactor
|
||||
LiftedOperations::calcGroundMultiplication (Parfactor pf)
|
||||
calcGroundMultiplication (Parfactor pf)
|
||||
{
|
||||
LogVarSet lvs = pf.constr()->logVarSet();
|
||||
lvs -= pf.constr()->singletons();
|
||||
@ -206,10 +212,10 @@ LiftedOperations::calcGroundMultiplication (Parfactor pf)
|
||||
|
||||
|
||||
|
||||
namespace {
|
||||
|
||||
Parfactors
|
||||
LiftedOperations::absorve (
|
||||
ObservedFormula& obsFormula,
|
||||
Parfactor* g)
|
||||
absorve (ObservedFormula& obsFormula, Parfactor* g)
|
||||
{
|
||||
Parfactors absorvedPfs;
|
||||
const ProbFormulas& formulas = g->arguments();
|
||||
@ -269,3 +275,9 @@ LiftedOperations::absorve (
|
||||
return absorvedPfs;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace LiftedOperations
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,29 +1,26 @@
|
||||
#ifndef HORUS_LIFTEDOPERATIONS_H
|
||||
#define HORUS_LIFTEDOPERATIONS_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDOPERATIONS_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDOPERATIONS_H_
|
||||
|
||||
#include "ParfactorList.h"
|
||||
|
||||
class LiftedOperations
|
||||
{
|
||||
public:
|
||||
static void shatterAgainstQuery (
|
||||
ParfactorList& pfList, const Grounds& query);
|
||||
|
||||
static void runWeakBayesBall (
|
||||
ParfactorList& pfList, const Grounds&);
|
||||
namespace Horus {
|
||||
|
||||
static void absorveEvidence (
|
||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||
namespace LiftedOperations {
|
||||
|
||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||
void shatterAgainstQuery (ParfactorList& pfList, const Grounds& query);
|
||||
|
||||
static Parfactor calcGroundMultiplication (Parfactor pf);
|
||||
void runWeakBayesBall (ParfactorList& pfList, const Grounds& query);
|
||||
|
||||
private:
|
||||
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||
void absorveEvidence (ParfactorList& pfList, ObservedFormulas&);
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedOperations);
|
||||
};
|
||||
Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||
|
||||
#endif // HORUS_LIFTEDOPERATIONS_H
|
||||
Parfactor calcGroundMultiplication (Parfactor pf);
|
||||
|
||||
} // namespace LiftedOperations
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDOPERATIONS_H_
|
||||
|
||||
|
@ -1,14 +1,12 @@
|
||||
#ifndef HORUS_LIFTEDSOLVER_H
|
||||
#define HORUS_LIFTEDSOLVER_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDSOLVER_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDSOLVER_H_
|
||||
|
||||
#include "ParfactorList.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace Horus {
|
||||
|
||||
class LiftedSolver
|
||||
{
|
||||
class LiftedSolver {
|
||||
public:
|
||||
LiftedSolver (const ParfactorList& pfList)
|
||||
: parfactorList(pfList) { }
|
||||
@ -17,7 +15,7 @@ class LiftedSolver
|
||||
|
||||
virtual Params solveQuery (const Grounds& query) = 0;
|
||||
|
||||
virtual void printSolverFlags (void) const = 0;
|
||||
virtual void printSolverFlags() const = 0;
|
||||
|
||||
protected:
|
||||
const ParfactorList& parfactorList;
|
||||
@ -26,5 +24,7 @@ class LiftedSolver
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedSolver);
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDSOLVER_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDSOLVER_H_
|
||||
|
||||
|
@ -1,22 +1,21 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "LiftedUtils.h"
|
||||
#include "ConstraintTree.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace LiftedUtils {
|
||||
|
||||
|
||||
unordered_map<string, unsigned> symbolDict;
|
||||
std::unordered_map<std::string, unsigned> symbolDict;
|
||||
|
||||
|
||||
Symbol
|
||||
getSymbol (const string& symbolName)
|
||||
getSymbol (const std::string& symbolName)
|
||||
{
|
||||
unordered_map<string, unsigned>::iterator it
|
||||
std::unordered_map<std::string, unsigned>::iterator it
|
||||
= symbolDict.find (symbolName);
|
||||
if (it != symbolDict.end()) {
|
||||
return it->second;
|
||||
@ -29,12 +28,12 @@ getSymbol (const string& symbolName)
|
||||
|
||||
|
||||
void
|
||||
printSymbolDictionary (void)
|
||||
printSymbolDictionary()
|
||||
{
|
||||
unordered_map<string, unsigned>::const_iterator it
|
||||
std::unordered_map<std::string, unsigned>::const_iterator it
|
||||
= symbolDict.begin();
|
||||
while (it != symbolDict.end()) {
|
||||
cout << it->first << " -> " << it->second << endl;
|
||||
std::cout << it->first << " -> " << it->second << std::endl;
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
@ -43,9 +42,10 @@ printSymbolDictionary (void)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Symbol& s)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const Symbol& s)
|
||||
{
|
||||
unordered_map<string, unsigned>::const_iterator it
|
||||
std::unordered_map<std::string, unsigned>::const_iterator it
|
||||
= LiftedUtils::symbolDict.begin();
|
||||
while (it != LiftedUtils::symbolDict.end() && it->second != s) {
|
||||
++ it;
|
||||
@ -57,9 +57,10 @@ ostream& operator<< (ostream &os, const Symbol& s)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const LogVar& X)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const LogVar& X)
|
||||
{
|
||||
const string labels[] = {
|
||||
const std::string labels[] = {
|
||||
"A", "B", "C", "D", "E", "F",
|
||||
"G", "H", "I", "J", "K", "M" };
|
||||
(X >= 12) ? os << "X_" << X.id_ : os << labels[X];
|
||||
@ -68,7 +69,8 @@ ostream& operator<< (ostream &os, const LogVar& X)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Tuple& t)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const Tuple& t)
|
||||
{
|
||||
os << "(" ;
|
||||
for (size_t i = 0; i < t.size(); i++) {
|
||||
@ -80,7 +82,8 @@ ostream& operator<< (ostream &os, const Tuple& t)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Ground& gr)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const Ground& gr)
|
||||
{
|
||||
os << gr.functor();
|
||||
os << "(" ;
|
||||
@ -95,12 +98,12 @@ ostream& operator<< (ostream &os, const Ground& gr)
|
||||
|
||||
|
||||
LogVars
|
||||
Substitution::getDiscardedLogVars (void) const
|
||||
Substitution::getDiscardedLogVars() const
|
||||
{
|
||||
LogVars discardedLvs;
|
||||
set<LogVar> doneLvs;
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
it = subs_.begin();
|
||||
std::set<LogVar> doneLvs;
|
||||
std::unordered_map<LogVar, LogVar>::const_iterator it
|
||||
= subs_.begin();
|
||||
while (it != subs_.end()) {
|
||||
if (Util::contains (doneLvs, it->second)) {
|
||||
discardedLvs.push_back (it->first);
|
||||
@ -114,9 +117,10 @@ Substitution::getDiscardedLogVars (void) const
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Substitution& theta)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const Substitution& theta)
|
||||
{
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
std::unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
os << "[" ;
|
||||
it = theta.subs_.begin();
|
||||
while (it != theta.subs_.end()) {
|
||||
@ -128,3 +132,5 @@ ostream& operator<< (ostream &os, const Substitution& theta)
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,164 +1,210 @@
|
||||
#ifndef HORUS_LIFTEDUTILS_H
|
||||
#define HORUS_LIFTEDUTILS_H
|
||||
|
||||
#include <string>
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDUTILS_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDUTILS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <ostream>
|
||||
|
||||
#include "TinySet.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace Horus {
|
||||
|
||||
|
||||
class Symbol
|
||||
{
|
||||
class Symbol {
|
||||
public:
|
||||
Symbol (void) : id_(Util::maxUnsigned()) { }
|
||||
Symbol() : id_(Util::maxUnsigned()) { }
|
||||
|
||||
Symbol (unsigned id) : id_(id) { }
|
||||
|
||||
operator unsigned (void) const { return id_; }
|
||||
operator unsigned() const { return id_; }
|
||||
|
||||
bool valid (void) const { return id_ != Util::maxUnsigned(); }
|
||||
bool valid() const { return id_ != Util::maxUnsigned(); }
|
||||
|
||||
static Symbol invalid (void) { return Symbol(); }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Symbol& s);
|
||||
static Symbol invalid() { return Symbol(); }
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<< (std::ostream&, const Symbol&);
|
||||
|
||||
unsigned id_;
|
||||
};
|
||||
|
||||
|
||||
class LogVar
|
||||
{
|
||||
class LogVar {
|
||||
public:
|
||||
LogVar (void) : id_(Util::maxUnsigned()) { }
|
||||
LogVar() : id_(Util::maxUnsigned()) { }
|
||||
|
||||
LogVar (unsigned id) : id_(id) { }
|
||||
|
||||
operator unsigned (void) const { return id_; }
|
||||
operator unsigned() const { return id_; }
|
||||
|
||||
LogVar& operator++ (void)
|
||||
{
|
||||
assert (valid());
|
||||
id_ ++;
|
||||
return *this;
|
||||
}
|
||||
LogVar& operator++();
|
||||
|
||||
bool valid (void) const
|
||||
{
|
||||
return id_ != Util::maxUnsigned();
|
||||
}
|
||||
|
||||
friend ostream& operator<< (ostream &os, const LogVar& X);
|
||||
bool valid() const;
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<< (std::ostream&, const LogVar&);
|
||||
|
||||
unsigned id_;
|
||||
};
|
||||
|
||||
|
||||
namespace std {
|
||||
template <> struct hash<Symbol> {
|
||||
size_t operator() (const Symbol& s) const {
|
||||
return std::hash<unsigned>() (s);
|
||||
}};
|
||||
|
||||
template <> struct hash<LogVar> {
|
||||
size_t operator() (const LogVar& X) const {
|
||||
return std::hash<unsigned>() (X);
|
||||
}};
|
||||
};
|
||||
|
||||
|
||||
typedef vector<Symbol> Symbols;
|
||||
typedef vector<Symbol> Tuple;
|
||||
typedef vector<Tuple> Tuples;
|
||||
typedef vector<LogVar> LogVars;
|
||||
typedef TinySet<Symbol> SymbolSet;
|
||||
typedef TinySet<LogVar> LogVarSet;
|
||||
typedef TinySet<Tuple> TupleSet;
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Tuple& t);
|
||||
|
||||
|
||||
namespace LiftedUtils {
|
||||
Symbol getSymbol (const string&);
|
||||
void printSymbolDictionary (void);
|
||||
inline LogVar&
|
||||
LogVar::operator++()
|
||||
{
|
||||
assert (valid());
|
||||
id_ ++;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
class Ground
|
||||
inline bool
|
||||
LogVar::valid() const
|
||||
{
|
||||
return id_ != Util::maxUnsigned();
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
||||
namespace std {
|
||||
|
||||
template <> struct hash<Horus::Symbol> {
|
||||
size_t operator() (const Horus::Symbol& s) const {
|
||||
return std::hash<unsigned>() (s);
|
||||
}};
|
||||
|
||||
template <> struct hash<Horus::LogVar> {
|
||||
size_t operator() (const Horus::LogVar& X) const {
|
||||
return std::hash<unsigned>() (X);
|
||||
}};
|
||||
|
||||
} // namespace std
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
typedef std::vector<Symbol> Symbols;
|
||||
typedef std::vector<Symbol> Tuple;
|
||||
typedef std::vector<Tuple> Tuples;
|
||||
typedef std::vector<LogVar> LogVars;
|
||||
typedef TinySet<Symbol> SymbolSet;
|
||||
typedef TinySet<LogVar> LogVarSet;
|
||||
typedef TinySet<Tuple> TupleSet;
|
||||
|
||||
|
||||
std::ostream& operator<< (std::ostream&, const Tuple&);
|
||||
|
||||
|
||||
namespace LiftedUtils {
|
||||
|
||||
Symbol getSymbol (const std::string&);
|
||||
|
||||
void printSymbolDictionary();
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
class Ground {
|
||||
public:
|
||||
Ground (Symbol f) : functor_(f) { }
|
||||
|
||||
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
|
||||
Ground (Symbol f, const Symbols& args)
|
||||
: functor_(f), args_(args) { }
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
Symbol functor() const { return functor_; }
|
||||
|
||||
Symbols args (void) const { return args_; }
|
||||
Symbols args() const { return args_; }
|
||||
|
||||
size_t arity (void) const { return args_.size(); }
|
||||
size_t arity() const { return args_.size(); }
|
||||
|
||||
bool isAtom (void) const { return args_.empty(); }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Ground& gr);
|
||||
bool isAtom() const { return args_.empty(); }
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<< (std::ostream&, const Ground&);
|
||||
|
||||
Symbol functor_;
|
||||
Symbols args_;
|
||||
};
|
||||
|
||||
typedef vector<Ground> Grounds;
|
||||
typedef std::vector<Ground> Grounds;
|
||||
|
||||
|
||||
|
||||
class Substitution
|
||||
{
|
||||
class Substitution {
|
||||
public:
|
||||
void add (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (Util::contains (subs_, X_old) == false);
|
||||
subs_.insert (make_pair (X_old, X_new));
|
||||
}
|
||||
void add (LogVar X_old, LogVar X_new);
|
||||
|
||||
void rename (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (Util::contains (subs_, X_old));
|
||||
subs_.find (X_old)->second = X_new;
|
||||
}
|
||||
void rename (LogVar X_old, LogVar X_new);
|
||||
|
||||
LogVar newNameFor (LogVar X) const
|
||||
{
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
it = subs_.find (X);
|
||||
if (it != subs_.end()) {
|
||||
return subs_.find (X)->second;
|
||||
}
|
||||
return X;
|
||||
}
|
||||
LogVar newNameFor (LogVar X) const;
|
||||
|
||||
bool containsReplacementFor (LogVar X) const
|
||||
{
|
||||
return Util::contains (subs_, X);
|
||||
}
|
||||
bool containsReplacementFor (LogVar X) const;
|
||||
|
||||
size_t nrReplacements (void) const { return subs_.size(); }
|
||||
size_t nrReplacements() const;
|
||||
|
||||
LogVars getDiscardedLogVars (void) const;
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
||||
LogVars getDiscardedLogVars() const;
|
||||
|
||||
private:
|
||||
unordered_map<LogVar, LogVar> subs_;
|
||||
friend std::ostream& operator<< (
|
||||
std::ostream&, const Substitution&);
|
||||
|
||||
std::unordered_map<LogVar, LogVar> subs_;
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDUTILS_H
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Substitution::add (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (Util::contains (subs_, X_old) == false);
|
||||
subs_.insert (std::make_pair (X_old, X_new));
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Substitution::rename (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (Util::contains (subs_, X_old));
|
||||
subs_.find (X_old)->second = X_new;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline LogVar
|
||||
Substitution::newNameFor (LogVar X) const
|
||||
{
|
||||
std::unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
it = subs_.find (X);
|
||||
if (it != subs_.end()) {
|
||||
return subs_.find (X)->second;
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
Substitution::containsReplacementFor (LogVar X) const
|
||||
{
|
||||
return Util::contains (subs_, X);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline size_t
|
||||
Substitution::nrReplacements() const
|
||||
{
|
||||
return subs_.size();
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDUTILS_H_
|
||||
|
||||
|
@ -1,6 +1,12 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "LiftedVe.h"
|
||||
#include "LiftedOperations.h"
|
||||
@ -8,21 +14,158 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
vector<LiftedOperator*>
|
||||
namespace Horus {
|
||||
|
||||
class LiftedOperator {
|
||||
public:
|
||||
virtual ~LiftedOperator() { }
|
||||
|
||||
virtual double getLogCost() = 0;
|
||||
|
||||
virtual void apply() = 0;
|
||||
|
||||
virtual std::string toString() = 0;
|
||||
|
||||
static std::vector<LiftedOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
static void printValidOps (ParfactorList&, const Grounds&);
|
||||
|
||||
static std::vector<ParfactorList::iterator> getParfactorsWithGroup (
|
||||
ParfactorList&, PrvGroup group);
|
||||
|
||||
private:
|
||||
DISALLOW_ASSIGN (LiftedOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class ProductOperator : public LiftedOperator {
|
||||
public:
|
||||
ProductOperator (
|
||||
ParfactorList::iterator g1,
|
||||
ParfactorList::iterator g2,
|
||||
ParfactorList& pfList)
|
||||
: g1_(g1), g2_(g2), pfList_(pfList) { }
|
||||
|
||||
double getLogCost();
|
||||
|
||||
void apply();
|
||||
|
||||
static std::vector<ProductOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
std::string toString();
|
||||
|
||||
private:
|
||||
static bool validOp (Parfactor*, Parfactor*);
|
||||
|
||||
ParfactorList::iterator g1_;
|
||||
ParfactorList::iterator g2_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (ProductOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SumOutOperator : public LiftedOperator {
|
||||
public:
|
||||
SumOutOperator (PrvGroup group, ParfactorList& pfList)
|
||||
: group_(group), pfList_(pfList) { }
|
||||
|
||||
double getLogCost();
|
||||
|
||||
void apply();
|
||||
|
||||
static std::vector<SumOutOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
std::string toString();
|
||||
|
||||
private:
|
||||
static bool validOp (PrvGroup, ParfactorList&, const Grounds&);
|
||||
|
||||
static bool isToEliminate (Parfactor*, PrvGroup, const Grounds&);
|
||||
|
||||
PrvGroup group_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (SumOutOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CountingOperator : public LiftedOperator {
|
||||
public:
|
||||
CountingOperator (
|
||||
ParfactorList::iterator pfIter,
|
||||
LogVar X,
|
||||
ParfactorList& pfList)
|
||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||
|
||||
double getLogCost();
|
||||
|
||||
void apply();
|
||||
|
||||
static std::vector<CountingOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
std::string toString();
|
||||
|
||||
private:
|
||||
static bool validOp (Parfactor*, LogVar);
|
||||
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (CountingOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class GroundOperator : public LiftedOperator {
|
||||
public:
|
||||
GroundOperator (
|
||||
PrvGroup group,
|
||||
unsigned lvIndex,
|
||||
ParfactorList& pfList)
|
||||
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
|
||||
|
||||
double getLogCost();
|
||||
|
||||
void apply();
|
||||
|
||||
static std::vector<GroundOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
std::string toString();
|
||||
|
||||
private:
|
||||
std::vector<std::pair<PrvGroup, unsigned>> getAffectedFormulas();
|
||||
|
||||
PrvGroup group_;
|
||||
unsigned lvIndex_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (GroundOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
std::vector<LiftedOperator*>
|
||||
LiftedOperator::getValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<LiftedOperator*> validOps;
|
||||
vector<ProductOperator*> multOps;
|
||||
std::vector<LiftedOperator*> validOps;
|
||||
std::vector<ProductOperator*> multOps;
|
||||
|
||||
multOps = ProductOperator::getValidOps (pfList);
|
||||
validOps.insert (validOps.end(), multOps.begin(), multOps.end());
|
||||
|
||||
if (Globals::verbosity > 1 || multOps.empty()) {
|
||||
vector<SumOutOperator*> sumOutOps;
|
||||
vector<CountingOperator*> countOps;
|
||||
vector<GroundOperator*> groundOps;
|
||||
std::vector<SumOutOperator*> sumOutOps;
|
||||
std::vector<CountingOperator*> countOps;
|
||||
std::vector<GroundOperator*> groundOps;
|
||||
sumOutOps = SumOutOperator::getValidOps (pfList, query);
|
||||
countOps = CountingOperator::getValidOps (pfList);
|
||||
groundOps = GroundOperator::getValidOps (pfList);
|
||||
@ -41,21 +184,21 @@ LiftedOperator::printValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<LiftedOperator*> validOps;
|
||||
std::vector<LiftedOperator*> validOps;
|
||||
validOps = LiftedOperator::getValidOps (pfList, query);
|
||||
for (size_t i = 0; i < validOps.size(); i++) {
|
||||
cout << "-> " << validOps[i]->toString();
|
||||
std::cout << "-> " << validOps[i]->toString();
|
||||
delete validOps[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<ParfactorList::iterator>
|
||||
std::vector<ParfactorList::iterator>
|
||||
LiftedOperator::getParfactorsWithGroup (
|
||||
ParfactorList& pfList, PrvGroup group)
|
||||
{
|
||||
vector<ParfactorList::iterator> iters;
|
||||
std::vector<ParfactorList::iterator> iters;
|
||||
ParfactorList::iterator pflIt = pfList.begin();
|
||||
while (pflIt != pfList.end()) {
|
||||
if ((*pflIt)->containsGroup (group)) {
|
||||
@ -69,7 +212,7 @@ LiftedOperator::getParfactorsWithGroup (
|
||||
|
||||
|
||||
double
|
||||
ProductOperator::getLogCost (void)
|
||||
ProductOperator::getLogCost()
|
||||
{
|
||||
return std::log (0.0);
|
||||
}
|
||||
@ -77,7 +220,7 @@ ProductOperator::getLogCost (void)
|
||||
|
||||
|
||||
void
|
||||
ProductOperator::apply (void)
|
||||
ProductOperator::apply()
|
||||
{
|
||||
Parfactor* g1 = *g1_;
|
||||
Parfactor* g2 = *g2_;
|
||||
@ -89,13 +232,13 @@ ProductOperator::apply (void)
|
||||
|
||||
|
||||
|
||||
vector<ProductOperator*>
|
||||
std::vector<ProductOperator*>
|
||||
ProductOperator::getValidOps (ParfactorList& pfList)
|
||||
{
|
||||
vector<ProductOperator*> validOps;
|
||||
std::vector<ProductOperator*> validOps;
|
||||
ParfactorList::iterator it1 = pfList.begin();
|
||||
ParfactorList::iterator penultimate = -- pfList.end();
|
||||
set<Parfactor*> pfs;
|
||||
std::set<Parfactor*> pfs;
|
||||
while (it1 != penultimate) {
|
||||
if (Util::contains (pfs, *it1)) {
|
||||
++ it1;
|
||||
@ -128,15 +271,15 @@ ProductOperator::getValidOps (ParfactorList& pfList)
|
||||
|
||||
|
||||
|
||||
string
|
||||
ProductOperator::toString (void)
|
||||
std::string
|
||||
ProductOperator::toString()
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "just multiplicate " ;
|
||||
ss << (*g1_)->getAllGroups();
|
||||
ss << " x " ;
|
||||
ss << (*g2_)->getAllGroups();
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -168,14 +311,14 @@ ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
|
||||
|
||||
|
||||
double
|
||||
SumOutOperator::getLogCost (void)
|
||||
SumOutOperator::getLogCost()
|
||||
{
|
||||
TinySet<PrvGroup> groupSet;
|
||||
ParfactorList::const_iterator pfIter = pfList_.begin();
|
||||
unsigned nrProdFactors = 0;
|
||||
while (pfIter != pfList_.end()) {
|
||||
if ((*pfIter)->containsGroup (group_)) {
|
||||
vector<PrvGroup> groups = (*pfIter)->getAllGroups();
|
||||
std::vector<PrvGroup> groups = (*pfIter)->getAllGroups();
|
||||
groupSet |= TinySet<PrvGroup> (groups);
|
||||
++ nrProdFactors;
|
||||
}
|
||||
@ -203,9 +346,9 @@ SumOutOperator::getLogCost (void)
|
||||
|
||||
|
||||
void
|
||||
SumOutOperator::apply (void)
|
||||
SumOutOperator::apply()
|
||||
{
|
||||
vector<ParfactorList::iterator> iters;
|
||||
std::vector<ParfactorList::iterator> iters;
|
||||
iters = getParfactorsWithGroup (pfList_, group_);
|
||||
Parfactor* product = *(iters[0]);
|
||||
pfList_.remove (iters[0]);
|
||||
@ -234,13 +377,13 @@ SumOutOperator::apply (void)
|
||||
|
||||
|
||||
|
||||
vector<SumOutOperator*>
|
||||
std::vector<SumOutOperator*>
|
||||
SumOutOperator::getValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<SumOutOperator*> validOps;
|
||||
set<PrvGroup> allGroups;
|
||||
std::vector<SumOutOperator*> validOps;
|
||||
std::set<PrvGroup> allGroups;
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
@ -249,7 +392,7 @@ SumOutOperator::getValidOps (
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
set<PrvGroup>::const_iterator groupIt = allGroups.begin();
|
||||
std::set<PrvGroup>::const_iterator groupIt = allGroups.begin();
|
||||
while (groupIt != allGroups.end()) {
|
||||
if (validOp (*groupIt, pfList, query)) {
|
||||
validOps.push_back (new SumOutOperator (*groupIt, pfList));
|
||||
@ -261,18 +404,18 @@ SumOutOperator::getValidOps (
|
||||
|
||||
|
||||
|
||||
string
|
||||
SumOutOperator::toString (void)
|
||||
std::string
|
||||
SumOutOperator::toString()
|
||||
{
|
||||
stringstream ss;
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
std::stringstream ss;
|
||||
std::vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = getParfactorsWithGroup (pfList_, group_);
|
||||
size_t idx = (*pfIters[0])->indexOfGroup (group_);
|
||||
ProbFormula f = (*pfIters[0])->argument (idx);
|
||||
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
||||
ss << "sum out " << f.functor() << "/" << f.arity();
|
||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -284,7 +427,7 @@ SumOutOperator::validOp (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
std::vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = getParfactorsWithGroup (pfList, group);
|
||||
if (isToEliminate (*pfIters[0], group, query) == false) {
|
||||
return false;
|
||||
@ -335,7 +478,7 @@ SumOutOperator::isToEliminate (
|
||||
|
||||
|
||||
double
|
||||
CountingOperator::getLogCost (void)
|
||||
CountingOperator::getLogCost()
|
||||
{
|
||||
double cost = 0.0;
|
||||
size_t fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||
@ -370,7 +513,7 @@ CountingOperator::getLogCost (void)
|
||||
|
||||
|
||||
void
|
||||
CountingOperator::apply (void)
|
||||
CountingOperator::apply()
|
||||
{
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
|
||||
(*pfIter_)->countConvert (X_);
|
||||
@ -393,10 +536,10 @@ CountingOperator::apply (void)
|
||||
|
||||
|
||||
|
||||
vector<CountingOperator*>
|
||||
std::vector<CountingOperator*>
|
||||
CountingOperator::getValidOps (ParfactorList& pfList)
|
||||
{
|
||||
vector<CountingOperator*> validOps;
|
||||
std::vector<CountingOperator*> validOps;
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
LogVarSet candidates = (*it)->uncountedLogVars();
|
||||
@ -414,17 +557,17 @@ CountingOperator::getValidOps (ParfactorList& pfList)
|
||||
|
||||
|
||||
|
||||
string
|
||||
CountingOperator::toString (void)
|
||||
std::string
|
||||
CountingOperator::toString()
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "count convert " << X_ << " in " ;
|
||||
ss << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
|
||||
Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_);
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
ss << " º " << pfs[i]->getLabel() << endl;
|
||||
ss << " º " << pfs[i]->getLabel() << std::endl;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
@ -455,16 +598,16 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
||||
|
||||
|
||||
double
|
||||
GroundOperator::getLogCost (void)
|
||||
GroundOperator::getLogCost()
|
||||
{
|
||||
vector<pair<PrvGroup, unsigned>> affectedFormulas;
|
||||
std::vector<std::pair<PrvGroup, unsigned>> affectedFormulas;
|
||||
affectedFormulas = getAffectedFormulas();
|
||||
// cout << "affected formulas: " ;
|
||||
// std::cout << "affected formulas: " ;
|
||||
// for (size_t i = 0; i < affectedFormulas.size(); i++) {
|
||||
// cout << affectedFormulas[i].first << ":" ;
|
||||
// cout << affectedFormulas[i].second << " " ;
|
||||
// std::cout << affectedFormulas[i].first << ":" ;
|
||||
// std::cout << affectedFormulas[i].second << " " ;
|
||||
// }
|
||||
// cout << "cost =" ;
|
||||
// std::cout << "cost =" ;
|
||||
double totalCost = std::log (0.0);
|
||||
ParfactorList::iterator pflIt = pfList_.begin();
|
||||
while (pflIt != pfList_.end()) {
|
||||
@ -495,20 +638,20 @@ GroundOperator::getLogCost (void)
|
||||
}
|
||||
}
|
||||
if (willBeAffected) {
|
||||
// cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
|
||||
// std::cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
|
||||
double pfCost = reps + pfSize;
|
||||
totalCost = Util::logSum (totalCost, pfCost);
|
||||
}
|
||||
++ pflIt;
|
||||
}
|
||||
// cout << endl;
|
||||
// std::cout << std::endl;
|
||||
return totalCost + 3;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
GroundOperator::apply (void)
|
||||
GroundOperator::apply()
|
||||
{
|
||||
ParfactorList::iterator pfIter;
|
||||
pfIter = getParfactorsWithGroup (pfList_, group_).front();
|
||||
@ -537,11 +680,11 @@ GroundOperator::apply (void)
|
||||
|
||||
|
||||
|
||||
vector<GroundOperator*>
|
||||
std::vector<GroundOperator*>
|
||||
GroundOperator::getValidOps (ParfactorList& pfList)
|
||||
{
|
||||
vector<GroundOperator*> validOps;
|
||||
set<PrvGroup> allGroups;
|
||||
std::vector<GroundOperator*> validOps;
|
||||
std::set<PrvGroup> allGroups;
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
@ -564,18 +707,18 @@ GroundOperator::getValidOps (ParfactorList& pfList)
|
||||
|
||||
|
||||
|
||||
string
|
||||
GroundOperator::toString (void)
|
||||
std::string
|
||||
GroundOperator::toString()
|
||||
{
|
||||
stringstream ss;
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
std::stringstream ss;
|
||||
std::vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = getParfactorsWithGroup (pfList_, group_);
|
||||
Parfactor* pf = *(getParfactorsWithGroup (pfList_, group_).front());
|
||||
size_t idx = pf->indexOfGroup (group_);
|
||||
ProbFormula f = pf->argument (idx);
|
||||
LogVar lv = f.logVars()[lvIndex_];
|
||||
TupleSet tupleSet = pf->constr()->tupleSet ({lv});
|
||||
string pos = "th";
|
||||
std::string pos = "th";
|
||||
if (lvIndex_ == 0) {
|
||||
pos = "st" ;
|
||||
} else if (lvIndex_ == 1) {
|
||||
@ -586,21 +729,21 @@ GroundOperator::toString (void)
|
||||
ss << "grounding " << lvIndex_ + 1 << pos << " log var in " ;
|
||||
ss << f.functor() << "/" << f.arity();
|
||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<pair<PrvGroup, unsigned>>
|
||||
GroundOperator::getAffectedFormulas (void)
|
||||
std::vector<std::pair<PrvGroup, unsigned>>
|
||||
GroundOperator::getAffectedFormulas()
|
||||
{
|
||||
vector<pair<PrvGroup, unsigned>> affectedFormulas;
|
||||
affectedFormulas.push_back (make_pair (group_, lvIndex_));
|
||||
queue<pair<PrvGroup, unsigned>> q;
|
||||
q.push (make_pair (group_, lvIndex_));
|
||||
std::vector<std::pair<PrvGroup, unsigned>> affectedFormulas;
|
||||
affectedFormulas.push_back (std::make_pair (group_, lvIndex_));
|
||||
std::queue<std::pair<PrvGroup, unsigned>> q;
|
||||
q.push (std::make_pair (group_, lvIndex_));
|
||||
while (q.empty() == false) {
|
||||
pair<PrvGroup, unsigned> front = q.front();
|
||||
std::pair<PrvGroup, unsigned> front = q.front();
|
||||
ParfactorList::iterator pflIt = pfList_.begin();
|
||||
while (pflIt != pfList_.end()) {
|
||||
size_t idx = (*pflIt)->indexOfGroup (front.first);
|
||||
@ -610,7 +753,7 @@ GroundOperator::getAffectedFormulas (void)
|
||||
const ProbFormulas& fs = (*pflIt)->arguments();
|
||||
for (size_t i = 0; i < fs.size(); i++) {
|
||||
if (i != idx && fs[i].contains (X)) {
|
||||
pair<PrvGroup, unsigned> pair = make_pair (
|
||||
std::pair<PrvGroup, unsigned> pair = std::make_pair (
|
||||
fs[i].group(), fs[i].indexOf (X));
|
||||
if (Util::contains (affectedFormulas, pair) == false) {
|
||||
q.push (pair);
|
||||
@ -645,13 +788,13 @@ LiftedVe::solveQuery (const Grounds& query)
|
||||
|
||||
|
||||
void
|
||||
LiftedVe::printSolverFlags (void) const
|
||||
LiftedVe::printSolverFlags() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "lve [" ;
|
||||
ss << "log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
std::cout << ss.str() << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@ -675,9 +818,9 @@ LiftedVe::runSolver (const Grounds& query)
|
||||
break;
|
||||
}
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "best operation: " << op->toString();
|
||||
std::cout << "best operation: " << op->toString();
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
op->apply();
|
||||
@ -693,8 +836,9 @@ LiftedVe::runSolver (const Grounds& query)
|
||||
}
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "largest cost = " << std::exp (largestCost_) << endl;
|
||||
cout << endl;
|
||||
std::cout << "largest cost = " << std::exp (largestCost_);
|
||||
std::cout << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
(*pfList_.begin())->simplifyGrounds();
|
||||
(*pfList_.begin())->reorderAccordingGrounds (query);
|
||||
@ -707,7 +851,7 @@ LiftedVe::getBestOperation (const Grounds& query)
|
||||
{
|
||||
double bestCost = 0.0;
|
||||
LiftedOperator* bestOp = 0;
|
||||
vector<LiftedOperator*> validOps;
|
||||
std::vector<LiftedOperator*> validOps;
|
||||
validOps = LiftedOperator::getValidOps (pfList_, query);
|
||||
for (size_t i = 0; i < validOps.size(); i++) {
|
||||
double cost = validOps[i]->getLogCost();
|
||||
@ -727,3 +871,5 @@ LiftedVe::getBestOperation (const Grounds& query)
|
||||
return bestOp;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,157 +1,23 @@
|
||||
#ifndef HORUS_LIFTEDVE_H
|
||||
#define HORUS_LIFTEDVE_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDVE_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDVE_H_
|
||||
|
||||
#include "LiftedSolver.h"
|
||||
#include "ParfactorList.h"
|
||||
|
||||
|
||||
class LiftedOperator
|
||||
{
|
||||
public:
|
||||
virtual ~LiftedOperator (void) { }
|
||||
namespace Horus {
|
||||
|
||||
virtual double getLogCost (void) = 0;
|
||||
|
||||
virtual void apply (void) = 0;
|
||||
|
||||
virtual string toString (void) = 0;
|
||||
|
||||
static vector<LiftedOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
static void printValidOps (ParfactorList&, const Grounds&);
|
||||
|
||||
static vector<ParfactorList::iterator> getParfactorsWithGroup (
|
||||
ParfactorList&, PrvGroup group);
|
||||
|
||||
private:
|
||||
DISALLOW_ASSIGN (LiftedOperator);
|
||||
};
|
||||
class LiftedOperator;
|
||||
|
||||
|
||||
|
||||
class ProductOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
ProductOperator (
|
||||
ParfactorList::iterator g1, ParfactorList::iterator g2,
|
||||
ParfactorList& pfList) : g1_(g1), g2_(g2), pfList_(pfList) { }
|
||||
|
||||
double getLogCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<ProductOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (Parfactor*, Parfactor*);
|
||||
|
||||
ParfactorList::iterator g1_;
|
||||
ParfactorList::iterator g2_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (ProductOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SumOutOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
SumOutOperator (PrvGroup group, ParfactorList& pfList)
|
||||
: group_(group), pfList_(pfList) { }
|
||||
|
||||
double getLogCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<SumOutOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (PrvGroup, ParfactorList&, const Grounds&);
|
||||
|
||||
static bool isToEliminate (Parfactor*, PrvGroup, const Grounds&);
|
||||
|
||||
PrvGroup group_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (SumOutOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CountingOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
CountingOperator (
|
||||
ParfactorList::iterator pfIter,
|
||||
LogVar X,
|
||||
ParfactorList& pfList)
|
||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||
|
||||
double getLogCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<CountingOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (Parfactor*, LogVar);
|
||||
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (CountingOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class GroundOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
GroundOperator (
|
||||
PrvGroup group,
|
||||
unsigned lvIndex,
|
||||
ParfactorList& pfList)
|
||||
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
|
||||
|
||||
double getLogCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<GroundOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void);
|
||||
|
||||
PrvGroup group_;
|
||||
unsigned lvIndex_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (GroundOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LiftedVe : public LiftedSolver
|
||||
{
|
||||
class LiftedVe : public LiftedSolver {
|
||||
public:
|
||||
LiftedVe (const ParfactorList& pfList)
|
||||
: LiftedSolver(pfList) { }
|
||||
|
||||
Params solveQuery (const Grounds&);
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
void printSolverFlags() const;
|
||||
|
||||
private:
|
||||
void runSolver (const Grounds&);
|
||||
@ -164,5 +30,7 @@ class LiftedVe : public LiftedSolver
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedVe);
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDVE_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDVE_H_
|
||||
|
||||
|
@ -1,10 +1,20 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "LiftedWCNF.h"
|
||||
#include "ParfactorList.h"
|
||||
#include "ConstraintTree.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
bool
|
||||
Literal::isGround (ConstraintTree constr, LogVarSet ipgLogVars) const
|
||||
Literal::isGround (
|
||||
ConstraintTree constr,
|
||||
const LogVarSet& ipgLogVars) const
|
||||
{
|
||||
if (logVars_.empty()) {
|
||||
return true;
|
||||
@ -24,13 +34,13 @@ Literal::indexOfLogVar (LogVar X) const
|
||||
|
||||
|
||||
|
||||
string
|
||||
std::string
|
||||
Literal::toString (
|
||||
LogVarSet ipgLogVars,
|
||||
LogVarSet posCountedLvs,
|
||||
LogVarSet negCountedLvs) const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
negated_ ? ss << "¬" : ss << "" ;
|
||||
ss << "λ" ;
|
||||
ss << lid_ ;
|
||||
@ -44,7 +54,7 @@ Literal::toString (
|
||||
ss << "-" << logVars_[i];
|
||||
} else if (ipgLogVars.contains (logVars_[i])) {
|
||||
LogVar X = logVars_[i];
|
||||
const string labels[] = {
|
||||
const std::string labels[] = {
|
||||
"a", "b", "c", "d", "e", "f",
|
||||
"g", "h", "i", "j", "k", "m" };
|
||||
(X >= 12) ? ss << "x_" << X : ss << labels[X];
|
||||
@ -60,7 +70,7 @@ Literal::toString (
|
||||
|
||||
|
||||
std::ostream&
|
||||
operator<< (ostream &os, const Literal& lit)
|
||||
operator<< (std::ostream& os, const Literal& lit)
|
||||
{
|
||||
os << lit.toString();
|
||||
return os;
|
||||
@ -216,7 +226,7 @@ Clause::isIpgLogVar (LogVar X) const
|
||||
|
||||
|
||||
TinySet<LiteralId>
|
||||
Clause::lidSet (void) const
|
||||
Clause::lidSet() const
|
||||
{
|
||||
TinySet<LiteralId> lidSet;
|
||||
for (size_t i = 0; i < literals_.size(); i++) {
|
||||
@ -228,7 +238,7 @@ Clause::lidSet (void) const
|
||||
|
||||
|
||||
LogVarSet
|
||||
Clause::ipgCandidates (void) const
|
||||
Clause::ipgCandidates() const
|
||||
{
|
||||
LogVarSet candidates;
|
||||
LogVarSet allLvs = constr_.logVarSet();
|
||||
@ -259,11 +269,11 @@ Clause::logVarTypes (size_t litIdx) const
|
||||
const LogVars& lvs = literals_[litIdx].logVars();
|
||||
for (size_t i = 0; i < lvs.size(); i++) {
|
||||
if (posCountedLvs_.contains (lvs[i])) {
|
||||
types.push_back (LogVarType::POS_LV);
|
||||
types.push_back (LogVarType::posLvt);
|
||||
} else if (negCountedLvs_.contains (lvs[i])) {
|
||||
types.push_back (LogVarType::NEG_LV);
|
||||
types.push_back (LogVarType::negLvt);
|
||||
} else {
|
||||
types.push_back (LogVarType::FULL_LV);
|
||||
types.push_back (LogVarType::fullLvt);
|
||||
}
|
||||
}
|
||||
return types;
|
||||
@ -320,7 +330,7 @@ void
|
||||
Clause::printClauses (const Clauses& clauses)
|
||||
{
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
cout << *clauses[i] << endl;
|
||||
std::cout << *clauses[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -337,7 +347,7 @@ Clause::deleteClauses (Clauses& clauses)
|
||||
|
||||
|
||||
std::ostream&
|
||||
operator<< (ostream &os, const Clause& clause)
|
||||
operator<< (std::ostream& os, const Clause& clause)
|
||||
{
|
||||
for (unsigned i = 0; i < clause.literals_.size(); i++) {
|
||||
if (i != 0) os << " v " ;
|
||||
@ -369,14 +379,14 @@ Clause::getLogVarSetExcluding (size_t idx) const
|
||||
|
||||
|
||||
std::ostream&
|
||||
operator<< (std::ostream &os, const LitLvTypes& lit)
|
||||
operator<< (std::ostream& os, const LitLvTypes& lit)
|
||||
{
|
||||
os << lit.lid_ << "<" ;
|
||||
for (size_t i = 0; i < lit.lvTypes_.size(); i++) {
|
||||
switch (lit.lvTypes_[i]) {
|
||||
case LogVarType::FULL_LV: os << "F" ; break;
|
||||
case LogVarType::POS_LV: os << "P" ; break;
|
||||
case LogVarType::NEG_LV: os << "N" ; break;
|
||||
case LogVarType::fullLvt: os << "F" ; break;
|
||||
case LogVarType::posLvt: os << "P" ; break;
|
||||
case LogVarType::negLvt: os << "N" ; break;
|
||||
}
|
||||
}
|
||||
os << ">" ;
|
||||
@ -385,6 +395,14 @@ operator<< (std::ostream &os, const LitLvTypes& lit)
|
||||
|
||||
|
||||
|
||||
void
|
||||
LitLvTypes::setAllFullLogVars()
|
||||
{
|
||||
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::fullLvt);
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
: freeLiteralId_(0), pfList_(pfList)
|
||||
{
|
||||
@ -394,7 +412,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
/*
|
||||
// INCLUSION-EXCLUSION TEST
|
||||
clauses_.clear();
|
||||
vector<vector<string>> names = {
|
||||
std::vector<std::vector<string>> names = {
|
||||
{"a1","b1"},{"a2","b2"}
|
||||
};
|
||||
Clause* c1 = new Clause (names);
|
||||
@ -406,7 +424,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
/*
|
||||
// INDEPENDENT PARTIAL GROUND TEST
|
||||
clauses_.clear();
|
||||
vector<vector<string>> names = {
|
||||
std::vector<std::vector<string>> names = {
|
||||
{"a1","b1"},{"a2","b2"}
|
||||
};
|
||||
Clause* c1 = new Clause (names);
|
||||
@ -422,7 +440,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
/*
|
||||
// ATOM-COUNTING TEST
|
||||
clauses_.clear();
|
||||
vector<vector<string>> names = {
|
||||
std::vector<std::vector<string>> names = {
|
||||
{"p1","p1"},{"p1","p2"},{"p1","p3"},
|
||||
{"p2","p1"},{"p2","p2"},{"p2","p3"},
|
||||
{"p3","p1"},{"p3","p2"},{"p3","p3"}
|
||||
@ -438,21 +456,21 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
*/
|
||||
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "FORMULA INDICATORS:" << endl;
|
||||
std::cout << "FORMULA INDICATORS:" << std::endl;
|
||||
printFormulaIndicators();
|
||||
cout << endl;
|
||||
cout << "WEIGHTED INDICATORS:" << endl;
|
||||
std::cout << std::endl;
|
||||
std::cout << "WEIGHTED INDICATORS:" << std::endl;
|
||||
printWeights();
|
||||
cout << endl;
|
||||
cout << "CLAUSES:" << endl;
|
||||
std::cout << std::endl;
|
||||
std::cout << "CLAUSES:" << std::endl;
|
||||
printClauses();
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedWCNF::~LiftedWCNF (void)
|
||||
LiftedWCNF::~LiftedWCNF()
|
||||
{
|
||||
Clause::deleteClauses (clauses_);
|
||||
}
|
||||
@ -462,7 +480,7 @@ LiftedWCNF::~LiftedWCNF (void)
|
||||
void
|
||||
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
|
||||
{
|
||||
weights_[lid] = make_pair (posW, negW);
|
||||
weights_[lid] = std::make_pair (posW, negW);
|
||||
}
|
||||
|
||||
|
||||
@ -470,8 +488,8 @@ LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
|
||||
double
|
||||
LiftedWCNF::posWeight (LiteralId lid) const
|
||||
{
|
||||
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||
it = weights_.find (lid);
|
||||
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
|
||||
= weights_.find (lid);
|
||||
return it != weights_.end() ? it->second.first : LogAware::one();
|
||||
}
|
||||
|
||||
@ -480,14 +498,14 @@ LiftedWCNF::posWeight (LiteralId lid) const
|
||||
double
|
||||
LiftedWCNF::negWeight (LiteralId lid) const
|
||||
{
|
||||
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||
it = weights_.find (lid);
|
||||
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
|
||||
= weights_.find (lid);
|
||||
return it != weights_.end() ? it->second.second : LogAware::one();
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<LiteralId>
|
||||
std::vector<LiteralId>
|
||||
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
|
||||
{
|
||||
assert (Util::contains (map_, prvGroup));
|
||||
@ -536,9 +554,10 @@ LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
|
||||
ConstraintTree tempConstr = (*it)->constr()->projectedCopy(
|
||||
formulas[i].logVars());
|
||||
Clause* clause = new Clause (tempConstr);
|
||||
vector<LiteralId> lids;
|
||||
std::vector<LiteralId> lids;
|
||||
for (size_t j = 0; j < formulas[i].range(); j++) {
|
||||
clause->addLiteral (Literal (freeLiteralId_, formulas[i].logVars()));
|
||||
clause->addLiteral (Literal (
|
||||
freeLiteralId_, formulas[i].logVars()));
|
||||
lids.push_back (freeLiteralId_);
|
||||
freeLiteralId_ ++;
|
||||
}
|
||||
@ -568,7 +587,7 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
Indexer indexer ((*it)->ranges());
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
std::vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
while (indexer.valid()) {
|
||||
LiteralId paramVarLid = freeLiteralId_;
|
||||
// λu1 ∧ ... ∧ λun ∧ λxi <=> θxi|u1,...,un
|
||||
@ -606,26 +625,26 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::printFormulaIndicators (void) const
|
||||
LiftedWCNF::printFormulaIndicators() const
|
||||
{
|
||||
if (map_.empty()) {
|
||||
return;
|
||||
}
|
||||
set<PrvGroup> allGroups;
|
||||
std::set<PrvGroup> allGroups;
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
for (size_t i = 0; i < formulas.size(); i++) {
|
||||
if (Util::contains (allGroups, formulas[i].group()) == false) {
|
||||
allGroups.insert (formulas[i].group());
|
||||
cout << formulas[i] << " | " ;
|
||||
std::cout << formulas[i] << " | " ;
|
||||
ConstraintTree tempCt = (*it)->constr()->projectedCopy (
|
||||
formulas[i].logVars());
|
||||
cout << tempCt.tupleSet();
|
||||
cout << " indicators => " ;
|
||||
vector<LiteralId> indicators =
|
||||
std::cout << tempCt.tupleSet();
|
||||
std::cout << " indicators => " ;
|
||||
std::vector<LiteralId> indicators =
|
||||
(map_.find (formulas[i].group()))->second;
|
||||
cout << indicators << endl;
|
||||
std::cout << indicators << std::endl;
|
||||
}
|
||||
}
|
||||
++ it;
|
||||
@ -635,14 +654,14 @@ LiftedWCNF::printFormulaIndicators (void) const
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::printWeights (void) const
|
||||
LiftedWCNF::printWeights() const
|
||||
{
|
||||
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||
it = weights_.begin();
|
||||
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
|
||||
= weights_.begin();
|
||||
while (it != weights_.end()) {
|
||||
cout << "λ" << it->first << " weights: " ;
|
||||
cout << it->second.first << " " << it->second.second;
|
||||
cout << endl;
|
||||
std::cout << "λ" << it->first << " weights: " ;
|
||||
std::cout << it->second.first << " " << it->second.second;
|
||||
std::cout << std::endl;
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
@ -650,8 +669,10 @@ LiftedWCNF::printWeights (void) const
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::printClauses (void) const
|
||||
LiftedWCNF::printClauses() const
|
||||
{
|
||||
Clause::printClauses (clauses_);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -1,90 +1,95 @@
|
||||
#ifndef HORUS_LIFTEDWCNF_H
|
||||
#define HORUS_LIFTEDWCNF_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDWCNF_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDWCNF_H_
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <ostream>
|
||||
|
||||
#include "ParfactorList.h"
|
||||
#include "ConstraintTree.h"
|
||||
#include "ProbFormula.h"
|
||||
#include "LiftedUtils.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class ConstraintTree;
|
||||
namespace Horus {
|
||||
|
||||
enum LogVarType
|
||||
{
|
||||
FULL_LV,
|
||||
POS_LV,
|
||||
NEG_LV
|
||||
class ParfactorList;
|
||||
|
||||
enum class LogVarType {
|
||||
fullLvt,
|
||||
posLvt,
|
||||
negLvt
|
||||
};
|
||||
|
||||
typedef long LiteralId;
|
||||
typedef vector<LogVarType> LogVarTypes;
|
||||
typedef long LiteralId;
|
||||
typedef std::vector<LogVarType> LogVarTypes;
|
||||
|
||||
|
||||
class Literal
|
||||
{
|
||||
class Literal {
|
||||
public:
|
||||
Literal (LiteralId lid, const LogVars& lvs) :
|
||||
lid_(lid), logVars_(lvs), negated_(false) { }
|
||||
Literal (LiteralId lid, const LogVars& lvs)
|
||||
: lid_(lid), logVars_(lvs), negated_(false) { }
|
||||
|
||||
Literal (const Literal& lit, bool negated) :
|
||||
lid_(lit.lid_), logVars_(lit.logVars_), negated_(negated) { }
|
||||
Literal (const Literal& lit, bool negated)
|
||||
: lid_(lit.lid_), logVars_(lit.logVars_), negated_(negated) { }
|
||||
|
||||
LiteralId lid (void) const { return lid_; }
|
||||
LiteralId lid() const { return lid_; }
|
||||
|
||||
LogVars logVars (void) const { return logVars_; }
|
||||
LogVars logVars() const { return logVars_; }
|
||||
|
||||
size_t nrLogVars (void) const { return logVars_.size(); }
|
||||
size_t nrLogVars() const { return logVars_.size(); }
|
||||
|
||||
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||
LogVarSet logVarSet() const { return LogVarSet (logVars_); }
|
||||
|
||||
void complement (void) { negated_ = !negated_; }
|
||||
void complement() { negated_ = !negated_; }
|
||||
|
||||
bool isPositive (void) const { return negated_ == false; }
|
||||
bool isPositive() const { return negated_ == false; }
|
||||
|
||||
bool isNegative (void) const { return negated_; }
|
||||
bool isNegative() const { return negated_; }
|
||||
|
||||
bool isGround (ConstraintTree constr, LogVarSet ipgLogVars) const;
|
||||
bool isGround (ConstraintTree constr, const LogVarSet& ipgLogVars) const;
|
||||
|
||||
size_t indexOfLogVar (LogVar X) const;
|
||||
|
||||
string toString (LogVarSet ipgLogVars = LogVarSet(),
|
||||
LogVarSet posCountedLvs = LogVarSet(),
|
||||
LogVarSet negCountedLvs = LogVarSet()) const;
|
||||
|
||||
friend std::ostream& operator<< (std::ostream &os, const Literal& lit);
|
||||
std::string toString (
|
||||
LogVarSet ipgLogVars = LogVarSet(),
|
||||
LogVarSet posCountedLvs = LogVarSet(),
|
||||
LogVarSet negCountedLvs = LogVarSet()) const;
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<< (std::ostream&, const Literal&);
|
||||
|
||||
LiteralId lid_;
|
||||
LogVars logVars_;
|
||||
bool negated_;
|
||||
};
|
||||
|
||||
typedef vector<Literal> Literals;
|
||||
typedef std::vector<Literal> Literals;
|
||||
|
||||
|
||||
|
||||
class Clause
|
||||
{
|
||||
class Clause {
|
||||
public:
|
||||
Clause (const ConstraintTree& ct = ConstraintTree({})) : constr_(ct) { }
|
||||
|
||||
Clause (vector<vector<string>> names) : constr_(ConstraintTree (names)) { }
|
||||
Clause (std::vector<std::vector<std::string>> names) :
|
||||
constr_(ConstraintTree (names)) { }
|
||||
|
||||
void addLiteral (const Literal& l) { literals_.push_back (l); }
|
||||
|
||||
const Literals& literals (void) const { return literals_; }
|
||||
const Literals& literals() const { return literals_; }
|
||||
|
||||
Literals& literals (void) { return literals_; }
|
||||
Literals& literals() { return literals_; }
|
||||
|
||||
size_t nrLiterals (void) const { return literals_.size(); }
|
||||
size_t nrLiterals() const { return literals_.size(); }
|
||||
|
||||
const ConstraintTree& constr (void) const { return constr_; }
|
||||
const ConstraintTree& constr() const { return constr_; }
|
||||
|
||||
ConstraintTree constr (void) { return constr_; }
|
||||
ConstraintTree constr() { return constr_; }
|
||||
|
||||
bool isUnit (void) const { return literals_.size() == 1; }
|
||||
bool isUnit() const { return literals_.size() == 1; }
|
||||
|
||||
LogVarSet ipgLogVars (void) const { return ipgLvs_; }
|
||||
LogVarSet ipgLogVars() const { return ipgLvs_; }
|
||||
|
||||
void addIpgLogVar (LogVar X) { ipgLvs_.insert (X); }
|
||||
|
||||
@ -92,13 +97,13 @@ class Clause
|
||||
|
||||
void addNegCountedLogVar (LogVar X) { negCountedLvs_.insert (X); }
|
||||
|
||||
LogVarSet posCountedLogVars (void) const { return posCountedLvs_; }
|
||||
LogVarSet posCountedLogVars() const { return posCountedLvs_; }
|
||||
|
||||
LogVarSet negCountedLogVars (void) const { return negCountedLvs_; }
|
||||
LogVarSet negCountedLogVars() const { return negCountedLvs_; }
|
||||
|
||||
unsigned nrPosCountedLogVars (void) const { return posCountedLvs_.size(); }
|
||||
unsigned nrPosCountedLogVars() const { return posCountedLvs_.size(); }
|
||||
|
||||
unsigned nrNegCountedLogVars (void) const { return negCountedLvs_.size(); }
|
||||
unsigned nrNegCountedLogVars() const { return negCountedLvs_.size(); }
|
||||
|
||||
void addLiteralComplemented (const Literal& lit);
|
||||
|
||||
@ -122,9 +127,9 @@ class Clause
|
||||
|
||||
bool isIpgLogVar (LogVar X) const;
|
||||
|
||||
TinySet<LiteralId> lidSet (void) const;
|
||||
TinySet<LiteralId> lidSet() const;
|
||||
|
||||
LogVarSet ipgCandidates (void) const;
|
||||
LogVarSet ipgCandidates() const;
|
||||
|
||||
LogVarTypes logVarTypes (size_t litIdx) const;
|
||||
|
||||
@ -132,78 +137,78 @@ class Clause
|
||||
|
||||
static bool independentClauses (Clause& c1, Clause& c2);
|
||||
|
||||
static vector<Clause*> copyClauses (const vector<Clause*>& clauses);
|
||||
static std::vector<Clause*> copyClauses (
|
||||
const std::vector<Clause*>& clauses);
|
||||
|
||||
static void printClauses (const vector<Clause*>& clauses);
|
||||
static void printClauses (const std::vector<Clause*>& clauses);
|
||||
|
||||
static void deleteClauses (vector<Clause*>& clauses);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const Clause& clause);
|
||||
static void deleteClauses (std::vector<Clause*>& clauses);
|
||||
|
||||
private:
|
||||
LogVarSet getLogVarSetExcluding (size_t idx) const;
|
||||
|
||||
Literals literals_;
|
||||
LogVarSet ipgLvs_;
|
||||
LogVarSet posCountedLvs_;
|
||||
LogVarSet negCountedLvs_;
|
||||
ConstraintTree constr_;
|
||||
friend std::ostream& operator<< (std::ostream&, const Clause&);
|
||||
|
||||
Literals literals_;
|
||||
LogVarSet ipgLvs_;
|
||||
LogVarSet posCountedLvs_;
|
||||
LogVarSet negCountedLvs_;
|
||||
ConstraintTree constr_;
|
||||
|
||||
DISALLOW_ASSIGN (Clause);
|
||||
};
|
||||
|
||||
typedef vector<Clause*> Clauses;
|
||||
typedef std::vector<Clause*> Clauses;
|
||||
|
||||
|
||||
|
||||
class LitLvTypes
|
||||
{
|
||||
class LitLvTypes {
|
||||
public:
|
||||
struct CompareLitLvTypes
|
||||
{
|
||||
bool operator() (
|
||||
const LitLvTypes& types1,
|
||||
const LitLvTypes& types2) const
|
||||
{
|
||||
if (types1.lid_ < types2.lid_) {
|
||||
return true;
|
||||
}
|
||||
if (types1.lid_ == types2.lid_) {
|
||||
return types1.lvTypes_ < types2.lvTypes_;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
LitLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
|
||||
lid_(lid), lvTypes_(lvTypes) { }
|
||||
|
||||
LiteralId lid (void) const { return lid_; }
|
||||
LiteralId lid() const { return lid_; }
|
||||
|
||||
const LogVarTypes& logVarTypes (void) const { return lvTypes_; }
|
||||
const LogVarTypes& logVarTypes() const { return lvTypes_; }
|
||||
|
||||
void setAllFullLogVars (void) {
|
||||
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV); }
|
||||
|
||||
friend std::ostream& operator<< (std::ostream &os, const LitLvTypes& lit);
|
||||
void setAllFullLogVars();
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<< (std::ostream&, const LitLvTypes&);
|
||||
|
||||
LiteralId lid_;
|
||||
LogVarTypes lvTypes_;
|
||||
};
|
||||
|
||||
typedef TinySet<LitLvTypes,LitLvTypes::CompareLitLvTypes> LitLvTypesSet;
|
||||
|
||||
|
||||
|
||||
class LiftedWCNF
|
||||
struct CmpLitLvTypes
|
||||
{
|
||||
bool operator() (
|
||||
const LitLvTypes& types1,
|
||||
const LitLvTypes& types2) const
|
||||
{
|
||||
if (types1.lid() < types2.lid()) {
|
||||
return true;
|
||||
}
|
||||
if (types1.lid() == types2.lid()){
|
||||
return types1.logVarTypes() < types2.logVarTypes();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
typedef TinySet<LitLvTypes, CmpLitLvTypes> LitLvTypesSet;
|
||||
|
||||
|
||||
|
||||
class LiftedWCNF {
|
||||
public:
|
||||
LiftedWCNF (const ParfactorList& pfList);
|
||||
|
||||
~LiftedWCNF (void);
|
||||
~LiftedWCNF();
|
||||
|
||||
const Clauses& clauses (void) const { return clauses_; }
|
||||
const Clauses& clauses() const { return clauses_; }
|
||||
|
||||
void addWeight (LiteralId lid, double posW, double negW);
|
||||
|
||||
@ -211,15 +216,15 @@ class LiftedWCNF
|
||||
|
||||
double negWeight (LiteralId lid) const;
|
||||
|
||||
vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
|
||||
std::vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
|
||||
|
||||
Clause* createClause (LiteralId lid) const;
|
||||
|
||||
void printFormulaIndicators (void) const;
|
||||
void printFormulaIndicators() const;
|
||||
|
||||
void printWeights (void) const;
|
||||
void printWeights() const;
|
||||
|
||||
void printClauses (void) const;
|
||||
void printClauses() const;
|
||||
|
||||
private:
|
||||
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
|
||||
@ -228,14 +233,16 @@ class LiftedWCNF
|
||||
|
||||
void addParameterClauses (const ParfactorList& pfList);
|
||||
|
||||
Clauses clauses_;
|
||||
LiteralId freeLiteralId_;
|
||||
const ParfactorList& pfList_;
|
||||
unordered_map<PrvGroup, vector<LiteralId>> map_;
|
||||
unordered_map<LiteralId, std::pair<double,double>> weights_;
|
||||
Clauses clauses_;
|
||||
LiteralId freeLiteralId_;
|
||||
const ParfactorList& pfList_;
|
||||
std::unordered_map<PrvGroup, std::vector<LiteralId>> map_;
|
||||
std::unordered_map<LiteralId, std::pair<double,double>> weights_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedWCNF);
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDWCNF_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDWCNF_H_
|
||||
|
||||
|
@ -43,9 +43,9 @@ SO=@SO@
|
||||
#4.1VPATH=@srcdir@:@srcdir@/OPTYap
|
||||
CWD=$(PWD)
|
||||
|
||||
HCLI = $(srcdir)/hcli
|
||||
utestsdir=@srcdir@/unit_tests
|
||||
|
||||
HEADERS = \
|
||||
MAIN_HEADERS = \
|
||||
$(srcdir)/BayesBall.h \
|
||||
$(srcdir)/BayesBallGraph.h \
|
||||
$(srcdir)/BeliefProp.h \
|
||||
@ -54,6 +54,8 @@ HEADERS = \
|
||||
$(srcdir)/ElimGraph.h \
|
||||
$(srcdir)/Factor.h \
|
||||
$(srcdir)/FactorGraph.h \
|
||||
$(srcdir)/GenericFactor.h \
|
||||
$(srcdir)/GroundSolver.h \
|
||||
$(srcdir)/Histogram.h \
|
||||
$(srcdir)/Horus.h \
|
||||
$(srcdir)/Indexer.h \
|
||||
@ -67,14 +69,20 @@ HEADERS = \
|
||||
$(srcdir)/Parfactor.h \
|
||||
$(srcdir)/ParfactorList.h \
|
||||
$(srcdir)/ProbFormula.h \
|
||||
$(srcdir)/GroundSolver.h \
|
||||
$(srcdir)/TinySet.h \
|
||||
$(srcdir)/Util.h \
|
||||
$(srcdir)/Var.h \
|
||||
$(srcdir)/VarElim.h \
|
||||
$(srcdir)/WeightedBp.h
|
||||
|
||||
CPP_SOURCES = \
|
||||
UTESTS_HEADERS = \
|
||||
$(utestsdir)/Common.h
|
||||
|
||||
HEADERS = \
|
||||
$(MAIN_HEADERS) \
|
||||
$(UTESTS_HEADERS)
|
||||
|
||||
MAIN_SOURCES = \
|
||||
$(srcdir)/BayesBall.cpp \
|
||||
$(srcdir)/BayesBallGraph.cpp \
|
||||
$(srcdir)/BeliefProp.cpp \
|
||||
@ -83,9 +91,12 @@ CPP_SOURCES = \
|
||||
$(srcdir)/ElimGraph.cpp \
|
||||
$(srcdir)/Factor.cpp \
|
||||
$(srcdir)/FactorGraph.cpp \
|
||||
$(srcdir)/GenericFactor.cpp \
|
||||
$(srcdir)/GroundSolver.cpp \
|
||||
$(srcdir)/Histogram.cpp \
|
||||
$(srcdir)/HorusCli.cpp \
|
||||
$(srcdir)/HorusYap.cpp \
|
||||
$(srcdir)/Indexer.cpp \
|
||||
$(srcdir)/LiftedBp.cpp \
|
||||
$(srcdir)/LiftedKc.cpp \
|
||||
$(srcdir)/LiftedOperations.cpp \
|
||||
@ -95,12 +106,23 @@ CPP_SOURCES = \
|
||||
$(srcdir)/Parfactor.cpp \
|
||||
$(srcdir)/ParfactorList.cpp \
|
||||
$(srcdir)/ProbFormula.cpp \
|
||||
$(srcdir)/GroundSolver.cpp \
|
||||
$(srcdir)/Util.cpp \
|
||||
$(srcdir)/Var.cpp \
|
||||
$(srcdir)/VarElim.cpp \
|
||||
$(srcdir)/WeightedBp.cpp
|
||||
|
||||
UTESTS_SOURCES = \
|
||||
$(utestsdir)/BeliefPropTest.cpp \
|
||||
$(utestsdir)/Common.cpp \
|
||||
$(utestsdir)/CountingBpTest.cpp \
|
||||
$(utestsdir)/FactorTest.cpp \
|
||||
$(utestsdir)/VarElimTest.cpp \
|
||||
$(utestsdir)/UnitTesting.cpp
|
||||
|
||||
SOURCES = \
|
||||
$(MAIN_SOURCES) \
|
||||
$(UTESTS_SOURCES)
|
||||
|
||||
OBJS = \
|
||||
BayesBall.o \
|
||||
BayesBallGraph.o \
|
||||
@ -110,8 +132,10 @@ OBJS = \
|
||||
ElimGraph.o \
|
||||
Factor.o \
|
||||
FactorGraph.o \
|
||||
GenericFactor.o \
|
||||
GroundSolver.o \
|
||||
Histogram.o \
|
||||
HorusYap.o \
|
||||
Indexer.o \
|
||||
LiftedBp.o \
|
||||
LiftedKc.o \
|
||||
LiftedOperations.o \
|
||||
@ -121,12 +145,15 @@ OBJS = \
|
||||
ProbFormula.o \
|
||||
Parfactor.o \
|
||||
ParfactorList.o \
|
||||
GroundSolver.o \
|
||||
Util.o \
|
||||
Var.o \
|
||||
VarElim.o \
|
||||
WeightedBp.o
|
||||
|
||||
LIB_OBJS = \
|
||||
$(OBJS) \
|
||||
HorusYap.o
|
||||
|
||||
HCLI_OBJS = \
|
||||
BayesBall.o \
|
||||
BayesBallGraph.o \
|
||||
@ -135,51 +162,82 @@ HCLI_OBJS = \
|
||||
ElimGraph.o \
|
||||
Factor.o \
|
||||
FactorGraph.o \
|
||||
HorusCli.o \
|
||||
GenericFactor.o \
|
||||
GroundSolver.o \
|
||||
HorusCli.o \
|
||||
Indexer.o \
|
||||
Util.o \
|
||||
Var.o \
|
||||
VarElim.o \
|
||||
WeightedBp.o
|
||||
|
||||
SOBJS=horus.@SO@
|
||||
UTESTS_OBJS = \
|
||||
$(OBJS) \
|
||||
$(utestsdir)/BeliefPropTest.o \
|
||||
$(utestsdir)/Common.o \
|
||||
$(utestsdir)/CountingBpTest.o \
|
||||
$(utestsdir)/FactorTest.o \
|
||||
$(utestsdir)/VarElimTest.o \
|
||||
$(utestsdir)/UnitTesting.o
|
||||
|
||||
|
||||
all: $(SOBJS) hcli
|
||||
LIB = $(srcdir)/horus.@SO@
|
||||
HCLI = $(srcdir)/hcli
|
||||
UTESTING = $(srcdir)/run_tests
|
||||
|
||||
|
||||
all: $(LIB) $(HCLI)
|
||||
|
||||
|
||||
# Don't require $(UTESTING) by default as we
|
||||
# don't want a hard dependency on CppUnit
|
||||
with_tests: $(LIB) $(HCLI) $(UTESTING)
|
||||
|
||||
|
||||
@DO_SECOND_LD@$(LIB): $(LIB_OBJS)
|
||||
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o $@ $(LIB_OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
||||
|
||||
|
||||
$(HCLI): $(HCLI_OBJS)
|
||||
$(CXX) -o $@ $(HCLI_OBJS)
|
||||
|
||||
|
||||
$(UTESTING): $(UTESTS_OBJS)
|
||||
$(CXX) -o $@ $(UTESTS_OBJS) -lcppunit
|
||||
|
||||
|
||||
# default rule
|
||||
%.o : $(srcdir)/%.cpp
|
||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
||||
|
||||
|
||||
@DO_SECOND_LD@horus.@SO@: $(OBJS)
|
||||
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
||||
|
||||
|
||||
hcli: $(HCLI_OBJS)
|
||||
$(CXX) -o $(HCLI) $(HCLI_OBJS)
|
||||
$(CXX) -o $@ -c $(CXXFLAGS) $<
|
||||
|
||||
|
||||
install: all
|
||||
$(INSTALL_PROGRAM) $(SOBJS) $(DESTDIR)$(YAPLIBDIR)
|
||||
$(INSTALL_PROGRAM) $(LIB) $(DESTDIR)$(YAPLIBDIR)
|
||||
$(INSTALL_PROGRAM) $(HCLI) $(DESTDIR)$(BINDIR)
|
||||
|
||||
|
||||
clean:
|
||||
rm -f *.o *~ $(OBJS) $(SOBJS) $(HCLI) *.BAK
|
||||
rm -f $(LIB) $(HCLI) $(UTESTING) *.o *~ $(utestsdir)/*.o $(utestsdir)/*~
|
||||
|
||||
|
||||
erase_dots:
|
||||
rm -f *.dot *.png
|
||||
remove_dots:
|
||||
rm -f *.dot *.png *.svg
|
||||
|
||||
|
||||
depend: $(HEADERS) $(CPP_SOURCES)
|
||||
depend: $(SOURCES) $(HEADERS)
|
||||
-@if test "$(GCC)" = yes; then\
|
||||
$(CC) -std=c++0x -MM -MG $(CFLAGS) -I$(srcdir) -I$(srcdir)/../../../../include -I$(srcdir)/../../../../H $(CPP_SOURCES) >> Makefile;\
|
||||
for F in $(SOURCES); do \
|
||||
D=`dirname $$F`; \
|
||||
B=`basename $$F .cpp`; \
|
||||
$(CXX) $(CXXFLAGS) -MM -MG -MT "$$D/$$B.o" -I$(srcdir)/../../../../H -I$(srcdir)/../../../../include $$F >> Makefile; \
|
||||
done; \
|
||||
else\
|
||||
makedepend -f - -- $(CFLAGS) -I$(srcdir)/../../../../H -I$(srcdir)/../../../../include -- $(CPP_SOURCES) |\
|
||||
sed 's|.*/\([^:]*\):|\1:|' >> Makefile ;\
|
||||
makedepend -- $(CXXFLAGS) -- -I$(srcdir)/../../../../H -I$(srcdir)/../../../../include $(SOURCES); \
|
||||
fi
|
||||
|
||||
|
||||
.PHONY: default all install clean remove_dots depend
|
||||
|
||||
|
||||
# DO NOT DELETE THIS LINE -- make depend depends on it.
|
||||
|
||||
|
@ -1,3 +1,8 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "Parfactor.h"
|
||||
#include "Histogram.h"
|
||||
#include "Indexer.h"
|
||||
@ -5,6 +10,8 @@
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
Parfactor::Parfactor (
|
||||
const ProbFormulas& formulas,
|
||||
const Params& params,
|
||||
@ -84,7 +91,7 @@ Parfactor::Parfactor (const Parfactor& g)
|
||||
|
||||
|
||||
|
||||
Parfactor::~Parfactor (void)
|
||||
Parfactor::~Parfactor()
|
||||
{
|
||||
delete constr_;
|
||||
}
|
||||
@ -92,7 +99,7 @@ Parfactor::~Parfactor (void)
|
||||
|
||||
|
||||
LogVarSet
|
||||
Parfactor::countedLogVars (void) const
|
||||
Parfactor::countedLogVars() const
|
||||
{
|
||||
LogVarSet set;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
@ -106,7 +113,7 @@ Parfactor::countedLogVars (void) const
|
||||
|
||||
|
||||
LogVarSet
|
||||
Parfactor::uncountedLogVars (void) const
|
||||
Parfactor::uncountedLogVars() const
|
||||
{
|
||||
return constr_->logVarSet() - countedLogVars();
|
||||
}
|
||||
@ -114,7 +121,7 @@ Parfactor::uncountedLogVars (void) const
|
||||
|
||||
|
||||
LogVarSet
|
||||
Parfactor::elimLogVars (void) const
|
||||
Parfactor::elimLogVars() const
|
||||
{
|
||||
LogVarSet requiredToElim = constr_->logVarSet();
|
||||
requiredToElim -= constr_->singletons();
|
||||
@ -149,7 +156,7 @@ Parfactor::sumOutIndex (size_t fIdx)
|
||||
unsigned N = constr_->getConditionalCount (
|
||||
args_[fIdx].countedLogVar());
|
||||
unsigned R = args_[fIdx].range();
|
||||
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||
std::vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||
Indexer indexer (ranges_, fIdx);
|
||||
while (indexer.valid()) {
|
||||
if (Globals::logDomain) {
|
||||
@ -171,7 +178,7 @@ Parfactor::sumOutIndex (size_t fIdx)
|
||||
}
|
||||
constr_->remove (excl);
|
||||
|
||||
TFactor<ProbFormula>::sumOutIndex (fIdx);
|
||||
GenericFactor<ProbFormula>::sumOutIndex (fIdx);
|
||||
LogAware::pow (params_, exp);
|
||||
}
|
||||
|
||||
@ -181,7 +188,7 @@ void
|
||||
Parfactor::multiply (Parfactor& g)
|
||||
{
|
||||
alignAndExponentiate (this, &g);
|
||||
TFactor<ProbFormula>::multiply (g);
|
||||
GenericFactor<ProbFormula>::multiply (g);
|
||||
constr_->join (g.constr(), true);
|
||||
simplifyGrounds();
|
||||
assert (constr_->isCartesianProduct (countedLogVars()));
|
||||
@ -224,10 +231,10 @@ Parfactor::countConvert (LogVar X)
|
||||
unsigned N = constr_->getConditionalCount (X);
|
||||
unsigned R = ranges_[fIdx];
|
||||
unsigned H = HistogramSet::nrHistograms (N, R);
|
||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
std::vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
|
||||
Indexer indexer (ranges_);
|
||||
vector<Params> sumout (params_.size() / R);
|
||||
std::vector<Params> sumout (params_.size() / R);
|
||||
unsigned count = 0;
|
||||
while (indexer.valid()) {
|
||||
sumout[count].reserve (R);
|
||||
@ -279,11 +286,11 @@ Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
|
||||
unsigned H1 = HistogramSet::nrHistograms (N1, R);
|
||||
unsigned H2 = HistogramSet::nrHistograms (N2, R);
|
||||
|
||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
|
||||
vector<Histogram> histograms2 = HistogramSet::getHistograms (N2, R);
|
||||
std::vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
std::vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
|
||||
std::vector<Histogram> histograms2 = HistogramSet::getHistograms (N2, R);
|
||||
|
||||
vector<unsigned> sumIndexes;
|
||||
std::vector<unsigned> sumIndexes;
|
||||
sumIndexes.reserve (H1 * H2);
|
||||
for (unsigned i = 0; i < H1; i++) {
|
||||
for (unsigned j = 0; j < H2; j++) {
|
||||
@ -319,16 +326,16 @@ Parfactor::fullExpand (LogVar X)
|
||||
|
||||
unsigned N = constr_->getConditionalCount (X);
|
||||
unsigned R = args_[fIdx].range();
|
||||
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
||||
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
||||
std::vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
||||
std::vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
||||
assert (ranges_[fIdx] == originHists.size());
|
||||
vector<unsigned> sumIndexes;
|
||||
std::vector<unsigned> sumIndexes;
|
||||
sumIndexes.reserve (N * R);
|
||||
|
||||
Ranges expandRanges (N, R);
|
||||
Indexer indexer (expandRanges);
|
||||
while (indexer.valid()) {
|
||||
vector<unsigned> hist (R, 0);
|
||||
std::vector<unsigned> hist (R, 0);
|
||||
for (unsigned n = 0; n < N; n++) {
|
||||
hist += expandHists[indexer[n]];
|
||||
}
|
||||
@ -384,14 +391,14 @@ Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
|
||||
assert (args_[fIdx].isCounting() == false);
|
||||
assert (constr_->isCountNormalized (excl));
|
||||
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
|
||||
GenericFactor<ProbFormula>::absorveEvidence (formula, evidence);
|
||||
constr_->remove (excl);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::setNewGroups (void)
|
||||
Parfactor::setNewGroups()
|
||||
{
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
args_[i].setGroup (ProbFormula::getNewGroup());
|
||||
@ -494,7 +501,7 @@ Parfactor::containsGroup (PrvGroup group) const
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGroups (vector<PrvGroup> groups) const
|
||||
Parfactor::containsGroups (std::vector<PrvGroup> groups) const
|
||||
{
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
if (containsGroup (groups[i]) == false) {
|
||||
@ -565,10 +572,10 @@ Parfactor::nrFormulasWithGroup (PrvGroup group) const
|
||||
|
||||
|
||||
|
||||
vector<PrvGroup>
|
||||
Parfactor::getAllGroups (void) const
|
||||
std::vector<PrvGroup>
|
||||
Parfactor::getAllGroups() const
|
||||
{
|
||||
vector<PrvGroup> groups (args_.size());
|
||||
std::vector<PrvGroup> groups (args_.size());
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
groups[i] = args_[i].group();
|
||||
}
|
||||
@ -577,10 +584,10 @@ Parfactor::getAllGroups (void) const
|
||||
|
||||
|
||||
|
||||
string
|
||||
Parfactor::getLabel (void) const
|
||||
std::string
|
||||
Parfactor::getLabel() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "phi(" ;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
@ -598,6 +605,8 @@ Parfactor::getLabel (void) const
|
||||
void
|
||||
Parfactor::print (bool printParams) const
|
||||
{
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
cout << "Formulas: " ;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) cout << ", " ;
|
||||
@ -605,9 +614,10 @@ Parfactor::print (bool printParams) const
|
||||
}
|
||||
cout << endl;
|
||||
if (args_[0].group() != Util::maxUnsigned()) {
|
||||
vector<string> groups;
|
||||
std::vector<std::string> groups;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
groups.push_back (string ("g") + Util::toString (args_[i].group()));
|
||||
groups.push_back (std::string ("g")
|
||||
+ Util::toString (args_[i].group()));
|
||||
}
|
||||
cout << "Groups: " << groups << endl;
|
||||
}
|
||||
@ -633,12 +643,12 @@ Parfactor::print (bool printParams) const
|
||||
|
||||
|
||||
void
|
||||
Parfactor::printParameters (void) const
|
||||
Parfactor::printParameters() const
|
||||
{
|
||||
vector<string> jointStrings;
|
||||
std::vector<std::string> jointStrings;
|
||||
Indexer indexer (ranges_);
|
||||
while (indexer.valid()) {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
if (args_[i].isCounting()) {
|
||||
@ -659,22 +669,22 @@ Parfactor::printParameters (void) const
|
||||
++ indexer;
|
||||
}
|
||||
for (size_t i = 0; i < params_.size(); i++) {
|
||||
cout << "f(" << jointStrings[i] << ")" ;
|
||||
cout << " = " << params_[i] << endl;
|
||||
std::cout << "f(" << jointStrings[i] << ")" ;
|
||||
std::cout << " = " << params_[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::printProjections (void) const
|
||||
Parfactor::printProjections() const
|
||||
{
|
||||
ConstraintTree copy (*constr_);
|
||||
|
||||
LogVarSet Xs = copy.logVarSet();
|
||||
for (size_t i = 0; i < Xs.size(); i++) {
|
||||
cout << "-> projection of " << Xs[i] << ": " ;
|
||||
cout << copy.tupleSet ({Xs[i]}) << endl;
|
||||
std::cout << "-> projection of " << Xs[i] << ": " ;
|
||||
std::cout << copy.tupleSet ({Xs[i]}) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -684,12 +694,12 @@ void
|
||||
Parfactor::expandPotential (
|
||||
size_t fIdx,
|
||||
unsigned newRange,
|
||||
const vector<unsigned>& sumIndexes)
|
||||
const std::vector<unsigned>& sumIndexes)
|
||||
{
|
||||
ullong newSize = (params_.size() / ranges_[fIdx]) * newRange;
|
||||
if (newSize > params_.max_size()) {
|
||||
cerr << "Error: an overflow occurred when performing expansion." ;
|
||||
cerr << endl;
|
||||
std::cerr << "Error: an overflow occurred when performing expansion." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
|
||||
@ -698,7 +708,7 @@ Parfactor::expandPotential (
|
||||
params_.reserve (newSize);
|
||||
|
||||
size_t prod = 1;
|
||||
vector<size_t> offsets (ranges_.size());
|
||||
std::vector<size_t> offsets (ranges_.size());
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
offsets[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
@ -706,7 +716,7 @@ Parfactor::expandPotential (
|
||||
|
||||
size_t index = 0;
|
||||
ranges_[fIdx] = newRange;
|
||||
vector<unsigned> indices (ranges_.size(), 0);
|
||||
std::vector<unsigned> indices (ranges_.size(), 0);
|
||||
for (size_t k = 0; k < newSize; k++) {
|
||||
assert (index < backup.size());
|
||||
params_.push_back (backup[index]);
|
||||
@ -759,7 +769,7 @@ Parfactor::simplifyCountingFormulas (size_t fIdx)
|
||||
|
||||
|
||||
void
|
||||
Parfactor::simplifyGrounds (void)
|
||||
Parfactor::simplifyGrounds()
|
||||
{
|
||||
if (args_.size() == 1) {
|
||||
return;
|
||||
@ -872,12 +882,12 @@ Parfactor::alignLogicalVars (Parfactor* g1, Parfactor* g2)
|
||||
std::pair<LogVars, LogVars> res = getAlignLogVars (g1, g2);
|
||||
const LogVars& alignLvs1 = res.first;
|
||||
const LogVars& alignLvs2 = res.second;
|
||||
// cout << "ALIGNING :::::::::::::::::" << endl;
|
||||
// std::cout << "ALIGNING :::::::::::::::::" << std::endl;
|
||||
// g1->print();
|
||||
// cout << "AND" << endl;
|
||||
// g2->print();
|
||||
// cout << "-> align lvs1 = " << alignLvs1 << endl;
|
||||
// cout << "-> align lvs2 = " << alignLvs2 << endl;
|
||||
// std::cout << "-> align lvs1 = " << alignLvs1 << std::endl;
|
||||
// std::cout << "-> align lvs2 = " << alignLvs2 << std::endl;
|
||||
LogVar freeLogVar (0);
|
||||
Substitution theta1, theta2;
|
||||
for (size_t i = 0; i < alignLvs1.size(); i++) {
|
||||
@ -933,9 +943,11 @@ Parfactor::alignLogicalVars (Parfactor* g1, Parfactor* g2)
|
||||
}
|
||||
}
|
||||
|
||||
// cout << "theta1: " << theta1 << endl;
|
||||
// cout << "theta2: " << theta2 << endl;
|
||||
// std::cout << "theta1: " << theta1 << std::endl;
|
||||
// std::cout << "theta2: " << theta2 << std::endl;
|
||||
g1->applySubstitution (theta1);
|
||||
g2->applySubstitution (theta2);
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,20 +1,21 @@
|
||||
#ifndef HORUS_PARFACTOR_H
|
||||
#define HORUS_PARFACTOR_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_PARFACTOR_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_PARFACTOR_H_
|
||||
|
||||
#include "Factor.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "GenericFactor.h"
|
||||
#include "ProbFormula.h"
|
||||
#include "ConstraintTree.h"
|
||||
#include "LiftedUtils.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
class Parfactor : public TFactor<ProbFormula>
|
||||
{
|
||||
namespace Horus {
|
||||
|
||||
class Parfactor : public GenericFactor<ProbFormula> {
|
||||
public:
|
||||
Parfactor (
|
||||
const ProbFormulas&,
|
||||
const Params&,
|
||||
const Tuples&,
|
||||
Parfactor (const ProbFormulas&, const Params&, const Tuples&,
|
||||
unsigned distId);
|
||||
|
||||
Parfactor (const Parfactor*, const Tuple&);
|
||||
@ -23,21 +24,21 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
Parfactor (const Parfactor&);
|
||||
|
||||
~Parfactor (void);
|
||||
~Parfactor();
|
||||
|
||||
ConstraintTree* constr (void) { return constr_; }
|
||||
ConstraintTree* constr() { return constr_; }
|
||||
|
||||
const ConstraintTree* constr (void) const { return constr_; }
|
||||
const ConstraintTree* constr() const { return constr_; }
|
||||
|
||||
const LogVars& logVars (void) const { return constr_->logVars(); }
|
||||
const LogVars& logVars() const { return constr_->logVars(); }
|
||||
|
||||
const LogVarSet& logVarSet (void) const { return constr_->logVarSet(); }
|
||||
const LogVarSet& logVarSet() const { return constr_->logVarSet(); }
|
||||
|
||||
LogVarSet countedLogVars (void) const;
|
||||
LogVarSet countedLogVars() const;
|
||||
|
||||
LogVarSet uncountedLogVars (void) const;
|
||||
LogVarSet uncountedLogVars() const;
|
||||
|
||||
LogVarSet elimLogVars (void) const;
|
||||
LogVarSet elimLogVars() const;
|
||||
|
||||
LogVarSet exclusiveLogVars (size_t fIdx) const;
|
||||
|
||||
@ -57,7 +58,7 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
void absorveEvidence (const ProbFormula&, unsigned);
|
||||
|
||||
void setNewGroups (void);
|
||||
void setNewGroups();
|
||||
|
||||
void applySubstitution (const Substitution&);
|
||||
|
||||
@ -71,7 +72,7 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
bool containsGroup (PrvGroup) const;
|
||||
|
||||
bool containsGroups (vector<PrvGroup>) const;
|
||||
bool containsGroups (std::vector<PrvGroup>) const;
|
||||
|
||||
unsigned nrFormulas (LogVar) const;
|
||||
|
||||
@ -81,17 +82,17 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
unsigned nrFormulasWithGroup (PrvGroup) const;
|
||||
|
||||
vector<PrvGroup> getAllGroups (void) const;
|
||||
std::vector<PrvGroup> getAllGroups() const;
|
||||
|
||||
void print (bool = false) const;
|
||||
|
||||
void printParameters (void) const;
|
||||
void printParameters() const;
|
||||
|
||||
void printProjections (void) const;
|
||||
void printProjections() const;
|
||||
|
||||
string getLabel (void) const;
|
||||
std::string getLabel() const;
|
||||
|
||||
void simplifyGrounds (void);
|
||||
void simplifyGrounds();
|
||||
|
||||
static bool canMultiply (Parfactor*, Parfactor*);
|
||||
|
||||
@ -104,18 +105,20 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
Parfactor* g1, Parfactor* g2);
|
||||
|
||||
void expandPotential (size_t fIdx, unsigned newRange,
|
||||
const vector<unsigned>& sumIndexes);
|
||||
const std::vector<unsigned>& sumIndexes);
|
||||
|
||||
static void alignAndExponentiate (Parfactor*, Parfactor*);
|
||||
|
||||
static void alignLogicalVars (Parfactor*, Parfactor*);
|
||||
|
||||
ConstraintTree* constr_;
|
||||
ConstraintTree* constr_;
|
||||
|
||||
DISALLOW_ASSIGN (Parfactor);
|
||||
};
|
||||
|
||||
typedef vector<Parfactor*> Parfactors;
|
||||
typedef std::vector<Parfactor*> Parfactors;
|
||||
|
||||
#endif // HORUS_PARFACTOR_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_PARFACTOR_H_
|
||||
|
||||
|
@ -1,10 +1,14 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <queue>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ParfactorList.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
ParfactorList::ParfactorList (const ParfactorList& pfList)
|
||||
{
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
@ -23,7 +27,7 @@ ParfactorList::ParfactorList (const Parfactors& pfs)
|
||||
|
||||
|
||||
|
||||
ParfactorList::~ParfactorList (void)
|
||||
ParfactorList::~ParfactorList()
|
||||
{
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
@ -64,27 +68,27 @@ ParfactorList::addShattered (Parfactor* pf)
|
||||
|
||||
|
||||
|
||||
list<Parfactor*>::iterator
|
||||
std::list<Parfactor*>::iterator
|
||||
ParfactorList::insertShattered (
|
||||
list<Parfactor*>::iterator it,
|
||||
std::list<Parfactor*>::iterator it,
|
||||
Parfactor* pf)
|
||||
{
|
||||
return pfList_.insert (it, pf);
|
||||
assert (isAllShattered());
|
||||
return pfList_.insert (it, pf);
|
||||
}
|
||||
|
||||
|
||||
|
||||
list<Parfactor*>::iterator
|
||||
ParfactorList::remove (list<Parfactor*>::iterator it)
|
||||
std::list<Parfactor*>::iterator
|
||||
ParfactorList::remove (std::list<Parfactor*>::iterator it)
|
||||
{
|
||||
return pfList_.erase (it);
|
||||
}
|
||||
|
||||
|
||||
|
||||
list<Parfactor*>::iterator
|
||||
ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
|
||||
std::list<Parfactor*>::iterator
|
||||
ParfactorList::removeAndDelete (std::list<Parfactor*>::iterator it)
|
||||
{
|
||||
delete *it;
|
||||
return pfList_.erase (it);
|
||||
@ -93,12 +97,12 @@ ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::isAllShattered (void) const
|
||||
ParfactorList::isAllShattered() const
|
||||
{
|
||||
if (pfList_.size() <= 1) {
|
||||
return true;
|
||||
}
|
||||
vector<Parfactor*> pfs (pfList_.begin(), pfList_.end());
|
||||
Parfactors pfs (pfList_.begin(), pfList_.end());
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
assert (isShattered (pfs[i]));
|
||||
}
|
||||
@ -115,13 +119,25 @@ ParfactorList::isAllShattered (void) const
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::print (void) const
|
||||
ParfactorList::print() const
|
||||
{
|
||||
struct sortByParams {
|
||||
bool operator() (const Parfactor* pf1, const Parfactor* pf2)
|
||||
{
|
||||
if (pf1->params().size() < pf2->params().size()) {
|
||||
return true;
|
||||
} else if (pf1->params().size() == pf2->params().size() &&
|
||||
pf1->params() < pf2->params()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
Parfactors pfVec (pfList_.begin(), pfList_.end());
|
||||
std::sort (pfVec.begin(), pfVec.end(), sortByParams());
|
||||
for (size_t i = 0; i < pfVec.size(); i++) {
|
||||
pfVec[i]->print();
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -163,8 +179,8 @@ ParfactorList::isShattered (const Parfactor* g) const
|
||||
formulas[i], *(g->constr()),
|
||||
formulas[j], *(g->constr())) == false) {
|
||||
g->print();
|
||||
cout << "-> not identical on positions " ;
|
||||
cout << i << " and " << j << endl;
|
||||
std::cout << "-> not identical on positions " ;
|
||||
std::cout << i << " and " << j << std::endl;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
@ -172,8 +188,8 @@ ParfactorList::isShattered (const Parfactor* g) const
|
||||
formulas[i], *(g->constr()),
|
||||
formulas[j], *(g->constr())) == false) {
|
||||
g->print();
|
||||
cout << "-> not disjoint on positions " ;
|
||||
cout << i << " and " << j << endl;
|
||||
std::cout << "-> not disjoint on positions " ;
|
||||
std::cout << i << " and " << j << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -200,9 +216,10 @@ ParfactorList::isShattered (
|
||||
fms1[i], *(g1->constr()),
|
||||
fms2[j], *(g2->constr())) == false) {
|
||||
g1->print();
|
||||
cout << "^" << endl;
|
||||
std::cout << "^" << std::endl;
|
||||
g2->print();
|
||||
cout << "-> not identical on group " << fms1[i].group() << endl;
|
||||
std::cout << "-> not identical on group " ;
|
||||
std::cout << fms1[i].group() << std::endl;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
@ -210,10 +227,10 @@ ParfactorList::isShattered (
|
||||
fms1[i], *(g1->constr()),
|
||||
fms2[j], *(g2->constr())) == false) {
|
||||
g1->print();
|
||||
cout << "^" << endl;
|
||||
std::cout << "^" << std::endl;
|
||||
g2->print();
|
||||
cout << "-> not disjoint on groups " << fms1[i].group();
|
||||
cout << " and " << fms2[j].group() << endl;
|
||||
std::cout << "-> not disjoint on groups " << fms1[i].group();
|
||||
std::cout << " and " << fms2[j].group() << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -227,12 +244,12 @@ ParfactorList::isShattered (
|
||||
void
|
||||
ParfactorList::addToShatteredList (Parfactor* g)
|
||||
{
|
||||
queue<Parfactor*> residuals;
|
||||
std::queue<Parfactor*> residuals;
|
||||
residuals.push (g);
|
||||
while (residuals.empty() == false) {
|
||||
Parfactor* pf = residuals.front();
|
||||
bool pfSplitted = false;
|
||||
list<Parfactor*>::iterator pfIter;
|
||||
std::list<Parfactor*>::iterator pfIter;
|
||||
pfIter = pfList_.begin();
|
||||
while (pfIter != pfList_.end()) {
|
||||
std::pair<Parfactors, Parfactors> shattRes;
|
||||
@ -269,7 +286,7 @@ Parfactors
|
||||
ParfactorList::shatterAgainstMySelf (Parfactor* g)
|
||||
{
|
||||
Parfactors pfs;
|
||||
queue<Parfactor*> residuals;
|
||||
std::queue<Parfactor*> residuals;
|
||||
residuals.push (g);
|
||||
bool shattered = true;
|
||||
while (residuals.empty() == false) {
|
||||
@ -325,19 +342,22 @@ ParfactorList::shatterAgainstMySelf (
|
||||
{
|
||||
/*
|
||||
Util::printDashedLine();
|
||||
cout << "-> SHATTERING" << endl;
|
||||
std::cout << "-> SHATTERING" << std::endl;
|
||||
g->print();
|
||||
cout << "-> ON: " << g->argument (fIdx1) << "|" ;
|
||||
cout << g->constr()->tupleSet (g->argument (fIdx1).logVars()) << endl;
|
||||
cout << "-> ON: " << g->argument (fIdx2) << "|" ;
|
||||
cout << g->constr()->tupleSet (g->argument (fIdx2).logVars()) << endl;
|
||||
std::cout << "-> ON: " << g->argument (fIdx1) << "|" ;
|
||||
std::cout << g->constr()->tupleSet (g->argument (fIdx1).logVars());
|
||||
std::cout << std::endl;
|
||||
std::cout << "-> ON: " << g->argument (fIdx2) << "|" ;
|
||||
std::cout << g->constr()->tupleSet (g->argument (fIdx2).logVars())
|
||||
std::cout << std::endl;
|
||||
Util::printDashedLine();
|
||||
*/
|
||||
ProbFormula& f1 = g->argument (fIdx1);
|
||||
ProbFormula& f2 = g->argument (fIdx2);
|
||||
if (f1.isAtom()) {
|
||||
cerr << "Error: a ground occurs twice in the same parfactor." << endl;
|
||||
cerr << endl;
|
||||
std::cerr << "Error: a ground occurs twice in the same parfactor." ;
|
||||
std::cerr << std::endl;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
assert (g->constr()->empty() == false);
|
||||
@ -441,14 +461,14 @@ ParfactorList::shatter (
|
||||
ProbFormula& f2 = g2->argument (fIdx2);
|
||||
/*
|
||||
Util::printDashedLine();
|
||||
cout << "-> SHATTERING" << endl;
|
||||
std::cout << "-> SHATTERING" << std::endl;
|
||||
g1->print();
|
||||
cout << "-> WITH" << endl;
|
||||
std::cout << "-> WITH" << std::endl;
|
||||
g2->print();
|
||||
cout << "-> ON: " << f1 << "|" ;
|
||||
cout << g1->constr()->tupleSet (f1.logVars()) << endl;
|
||||
cout << "-> ON: " << f2 << "|" ;
|
||||
cout << g2->constr()->tupleSet (f2.logVars()) << endl;
|
||||
std::cout << "-> ON: " << f1 << "|" ;
|
||||
std::cout << g1->constr()->tupleSet (f1.logVars()) << std::endl;
|
||||
std::cout << "-> ON: " << f2 << "|" ;
|
||||
std::cout << g2->constr()->tupleSet (f2.logVars()) << std::endl;
|
||||
Util::printDashedLine();
|
||||
*/
|
||||
if (f1.isAtom()) {
|
||||
@ -486,12 +506,12 @@ ParfactorList::shatter (
|
||||
assert (commCt1->tupleSet (f1.logVars()) ==
|
||||
commCt2->tupleSet (f2.logVars()));
|
||||
|
||||
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
|
||||
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
|
||||
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
|
||||
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
|
||||
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
|
||||
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
|
||||
// std::stringstream ss1; ss1 << "" << count << "_A.dot" ;
|
||||
// std::stringstream ss2; ss2 << "" << count << "_B.dot" ;
|
||||
// std::stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
|
||||
// std::stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
|
||||
// std::stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
|
||||
// std::stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
|
||||
// g1->constr()->exportToGraphViz (ss1.str().c_str(), true);
|
||||
// g2->constr()->exportToGraphViz (ss2.str().c_str(), true);
|
||||
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
|
||||
@ -638,3 +658,5 @@ ParfactorList::disjoint (
|
||||
return (ts1 & ts2).empty();
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_PARFACTORLIST_H
|
||||
#define HORUS_PARFACTORLIST_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_PARFACTORLIST_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_PARFACTORLIST_H_
|
||||
|
||||
#include <list>
|
||||
|
||||
@ -7,39 +7,38 @@
|
||||
#include "ProbFormula.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class Parfactor;
|
||||
|
||||
class ParfactorList
|
||||
{
|
||||
|
||||
class ParfactorList {
|
||||
public:
|
||||
ParfactorList (void) { }
|
||||
ParfactorList() { }
|
||||
|
||||
ParfactorList (const ParfactorList&);
|
||||
|
||||
ParfactorList (const Parfactors&);
|
||||
|
||||
~ParfactorList (void);
|
||||
~ParfactorList();
|
||||
|
||||
const list<Parfactor*>& parfactors (void) const { return pfList_; }
|
||||
const std::list<Parfactor*>& parfactors() const { return pfList_; }
|
||||
|
||||
void clear (void) { pfList_.clear(); }
|
||||
void clear() { pfList_.clear(); }
|
||||
|
||||
size_t size (void) const { return pfList_.size(); }
|
||||
size_t size() const { return pfList_.size(); }
|
||||
|
||||
typedef std::list<Parfactor*>::iterator iterator;
|
||||
|
||||
iterator begin (void) { return pfList_.begin(); }
|
||||
iterator begin() { return pfList_.begin(); }
|
||||
|
||||
iterator end (void) { return pfList_.end(); }
|
||||
iterator end() { return pfList_.end(); }
|
||||
|
||||
typedef std::list<Parfactor*>::const_iterator const_iterator;
|
||||
|
||||
const_iterator begin (void) const { return pfList_.begin(); }
|
||||
const_iterator begin() const { return pfList_.begin(); }
|
||||
|
||||
const_iterator end (void) const { return pfList_.end(); }
|
||||
const_iterator end() const { return pfList_.end(); }
|
||||
|
||||
void add (Parfactor* pf);
|
||||
|
||||
@ -47,16 +46,18 @@ class ParfactorList
|
||||
|
||||
void addShattered (Parfactor* pf);
|
||||
|
||||
list<Parfactor*>::iterator insertShattered (
|
||||
list<Parfactor*>::iterator, Parfactor*);
|
||||
std::list<Parfactor*>::iterator insertShattered (
|
||||
std::list<Parfactor*>::iterator, Parfactor*);
|
||||
|
||||
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
|
||||
std::list<Parfactor*>::iterator remove (
|
||||
std::list<Parfactor*>::iterator);
|
||||
|
||||
list<Parfactor*>::iterator removeAndDelete (list<Parfactor*>::iterator);
|
||||
std::list<Parfactor*>::iterator removeAndDelete (
|
||||
std::list<Parfactor*>::iterator);
|
||||
|
||||
bool isAllShattered (void) const;
|
||||
bool isAllShattered() const;
|
||||
|
||||
void print (void) const;
|
||||
void print() const;
|
||||
|
||||
ParfactorList& operator= (const ParfactorList& pfList);
|
||||
|
||||
@ -101,22 +102,10 @@ class ParfactorList
|
||||
const ProbFormula&, ConstraintTree,
|
||||
const ProbFormula&, ConstraintTree) const;
|
||||
|
||||
struct sortByParams
|
||||
{
|
||||
inline bool operator() (const Parfactor* pf1, const Parfactor* pf2)
|
||||
{
|
||||
if (pf1->params().size() < pf2->params().size()) {
|
||||
return true;
|
||||
} else if (pf1->params().size() == pf2->params().size() &&
|
||||
pf1->params() < pf2->params()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
list<Parfactor*> pfList_;
|
||||
std::list<Parfactor*> pfList_;
|
||||
};
|
||||
|
||||
#endif // HORUS_PARFACTORLIST_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_PARFACTORLIST_H_
|
||||
|
||||
|
@ -1,6 +1,13 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ProbFormula.h"
|
||||
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
PrvGroup ProbFormula::freeGroup_ = 0;
|
||||
|
||||
|
||||
@ -38,7 +45,7 @@ ProbFormula::indexOf (LogVar X) const
|
||||
|
||||
|
||||
bool
|
||||
ProbFormula::isAtom (void) const
|
||||
ProbFormula::isAtom() const
|
||||
{
|
||||
return logVars_.empty();
|
||||
}
|
||||
@ -46,7 +53,7 @@ ProbFormula::isAtom (void) const
|
||||
|
||||
|
||||
bool
|
||||
ProbFormula::isCounting (void) const
|
||||
ProbFormula::isCounting() const
|
||||
{
|
||||
return countedLogVar_.valid();
|
||||
}
|
||||
@ -54,7 +61,7 @@ ProbFormula::isCounting (void) const
|
||||
|
||||
|
||||
LogVar
|
||||
ProbFormula::countedLogVar (void) const
|
||||
ProbFormula::countedLogVar() const
|
||||
{
|
||||
assert (isCounting());
|
||||
return countedLogVar_;
|
||||
@ -71,7 +78,7 @@ ProbFormula::setCountedLogVar (LogVar lv)
|
||||
|
||||
|
||||
void
|
||||
ProbFormula::clearCountedLogVar (void)
|
||||
ProbFormula::clearCountedLogVar()
|
||||
{
|
||||
countedLogVar_ = LogVar();
|
||||
}
|
||||
@ -93,15 +100,8 @@ ProbFormula::rename (LogVar oldName, LogVar newName)
|
||||
|
||||
|
||||
|
||||
bool operator== (const ProbFormula& f1, const ProbFormula& f2)
|
||||
{
|
||||
return f1.group_ == f2.group_ &&
|
||||
f1.logVars_ == f2.logVars_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::ostream& operator<< (ostream &os, const ProbFormula& f)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const ProbFormula& f)
|
||||
{
|
||||
os << f.functor_;
|
||||
if (f.isAtom() == false) {
|
||||
@ -122,7 +122,7 @@ std::ostream& operator<< (ostream &os, const ProbFormula& f)
|
||||
|
||||
|
||||
PrvGroup
|
||||
ProbFormula::getNewGroup (void)
|
||||
ProbFormula::getNewGroup()
|
||||
{
|
||||
freeGroup_ ++;
|
||||
assert (freeGroup_ != std::numeric_limits<PrvGroup>::max());
|
||||
@ -131,7 +131,24 @@ ProbFormula::getNewGroup (void)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const ObservedFormula& of)
|
||||
ObservedFormula::ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
||||
: functor_(f), arity_(a), evidence_(ev), constr_(a)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
ObservedFormula::ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
||||
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
|
||||
{
|
||||
constr_.addTuple (tuple);
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const ObservedFormula& of)
|
||||
{
|
||||
os << of.functor_ << "/" << of.arity_;
|
||||
os << "|" << of.constr_.tupleSet();
|
||||
@ -139,3 +156,5 @@ ostream& operator<< (ostream &os, const ObservedFormula& of)
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,16 +1,20 @@
|
||||
#ifndef HORUS_PROBFORMULA_H
|
||||
#define HORUS_PROBFORMULA_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_PROBFORMULA_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_PROBFORMULA_H_
|
||||
|
||||
#include <vector>
|
||||
#include <ostream>
|
||||
#include <limits>
|
||||
|
||||
#include "ConstraintTree.h"
|
||||
#include "LiftedUtils.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
typedef unsigned long PrvGroup;
|
||||
|
||||
class ProbFormula
|
||||
{
|
||||
class ProbFormula {
|
||||
public:
|
||||
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
|
||||
: functor_(f), logVars_(lvs), range_(range),
|
||||
@ -20,19 +24,19 @@ class ProbFormula
|
||||
: functor_(f), range_(r),
|
||||
group_(std::numeric_limits<PrvGroup>::max()) { }
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
Symbol functor() const { return functor_; }
|
||||
|
||||
unsigned arity (void) const { return logVars_.size(); }
|
||||
unsigned arity() const { return logVars_.size(); }
|
||||
|
||||
unsigned range (void) const { return range_; }
|
||||
unsigned range() const { return range_; }
|
||||
|
||||
LogVars& logVars (void) { return logVars_; }
|
||||
LogVars& logVars() { return logVars_; }
|
||||
|
||||
const LogVars& logVars (void) const { return logVars_; }
|
||||
const LogVars& logVars() const { return logVars_; }
|
||||
|
||||
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||
LogVarSet logVarSet() const { return LogVarSet (logVars_); }
|
||||
|
||||
PrvGroup group (void) const { return group_; }
|
||||
PrvGroup group() const { return group_; }
|
||||
|
||||
void setGroup (PrvGroup g) { group_ = g; }
|
||||
|
||||
@ -44,25 +48,28 @@ class ProbFormula
|
||||
|
||||
size_t indexOf (LogVar) const;
|
||||
|
||||
bool isAtom (void) const;
|
||||
bool isAtom() const;
|
||||
|
||||
bool isCounting (void) const;
|
||||
bool isCounting() const;
|
||||
|
||||
LogVar countedLogVar (void) const;
|
||||
LogVar countedLogVar() const;
|
||||
|
||||
void setCountedLogVar (LogVar);
|
||||
|
||||
void clearCountedLogVar (void);
|
||||
void clearCountedLogVar();
|
||||
|
||||
void rename (LogVar, LogVar);
|
||||
|
||||
static PrvGroup getNewGroup (void);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
|
||||
|
||||
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
|
||||
static PrvGroup getNewGroup();
|
||||
|
||||
private:
|
||||
|
||||
friend bool operator== (
|
||||
const ProbFormula& f1, const ProbFormula& f2);
|
||||
|
||||
friend std::ostream& operator<< (
|
||||
std::ostream&, const ProbFormula&);
|
||||
|
||||
Symbol functor_;
|
||||
LogVars logVars_;
|
||||
unsigned range_;
|
||||
@ -71,45 +78,50 @@ class ProbFormula
|
||||
static PrvGroup freeGroup_;
|
||||
};
|
||||
|
||||
typedef vector<ProbFormula> ProbFormulas;
|
||||
typedef std::vector<ProbFormula> ProbFormulas;
|
||||
|
||||
|
||||
class ObservedFormula
|
||||
inline bool
|
||||
operator== (const ProbFormula& f1, const ProbFormula& f2)
|
||||
{
|
||||
return f1.group_ == f2.group_ && f1.logVars_ == f2.logVars_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
class ObservedFormula {
|
||||
public:
|
||||
ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
||||
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
|
||||
ObservedFormula (Symbol f, unsigned a, unsigned ev);
|
||||
|
||||
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
||||
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
|
||||
{
|
||||
constr_.addTuple (tuple);
|
||||
}
|
||||
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple);
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
Symbol functor() const { return functor_; }
|
||||
|
||||
unsigned arity (void) const { return arity_; }
|
||||
unsigned arity() const { return arity_; }
|
||||
|
||||
unsigned evidence (void) const { return evidence_; }
|
||||
unsigned evidence() const { return evidence_; }
|
||||
|
||||
void setEvidence (unsigned ev) { evidence_ = ev; }
|
||||
|
||||
ConstraintTree& constr (void) { return constr_; }
|
||||
ConstraintTree& constr() { return constr_; }
|
||||
|
||||
bool isAtom (void) const { return arity_ == 0; }
|
||||
bool isAtom() const { return arity_ == 0; }
|
||||
|
||||
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<< (
|
||||
std::ostream&, const ObservedFormula&);
|
||||
|
||||
Symbol functor_;
|
||||
unsigned arity_;
|
||||
unsigned evidence_;
|
||||
ConstraintTree constr_;
|
||||
};
|
||||
|
||||
typedef vector<ObservedFormula> ObservedFormulas;
|
||||
typedef std::vector<ObservedFormula> ObservedFormulas;
|
||||
|
||||
#endif // HORUS_PROBFORMULA_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_PROBFORMULA_H_
|
||||
|
||||
|
@ -1,20 +1,18 @@
|
||||
#ifndef HORUS_TINYSET_H
|
||||
#define HORUS_TINYSET_H
|
||||
|
||||
#include <algorithm>
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_TINYSET_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_TINYSET_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <ostream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
template <typename T, typename Compare = std::less<T>>
|
||||
class TinySet
|
||||
{
|
||||
class TinySet {
|
||||
public:
|
||||
|
||||
typedef typename vector<T>::iterator iterator;
|
||||
typedef typename vector<T>::const_iterator const_iterator;
|
||||
typedef typename std::vector<T>::iterator iterator;
|
||||
typedef typename std::vector<T>::const_iterator const_iterator;
|
||||
|
||||
TinySet (const TinySet& s)
|
||||
: vec_(s.vec_), cmp_(s.cmp_) { }
|
||||
@ -25,190 +23,72 @@ class TinySet
|
||||
TinySet (const T& t, const Compare& cmp = Compare())
|
||||
: vec_(1, t), cmp_(cmp) { }
|
||||
|
||||
TinySet (const vector<T>& elements, const Compare& cmp = Compare())
|
||||
: vec_(elements), cmp_(cmp)
|
||||
{
|
||||
std::sort (begin(), end(), cmp_);
|
||||
iterator it = unique_cmp (begin(), end());
|
||||
vec_.resize (it - begin());
|
||||
}
|
||||
TinySet (const std::vector<T>& elements, const Compare& cmp = Compare());
|
||||
|
||||
iterator insert (const T& t)
|
||||
{
|
||||
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
if (it == end() || cmp_(t, *it)) {
|
||||
vec_.insert (it, t);
|
||||
}
|
||||
return it;
|
||||
}
|
||||
iterator insert (const T& t);
|
||||
|
||||
void insert_sorted (const T& t)
|
||||
{
|
||||
vec_.push_back (t);
|
||||
assert (consistent());
|
||||
}
|
||||
void insert_sorted (const T& t);
|
||||
|
||||
void remove (const T& t)
|
||||
{
|
||||
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
if (it != end()) {
|
||||
vec_.erase (it);
|
||||
}
|
||||
}
|
||||
void remove (const T& t);
|
||||
|
||||
const_iterator find (const T& t) const
|
||||
{
|
||||
const_iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
return it == end() || cmp_(t, *it) ? end() : it;
|
||||
}
|
||||
|
||||
iterator find (const T& t)
|
||||
{
|
||||
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
return it == end() || cmp_(t, *it) ? end() : it;
|
||||
}
|
||||
const_iterator find (const T& t) const;
|
||||
|
||||
iterator find (const T& t);
|
||||
|
||||
/* set union */
|
||||
TinySet operator| (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_union (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
std::back_inserter (res.vec_),
|
||||
cmp_);
|
||||
return res;
|
||||
}
|
||||
TinySet operator| (const TinySet& s) const;
|
||||
|
||||
/* set intersection */
|
||||
TinySet operator& (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_intersection (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
std::back_inserter (res.vec_),
|
||||
cmp_);
|
||||
return res;
|
||||
}
|
||||
TinySet operator& (const TinySet& s) const;
|
||||
|
||||
/* set difference */
|
||||
TinySet operator- (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_difference (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
std::back_inserter (res.vec_),
|
||||
cmp_);
|
||||
return res;
|
||||
}
|
||||
TinySet operator- (const TinySet& s) const;
|
||||
|
||||
TinySet& operator|= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this | s);
|
||||
}
|
||||
TinySet& operator|= (const TinySet& s);
|
||||
|
||||
TinySet& operator&= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this & s);
|
||||
}
|
||||
TinySet& operator&= (const TinySet& s);
|
||||
|
||||
TinySet& operator-= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this - s);
|
||||
}
|
||||
TinySet& operator-= (const TinySet& s);
|
||||
|
||||
bool contains (const T& t) const
|
||||
{
|
||||
return std::binary_search (
|
||||
vec_.begin(), vec_.end(), t, cmp_);
|
||||
}
|
||||
bool contains (const T& t) const;
|
||||
|
||||
bool contains (const TinySet& s) const
|
||||
{
|
||||
return std::includes (
|
||||
vec_.begin(),
|
||||
vec_.end(),
|
||||
s.vec_.begin(),
|
||||
s.vec_.end(),
|
||||
cmp_);
|
||||
}
|
||||
bool contains (const TinySet& s) const;
|
||||
|
||||
bool in (const TinySet& s) const
|
||||
{
|
||||
return std::includes (
|
||||
s.vec_.begin(),
|
||||
s.vec_.end(),
|
||||
vec_.begin(),
|
||||
vec_.end(),
|
||||
cmp_);
|
||||
}
|
||||
bool in (const TinySet& s) const;
|
||||
|
||||
bool intersects (const TinySet& s) const
|
||||
{
|
||||
return (*this & s).size() > 0;
|
||||
}
|
||||
bool intersects (const TinySet& s) const;
|
||||
|
||||
const T& operator[] (typename vector<T>::size_type i) const
|
||||
{
|
||||
return vec_[i];
|
||||
}
|
||||
const T& operator[] (typename std::vector<T>::size_type i) const;
|
||||
|
||||
T& operator[] (typename vector<T>::size_type i)
|
||||
{
|
||||
return vec_[i];
|
||||
}
|
||||
T& operator[] (typename std::vector<T>::size_type i);
|
||||
|
||||
T front (void) const
|
||||
{
|
||||
return vec_.front();
|
||||
}
|
||||
T front() const;
|
||||
|
||||
T& front (void)
|
||||
{
|
||||
return vec_.front();
|
||||
}
|
||||
T& front();
|
||||
|
||||
T back (void) const
|
||||
{
|
||||
return vec_.back();
|
||||
}
|
||||
T back() const;
|
||||
|
||||
T& back (void)
|
||||
{
|
||||
return vec_.back();
|
||||
}
|
||||
T& back();
|
||||
|
||||
const vector<T>& elements (void) const
|
||||
{
|
||||
return vec_;
|
||||
}
|
||||
const std::vector<T>& elements() const;
|
||||
|
||||
bool empty (void) const
|
||||
{
|
||||
return vec_.empty();
|
||||
}
|
||||
bool empty() const;
|
||||
|
||||
typename vector<T>::size_type size (void) const
|
||||
{
|
||||
return vec_.size();
|
||||
}
|
||||
typename std::vector<T>::size_type size() const;
|
||||
|
||||
void clear (void)
|
||||
{
|
||||
vec_.clear();
|
||||
}
|
||||
void clear();
|
||||
|
||||
void reserve (typename vector<T>::size_type size)
|
||||
{
|
||||
vec_.reserve (size);
|
||||
}
|
||||
void reserve (typename std::vector<T>::size_type size);
|
||||
|
||||
iterator begin (void) { return vec_.begin(); }
|
||||
iterator end (void) { return vec_.end(); }
|
||||
const_iterator begin (void) const { return vec_.begin(); }
|
||||
const_iterator end (void) const { return vec_.end(); }
|
||||
iterator begin() { return vec_.begin(); }
|
||||
iterator end () { return vec_.end(); }
|
||||
const_iterator begin() const { return vec_.begin(); }
|
||||
const_iterator end () const { return vec_.end(); }
|
||||
|
||||
private:
|
||||
iterator unique_cmp (iterator first, iterator last);
|
||||
|
||||
bool consistent() const;
|
||||
|
||||
friend bool operator== (const TinySet& s1, const TinySet& s2)
|
||||
{
|
||||
@ -223,7 +103,7 @@ class TinySet
|
||||
friend std::ostream& operator<< (std::ostream& out, const TinySet& s)
|
||||
{
|
||||
out << "{" ;
|
||||
typename vector<T>::size_type i;
|
||||
typename std::vector<T>::size_type i;
|
||||
for (i = 0; i < s.size(); i++) {
|
||||
out << ((i != 0) ? "," : "") << s.vec_[i];
|
||||
}
|
||||
@ -231,35 +111,299 @@ class TinySet
|
||||
return out;
|
||||
}
|
||||
|
||||
private:
|
||||
iterator unique_cmp (iterator first, iterator last)
|
||||
{
|
||||
if (first == last) {
|
||||
return last;
|
||||
}
|
||||
iterator result = first;
|
||||
while (++first != last) {
|
||||
if (cmp_(*result, *first)) {
|
||||
*(++result) = *first;
|
||||
}
|
||||
}
|
||||
return ++result;
|
||||
}
|
||||
|
||||
bool consistent (void) const
|
||||
{
|
||||
typename vector<T>::size_type i;
|
||||
for (i = 0; i < vec_.size() - 1; i++) {
|
||||
if ( ! cmp_(vec_[i], vec_[i + 1])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
vector<T> vec_;
|
||||
Compare cmp_;
|
||||
std::vector<T> vec_;
|
||||
Compare cmp_;
|
||||
};
|
||||
|
||||
#endif // HORUS_TINYSET_H
|
||||
|
||||
|
||||
template <typename T, typename C> inline
|
||||
TinySet<T,C>::TinySet (const std::vector<T>& elements, const C& cmp)
|
||||
: vec_(elements), cmp_(cmp)
|
||||
{
|
||||
std::sort (begin(), end(), cmp_);
|
||||
iterator it = unique_cmp (begin(), end());
|
||||
vec_.resize (it - begin());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline typename TinySet<T,C>::iterator
|
||||
TinySet<T,C>::insert (const T& t)
|
||||
{
|
||||
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
if (it == end() || cmp_(t, *it)) {
|
||||
vec_.insert (it, t);
|
||||
}
|
||||
return it;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline void
|
||||
TinySet<T,C>::insert_sorted (const T& t)
|
||||
{
|
||||
vec_.push_back (t);
|
||||
assert (consistent());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline void
|
||||
TinySet<T,C>::remove (const T& t)
|
||||
{
|
||||
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
if (it != end()) {
|
||||
vec_.erase (it);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline typename TinySet<T,C>::const_iterator
|
||||
TinySet<T,C>::find (const T& t) const
|
||||
{
|
||||
const_iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
return it == end() || cmp_(t, *it) ? end() : it;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline typename TinySet<T,C>::iterator
|
||||
TinySet<T,C>::find (const T& t)
|
||||
{
|
||||
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
return it == end() || cmp_(t, *it) ? end() : it;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* set union */
|
||||
template <typename T, typename C> inline TinySet<T,C>
|
||||
TinySet<T,C>::operator| (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_union (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
std::back_inserter (res.vec_),
|
||||
cmp_);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* set intersection */
|
||||
template <typename T, typename C> inline TinySet<T,C>
|
||||
TinySet<T,C>::operator& (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_intersection (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
std::back_inserter (res.vec_),
|
||||
cmp_);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* set difference */
|
||||
template <typename T, typename C> inline TinySet<T,C>
|
||||
TinySet<T,C>::operator- (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_difference (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
std::back_inserter (res.vec_),
|
||||
cmp_);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline TinySet<T,C>&
|
||||
TinySet<T,C>::operator|= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this | s);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline TinySet<T,C>&
|
||||
TinySet<T,C>::operator&= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this & s);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline TinySet<T,C>&
|
||||
TinySet<T,C>::operator-= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this - s);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline bool
|
||||
TinySet<T,C>::contains (const T& t) const
|
||||
{
|
||||
return std::binary_search (
|
||||
vec_.begin(), vec_.end(), t, cmp_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline bool
|
||||
TinySet<T,C>::contains (const TinySet& s) const
|
||||
{
|
||||
return std::includes (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
cmp_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline bool
|
||||
TinySet<T,C>::in (const TinySet& s) const
|
||||
{
|
||||
return std::includes (
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
vec_.begin(), vec_.end(),
|
||||
cmp_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline bool
|
||||
TinySet<T,C>::intersects (const TinySet& s) const
|
||||
{
|
||||
return (*this & s).size() > 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline const T&
|
||||
TinySet<T,C>::operator[] (typename std::vector<T>::size_type i) const
|
||||
{
|
||||
return vec_[i];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline T&
|
||||
TinySet<T,C>::operator[] (typename std::vector<T>::size_type i)
|
||||
{
|
||||
return vec_[i];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline T
|
||||
TinySet<T,C>::front() const
|
||||
{
|
||||
return vec_.front();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline T&
|
||||
TinySet<T,C>::front()
|
||||
{
|
||||
return vec_.front();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline T
|
||||
TinySet<T,C>::back() const
|
||||
{
|
||||
return vec_.back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline T&
|
||||
TinySet<T,C>::back()
|
||||
{
|
||||
return vec_.back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline const std::vector<T>&
|
||||
TinySet<T,C>::elements() const
|
||||
{
|
||||
return vec_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline bool
|
||||
TinySet<T,C>::empty() const
|
||||
{
|
||||
return vec_.empty();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline typename std::vector<T>::size_type
|
||||
TinySet<T,C>::size() const
|
||||
{
|
||||
return vec_.size();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline void
|
||||
TinySet<T,C>::clear()
|
||||
{
|
||||
vec_.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline void
|
||||
TinySet<T,C>::reserve (typename std::vector<T>::size_type size)
|
||||
{
|
||||
vec_.reserve (size);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> typename TinySet<T,C>::iterator
|
||||
TinySet<T,C>::unique_cmp (iterator first, iterator last)
|
||||
{
|
||||
if (first == last) {
|
||||
return last;
|
||||
}
|
||||
iterator result = first;
|
||||
while (++first != last) {
|
||||
if (cmp_(*result, *first)) {
|
||||
*(++result) = *first;
|
||||
}
|
||||
}
|
||||
return ++result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline bool
|
||||
TinySet<T,C>::consistent() const
|
||||
{
|
||||
typename std::vector<T>::size_type i;
|
||||
for (i = 0; i < vec_.size() - 1; i++) {
|
||||
if ( ! cmp_(vec_[i], vec_[i + 1])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_TINYSET_H_
|
||||
|
||||
|
@ -1,27 +1,27 @@
|
||||
#include <fstream>
|
||||
|
||||
#include "Util.h"
|
||||
#include "Indexer.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "BeliefProp.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace Globals {
|
||||
|
||||
bool logDomain = false;
|
||||
|
||||
unsigned verbosity = 0;
|
||||
|
||||
LiftedSolverType liftedSolver = LiftedSolverType::LVE;
|
||||
LiftedSolverType liftedSolver = LiftedSolverType::lveSolver;
|
||||
|
||||
GroundSolverType groundSolver = GroundSolverType::VE;
|
||||
GroundSolverType groundSolver = GroundSolverType::veSolver;
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
|
||||
namespace Util {
|
||||
|
||||
|
||||
template <> std::string
|
||||
toString (const bool& b)
|
||||
{
|
||||
@ -33,14 +33,14 @@ toString (const bool& b)
|
||||
|
||||
|
||||
unsigned
|
||||
stringToUnsigned (string str)
|
||||
stringToUnsigned (std::string str)
|
||||
{
|
||||
int val;
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << str;
|
||||
ss >> val;
|
||||
if (val < 0) {
|
||||
cerr << "Error: the number readed is negative." << endl;
|
||||
std::cerr << "Error: the number readed is negative." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
return static_cast<unsigned> (val);
|
||||
@ -49,10 +49,10 @@ stringToUnsigned (string str)
|
||||
|
||||
|
||||
double
|
||||
stringToDouble (string str)
|
||||
stringToDouble (std::string str)
|
||||
{
|
||||
double val;
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << str;
|
||||
ss >> val;
|
||||
return val;
|
||||
@ -117,7 +117,7 @@ size_t
|
||||
sizeExpected (const Ranges& ranges)
|
||||
{
|
||||
return std::accumulate (ranges.begin(),
|
||||
ranges.end(), 1, multiplies<unsigned>());
|
||||
ranges.end(), 1, std::multiplies<unsigned>());
|
||||
}
|
||||
|
||||
|
||||
@ -136,10 +136,10 @@ nrDigits (int num)
|
||||
|
||||
|
||||
bool
|
||||
isInteger (const string& s)
|
||||
isInteger (const std::string& s)
|
||||
{
|
||||
stringstream ss1 (s);
|
||||
stringstream ss2;
|
||||
std::stringstream ss1 (s);
|
||||
std::stringstream ss2;
|
||||
int integer;
|
||||
ss1 >> integer;
|
||||
ss2 << integer;
|
||||
@ -148,10 +148,10 @@ isInteger (const string& s)
|
||||
|
||||
|
||||
|
||||
string
|
||||
std::string
|
||||
parametersToString (const Params& v, unsigned precision)
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss.precision (precision);
|
||||
ss << "[" ;
|
||||
for (size_t i = 0; i < v.size(); i++) {
|
||||
@ -164,7 +164,7 @@ parametersToString (const Params& v, unsigned precision)
|
||||
|
||||
|
||||
|
||||
vector<string>
|
||||
std::vector<std::string>
|
||||
getStateLines (const Vars& vars)
|
||||
{
|
||||
Ranges ranges;
|
||||
@ -172,9 +172,9 @@ getStateLines (const Vars& vars)
|
||||
ranges.push_back (vars[i]->range());
|
||||
}
|
||||
Indexer indexer (ranges);
|
||||
vector<string> jointStrings;
|
||||
std::vector<std::string> jointStrings;
|
||||
while (indexer.valid()) {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
for (size_t i = 0; i < vars.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
ss << vars[i]->label() << "=" ;
|
||||
@ -188,34 +188,42 @@ getStateLines (const Vars& vars)
|
||||
|
||||
|
||||
|
||||
bool invalidValue (string option, string value)
|
||||
bool invalidValue (std::string option, std::string value)
|
||||
{
|
||||
cerr << "Warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << option << "'." ;
|
||||
cerr << endl;
|
||||
std::cerr << "Warning: invalid value `" << value << "' " ;
|
||||
std::cerr << "for `" << option << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
setHorusFlag (string option, string value)
|
||||
setHorusFlag (std::string option, std::string value)
|
||||
{
|
||||
bool returnVal = true;
|
||||
if (option == "lifted_solver") {
|
||||
if (value == "lve") Globals::liftedSolver = LiftedSolverType::LVE;
|
||||
else if (value == "lbp") Globals::liftedSolver = LiftedSolverType::LBP;
|
||||
else if (value == "lkc") Globals::liftedSolver = LiftedSolverType::LKC;
|
||||
else returnVal = invalidValue (option, value);
|
||||
if (value == "lve")
|
||||
Globals::liftedSolver = LiftedSolverType::lveSolver;
|
||||
else if (value == "lbp")
|
||||
Globals::liftedSolver = LiftedSolverType::lbpSolver;
|
||||
else if (value == "lkc")
|
||||
Globals::liftedSolver = LiftedSolverType::lkcSolver;
|
||||
else
|
||||
returnVal = invalidValue (option, value);
|
||||
|
||||
} else if (option == "ground_solver" || option == "solver") {
|
||||
if (value == "hve") Globals::groundSolver = GroundSolverType::VE;
|
||||
else if (value == "bp") Globals::groundSolver = GroundSolverType::BP;
|
||||
else if (value == "cbp") Globals::groundSolver = GroundSolverType::CBP;
|
||||
else returnVal = invalidValue (option, value);
|
||||
if (value == "hve")
|
||||
Globals::groundSolver = GroundSolverType::veSolver;
|
||||
else if (value == "bp")
|
||||
Globals::groundSolver = GroundSolverType::bpSolver;
|
||||
else if (value == "cbp")
|
||||
Globals::groundSolver = GroundSolverType::CbpSolver;
|
||||
else
|
||||
returnVal = invalidValue (option, value);
|
||||
|
||||
} else if (option == "verbosity") {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << value;
|
||||
ss >> Globals::verbosity;
|
||||
|
||||
@ -225,40 +233,42 @@ setHorusFlag (string option, string value)
|
||||
else returnVal = invalidValue (option, value);
|
||||
|
||||
} else if (option == "hve_elim_heuristic") {
|
||||
typedef ElimGraph::ElimHeuristic ElimHeuristic;
|
||||
if (value == "sequential")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::sequentialEh);
|
||||
else if (value == "min_neighbors")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_NEIGHBORS);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::minNeighborsEh);
|
||||
else if (value == "min_weight")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_WEIGHT);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::minWeightEh);
|
||||
else if (value == "min_fill")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_FILL);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::minFillEh);
|
||||
else if (value == "weighted_min_fill")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::weightedMinFillEh);
|
||||
else
|
||||
returnVal = invalidValue (option, value);
|
||||
|
||||
} else if (option == "bp_msg_schedule") {
|
||||
typedef BeliefProp::MsgSchedule MsgSchedule;
|
||||
if (value == "seq_fixed")
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED);
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::seqFixedSch);
|
||||
else if (value == "seq_random")
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_RANDOM);
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::seqRandomSch);
|
||||
else if (value == "parallel")
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::PARALLEL);
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::parallelSch);
|
||||
else if (value == "max_residual")
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::MAX_RESIDUAL);
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::maxResidualSch);
|
||||
else
|
||||
returnVal = invalidValue (option, value);
|
||||
|
||||
} else if (option == "bp_accuracy") {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
double acc;
|
||||
ss << value;
|
||||
ss >> acc;
|
||||
BeliefProp::setAccuracy (acc);
|
||||
|
||||
} else if (option == "bp_max_iter") {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
unsigned mi;
|
||||
ss << value;
|
||||
ss >> mi;
|
||||
@ -285,7 +295,7 @@ setHorusFlag (string option, string value)
|
||||
else returnVal = invalidValue (option, value);
|
||||
|
||||
} else {
|
||||
cerr << "Warning: invalid option `" << option << "'" << endl;
|
||||
std::cerr << "Warning: invalid option `" << option << "'" << std::endl;
|
||||
returnVal = false;
|
||||
}
|
||||
return returnVal;
|
||||
@ -294,20 +304,20 @@ setHorusFlag (string option, string value)
|
||||
|
||||
|
||||
void
|
||||
printHeader (string header, std::ostream& os)
|
||||
printHeader (std::string header, std::ostream& os)
|
||||
{
|
||||
printAsteriskLine (os);
|
||||
os << header << endl;
|
||||
os << header << std::endl;
|
||||
printAsteriskLine (os);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
printSubHeader (string header, std::ostream& os)
|
||||
printSubHeader (std::string header, std::ostream& os)
|
||||
{
|
||||
printDashedLine (os);
|
||||
os << header << endl;
|
||||
os << header << std::endl;
|
||||
printDashedLine (os);
|
||||
}
|
||||
|
||||
@ -318,7 +328,7 @@ printAsteriskLine (std::ostream& os)
|
||||
{
|
||||
os << "********************************" ;
|
||||
os << "********************************" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@ -328,11 +338,10 @@ printDashedLine (std::ostream& os)
|
||||
{
|
||||
os << "--------------------------------" ;
|
||||
os << "--------------------------------" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
} // namespace Util
|
||||
|
||||
|
||||
|
||||
@ -362,10 +371,10 @@ getL1Distance (const Params& v1, const Params& v2)
|
||||
double dist = 0.0;
|
||||
if (Globals::logDomain) {
|
||||
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||
std::plus<double>(), FuncObject::abs_diff_exp<double>());
|
||||
std::plus<double>(), FuncObj::abs_diff_exp<double>());
|
||||
} else {
|
||||
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||
std::plus<double>(), FuncObject::abs_diff<double>());
|
||||
std::plus<double>(), FuncObj::abs_diff<double>());
|
||||
}
|
||||
return dist;
|
||||
}
|
||||
@ -379,10 +388,10 @@ getMaxNorm (const Params& v1, const Params& v2)
|
||||
double max = 0.0;
|
||||
if (Globals::logDomain) {
|
||||
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||
FuncObject::max<double>(), FuncObject::abs_diff_exp<double>());
|
||||
FuncObj::max<double>(), FuncObj::abs_diff_exp<double>());
|
||||
} else {
|
||||
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||
FuncObject::max<double>(), FuncObject::abs_diff<double>());
|
||||
FuncObj::max<double>(), FuncObj::abs_diff<double>());
|
||||
}
|
||||
return max;
|
||||
}
|
||||
@ -428,5 +437,8 @@ pow (Params& v, double exp)
|
||||
Globals::logDomain ? v *= exp : v ^= exp;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace LogAware
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
||||
|
@ -1,69 +1,78 @@
|
||||
#ifndef HORUS_UTIL_H
|
||||
#define HORUS_UTIL_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_UTIL_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_UTIL_H_
|
||||
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace {
|
||||
|
||||
const double NEG_INF = -std::numeric_limits<double>::infinity();
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
|
||||
namespace Util {
|
||||
|
||||
template <typename T> void addToVector (vector<T>&, const vector<T>&);
|
||||
template <typename T> void
|
||||
addToVector (std::vector<T>&, const std::vector<T>&);
|
||||
|
||||
template <typename T> void addToSet (set<T>&, const vector<T>&);
|
||||
template <typename T> void
|
||||
addToSet (std::set<T>&, const std::vector<T>&);
|
||||
|
||||
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
|
||||
template <typename T> void
|
||||
addToQueue (std::queue<T>&, const std::vector<T>&);
|
||||
|
||||
template <typename T> bool contains (const vector<T>&, const T&);
|
||||
template <typename T> bool
|
||||
contains (const std::vector<T>&, const T&);
|
||||
|
||||
template <typename T> bool contains (const set<T>&, const T&);
|
||||
template <typename T> bool contains
|
||||
(const std::set<T>&, const T&);
|
||||
|
||||
template <typename K, typename V> bool contains (
|
||||
const unordered_map<K, V>&, const K&);
|
||||
template <typename K, typename V> bool
|
||||
contains (const std::unordered_map<K, V>&, const K&);
|
||||
|
||||
template <typename T> size_t indexOf (const vector<T>&, const T&);
|
||||
template <typename T> size_t
|
||||
indexOf (const std::vector<T>&, const T&);
|
||||
|
||||
template <class Operation>
|
||||
void apply_n_times (Params& v1, const Params& v2,
|
||||
unsigned repetitions, Operation);
|
||||
template <class Operation> void
|
||||
apply_n_times (Params& v1, const Params& v2, unsigned reps, Operation);
|
||||
|
||||
template <typename T> void log (vector<T>&);
|
||||
template <typename T> void
|
||||
log (std::vector<T>&);
|
||||
|
||||
template <typename T> void exp (vector<T>&);
|
||||
template <typename T> void
|
||||
exp (std::vector<T>&);
|
||||
|
||||
template <typename T> string elementsToString (
|
||||
const vector<T>& v, string sep = " ");
|
||||
template <typename T> std::string
|
||||
elementsToString (const std::vector<T>& v, std::string sep = " ");
|
||||
|
||||
template <typename T> std::string toString (const T&);
|
||||
template <typename T> std::string
|
||||
toString (const T&);
|
||||
|
||||
template <> std::string toString (const bool&);
|
||||
template <> std::string
|
||||
toString (const bool&);
|
||||
|
||||
double logSum (double, double);
|
||||
|
||||
unsigned maxUnsigned (void);
|
||||
unsigned maxUnsigned();
|
||||
|
||||
unsigned stringToUnsigned (string);
|
||||
unsigned stringToUnsigned (std::string);
|
||||
|
||||
double stringToDouble (string);
|
||||
double stringToDouble (std::string);
|
||||
|
||||
double factorial (unsigned);
|
||||
|
||||
@ -75,28 +84,29 @@ size_t sizeExpected (const Ranges&);
|
||||
|
||||
unsigned nrDigits (int);
|
||||
|
||||
bool isInteger (const string&);
|
||||
bool isInteger (const std::string&);
|
||||
|
||||
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
||||
std::string parametersToString (
|
||||
const Params&, unsigned = Constants::precision);
|
||||
|
||||
vector<string> getStateLines (const Vars&);
|
||||
std::vector<std::string> getStateLines (const Vars&);
|
||||
|
||||
bool setHorusFlag (string option, string value);
|
||||
bool setHorusFlag (std::string option, std::string value);
|
||||
|
||||
void printHeader (string, std::ostream& os = std::cout);
|
||||
void printHeader (std::string, std::ostream& os = std::cout);
|
||||
|
||||
void printSubHeader (string, std::ostream& os = std::cout);
|
||||
void printSubHeader (std::string, std::ostream& os = std::cout);
|
||||
|
||||
void printAsteriskLine (std::ostream& os = std::cout);
|
||||
|
||||
void printDashedLine (std::ostream& os = std::cout);
|
||||
|
||||
};
|
||||
} // namespace Util
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToVector (vector<T>& v, const vector<T>& elements)
|
||||
Util::addToVector (std::vector<T>& v, const std::vector<T>& elements)
|
||||
{
|
||||
v.insert (v.end(), elements.begin(), elements.end());
|
||||
}
|
||||
@ -104,7 +114,7 @@ Util::addToVector (vector<T>& v, const vector<T>& elements)
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToSet (set<T>& s, const vector<T>& elements)
|
||||
Util::addToSet (std::set<T>& s, const std::vector<T>& elements)
|
||||
{
|
||||
s.insert (elements.begin(), elements.end());
|
||||
}
|
||||
@ -112,7 +122,7 @@ Util::addToSet (set<T>& s, const vector<T>& elements)
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
||||
Util::addToQueue (std::queue<T>& q, const std::vector<T>& elements)
|
||||
{
|
||||
for (size_t i = 0; i < elements.size(); i++) {
|
||||
q.push (elements[i]);
|
||||
@ -122,7 +132,7 @@ Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
Util::contains (const vector<T>& v, const T& e)
|
||||
Util::contains (const std::vector<T>& v, const T& e)
|
||||
{
|
||||
return std::find (v.begin(), v.end(), e) != v.end();
|
||||
}
|
||||
@ -130,7 +140,7 @@ Util::contains (const vector<T>& v, const T& e)
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
Util::contains (const set<T>& s, const T& e)
|
||||
Util::contains (const std::set<T>& s, const T& e)
|
||||
{
|
||||
return s.find (e) != s.end();
|
||||
}
|
||||
@ -138,7 +148,7 @@ Util::contains (const set<T>& s, const T& e)
|
||||
|
||||
|
||||
template <typename K, typename V> bool
|
||||
Util::contains (const unordered_map<K, V>& m, const K& k)
|
||||
Util::contains (const std::unordered_map<K, V>& m, const K& k)
|
||||
{
|
||||
return m.find (k) != m.end();
|
||||
}
|
||||
@ -146,7 +156,7 @@ Util::contains (const unordered_map<K, V>& m, const K& k)
|
||||
|
||||
|
||||
template <typename T> size_t
|
||||
Util::indexOf (const vector<T>& v, const T& e)
|
||||
Util::indexOf (const std::vector<T>& v, const T& e)
|
||||
{
|
||||
return std::distance (v.begin(),
|
||||
std::find (v.begin(), v.end(), e));
|
||||
@ -155,7 +165,10 @@ Util::indexOf (const vector<T>& v, const T& e)
|
||||
|
||||
|
||||
template <class Operation> void
|
||||
Util::apply_n_times (Params& v1, const Params& v2, unsigned repetitions,
|
||||
Util::apply_n_times (
|
||||
Params& v1,
|
||||
const Params& v2,
|
||||
unsigned repetitions,
|
||||
Operation unary_op)
|
||||
{
|
||||
Params::iterator first = v1.begin();
|
||||
@ -174,7 +187,7 @@ Util::apply_n_times (Params& v1, const Params& v2, unsigned repetitions,
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::log (vector<T>& v)
|
||||
Util::log (std::vector<T>& v)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(), ::log);
|
||||
}
|
||||
@ -182,17 +195,17 @@ Util::log (vector<T>& v)
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::exp (vector<T>& v)
|
||||
Util::exp (std::vector<T>& v)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(), ::exp);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> string
|
||||
Util::elementsToString (const vector<T>& v, string sep)
|
||||
template <typename T> std::string
|
||||
Util::elementsToString (const std::vector<T>& v, std::string sep)
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
for (size_t i = 0; i < v.size(); i++) {
|
||||
ss << ((i != 0) ? sep : "") << v[i];
|
||||
}
|
||||
@ -245,7 +258,7 @@ Util::logSum (double x, double y)
|
||||
|
||||
|
||||
inline unsigned
|
||||
Util::maxUnsigned (void)
|
||||
Util::maxUnsigned()
|
||||
{
|
||||
return std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
@ -277,106 +290,106 @@ void pow (Params&, unsigned);
|
||||
|
||||
void pow (Params&, double);
|
||||
|
||||
};
|
||||
} // namespace LogAware
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator+=(std::vector<T>& v, double val)
|
||||
template <typename T> void
|
||||
operator+=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (plus<double>(), val));
|
||||
std::bind2nd (std::plus<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator-=(std::vector<T>& v, double val)
|
||||
template <typename T> void
|
||||
operator-=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (minus<double>(), val));
|
||||
std::bind2nd (std::minus<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator*=(std::vector<T>& v, double val)
|
||||
template <typename T> void
|
||||
operator*=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (multiplies<double>(), val));
|
||||
std::bind2nd (std::multiplies<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator/=(std::vector<T>& v, double val)
|
||||
template <typename T> void
|
||||
operator/=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (divides<double>(), val));
|
||||
std::bind2nd (std::divides<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator+=(std::vector<T>& a, const std::vector<T>& b)
|
||||
template <typename T> void
|
||||
operator+=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
plus<double>());
|
||||
std::plus<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator-=(std::vector<T>& a, const std::vector<T>& b)
|
||||
template <typename T> void
|
||||
operator-=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
minus<double>());
|
||||
std::minus<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator*=(std::vector<T>& a, const std::vector<T>& b)
|
||||
template <typename T> void
|
||||
operator*=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
multiplies<double>());
|
||||
std::multiplies<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator/=(std::vector<T>& a, const std::vector<T>& b)
|
||||
template <typename T> void
|
||||
operator/=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
divides<double>());
|
||||
std::divides<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator^=(std::vector<T>& v, double exp)
|
||||
template <typename T> void
|
||||
operator^=(std::vector<T>& v, double exp)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
|
||||
std::bind2nd (std::ptr_fun<double, double, double> (std::pow), exp));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator^=(std::vector<T>& v, int iexp)
|
||||
template <typename T> void
|
||||
operator^=(std::vector<T>& v, int iexp)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
|
||||
std::bind2nd (std::ptr_fun<double, int, double> (std::pow), iexp));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<< (std::ostream& os, const vector<T>& v)
|
||||
template <typename T> std::ostream&
|
||||
operator<< (std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
os << "[" ;
|
||||
os << Util::elementsToString (v, ", ");
|
||||
@ -385,40 +398,33 @@ std::ostream& operator<< (std::ostream& os, const vector<T>& v)
|
||||
}
|
||||
|
||||
|
||||
namespace FuncObject {
|
||||
namespace FuncObj {
|
||||
|
||||
template<typename T>
|
||||
struct max : public std::binary_function<T, T, T>
|
||||
{
|
||||
T operator() (const T& x, const T& y) const
|
||||
{
|
||||
struct max : public std::binary_function<T, T, T> {
|
||||
T operator() (const T& x, const T& y) const {
|
||||
return x < y ? y : x;
|
||||
}
|
||||
};
|
||||
}};
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct abs_diff : public std::binary_function<T, T, T>
|
||||
{
|
||||
T operator() (const T& x, const T& y) const
|
||||
{
|
||||
struct abs_diff : public std::binary_function<T, T, T> {
|
||||
T operator() (const T& x, const T& y) const {
|
||||
return std::abs (x - y);
|
||||
}
|
||||
};
|
||||
}};
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct abs_diff_exp : public std::binary_function<T, T, T>
|
||||
{
|
||||
T operator() (const T& x, const T& y) const
|
||||
{
|
||||
struct abs_diff_exp : public std::binary_function<T, T, T> {
|
||||
T operator() (const T& x, const T& y) const {
|
||||
return std::abs (std::exp (x) - std::exp (y));
|
||||
}
|
||||
};
|
||||
}};
|
||||
|
||||
}
|
||||
} // namespace FuncObj
|
||||
|
||||
#endif // HORUS_UTIL_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_UTIL_H_
|
||||
|
||||
|
@ -3,7 +3,9 @@
|
||||
#include "Var.h"
|
||||
|
||||
|
||||
unordered_map<VarId, VarInfo> Var::varsInfo_;
|
||||
namespace Horus {
|
||||
|
||||
std::unordered_map<VarId, Var::VarInfo> Var::varsInfo_;
|
||||
|
||||
|
||||
Var::Var (const Var* v)
|
||||
@ -45,13 +47,14 @@ Var::setEvidence (int evidence)
|
||||
|
||||
|
||||
|
||||
string
|
||||
Var::label (void) const
|
||||
std::string
|
||||
Var::label() const
|
||||
{
|
||||
if (Var::varsHaveInfo()) {
|
||||
return Var::getVarInfo (varId_).label;
|
||||
assert (Util::contains (varsInfo_, varId_));
|
||||
return varsInfo_.find (varId_)->second.first;
|
||||
}
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "x" << varId_;
|
||||
return ss.str();
|
||||
}
|
||||
@ -59,17 +62,46 @@ Var::label (void) const
|
||||
|
||||
|
||||
States
|
||||
Var::states (void) const
|
||||
Var::states() const
|
||||
{
|
||||
if (Var::varsHaveInfo()) {
|
||||
return Var::getVarInfo (varId_).states;
|
||||
assert (Util::contains (varsInfo_, varId_));
|
||||
return varsInfo_.find (varId_)->second.second;
|
||||
}
|
||||
States states;
|
||||
for (unsigned i = 0; i < range_; i++) {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << i ;
|
||||
states.push_back (ss.str());
|
||||
}
|
||||
return states;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Var::addVarInfo (
|
||||
VarId vid, std::string label, const States& states)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid) == false);
|
||||
varsInfo_.insert (std::make_pair (vid, VarInfo (label, states)));
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Var::varsHaveInfo()
|
||||
{
|
||||
return varsInfo_.empty() == false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Var::clearVarsInfo()
|
||||
{
|
||||
varsInfo_.clear();
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,102 +1,105 @@
|
||||
#ifndef HORUS_VAR_H
|
||||
#define HORUS_VAR_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_VAR_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_VAR_H_
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
|
||||
#include "Util.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace Horus {
|
||||
|
||||
|
||||
struct VarInfo
|
||||
{
|
||||
VarInfo (string l, const States& sts)
|
||||
: label(l), states(sts) { }
|
||||
string label;
|
||||
States states;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class Var
|
||||
{
|
||||
class Var {
|
||||
public:
|
||||
Var (const Var*);
|
||||
|
||||
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
||||
Var (VarId, unsigned range, int evidence = Constants::unobserved);
|
||||
|
||||
virtual ~Var (void) { };
|
||||
virtual ~Var() { };
|
||||
|
||||
VarId varId (void) const { return varId_; }
|
||||
VarId varId() const { return varId_; }
|
||||
|
||||
unsigned range (void) const { return range_; }
|
||||
unsigned range() const { return range_; }
|
||||
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
int getEvidence() const { return evidence_; }
|
||||
|
||||
size_t getIndex (void) const { return index_; }
|
||||
size_t getIndex() const { return index_; }
|
||||
|
||||
void setIndex (size_t idx) { index_ = idx; }
|
||||
|
||||
bool hasEvidence (void) const
|
||||
{
|
||||
return evidence_ != Constants::NO_EVIDENCE;
|
||||
}
|
||||
bool hasEvidence() const;
|
||||
|
||||
operator size_t (void) const { return index_; }
|
||||
operator size_t() const;
|
||||
|
||||
bool operator== (const Var& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ == var.varId();
|
||||
}
|
||||
bool operator== (const Var& var) const;
|
||||
|
||||
bool operator!= (const Var& var) const
|
||||
{
|
||||
return !(*this == var);
|
||||
}
|
||||
bool operator!= (const Var& var) const;
|
||||
|
||||
bool isValidState (int);
|
||||
|
||||
void setEvidence (int);
|
||||
|
||||
string label (void) const;
|
||||
std::string label() const;
|
||||
|
||||
States states (void) const;
|
||||
States states() const;
|
||||
|
||||
static void addVarInfo (
|
||||
VarId vid, string label, const States& states)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid) == false);
|
||||
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
||||
}
|
||||
VarId vid, std::string label, const States& states);
|
||||
|
||||
static VarInfo getVarInfo (VarId vid)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid));
|
||||
return varsInfo_.find (vid)->second;
|
||||
}
|
||||
static bool varsHaveInfo();
|
||||
|
||||
static bool varsHaveInfo (void)
|
||||
{
|
||||
return varsInfo_.empty() == false;
|
||||
}
|
||||
|
||||
static void clearVarsInfo (void)
|
||||
{
|
||||
varsInfo_.clear();
|
||||
}
|
||||
static void clearVarsInfo();
|
||||
|
||||
private:
|
||||
typedef std::pair<std::string, States> VarInfo;
|
||||
|
||||
VarId varId_;
|
||||
unsigned range_;
|
||||
int evidence_;
|
||||
size_t index_;
|
||||
|
||||
static unordered_map<VarId, VarInfo> varsInfo_;
|
||||
static std::unordered_map<VarId, VarInfo> varsInfo_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(Var);
|
||||
};
|
||||
|
||||
#endif // HORUS_VAR_H
|
||||
|
||||
|
||||
inline bool
|
||||
Var::hasEvidence() const
|
||||
{
|
||||
return evidence_ != Constants::unobserved;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline
|
||||
Var::operator size_t() const
|
||||
{
|
||||
return index_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
Var::operator== (const Var& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ == var.varId();
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
Var::operator!= (const Var& var) const
|
||||
{
|
||||
return !(*this == var);
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_VAR_H_
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "VarElim.h"
|
||||
#include "ElimGraph.h"
|
||||
@ -6,16 +8,18 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
Params
|
||||
VarElim::solveQuery (VarIds queryVids)
|
||||
{
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "Solving query on " ;
|
||||
std::cout << "Solving query on " ;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
if (i != 0) cout << ", " ;
|
||||
cout << fg.getVarNode (queryVids[i])->label();
|
||||
if (i != 0) std::cout << ", " ;
|
||||
std::cout << fg.getVarNode (queryVids[i])->label();
|
||||
}
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
totalFactorSize_ = 0;
|
||||
largestFactorSize_ = 0;
|
||||
@ -33,27 +37,28 @@ VarElim::solveQuery (VarIds queryVids)
|
||||
|
||||
|
||||
void
|
||||
VarElim::printSolverFlags (void) const
|
||||
VarElim::printSolverFlags() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "variable elimination [" ;
|
||||
ss << "elim_heuristic=" ;
|
||||
typedef ElimGraph::ElimHeuristic ElimHeuristic;
|
||||
switch (ElimGraph::elimHeuristic()) {
|
||||
case ElimHeuristic::SEQUENTIAL: ss << "sequential"; break;
|
||||
case ElimHeuristic::MIN_NEIGHBORS: ss << "min_neighbors"; break;
|
||||
case ElimHeuristic::MIN_WEIGHT: ss << "min_weight"; break;
|
||||
case ElimHeuristic::MIN_FILL: ss << "min_fill"; break;
|
||||
case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
|
||||
case ElimHeuristic::sequentialEh: ss << "sequential"; break;
|
||||
case ElimHeuristic::minNeighborsEh: ss << "min_neighbors"; break;
|
||||
case ElimHeuristic::minWeightEh: ss << "min_weight"; break;
|
||||
case ElimHeuristic::minFillEh: ss << "min_fill"; break;
|
||||
case ElimHeuristic::weightedMinFillEh: ss << "weighted_min_fill"; break;
|
||||
}
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
std::cout << ss.str() << std::endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElim::createFactorList (void)
|
||||
VarElim::createFactorList()
|
||||
{
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
factorList_.reserve (facNodes.size() * 2);
|
||||
@ -61,7 +66,7 @@ VarElim::createFactorList (void)
|
||||
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
||||
const VarIds& args = facNodes[i]->factor().arguments();
|
||||
for (size_t j = 0; j < args.size(); j++) {
|
||||
unordered_map<VarId, vector<size_t>>::iterator it;
|
||||
std::unordered_map<VarId, std::vector<size_t>>::iterator it;
|
||||
it = varMap_.find (args[j]);
|
||||
if (it != varMap_.end()) {
|
||||
it->second.push_back (i);
|
||||
@ -75,22 +80,22 @@ VarElim::createFactorList (void)
|
||||
|
||||
|
||||
void
|
||||
VarElim::absorveEvidence (void)
|
||||
VarElim::absorveEvidence()
|
||||
{
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printDashedLine();
|
||||
cout << "(initial factor list)" << endl;
|
||||
std::cout << "(initial factor list)" << std::endl;
|
||||
printActiveFactors();
|
||||
}
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
if (varNodes[i]->hasEvidence()) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "-> aborving evidence on ";
|
||||
cout << varNodes[i]->label() << " = " ;
|
||||
cout << varNodes[i]->getEvidence() << endl;
|
||||
std::cout << "-> aborving evidence on ";
|
||||
std::cout << varNodes[i]->label() << " = " ;
|
||||
std::cout << varNodes[i]->getEvidence() << std::endl;
|
||||
}
|
||||
const vector<size_t>& indices = varMap_[varNodes[i]->varId()];
|
||||
const std::vector<size_t>& indices = varMap_[varNodes[i]->varId()];
|
||||
for (size_t j = 0; j < indices.size(); j++) {
|
||||
size_t idx = indices[j];
|
||||
if (factorList_[idx]->nrArguments() > 1) {
|
||||
@ -118,8 +123,8 @@ VarElim::processFactorList (const VarIds& queryVids)
|
||||
Util::printDashedLine();
|
||||
printActiveFactors();
|
||||
}
|
||||
cout << "-> summing out " ;
|
||||
cout << fg.getVarNode (elimOrder[i])->label() << endl;
|
||||
std::cout << "-> summing out " ;
|
||||
std::cout << fg.getVarNode (elimOrder[i])->label() << std::endl;
|
||||
}
|
||||
eliminate (elimOrder[i]);
|
||||
}
|
||||
@ -143,9 +148,9 @@ VarElim::processFactorList (const VarIds& queryVids)
|
||||
result.reorderArguments (unobservedVids);
|
||||
result.normalize();
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "total factor size: " << totalFactorSize_ << endl;
|
||||
cout << "largest factor size: " << largestFactorSize_ << endl;
|
||||
cout << endl;
|
||||
std::cout << "total factor size: " << totalFactorSize_ << std::endl;
|
||||
std::cout << "largest factor size: " << largestFactorSize_ << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
return result.params();
|
||||
}
|
||||
@ -156,7 +161,7 @@ void
|
||||
VarElim::eliminate (VarId vid)
|
||||
{
|
||||
Factor* result = new Factor();
|
||||
const vector<size_t>& indices = varMap_[vid];
|
||||
const std::vector<size_t>& indices = varMap_[vid];
|
||||
for (size_t i = 0; i < indices.size(); i++) {
|
||||
size_t idx = indices[i];
|
||||
if (factorList_[idx]) {
|
||||
@ -173,7 +178,7 @@ VarElim::eliminate (VarId vid)
|
||||
result->sumOut (vid);
|
||||
const VarIds& args = result->arguments();
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
vector<size_t>& indices2 = varMap_[args[i]];
|
||||
std::vector<size_t>& indices2 = varMap_[args[i]];
|
||||
indices2.push_back (factorList_.size());
|
||||
}
|
||||
factorList_.push_back (result);
|
||||
@ -185,14 +190,16 @@ VarElim::eliminate (VarId vid)
|
||||
|
||||
|
||||
void
|
||||
VarElim::printActiveFactors (void)
|
||||
VarElim::printActiveFactors()
|
||||
{
|
||||
for (size_t i = 0; i < factorList_.size(); i++) {
|
||||
if (factorList_[i]) {
|
||||
cout << factorList_[i]->getLabel() << " " ;
|
||||
cout << factorList_[i]->params();
|
||||
cout << endl;
|
||||
std::cout << factorList_[i]->getLabel() << " " ;
|
||||
std::cout << factorList_[i]->params();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,45 +1,46 @@
|
||||
#ifndef HORUS_VARELIM_H
|
||||
#define HORUS_VARELIM_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_VARELIM_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_VARELIM_H_
|
||||
|
||||
#include "unordered_map"
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "GroundSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace Horus {
|
||||
|
||||
|
||||
class VarElim : public GroundSolver
|
||||
{
|
||||
class VarElim : public GroundSolver {
|
||||
public:
|
||||
VarElim (const FactorGraph& fg) : GroundSolver (fg) { }
|
||||
|
||||
~VarElim (void) { }
|
||||
~VarElim() { }
|
||||
|
||||
Params solveQuery (VarIds);
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
void printSolverFlags() const;
|
||||
|
||||
private:
|
||||
void createFactorList (void);
|
||||
void createFactorList();
|
||||
|
||||
void absorveEvidence (void);
|
||||
void absorveEvidence();
|
||||
|
||||
Params processFactorList (const VarIds&);
|
||||
|
||||
void eliminate (VarId);
|
||||
|
||||
void printActiveFactors (void);
|
||||
void printActiveFactors();
|
||||
|
||||
Factors factorList_;
|
||||
unsigned largestFactorSize_;
|
||||
unsigned totalFactorSize_;
|
||||
unordered_map<VarId, vector<size_t>> varMap_;
|
||||
std::unordered_map<VarId, std::vector<size_t>> varMap_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (VarElim);
|
||||
};
|
||||
|
||||
#endif // HORUS_VARELIM_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_VARELIM_H_
|
||||
|
||||
|
@ -1,7 +1,24 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "WeightedBp.h"
|
||||
|
||||
|
||||
WeightedBp::~WeightedBp (void)
|
||||
namespace Horus {
|
||||
|
||||
WeightedBp::WeightedBp (
|
||||
const FactorGraph& fg,
|
||||
const std::vector<std::vector<unsigned>>& weights)
|
||||
: BeliefProp (fg), weights_(weights)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
WeightedBp::~WeightedBp()
|
||||
{
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
@ -25,7 +42,7 @@ WeightedBp::getPosterioriOf (VarId vid)
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const BpLinks& links = ninf(var)->getLinks();
|
||||
const BpLinks& links = getLinks (var);
|
||||
if (Globals::logDomain) {
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||
@ -46,9 +63,24 @@ WeightedBp::getPosterioriOf (VarId vid)
|
||||
|
||||
|
||||
|
||||
void
|
||||
WeightedBp::createLinks (void)
|
||||
WeightedBp::WeightedLink::WeightedLink (
|
||||
FacNode* fn,
|
||||
VarNode* vn,
|
||||
size_t idx,
|
||||
unsigned weight)
|
||||
: BpLink (fn, vn), index_(idx), weight_(weight),
|
||||
pwdMsg_(vn->range(), LogAware::one())
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
WeightedBp::createLinks()
|
||||
{
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "compressed factor graph contains " ;
|
||||
cout << fg.nrVarNodes() << " variables and " ;
|
||||
@ -78,7 +110,7 @@ WeightedBp::createLinks (void)
|
||||
|
||||
|
||||
void
|
||||
WeightedBp::maxResidualSchedule (void)
|
||||
WeightedBp::maxResidualSchedule()
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
@ -86,7 +118,7 @@ WeightedBp::maxResidualSchedule (void)
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
if (Globals::verbosity >= 1) {
|
||||
cout << "calculating " << links_[i]->toString() << endl;
|
||||
std::cout << "calculating " << links_[i]->toString() << std::endl;
|
||||
}
|
||||
}
|
||||
return;
|
||||
@ -94,18 +126,20 @@ WeightedBp::maxResidualSchedule (void)
|
||||
|
||||
for (size_t c = 0; c < links_.size(); c++) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << endl << "current residuals:" << endl;
|
||||
std::cout << std::endl << "current residuals:" << std::endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); ++it) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->residual() << endl;
|
||||
std::cout << " " << std::setw (30) << std::left;
|
||||
std::cout << (*it)->toString();
|
||||
std::cout << "residual = " << (*it)->residual() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
BpLink* link = *it;
|
||||
if (Globals::verbosity >= 1) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
std::cout << "updating " << (*sortedOrder_.begin())->toString();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
if (link->residual() < accuracy_) {
|
||||
return;
|
||||
@ -118,11 +152,12 @@ WeightedBp::maxResidualSchedule (void)
|
||||
// update the messages that depend on message source --> destin
|
||||
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
||||
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
||||
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
|
||||
const BpLinks& links = getLinks (factorNeighbors[i]);
|
||||
for (size_t j = 0; j < links.size(); j++) {
|
||||
if (links[j]->varNode() != link->varNode()) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << " calculating " << links[j]->toString() << endl;
|
||||
std::cout << " calculating " << links[j]->toString();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
calculateMessage (links[j]);
|
||||
BpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
@ -133,11 +168,12 @@ WeightedBp::maxResidualSchedule (void)
|
||||
}
|
||||
// in counting bp, the message that a variable X sends to
|
||||
// to a factor F depends on the message that F sent to the X
|
||||
const BpLinks& links = ninf(link->facNode())->getLinks();
|
||||
const BpLinks& links = getLinks (link->facNode());
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
if (links[i]->varNode() != link->varNode()) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << " calculating " << links[i]->toString() << endl;
|
||||
std::cout << " calculating " << links[i]->toString();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
calculateMessage (links[i]);
|
||||
BpLinkMap::iterator iter = linkMap_.find (links[i]);
|
||||
@ -156,7 +192,7 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
||||
WeightedLink* link = static_cast<WeightedLink*> (_link);
|
||||
FacNode* src = link->facNode();
|
||||
const VarNode* dst = link->varNode();
|
||||
const BpLinks& links = ninf(src)->getLinks();
|
||||
const BpLinks& links = getLinks (src);
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dst'
|
||||
unsigned reps = 1;
|
||||
@ -166,14 +202,14 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
||||
for (size_t i = links.size(); i-- > 0; ) {
|
||||
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message from " << links[i]->varNode()->label();
|
||||
cout << ": " ;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message from " << links[i]->varNode()->label();
|
||||
std::cout << ": " ;
|
||||
}
|
||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||
reps, std::plus<double>());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
reps *= links[i]->varNode()->range();
|
||||
@ -182,14 +218,14 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
||||
for (size_t i = links.size(); i-- > 0; ) {
|
||||
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message from " << links[i]->varNode()->label();
|
||||
cout << ": " ;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message from " << links[i]->varNode()->label();
|
||||
std::cout << ": " ;
|
||||
}
|
||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||
reps, std::multiplies<double>());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
reps *= links[i]->varNode()->range();
|
||||
@ -203,27 +239,33 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
||||
} else {
|
||||
result.params() *= src->factor().params();
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->factor().params() << endl;
|
||||
cout << " factor product: " << result.params() << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message product: " ;
|
||||
std::cout << msgProduct << std::endl;
|
||||
std::cout << " original factor: " ;
|
||||
std::cout << src->factor().params() << std::endl;
|
||||
std::cout << " factor product: " ;
|
||||
std::cout << result.params() << std::endl;
|
||||
}
|
||||
result.sumOutAllExceptIndex (link->index());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " marginalized: " << result.params() << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " marginalized: " ;
|
||||
std::cout << result.params() << std::endl;
|
||||
}
|
||||
link->nextMessage() = result.params();
|
||||
LogAware::normalize (link->nextMessage());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " curr msg: " << link->message() << endl;
|
||||
cout << " next msg: " << link->nextMessage() << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " curr msg: " ;
|
||||
std::cout << link->message() << std::endl;
|
||||
std::cout << " next msg: " ;
|
||||
std::cout << link->nextMessage() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
||||
WeightedBp::getVarToFactorMsg (const BpLink* _link)
|
||||
{
|
||||
const WeightedLink* link = static_cast<const WeightedLink*> (_link);
|
||||
const VarNode* src = link->varNode();
|
||||
@ -232,19 +274,19 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
double value = link->message()[src->getEvidence()];
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
msg[src->getEvidence()] = value;
|
||||
cout << msg << "^" << link->weight() << "-1" ;
|
||||
std::cout << msg << "^" << link->weight() << "-1" ;
|
||||
}
|
||||
msg[src->getEvidence()] = LogAware::pow (value, link->weight() - 1);
|
||||
} else {
|
||||
msg = link->message();
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << msg << "^" << link->weight() << "-1" ;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << msg << "^" << link->weight() << "-1" ;
|
||||
}
|
||||
LogAware::pow (msg, link->weight() - 1);
|
||||
}
|
||||
const BpLinks& links = ninf(src)->getLinks();
|
||||
const BpLinks& links = getLinks (src);
|
||||
if (Globals::logDomain) {
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||
@ -257,14 +299,14 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
||||
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
||||
msg *= l->powMessage();
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " x " << l->nextMessage() << "^" << link->weight();
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " x " << l->nextMessage() << "^" << link->weight();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " = " << msg;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
@ -272,8 +314,10 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
||||
|
||||
|
||||
void
|
||||
WeightedBp::printLinkInformation (void) const
|
||||
WeightedBp::printLinkInformation() const
|
||||
{
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
WeightedLink* l = static_cast<WeightedLink*> (links_[i]);
|
||||
cout << l->toString() << ":" << endl;
|
||||
@ -286,3 +330,5 @@ WeightedBp::printLinkInformation (void) const
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,64 +1,69 @@
|
||||
#ifndef HORUS_WEIGHTEDBP_H
|
||||
#define HORUS_WEIGHTEDBP_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_WEIGHTEDBP_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_WEIGHTEDBP_H_
|
||||
|
||||
#include "BeliefProp.h"
|
||||
|
||||
class WeightedLink : public BpLink
|
||||
{
|
||||
public:
|
||||
WeightedLink (FacNode* fn, VarNode* vn, size_t idx, unsigned weight)
|
||||
: BpLink (fn, vn), index_(idx), weight_(weight),
|
||||
pwdMsg_(vn->range(), LogAware::one()) { }
|
||||
|
||||
size_t index (void) const { return index_; }
|
||||
namespace Horus {
|
||||
|
||||
unsigned weight (void) const { return weight_; }
|
||||
|
||||
const Params& powMessage (void) const { return pwdMsg_; }
|
||||
|
||||
void updateMessage (void)
|
||||
{
|
||||
pwdMsg_ = *nextMsg_;
|
||||
swap (currMsg_, nextMsg_);
|
||||
LogAware::pow (pwdMsg_, weight_);
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (WeightedLink);
|
||||
|
||||
size_t index_;
|
||||
unsigned weight_;
|
||||
Params pwdMsg_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class WeightedBp : public BeliefProp
|
||||
{
|
||||
class WeightedBp : public BeliefProp {
|
||||
public:
|
||||
WeightedBp (const FactorGraph& fg,
|
||||
const vector<vector<unsigned>>& weights)
|
||||
: BeliefProp (fg), weights_(weights) { }
|
||||
const std::vector<std::vector<unsigned>>& weights);
|
||||
|
||||
~WeightedBp (void);
|
||||
~WeightedBp();
|
||||
|
||||
Params getPosterioriOf (VarId);
|
||||
|
||||
private:
|
||||
void createLinks (void);
|
||||
class WeightedLink : public BeliefProp::BpLink {
|
||||
public:
|
||||
WeightedLink (FacNode* fn, VarNode* vn, size_t idx,
|
||||
unsigned weight);
|
||||
|
||||
void maxResidualSchedule (void);
|
||||
size_t index() const { return index_; }
|
||||
|
||||
unsigned weight() const { return weight_; }
|
||||
|
||||
const Params& powMessage() const { return pwdMsg_; }
|
||||
|
||||
void updateMessage();
|
||||
|
||||
private:
|
||||
size_t index_;
|
||||
unsigned weight_;
|
||||
Params pwdMsg_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (WeightedLink);
|
||||
};
|
||||
|
||||
void createLinks();
|
||||
|
||||
void maxResidualSchedule();
|
||||
|
||||
void calcFactorToVarMsg (BpLink*);
|
||||
|
||||
Params getVarToFactorMsg (const BpLink*) const;
|
||||
Params getVarToFactorMsg (const BpLink*);
|
||||
|
||||
void printLinkInformation (void) const;
|
||||
void printLinkInformation() const;
|
||||
|
||||
vector<vector<unsigned>> weights_;
|
||||
std::vector<std::vector<unsigned>> weights_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (WeightedBp);
|
||||
};
|
||||
|
||||
#endif // HORUS_WEIGHTEDBP_H
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
WeightedBp::WeightedLink::updateMessage()
|
||||
{
|
||||
pwdMsg_ = *nextMsg_;
|
||||
swap (currMsg_, nextMsg_);
|
||||
LogAware::pow (pwdMsg_, weight_);
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_WEIGHTEDBP_H_
|
||||
|
||||
|
51
packages/CLPBN/horus/unit_tests/BeliefPropTest.cpp
Normal file
51
packages/CLPBN/horus/unit_tests/BeliefPropTest.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
#include "../BeliefProp.h"
|
||||
#include "../FactorGraph.h"
|
||||
#include "Common.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace UnitTests {
|
||||
|
||||
class BeliefPropTest : public CppUnit::TestFixture {
|
||||
CPPUNIT_TEST_SUITE (BeliefPropTest);
|
||||
CPPUNIT_TEST (testMarginals);
|
||||
CPPUNIT_TEST (testJoint);
|
||||
CPPUNIT_TEST_SUITE_END();
|
||||
public:
|
||||
void testMarginals();
|
||||
void testJoint();
|
||||
};
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefPropTest::testMarginals()
|
||||
{
|
||||
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
|
||||
BeliefProp solver (fg);
|
||||
for (unsigned i = 0; i < marginalProbs.size(); i++) {
|
||||
Params params = solver.solveQuery ({i});
|
||||
CPPUNIT_ASSERT (similiar (params, marginalProbs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefPropTest::testJoint()
|
||||
{
|
||||
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
|
||||
BeliefProp solver (fg);
|
||||
Params params = solver.solveQuery ({0, 4, 6});
|
||||
CPPUNIT_ASSERT (similiar (params, jointProbs));
|
||||
}
|
||||
|
||||
|
||||
|
||||
CPPUNIT_TEST_SUITE_REGISTRATION (BeliefPropTest);
|
||||
|
||||
} // namespace UnitTests
|
||||
|
||||
} // namespace Horus
|
||||
|
95
packages/CLPBN/horus/unit_tests/Common.cpp
Normal file
95
packages/CLPBN/horus/unit_tests/Common.cpp
Normal file
@ -0,0 +1,95 @@
|
||||
#include <cstdlib>
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
||||
#include "Common.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace UnitTests {
|
||||
|
||||
const std::string modelFile = "../examples/complex.fg" ;
|
||||
|
||||
|
||||
const std::vector<Params> marginalProbs = {
|
||||
/* marginals x0 = */ {0.5825521, 0.4174479},
|
||||
/* marginals x1 = */ {0.648528, 0.351472},
|
||||
{0.03100852, 0.9689915},
|
||||
{0.04565728, 0.503854, 0.4504888},
|
||||
{0.7713128, 0.03128429, 0.1974029},
|
||||
{0.8771822, 0.1228178},
|
||||
{0.05617282, 0.01509834, 0.9287288},
|
||||
{0.08224711, 0.5698616, 0.047964, 0.2999273},
|
||||
{0.1368483, 0.8631517},
|
||||
/* marginals x9 = */ {0.7529569, 0.2470431}
|
||||
};
|
||||
|
||||
|
||||
const Params jointProbs = {
|
||||
/* P(x0=0, x4=0, x6=0) = */ 0.025463399,
|
||||
/* P(x0=0, x4=0, x6=1) = */ 0.0067233122,
|
||||
/* P(x0=0, x4=0, x6=2) = */ 0.42069289,
|
||||
/* P(x0=0, x4=1, x6=0) = */ 0.0010111473,
|
||||
/* P(x0=0, x4=1, x6=1) = */ 0.00027096982,
|
||||
/* P(x0=0, x4=1, x6=2) = */ 0.016715682,
|
||||
/* P(x0=0, x4=2, x6=0) = */ 0.0062433667,
|
||||
/* P(x0=0, x4=2, x6=1) = */ 0.001828545,
|
||||
/* P(x0=0, x4=2, x6=2) = */ 0.10360283,
|
||||
/* P(x0=1, x4=0, x6=0) = */ 0.017910021,
|
||||
/* P(x0=1, x4=0, x6=1) = */ 0.0046988842,
|
||||
/* P(x0=1, x4=0, x6=2) = */ 0.29582433,
|
||||
/* P(x0=1, x4=1, x6=0) = */ 0.00074648444,
|
||||
/* P(x0=1, x4=1, x6=1) = */ 0.00019991076,
|
||||
/* P(x0=1, x4=1, x6=2) = */ 0.012340097,
|
||||
/* P(x0=1, x4=2, x6=0) = */ 0.0047984062,
|
||||
/* P(x0=1, x4=2, x6=1) = */ 0.0013767189,
|
||||
/* P(x0=1, x4=2, x6=2) = */ 0.079553004
|
||||
};
|
||||
|
||||
|
||||
|
||||
Params
|
||||
generateRandomParams (Ranges ranges)
|
||||
{
|
||||
Params params;
|
||||
unsigned size = std::accumulate (ranges.begin(), ranges.end(),
|
||||
1, std::multiplies<unsigned>());
|
||||
for (unsigned i = 0; i < size; i++) {
|
||||
params.push_back (rand() / double (RAND_MAX));
|
||||
}
|
||||
Horus::LogAware::normalize (params);
|
||||
return params;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
similiar (double v1, double v2)
|
||||
{
|
||||
const double epsilon = 0.0000001;
|
||||
return std::fabs (v1 - v2) < epsilon;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
similiar (const Params& p1, const Params& p2)
|
||||
{
|
||||
assert (p1.size() == p2.size());
|
||||
for (size_t i = 0; i < p1.size(); i++) {
|
||||
if (! similiar(p1[i], p2[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace UnitTests
|
||||
|
||||
} // namespace Horus;
|
28
packages/CLPBN/horus/unit_tests/Common.h
Normal file
28
packages/CLPBN/horus/unit_tests/Common.h
Normal file
@ -0,0 +1,28 @@
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "../Horus.h"
|
||||
|
||||
#include <cppunit/extensions/HelperMacros.h>
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace UnitTests {
|
||||
|
||||
extern const std::string modelFile;
|
||||
|
||||
extern const std::vector<Params> marginalProbs;
|
||||
|
||||
extern const Params jointProbs;
|
||||
|
||||
Params generateRandomParams (Ranges ranges);
|
||||
|
||||
bool similiar (double v1, double v2);
|
||||
|
||||
bool similiar (const Params& p1, const Params& p2);
|
||||
|
||||
} // namespace UnitTests
|
||||
|
||||
} // namespace Horus;
|
||||
|
51
packages/CLPBN/horus/unit_tests/CountingBpTest.cpp
Normal file
51
packages/CLPBN/horus/unit_tests/CountingBpTest.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
#include "../CountingBp.h"
|
||||
#include "../FactorGraph.h"
|
||||
#include "Common.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace UnitTests {
|
||||
|
||||
class CountingBpTest : public CppUnit::TestFixture {
|
||||
CPPUNIT_TEST_SUITE (CountingBpTest);
|
||||
CPPUNIT_TEST (testMarginals);
|
||||
CPPUNIT_TEST (testJoint);
|
||||
CPPUNIT_TEST_SUITE_END();
|
||||
public:
|
||||
void testMarginals();
|
||||
void testJoint();
|
||||
};
|
||||
|
||||
|
||||
|
||||
void
|
||||
CountingBpTest::testMarginals()
|
||||
{
|
||||
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
|
||||
CountingBp solver (fg);
|
||||
for (unsigned i = 0; i < marginalProbs.size(); i++) {
|
||||
Params params = solver.solveQuery ({i});
|
||||
CPPUNIT_ASSERT (similiar (params, marginalProbs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CountingBpTest::testJoint()
|
||||
{
|
||||
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
|
||||
CountingBp solver (fg);
|
||||
Params params = solver.solveQuery ({0, 4, 6});
|
||||
CPPUNIT_ASSERT (similiar (params, jointProbs));
|
||||
}
|
||||
|
||||
|
||||
|
||||
CPPUNIT_TEST_SUITE_REGISTRATION (CountingBpTest);
|
||||
|
||||
} // namespace UnitTests
|
||||
|
||||
} // namespace Horus
|
||||
|
107
packages/CLPBN/horus/unit_tests/FactorTest.cpp
Normal file
107
packages/CLPBN/horus/unit_tests/FactorTest.cpp
Normal file
@ -0,0 +1,107 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "../Factor.h"
|
||||
#include "Common.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace UnitTests {
|
||||
|
||||
class FactorTest : public CppUnit::TestFixture {
|
||||
CPPUNIT_TEST_SUITE (FactorTest);
|
||||
CPPUNIT_TEST (testSummingOut);
|
||||
CPPUNIT_TEST (testProduct);
|
||||
CPPUNIT_TEST_SUITE_END();
|
||||
public:
|
||||
void testSummingOut();
|
||||
void testProduct();
|
||||
};
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorTest::testSummingOut()
|
||||
{
|
||||
VarIds vids = {0, 1, 2, 3};
|
||||
Ranges ranges = {3, 2, 4, 3};
|
||||
Params params = {
|
||||
0.022757933283133, 0.0106825417145475, 0.0212115929862968,
|
||||
0.0216271252738214, 0.0246935408909929, 0.00535101952882101,
|
||||
0.00908008645423061, 0.0208088234425334, 0.00752400708452212,
|
||||
0.0150052316136527, 0.0129311224551535, 0.0170340535302049,
|
||||
0.00988081654256193, 0.0139063490493519, 0.025792784294836,
|
||||
0.0248167234610076, 0.017219348482278, 0.0194292243637016,
|
||||
0.00383554941557795, 0.0164407987747966, 0.00044152909395022,
|
||||
0.00657900705816833, 0.00371715392294919, 0.0217825142487465,
|
||||
0.00424392333677727, 0.0108602703755316, 0.00351559808401304,
|
||||
0.00294727405145356, 0.0270575932871257, 0.005911864680038,
|
||||
0.0138936584911577, 0.0227288019859002, 0.0165944064071987,
|
||||
0.0080185268930961, 0.0172692026753632, 0.0142012227138332,
|
||||
0.0133695464219171, 0.0263492891422071, 0.00792332157200822,
|
||||
0.0208935535064392, 0.0142677961715013, 0.0208544440271617,
|
||||
0.0108408824522857, 0.0241486127140633, 0.00767406849215521,
|
||||
0.00954694217537661, 0.0218786116033257, 0.0248934169744332,
|
||||
0.00188944195471982, 0.0257141610189036, 0.0142474911774847,
|
||||
0.00233097104867004, 0.00520644350532678, 0.0179646451004339,
|
||||
0.0241134853100298, 0.00945036684210405, 0.00173819089160705,
|
||||
0.000542358809684406, 0.0123976408935576, 0.00170905959437435,
|
||||
0.00645422348972241, 0.0262912993847153, 0.0244378615928878,
|
||||
0.0230486298969212, 0.00722310170606624, 0.0146203396838926,
|
||||
0.0101631280263959, 0.0205926481279833, 0.0138829042417413,
|
||||
0.0180864495984042, 0.0143994770626774, 0.00106397584149748
|
||||
};
|
||||
|
||||
Factor f (vids, ranges, params);
|
||||
double sum = std::accumulate (f.params().begin(), f.params().end(), 0.0);
|
||||
CPPUNIT_ASSERT (similiar (sum, 1.0));
|
||||
|
||||
f.sumOut (0);
|
||||
f.sumOut (3);
|
||||
f.sumOut (2);
|
||||
|
||||
sum = std::accumulate (f.params().begin(), f.params().end(), 0.0);
|
||||
CPPUNIT_ASSERT (similiar (sum, 1.0));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorTest::testProduct()
|
||||
{
|
||||
VarIds vids1 = {0, 1, 2};
|
||||
Ranges ranges1 = {3, 2, 2};
|
||||
Params params1 = {
|
||||
0.01, 0.02, 0.03, 0.04, 0.05, 0.06,
|
||||
0.07, 0.08, 0.09, 0.10, 0.11, 0.12
|
||||
};
|
||||
|
||||
VarIds vids2 = {1, 3, 0};
|
||||
Ranges ranges2 = {2, 3, 3};
|
||||
Params params2 = {
|
||||
0.15, 0.30, 0.45, 0.60, 0.75, 0.90, 1.20, 1.50, 1.80,
|
||||
0.99, 0.88, 0.77, 0.66, 0.55, 0.44, 0.33, 0.22, 0.11
|
||||
};
|
||||
|
||||
Factor f1 (vids1, ranges1, params1);
|
||||
Factor f2 (vids2, ranges2, params2);
|
||||
f1.multiply (f2);
|
||||
|
||||
Params result = {
|
||||
0.0015, 0.006, 0.012, 0.003, 0.012, 0.024, 0.0297, 0.0198, 0.0099,
|
||||
0.0396, 0.0264, 0.0132, 0.015, 0.0375, 0.075, 0.018, 0.045, 0.09,
|
||||
0.0616, 0.0385, 0.0154, 0.0704, 0.044, 0.0176, 0.0405, 0.081, 0.162,
|
||||
0.045, 0.09, 0.18, 0.0847, 0.0484, 0.0121, 0.0924, 0.0528, 0.0132
|
||||
};
|
||||
|
||||
CPPUNIT_ASSERT (similiar (f1.params(), result));
|
||||
}
|
||||
|
||||
|
||||
|
||||
CPPUNIT_TEST_SUITE_REGISTRATION (FactorTest);
|
||||
|
||||
} // namespace UnitTests
|
||||
|
||||
} // namespace Horus
|
||||
|
15
packages/CLPBN/horus/unit_tests/UnitTesting.cpp
Normal file
15
packages/CLPBN/horus/unit_tests/UnitTesting.cpp
Normal file
@ -0,0 +1,15 @@
|
||||
#include "../Factor.h"
|
||||
|
||||
#include <cppunit/ui/text/TestRunner.h>
|
||||
#include <cppunit/extensions/TestFactoryRegistry.h>
|
||||
|
||||
|
||||
int main()
|
||||
{
|
||||
CppUnit::TextUi::TestRunner runner;
|
||||
CppUnit::TestFactoryRegistry& registry =
|
||||
CppUnit::TestFactoryRegistry::getRegistry();
|
||||
runner.addTest (registry.makeTest());
|
||||
return runner.run() ? 1 : 0;
|
||||
}
|
||||
|
51
packages/CLPBN/horus/unit_tests/VarElimTest.cpp
Normal file
51
packages/CLPBN/horus/unit_tests/VarElimTest.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
#include "../VarElim.h"
|
||||
#include "../FactorGraph.h"
|
||||
#include "Common.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace UnitTests {
|
||||
|
||||
class VarElimTest : public CppUnit::TestFixture {
|
||||
CPPUNIT_TEST_SUITE (VarElimTest);
|
||||
CPPUNIT_TEST (testMarginals);
|
||||
CPPUNIT_TEST (testJoint);
|
||||
CPPUNIT_TEST_SUITE_END();
|
||||
public:
|
||||
void testMarginals();
|
||||
void testJoint();
|
||||
};
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimTest::testMarginals()
|
||||
{
|
||||
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
|
||||
VarElim solver (fg);
|
||||
for (unsigned i = 0; i < marginalProbs.size(); i++) {
|
||||
Params params = solver.solveQuery ({i});
|
||||
CPPUNIT_ASSERT (similiar (params, marginalProbs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimTest::testJoint()
|
||||
{
|
||||
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
|
||||
VarElim solver (fg);
|
||||
Params params = solver.solveQuery ({0, 4, 6});
|
||||
CPPUNIT_ASSERT (similiar (params, jointProbs));
|
||||
}
|
||||
|
||||
|
||||
|
||||
CPPUNIT_TEST_SUITE_REGISTRATION (VarElimTest);
|
||||
|
||||
} // namespace UnitTests
|
||||
|
||||
} // namespace Horus
|
||||
|
471
packages/CLPBN/html/index.html
Normal file
471
packages/CLPBN/html/index.html
Normal file
@ -0,0 +1,471 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset=utf-8>
|
||||
<meta name="description" content="Prolog Factor Language Tutorial">
|
||||
<meta name="keywords" content="Graphical Models, Lifted Inference, First Order, Bayesian Networks, Markov Networks, Variable Elimination">
|
||||
<meta name="author" content="Tiago Gomes">
|
||||
<link rel="stylesheet" type="text/css" href="pfl.css">
|
||||
<title>The Prolog Factor Language</title>
|
||||
</head>
|
||||
|
||||
<body id="top">
|
||||
|
||||
<div class="container">
|
||||
|
||||
<div class="header">
|
||||
<div id="leftcolumn"><h1>Prolog Factor Language</h1></div>
|
||||
<div id="rightcolumn">
|
||||
<div>
|
||||
<div class="name">Vítor Costa</div>
|
||||
<div class="email">vsc at gmail.com </div>
|
||||
</div>
|
||||
<div style="padding-top:10px">
|
||||
<div class="name">Tiago Gomes</div>
|
||||
<div class="email">tiago.avv at gmail.com</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div style="clear: both"></div>
|
||||
<nav id="menu">
|
||||
<ul>
|
||||
<li><a href="#intro">Introduction</a></li>
|
||||
<li><a href="#installation">Installation</a></li>
|
||||
<li><a href="#language">Language</a></li>
|
||||
<li><a href="#querying">Querying</a></li>
|
||||
<li><a href="#options">Options</a></li>
|
||||
<li><a href="#learning">Learning</a></li>
|
||||
<li><a href="#external_interface">External Interface</a></li>
|
||||
<li><a href="#papers">Papers</a></li>
|
||||
</ul>
|
||||
</nav>
|
||||
</div> <!-- end of header -->
|
||||
|
||||
|
||||
<div class="mainbody">
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<h2 id="introduction">Introduction</h2>
|
||||
The Prolog Factor Language (PFL) is a language that extends Prolog for providing a syntax to describe first-order probabilistic graphical models. These models can be either directed (bayesian networks) or undirected (markov networks). This language replaces the old one known as CLP(BN).
|
||||
|
||||
<p>The package also includes implementations for a set of well-known inference algorithms for solving probabilistic queries on these models. Both ground and lifted inference methods are support.</p>
|
||||
|
||||
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
|
||||
|
||||
|
||||
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<h2 id="installation">Installation</h2>
|
||||
PFL is included with the <a href="http://www.dcc.fc.up.pt/~vsc/Yap/">YAP</a> Prolog system. However, there isn't yet a stable release of YAP that includes PFL and you will need to install a development version. To do so, you must have installed the <a href="http://git-scm.com/">Git</a> version control system. The commands to perform a default installation of YAP in your home directory in a Unix-based environment are shown next.
|
||||
<p>
|
||||
<div class=console>
|
||||
<p>$ cd $HOME</p>
|
||||
<p>$ git clone git://yap.git.sourceforge.net/gitroot/yap/yap-6.3</p>
|
||||
<p>$ cd yap-6.3/</p>
|
||||
<p>$ ./configure --enable-clpbn-bp --prefix=$HOME</p>
|
||||
<p>$ make depend & make install</p>
|
||||
</div>
|
||||
|
||||
<p>In case you want to install YAP somewhere else or with different settings, please consult the YAP documentation. From now on, we will assume that the directory <span class=texttt>$HOME ▷ bin</span> (where the binary is) is in your <span class=texttt>$PATH</span> environment variable.</p>
|
||||
|
||||
<p>Once in a while, we will refer to the PFL examples directory. In a default installation, this directory will be located at <span class=texttt> $HOME ▷ share ▷ doc ▷ Yap ▷ packages ▷ examples ▷ CLPBN</span>.</p>
|
||||
|
||||
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
|
||||
|
||||
|
||||
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<h2 id="language">Language</h2>
|
||||
A first-order probabilistic graphical model is described using parametric factors, commonly known as parfactors. The PFL syntax for a parfactor is
|
||||
|
||||
<div style="text-align:center">
|
||||
<br>
|
||||
<em>Type</em> <em>F</em> ; <em>Phi</em> ; <em>C</em>.
|
||||
<br>
|
||||
<br>
|
||||
</div>
|
||||
|
||||
Where,
|
||||
<ul>
|
||||
<li>
|
||||
<em>Type</em> refers the type of network over which the parfactor is defined. It can be <span class=texttt>bayes</span> for directed networks, or <span class=texttt>markov</span> for undirected ones.
|
||||
<p>
|
||||
</li>
|
||||
|
||||
<li>
|
||||
<em>F</em> is a comma-separated sequence of Prolog terms that will define sets of random variables under the constraint <em>C</em>. If <em>Type</em> is <span class=texttt>bayes</span>, the first term defines the node while the remaining terms define its parents.
|
||||
<p>
|
||||
</li>
|
||||
|
||||
<li>
|
||||
<em>Phi</em> is either a Prolog list of potential values or a Prolog goal that unifies with one. Notice that if <em>Type</em> is <span class=texttt>bayes</span>, this will correspond to the conditional probability table. Domain combinations are implicitly assumed in ascending order, with the first term being the 'most significant' (e.g.
|
||||
<span class="texttt">x<sub>0</sub>y<sub>0</sub></span>,
|
||||
<span class="texttt">x<sub>0</sub>y<sub>1</sub></span>,
|
||||
<span class="texttt">x<sub>0</sub>y<sub>2</sub></span>,
|
||||
<span class="texttt">x<sub>1</sub>y<sub>0</sub></span>,
|
||||
<span class="texttt">x<sub>1</sub>y<sub>1</sub></span>,
|
||||
<span class="texttt">x<sub>1</sub>y<sub>2</sub></span>).
|
||||
<p>
|
||||
</li>
|
||||
|
||||
<li>
|
||||
<em>C</em> is a (possibly empty) list of Prolog goals that will instantiate the logical variables that appear in <em>F</em>, that is, the successful substitutions for the goals in <em>C</em> will be the valid values for the logical variables. This allows the constraint to be defined as any relation (set of tuples) over the logical variables.
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
<IMG style="display:block; margin:auto" src="sprinkler.png" alt="Sprinkler Network">
|
||||
|
||||
<p>Towards a better understanding of the language, next we show the PFL representation for the sprinkler network found in the above figure.</p>
|
||||
|
||||
<div class="pflcode">
|
||||
<pre >
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
bayes cloudy ; cloudy_table ; [].
|
||||
|
||||
bayes sprinkler, cloudy ; sprinkler_table ; [].
|
||||
|
||||
bayes rain, cloudy ; rain_table ; [].
|
||||
|
||||
bayes wet_grass, sprinkler, rain ; wet_grass_table ; [].
|
||||
|
||||
cloudy_table(
|
||||
[ 0.5,
|
||||
0.5 ]).
|
||||
|
||||
sprinkler_table(
|
||||
[ 0.1, 0.5,
|
||||
0.9, 0.5 ]).
|
||||
|
||||
rain_table(
|
||||
[ 0.8, 0.2,
|
||||
0.2, 0.8 ]).
|
||||
|
||||
wet_grass_table(
|
||||
[ 0.99, 0.9, 0.9, 0.0,
|
||||
0.01, 0.1, 0.1, 1.0 ]).
|
||||
</pre>
|
||||
</div>
|
||||
|
||||
<p>In the example, we started by loading the PFL library, then we have defined one factor for each node, and finally we have specified the probabilities for each conditional probability table.</p>
|
||||
|
||||
<p>Notice that this network is fully grounded, as all constraints are empty. Next we present the PFL representation for a well-known markov logic network - the social network model. For convenience, the two main weighted formulas of this model are shown below.</p>
|
||||
|
||||
<div class="pflcode">
|
||||
<pre>
|
||||
1.5 : Smokes(x) => Cancer(x)
|
||||
1.1 : Smokes(x) ^ Friends(x,y) => Smokes(y)
|
||||
</pre>
|
||||
</div>
|
||||
|
||||
<p>Next, we show the PFL representation for this model.</p>
|
||||
|
||||
<div class="pflcode">
|
||||
<pre>
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
person(anna).
|
||||
person(bob).
|
||||
|
||||
markov smokes(X), cancer(X) ;
|
||||
[4.482, 4.482, 1.0, 4.482] ;
|
||||
[person(X)].
|
||||
|
||||
markov friends(X,Y), smokes(X), smokes(Y) ;
|
||||
[3.004, 3.004, 3.004, 3.004, 3.004, 1.0, 1.0, 3.004] ;
|
||||
[person(X), person(Y)].
|
||||
</pre>
|
||||
</div>
|
||||
|
||||
<p>Notice that we have defined the world to be consisted of only two persons, <span class=texttt>anna</span> and <span class=texttt>bob</span>. We can easily add as many persons as we want by inserting in the program a fact like <span class=texttt>person @ 10.</span> . This would automatically create ten persons named <span class=texttt>p1</span>, <span class=texttt>p2</span>, ..., <span class=texttt>p10</span>.</p>
|
||||
|
||||
<p>Unlike other fist-order probabilistic languages, in PFL the logical variables that appear in the terms are not directly typed, and they will be only constrained by the goals that appears in the constraint of the parfactor. This allows the logical variables to be constrained to any relation (set of tuples), and not only pairwise (in)equalities. For instance, the next example defines a network with three ground factors, each defined respectively over the random variables <span class=texttt>p(a,b)</span>, <span class=texttt>p(b,d)</span> and <span class=texttt>p(d,e)</span>.</p>
|
||||
|
||||
<div class="pflcode">
|
||||
<pre>
|
||||
constraint(a,b).
|
||||
constraint(b,d).
|
||||
constraint(d,e).
|
||||
|
||||
markov p(A,B); some_table; [constraint(A,B)].
|
||||
</pre>
|
||||
</div>
|
||||
|
||||
<p>We can easily add static evidence to PFL programs by inserting a fact with the same functor and arguments as the random variable, plus one extra argument with the observed state or value. For instance, suppose that we know that <span class=texttt>anna</span> and <span class=texttt>bob</span> are friends. We can add this knowledge to the program with the following fact: <span class=texttt>friends(anna,bob,t).</span> .</p>
|
||||
|
||||
<p>One last note for the domain of the random variables. By default, all terms instantiate boolean (<span class=texttt>t</span>/<span class=texttt>f</span>) random variables. It is possible to choose a different domain for a term by appending a list of its possible values or states. Next we present a self-explanatory example of how this can be done.</p>
|
||||
|
||||
<div class="pflcode">
|
||||
<pre>
|
||||
bayes professor_ability::[high, medium, low] ; [0.5, 0.4, 0.1].
|
||||
</pre>
|
||||
</div>
|
||||
|
||||
<p>More probabilistic models defined using PFL can be found in the examples directory.</p>
|
||||
|
||||
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
|
||||
|
||||
|
||||
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<h2 id=querying>Querying</h2>
|
||||
In this section we demonstrate how to use PFL to solve probabilistic queries. We will use the sprinkler network as example.
|
||||
|
||||
<p>Assuming that the current directory is the one where the examples are located, first we load the model with the following command.</p>
|
||||
|
||||
<p class=console>$ yap -l sprinkler.pfl</p>
|
||||
|
||||
<p>Let's suppose that we want to estimate the marginal probability for the <em>WetGrass</em> random variable. To do so, we call the following goal.</p>
|
||||
|
||||
<p class=console>?- wet_grass(X).</p>
|
||||
|
||||
<p>The output of this goal will show the marginal probability for each <em>WetGrass</em> possible state or value, that is, <span class=texttt>t</span> and <span class=texttt>f</span>. Notice that in PFL a random variable is identified by a term with the same functor and arguments plus one extra argument.</p>
|
||||
|
||||
<p>Now let's suppose that we want to estimate the probability for the same random variable, but this time we have evidence that it had rained in the day before. We can estimate this probability without resorting to static evidence with:</p>
|
||||
|
||||
<p class=console>?- wet_grass(X), rain(t).</p>
|
||||
|
||||
<p>PFL also supports calculating joint probability distributions. For instance, we can obtain the joint probability for <em>Sprinkler</em> and <em>Rain</em> with:</p>
|
||||
|
||||
<p class=console>?- sprinkler(X), rain(Y).</p>
|
||||
|
||||
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
|
||||
|
||||
|
||||
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<h2 id="options">Options</h2>
|
||||
PFL supports both ground and lifted inference methods. The inference algorithm can be chosen by calling <span class=texttt>set_solver/1</span>. The following are supported:
|
||||
|
||||
<ul>
|
||||
<li><span class=texttt>ve</span>, variable elimination (written in Prolog)</li>
|
||||
<li><span class=texttt>hve</span>, variable elimination (written in C++)</li>
|
||||
<li><span class=texttt>jt</span>, junction tree</li>
|
||||
<li><span class=texttt>bdd</span>, binary decision diagrams</li>
|
||||
<li><span class=texttt>bp</span>, belief propagation</li>
|
||||
<li><span class=texttt>cbp</span>, counting belief propagation</li>
|
||||
<li><span class=texttt>gibbs</span>, gibbs sampling</li>
|
||||
<li><span class=texttt>lve</span>, generalized counting first-order variable elimination (GC-FOVE)</li>
|
||||
<li><span class=texttt>lkc</span>, lifted first-order knowledge compilation</li>
|
||||
<li><span class=texttt>lbp</span>, lifted first-order belief propagation</li>
|
||||
</ul>
|
||||
|
||||
<p>For instance, if we want to use belief propagation to solve some probabilistic query, we need to call first:</p>
|
||||
|
||||
<p class=console>?- set_solver(bp).</p>
|
||||
|
||||
<p>It is possible to tweak some parameters of PFL through <span class=texttt>set_pfl_flag/2</span> predicate. The first argument is a option name that identifies the parameter that we want to tweak. The second argument is some possible value for this option. Next we explain the available options in detail.</p>
|
||||
|
||||
<h3><span class=texttt>verbosity</span></h3>
|
||||
This option controls the level of debugging information that will be shown.
|
||||
<ul>
|
||||
<li>Values: a positive integer (default is 0 - no debugging). The higher the number, the more information that will be shown.</li>
|
||||
<li>Affects: <span class=texttt>hve</span>, <span class=texttt>bp</span>, <span class=texttt>cbp</span>, <span class=texttt>lve</span>, <span class=texttt>lkc</span> and <span class=texttt>lbp</span>.</li>
|
||||
</ul>
|
||||
<p>
|
||||
For instance, we can view some basic debugging information by calling the following goal.
|
||||
<p class="console">?- set_pfl_flag(verbosity, 1).</p>
|
||||
|
||||
<h3><span class=texttt>use_logarithms</span></h3>
|
||||
This option controls whether the calculations performed during inference should be done in a logarithm domain or not.
|
||||
<ul>
|
||||
<li>Values: <span class=texttt>true</span> (default) or <span class=texttt>false</span>.</li>
|
||||
<li>Affects: <span class=texttt>hve</span>, <span class=texttt>bp</span>, <span class=texttt>cbp</span>, <span class=texttt>lve</span>, <span class=texttt>lkc</span> and <span class=texttt>lbp</span>.</li>
|
||||
</ul>
|
||||
|
||||
<h3><span class=texttt>hve_elim_heuristic</span></h3>
|
||||
This option allows to choose which elimination heuristic will be used by the <span class=texttt>hve</span>.
|
||||
<ul>
|
||||
<li>Values: <span class=texttt>sequential</span>, <span class=texttt>min_neighbors</span>, <span class=texttt>min_weight</span>, <span class=texttt>min_fill</span> and
|
||||
<br><span class=texttt>weighted_min_fill</span> (default).</li>
|
||||
<li>Affects: <span class=texttt>hve</span>.</li>
|
||||
</ul>
|
||||
<p>An explanation for each of these heuristics can be found in Daphne Koller's book <em>Probabilistic Graphical Models</em>.</p>
|
||||
|
||||
<h3><span class=texttt>bp_max_iter</span></h3>
|
||||
This option establishes a maximum number of iterations. One iteration consists in sending all possible messages.
|
||||
<ul>
|
||||
<li>Values: a positive integer (default is <span class=texttt>1000</span>).</li>
|
||||
<li>Affects: <span class=texttt>bp</span>, <span class=texttt>cbp</span> and <span class=texttt>lbp</span>.</li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
<h3><span class=texttt>bp_accuracy</span></h3>
|
||||
This option allows to control when the message passing should cease. Be the residual of one message the difference (according some metric) between the one sent in the current iteration and the one sent in the previous. If the highest residual is lesser than the given value, the message passing is stopped and the probabilities are calculated using the last messages that were sent.
|
||||
<ul>
|
||||
<li>Values: a float-point number (default is <span class=texttt>0.0001</span>).</li>
|
||||
<li>Affects: <span class=texttt>bp</span>, <span class=texttt>cbp</span> and <span class=texttt>lbp</span>.</li>
|
||||
</ul>
|
||||
|
||||
<h3><span class=texttt>bp_msg_schedule</span></h3>
|
||||
This option allows to control the message sending order.
|
||||
<ul>
|
||||
<li>Values:
|
||||
<ul>
|
||||
<li><span class=texttt>seq_fixed</span> (default), at each iteration, all messages are sent with the same order.<p></li>
|
||||
<li><span class=texttt>seq_random</span>, at each iteration, all messages are sent with a random order.<p></li>
|
||||
<li><span class=texttt>parallel</span>, at each iteration, all messages are calculated using only the values of the previous iteration.<p></li>
|
||||
<li><span class=texttt>max_residual</span>, the next message to be sent is the one with maximum residual (as explained in the paper <em>Residual Belief Propagation: Informed Scheduling for Asynchronous Message Passing</em>).</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>Affects: <span class=texttt>bp</span>, <span class=texttt>cbp</span> and <span class=texttt>lbp</span>.
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
<h3><span class=texttt>export_libdai</span></h3>
|
||||
This option allows exporting the current model to the <a href="http://cs.ru.nl/~jorism/libDAI/doc/fileformats.html">libDAI</a> file format.
|
||||
<ul>
|
||||
<li>Values: <span class=texttt>true</span> or <span class=texttt>false</span> (default).</li>
|
||||
<li>Affects: <span class=texttt>hve</span>, <span class=texttt>bp</span>, and <span class=texttt>cbp</span>.</li>
|
||||
</ul>
|
||||
|
||||
<h3><span class=texttt>export_uai</span></h3>
|
||||
This option allows exporting the current model to the <a href="http://graphmod.ics.uci.edu/uai08/FileFormat">UAI</a> file format.
|
||||
<ul>
|
||||
<li>Values: <span class=texttt>true</span> or <span class=texttt>false</span> (default).</li>
|
||||
<li>Affects: <span class=texttt>hve</span>, <span class=texttt>bp</span>, and <span class=texttt>cbp</span>.</li>
|
||||
</ul>
|
||||
|
||||
<h3><span class=texttt>export_graphviz</span></h3>
|
||||
This option allows exporting the factor graph's structure into a format that can be parsed by <a href="http://www.graphviz.org/">Graphviz</a>.
|
||||
<ul>
|
||||
<li>Values: <span class=texttt>true</span> or <span class=texttt>false</span> (default).</li>
|
||||
<li>Affects: <span class=texttt>hve</span>, <span class=texttt>bp</span>, and <span class=texttt>cbp</span>.</li>
|
||||
</ul>
|
||||
|
||||
<h3><span class=texttt>print_fg</span></h3>
|
||||
This option allows to print a textual representation of the factor graph.
|
||||
<ul>
|
||||
<li>Values: <span class=texttt>true</span> or <span class=texttt>false</span> (default).</li>
|
||||
<li>Affects: <span class=texttt>hve</span>, <span class=texttt>bp</span>, and <span class=texttt>cbp</span>.</li>
|
||||
</ul>
|
||||
|
||||
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
|
||||
|
||||
|
||||
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<h2 id="learning">Learning</h2>
|
||||
PFL is capable to learn the parameters for bayesian networks, through an implementation of the expectation-maximization algorithm.
|
||||
|
||||
<p>Next we show an example of parameter learning for the sprinkler network.</p>
|
||||
|
||||
<div class="pflcode">
|
||||
<pre>
|
||||
:- [sprinkler.pfl].
|
||||
|
||||
:- use_module(library(clpbn/learning/em)).
|
||||
|
||||
data(t, t, t, t).
|
||||
data(_, t, _, t).
|
||||
data(t, t, f, f).
|
||||
data(t, t, f, t).
|
||||
data(t, _, _, t).
|
||||
data(t, f, t, t).
|
||||
data(t, t, f, t).
|
||||
data(t, _, f, f).
|
||||
data(t, t, f, f).
|
||||
data(f, f, t, t).
|
||||
|
||||
main :-
|
||||
findall(X, scan_data(X), L),
|
||||
em(L, 0.01, 10, CPTs, LogLik),
|
||||
writeln(LogLik:CPTs).
|
||||
|
||||
scan_data([cloudy(C), sprinkler(S), rain(R), wet_grass(W)]) :-
|
||||
data(C, S, R, W).
|
||||
</pre>
|
||||
</div>
|
||||
|
||||
<p>Parameter learning is done by calling the <span class=texttt>em/5</span> predicate. Its arguments are the following.</p>
|
||||
|
||||
<div style="text-align:center">
|
||||
<br>
|
||||
<span class=texttt>em(+Data, +MaxError, +MaxIters, -CPTs, -LogLik)</span>
|
||||
<br>
|
||||
<br>
|
||||
</div>
|
||||
|
||||
Where,
|
||||
<ul>
|
||||
<li><span class=texttt>Data</span> is a list of samples for the distribution that we want to estimate. Each sample is a list of either observed random variables or unobserved random variables (denoted when its state or value is not instantiated).</li>
|
||||
<li><span class=texttt>MaxError</span> is the maximum error allowed before stopping the EM loop.</li>
|
||||
<li><span class=texttt>MaxIters</span> is the maximum number of iterations for the EM loop.</li>
|
||||
<li><span class=texttt>CPTs</span> is a list with the estimated conditional probability tables.</li>
|
||||
<li><span class=texttt>LogLik</span> is the log-likelihood.</li>
|
||||
</ul>
|
||||
|
||||
<p>It is possible to choose the solver that will be used for the inference part during parameter learning with the <span class=texttt>set_em_solver/1</span> predicate (defaults to <span class=texttt>hve</span>). At the moment, only the following solvers support parameter learning: <span class=texttt>ve</span>, <span class=texttt>hve</span>, <span class=texttt>bdd</span>, <span class=texttt>bp</span> and <span class=texttt>cbp</span>.</p>
|
||||
|
||||
<p>Inside the <span class=texttt>learning</span> directory from the examples directory, one can find more examples of parameter learning.</p>
|
||||
|
||||
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
|
||||
|
||||
|
||||
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<h2 id="external_interface">External Interface</h2>
|
||||
This package also includes an external command for perform inference over probabilistic graphical models described in formats other than PFL. Currently two are support, the http://cs.ru.nl/ jorism/libDAI/doc/fileformats.htmllibDAI file format, and the http://graphmod.ics.uci.edu/uai08/FileFormatUAI08 file format.
|
||||
|
||||
<p>This command's name is <span class=texttt>hcli</span> and its usage is as follows.</p>
|
||||
|
||||
<p class=console>$ ./hcli [solver=hve|bp|cbp] [<OPTION>=<VALUE>]... <FILE>[<VAR>|<VAR>=<EVIDENCE>]... </p>
|
||||
|
||||
<p>Let's assume that the current directory is the one where the examples are located. We can perform inference in any supported model by passing the file name where the model is defined as argument. Next, we show how to load a model with <span class=texttt>hcli</span>.</p>
|
||||
|
||||
<p class=console>$ ./hcli burglary-alarm.uai</p>
|
||||
|
||||
<p>With the above command, the program will load the model and print the marginal probabilities for all defined random variables. We can view only the marginal probability for some variable with a identifier <em>X</em>, if we pass <em>X</em> as an extra argument following the file name. For instance, the following command will output only the marginal probability for the variable with identifier <em>0</em>.</p>
|
||||
|
||||
<p class=console>$ ./hcli burglary-alarm.uai 0</p>
|
||||
|
||||
<p>If we give more than one variable identifier as argument, the program will output the joint probability for all the passed variables.</p>
|
||||
|
||||
<p>Evidence can be given as a pair containing a variable identifier and its observed state (index), separated by a '=`. For instance, we can introduce knowledge that some variable with identifier <em>0</em> has evidence on its second state as follows.</p>
|
||||
|
||||
<p class=console>$ ./hcli burglary-alarm.uai 0=1</p>
|
||||
|
||||
<p>By default, all probability tasks are resolved using the <span class=texttt>hve</span> solver. It is possible to choose another solver using the <span class=texttt>solver</span> option as follows.</p>
|
||||
|
||||
<p class=console>$ ./hcli solver=bp burglary-alarm.uai</p>
|
||||
|
||||
<p>Notice that only the <span class=texttt>hve</span>, <span class=texttt>bp</span> and <span class=texttt>cbp</span> solvers can be used with <span class=texttt>hcli</span>.</p>
|
||||
|
||||
<p>The options that are available with the <span class=texttt>set_pfl_flag/2</span> predicate can be used in <span class=texttt>hcli</span> too. The syntax is a pair <span class=texttt><Option>=<Value></span> before the model's file name.</p>
|
||||
|
||||
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
|
||||
|
||||
|
||||
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
|
||||
<h2 id="papers">Papers</h2>
|
||||
<ul>
|
||||
<li><em>Evaluating Inference Algorithms for the Prolog Factor Language.</em></li>
|
||||
</ul>
|
||||
|
||||
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
|
||||
|
||||
</div> <!-- end of mainbody -->
|
||||
|
||||
<div class="footer"></div>
|
||||
|
||||
</div> <!-- end of container -->
|
||||
|
||||
</body>
|
||||
</html>
|
118
packages/CLPBN/html/pfl.css
Normal file
118
packages/CLPBN/html/pfl.css
Normal file
@ -0,0 +1,118 @@
|
||||
html {
|
||||
background: gray
|
||||
}
|
||||
|
||||
body {
|
||||
margin: 10px;
|
||||
}
|
||||
|
||||
.container {
|
||||
width: 900px;
|
||||
margin: auto;
|
||||
padding: 10px 60px;
|
||||
background-color: #ffffff;
|
||||
box-shadow: 0px 2px 7px #292929;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.header,
|
||||
.mainbody,
|
||||
.footer {
|
||||
padding: 0px;
|
||||
}
|
||||
|
||||
.header {
|
||||
height: 120px;
|
||||
border-bottom: 1px solid #EEE;
|
||||
background-color: #ffffff;
|
||||
border-top-left-radius: 5px;
|
||||
border-top-right-radius: 5px;
|
||||
}
|
||||
|
||||
#leftcolumn {
|
||||
width: 550px;
|
||||
float: left
|
||||
}
|
||||
|
||||
#rightcolumn {
|
||||
float: right;
|
||||
text-align:right;
|
||||
font-family: sans-serif;
|
||||
font-size:15px;
|
||||
line-height:15px;
|
||||
padding-top:15px;
|
||||
padding-right:0px;
|
||||
margin-right:0px;
|
||||
}
|
||||
|
||||
div.name {
|
||||
color:#2798CA;
|
||||
}
|
||||
|
||||
div.email {
|
||||
color:gray;
|
||||
}
|
||||
|
||||
.mainbody {
|
||||
padding-top:10px;
|
||||
}
|
||||
|
||||
.footer {
|
||||
height: 5px;
|
||||
border-bottom-left-radius: 5px;
|
||||
border-bottom-right-radius: 5px;
|
||||
}
|
||||
|
||||
#menu ul {
|
||||
padding:0px;
|
||||
margin: 20px 0px;
|
||||
background-color:#EDEDED;
|
||||
list-style:none;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
#menu ul li {
|
||||
display: inline;
|
||||
}
|
||||
|
||||
#menu ul li a {
|
||||
padding: 2px 10px;
|
||||
display: inline-block;
|
||||
color: #333;
|
||||
text-decoration: none;
|
||||
border-bottom:3px solid #EDEDED;
|
||||
}
|
||||
|
||||
#menu ul li a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
h2 {
|
||||
padding-top: 20px;
|
||||
padding-bottom: 0px;
|
||||
}
|
||||
|
||||
.pflcode {
|
||||
border-radius: 15px;
|
||||
max-width:550px;
|
||||
padding:5px 20px;
|
||||
margin:auto;
|
||||
margin-top:35px;
|
||||
margin-bottom:35px;
|
||||
background-color: #FAFAD3;
|
||||
font: normal 14px verdana;
|
||||
border: solid 1px #ddd;
|
||||
}
|
||||
|
||||
.console {
|
||||
font-family: monospace;
|
||||
color:white;
|
||||
background-color:black;
|
||||
padding: 5px;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.texttt {
|
||||
font-family: monospace
|
||||
}
|
||||
|
BIN
packages/CLPBN/html/sprinkler.png
Normal file
BIN
packages/CLPBN/html/sprinkler.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 31 KiB |
@ -38,7 +38,7 @@ CRACS \& INESC TEC, Faculty of Sciences, University of Porto
|
||||
\thispagestyle{empty}
|
||||
\vspace{5cm}
|
||||
\begin{center}
|
||||
\large Last revision: January 18, 2013
|
||||
\large Last revision: April 12, 2013
|
||||
\end{center}
|
||||
\newpage
|
||||
|
||||
@ -87,7 +87,7 @@ A first-order probabilistic graphical model is described using parametric factor
|
||||
|
||||
$$Type~~F~~;~~Phi~~;~~C.$$
|
||||
|
||||
, where
|
||||
Where,
|
||||
\begin{itemize}
|
||||
\item $Type$ refers the type of network over which the parfactor is defined. It can be \texttt{bayes} for directed networks, or \texttt{markov} for undirected ones.
|
||||
|
||||
@ -274,7 +274,7 @@ PFL also supports calculating joint probability distributions. For instance, we
|
||||
%------------------------------------------------------------------------------
|
||||
%------------------------------------------------------------------------------
|
||||
%------------------------------------------------------------------------------
|
||||
\section{Inference Options}
|
||||
\section{Options}
|
||||
PFL supports both ground and lifted inference methods. The inference algorithm can be chosen by calling \texttt{set\_solver/1}. The following are supported:
|
||||
\begin{itemize}
|
||||
\item \texttt{ve}, variable elimination (written in Prolog)
|
||||
@ -397,7 +397,7 @@ This option allows to print a textual representation of the factor graph.
|
||||
%------------------------------------------------------------------------------
|
||||
%------------------------------------------------------------------------------
|
||||
%------------------------------------------------------------------------------
|
||||
\section{Parameter Learning}
|
||||
\section{Learning}
|
||||
PFL is capable to learn the parameters for bayesian networks, through an implementation of the expectation-maximization algorithm.
|
||||
|
||||
Next we show an example of parameter learning for the sprinkler network.
|
||||
@ -429,7 +429,9 @@ scan_data([cloudy(C), sprinkler(S), rain(R), wet_grass(W)]) :-
|
||||
|
||||
Parameter learning is done by calling the \texttt{em/5} predicate. Its arguments are the following.
|
||||
|
||||
\begin{center}
|
||||
\texttt{em(+Data, +MaxError, +MaxIters, -CPTs, -LogLik)}
|
||||
\end{center}
|
||||
|
||||
Where,
|
||||
\begin{itemize}
|
||||
@ -489,9 +491,9 @@ The options that are available with the \texttt{set\_pfl\_flag/2} predicate can
|
||||
%------------------------------------------------------------------------------
|
||||
%------------------------------------------------------------------------------
|
||||
%------------------------------------------------------------------------------
|
||||
\section{Further Information}
|
||||
Please check the paper \textit{Evaluating Inference Algorithms for the Prolog Factor Language} for further information.
|
||||
|
||||
Any question? Don't hesitate to contact us!
|
||||
\section{Papers}
|
||||
\begin{itemize}
|
||||
\item \textit{Evaluating Inference Algorithms for the Prolog Factor Language}.
|
||||
\end{itemize}
|
||||
|
||||
\end{document}
|
||||
|
Reference in New Issue
Block a user