This commit is contained in:
Vitor Santos Costa 2013-04-16 00:18:07 +01:00
commit eeb53aef71
69 changed files with 5787 additions and 3602 deletions

View File

@ -0,0 +1,137 @@
10
2
0 1
2 2
4
0 1.02
1 0.87
2 0.88
3 0.45
4
1 2 3 4
2 2 3 3
36
0 0.11
1 1.11
2 0.41
3 0.12
4 0.1
5 0.17
6 1.21
7 1.1
8 0.11
9 0.41
10 0.8
11 0.71
12 0.14
13 0.24
14 0.54
15 1.4
16 0.23
17 0.24
18 0.65
19 0.05
20 0.32
21 0.12
22 0.99
23 0.69
24 0.29
25 1.29
26 0.15
27 1.24
28 0.42
29 0.124
30 0.67
31 0.078
32 0.14
33 0.55
34 0.45
35 0.1
3
2 5 6
2 2 3
12
0 0.15
1 0.55
2 2.21
3 5.71
4 0.44
5 0.14
6 0.5
7 1.75
8 1.29
9 3.29
10 0.36
11 1.56
2
7 2
4 2
8
0 0.11
1 0.59
2 0.15
3 0.124
4 0.41
5 2.11
6 1.06
7 0.929
1
3
3
3
0 0.1
1 0.58
2 0.74
1
4
3
3
0 3.2
1 0.28
2 1.24
2
8 4
2 3
6
0 0.19
1 3.1
2 0.49
3 1.5
4 2.1
5 2.8
1
5
2
2
0 0.74
1 0.14
1
6
3
3
0 0.032
1 0.028
2 0.24
2
9 7
2 4
8
0 0.61
1 0.61
2 1.4
3 0.24
4 0.09
5 0.19
6 1.4
7 0.6

View File

@ -0,0 +1,79 @@
:- use_module(library(pfl)).
%:- set_solver(ve).
%:- set_solver(hve).
%:- set_solver(jt).
%:- set_solver(bdd).
%:- set_solver(bp).
%:- set_solver(cbp).
%:- set_solver(gibbs).
%:- set_solver(lve).
%:- set_solver(lkc).
%:- set_solver(lbp).
/*
v01 v02
\ /
\ /
\ /
v03 v04 v05
/ \ | / \
/ \ | / \
/ \ | / \
v06 v07 v08
| |
| |
| |
v09 v10
*/
markov v01::[a,b] ; table1 ; [].
markov v02::[a,b,c] ; table2 ; [].
markov v03::[a,b], v01, v02 ; table3 ; [].
markov v04::[a,b,c] ; table4 ; [].
markov v05::[a,b,c] ; table5 ; [].
markov v06::[a,b,c,d], v03 ; table6 ; [].
markov v07::[a,b], v03, v04, v05 ; table7 ; [].
markov v08::[a,b], v05 ; table8 ; [].
markov v09::[a,b], v06 ; table9 ; [].
markov v10::[a,b], v07 ; table10 ; [].
table1([ 0.74, 0.14 ]).
table2([ 0.032, 0.028, 0.24 ]).
table3([
0.15, 0.44, 1.29, 2.21, 0.5, 0.36,
0.55, 0.14, 3.29, 5.71, 1.75, 1.56
]).
table4([ 0.1, 0.58, 0.74 ]).
table5([ 3.2, 0.28, 1.24 ]).
table6([ 0.11, 0.41, 0.59, 2.11, 0.15, 1.06, 0.124, 0.929 ]).
table7([
0.11, 0.14, 0.29, 0.1, 0.23, 0.42, 0.11, 0.32, 0.14,
0.41, 0.54, 0.15, 1.21, 0.65, 0.67, 0.8, 0.99, 0.45,
1.11, 0.24, 1.29, 0.17, 0.24, 0.124, 0.41, 0.12, 0.55,
0.12, 1.4, 1.24, 1.1, 0.05, 0.078, 0.71, 0.69, 0.1
]).
table8([ 0.19, 0.49, 2.1, 3.1, 1.5, 2.8 ]).
table9([ 0.61, 1.4, 0.09, 1.4, 0.61, 0.24, 0.19, 0.6 ]).
table10([ 1.02, 0.88, 0.87, 0.45 ]).

View File

@ -3,6 +3,25 @@
#include "BayesBall.h" #include "BayesBall.h"
namespace Horus {
BayesBall::BayesBall (FactorGraph& fg)
: fg_(fg) , dag_(fg.getStructure())
{
dag_.clear();
}
FactorGraph*
BayesBall::getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
{
BayesBall bb (fg);
return bb.getMinimalFactorGraph (vids);
}
FactorGraph* FactorGraph*
BayesBall::getMinimalFactorGraph (const VarIds& queryIds) BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
{ {
@ -19,22 +38,22 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
BBNode* n = sch.node; BBNode* n = sch.node;
n->setAsVisited(); n->setAsVisited();
if (n->hasEvidence() == false && sch.visitedFromChild) { if (n->hasEvidence() == false && sch.visitedFromChild) {
if (n->isMarkedOnTop() == false) { if (n->isMarkedAbove() == false) {
n->markOnTop(); n->markAbove();
scheduleParents (n, scheduling); scheduleParents (n, scheduling);
} }
if (n->isMarkedOnBottom() == false) { if (n->isMarkedBelow() == false) {
n->markOnBottom(); n->markBelow();
scheduleChilds (n, scheduling); scheduleChilds (n, scheduling);
} }
} }
if (sch.visitedFromParent) { if (sch.visitedFromParent) {
if (n->hasEvidence() && n->isMarkedOnTop() == false) { if (n->hasEvidence() && n->isMarkedAbove() == false) {
n->markOnTop(); n->markAbove();
scheduleParents (n, scheduling); scheduleParents (n, scheduling);
} }
if (n->hasEvidence() == false && n->isMarkedOnBottom() == false) { if (n->hasEvidence() == false && n->isMarkedBelow() == false) {
n->markOnBottom(); n->markBelow();
scheduleChilds (n, scheduling); scheduleChilds (n, scheduling);
} }
} }
@ -55,7 +74,7 @@ BayesBall::constructGraph (FactorGraph* fg) const
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
const BBNode* n = dag_.getNode ( const BBNode* n = dag_.getNode (
facNodes[i]->factor().argument (0)); facNodes[i]->factor().argument (0));
if (n->isMarkedOnTop()) { if (n->isMarkedAbove()) {
fg->addFactor (facNodes[i]->factor()); fg->addFactor (facNodes[i]->factor());
} else if (n->hasEvidence() && n->isVisited()) { } else if (n->hasEvidence() && n->isVisited()) {
VarIds varIds = { facNodes[i]->factor().argument (0) }; VarIds varIds = { facNodes[i]->factor().argument (0) };
@ -76,3 +95,5 @@ BayesBall::constructGraph (FactorGraph* fg) const
} }
} }
} // namespace Horus

View File

@ -1,5 +1,5 @@
#ifndef HORUS_BAYESBALL_H #ifndef YAP_PACKAGES_CLPBN_HORUS_BAYESBALL_H_
#define HORUS_BAYESBALL_H #define YAP_PACKAGES_CLPBN_HORUS_BAYESBALL_H_
#include <vector> #include <vector>
#include <queue> #include <queue>
@ -9,11 +9,19 @@
#include "BayesBallGraph.h" #include "BayesBallGraph.h"
#include "Horus.h" #include "Horus.h"
using namespace std;
namespace Horus {
struct ScheduleInfo class BayesBall {
{ public:
BayesBall (FactorGraph& fg);
FactorGraph* getMinimalFactorGraph (const VarIds&);
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids);
private:
struct ScheduleInfo {
ScheduleInfo (BBNode* n, bool vfp, bool vfc) ScheduleInfo (BBNode* n, bool vfp, bool vfc)
: node(n), visitedFromParent(vfp), visitedFromChild(vfc) { } : node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
@ -22,28 +30,7 @@ struct ScheduleInfo
bool visitedFromChild; bool visitedFromChild;
}; };
typedef std::queue<ScheduleInfo, std::list<ScheduleInfo>> Scheduling;
typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
class BayesBall
{
public:
BayesBall (FactorGraph& fg)
: fg_(fg) , dag_(fg.getStructure())
{
dag_.clear();
}
FactorGraph* getMinimalFactorGraph (const VarIds&);
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
{
BayesBall bb (fg);
return bb.getMinimalFactorGraph (vids);
}
private:
void constructGraph (FactorGraph* fg) const; void constructGraph (FactorGraph* fg) const;
@ -52,7 +39,6 @@ class BayesBall
void scheduleChilds (const BBNode* n, Scheduling& sch) const; void scheduleChilds (const BBNode* n, Scheduling& sch) const;
FactorGraph& fg_; FactorGraph& fg_;
BayesBallGraph& dag_; BayesBallGraph& dag_;
}; };
@ -61,8 +47,8 @@ class BayesBall
inline void inline void
BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const
{ {
const vector<BBNode*>& ps = n->parents(); const std::vector<BBNode*>& ps = n->parents();
for (vector<BBNode*>::const_iterator it = ps.begin(); for (std::vector<BBNode*>::const_iterator it = ps.begin();
it != ps.end(); ++it) { it != ps.end(); ++it) {
sch.push (ScheduleInfo (*it, false, true)); sch.push (ScheduleInfo (*it, false, true));
} }
@ -73,12 +59,14 @@ BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const
inline void inline void
BayesBall::scheduleChilds (const BBNode* n, Scheduling& sch) const BayesBall::scheduleChilds (const BBNode* n, Scheduling& sch) const
{ {
const vector<BBNode*>& cs = n->childs(); const std::vector<BBNode*>& cs = n->childs();
for (vector<BBNode*>::const_iterator it = cs.begin(); for (std::vector<BBNode*>::const_iterator it = cs.begin();
it != cs.end(); ++it) { it != cs.end(); ++it) {
sch.push (ScheduleInfo (*it, true, false)); sch.push (ScheduleInfo (*it, true, false));
} }
} }
#endif // HORUS_BAYESBALL_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_BAYESBALL_H_

View File

@ -1,14 +1,15 @@
#include <cstdlib>
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#include <sstream>
#include <fstream> #include <fstream>
#include <sstream>
#include "BayesBallGraph.h" #include "BayesBallGraph.h"
#include "Util.h" #include "Util.h"
namespace Horus {
void void
BayesBallGraph::addNode (BBNode* n) BayesBallGraph::addNode (BBNode* n)
{ {
@ -22,8 +23,8 @@ BayesBallGraph::addNode (BBNode* n)
void void
BayesBallGraph::addEdge (VarId vid1, VarId vid2) BayesBallGraph::addEdge (VarId vid1, VarId vid2)
{ {
unordered_map<VarId, BBNode*>::iterator it1; std::unordered_map<VarId, BBNode*>::iterator it1;
unordered_map<VarId, BBNode*>::iterator it2; std::unordered_map<VarId, BBNode*>::iterator it2;
it1 = varMap_.find (vid1); it1 = varMap_.find (vid1);
it2 = varMap_.find (vid2); it2 = varMap_.find (vid2);
assert (it1 != varMap_.end()); assert (it1 != varMap_.end());
@ -37,7 +38,7 @@ BayesBallGraph::addEdge (VarId vid1, VarId vid2)
const BBNode* const BBNode*
BayesBallGraph::getNode (VarId vid) const BayesBallGraph::getNode (VarId vid) const
{ {
unordered_map<VarId, BBNode*>::const_iterator it; std::unordered_map<VarId, BBNode*>::const_iterator it;
it = varMap_.find (vid); it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0; return it != varMap_.end() ? it->second : 0;
} }
@ -47,7 +48,7 @@ BayesBallGraph::getNode (VarId vid) const
BBNode* BBNode*
BayesBallGraph::getNode (VarId vid) BayesBallGraph::getNode (VarId vid)
{ {
unordered_map<VarId, BBNode*>::const_iterator it; std::unordered_map<VarId, BBNode*>::const_iterator it;
it = varMap_.find (vid); it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0; return it != varMap_.end() ? it->second : 0;
} }
@ -55,7 +56,7 @@ BayesBallGraph::getNode (VarId vid)
void void
BayesBallGraph::setIndexes (void) BayesBallGraph::setIndexes()
{ {
for (size_t i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i); nodes_[i]->setIndex (i);
@ -65,7 +66,7 @@ BayesBallGraph::setIndexes (void)
void void
BayesBallGraph::clear (void) BayesBallGraph::clear()
{ {
for (size_t i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
nodes_[i]->clear(); nodes_[i]->clear();
@ -77,13 +78,14 @@ BayesBallGraph::clear (void)
void void
BayesBallGraph::exportToGraphViz (const char* fileName) BayesBallGraph::exportToGraphViz (const char* fileName)
{ {
ofstream out (fileName); std::ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ; std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return; return;
} }
out << "digraph {" << endl; out << "digraph {" << std::endl;
out << "ranksep=1" << endl; out << "ranksep=1" << std::endl;
for (size_t i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
out << nodes_[i]->varId() ; out << nodes_[i]->varId() ;
out << " [" ; out << " [" ;
@ -91,16 +93,18 @@ BayesBallGraph::exportToGraphViz (const char* fileName)
if (nodes_[i]->hasEvidence()) { if (nodes_[i]->hasEvidence()) {
out << ",style=filled, fillcolor=yellow" ; out << ",style=filled, fillcolor=yellow" ;
} }
out << "]" << endl; out << "]" << std::endl;
} }
for (size_t i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
const vector<BBNode*>& childs = nodes_[i]->childs(); const std::vector<BBNode*>& childs = nodes_[i]->childs();
for (size_t j = 0; j < childs.size(); j++) { for (size_t j = 0; j < childs.size(); j++) {
out << nodes_[i]->varId() << " -> " << childs[j]->varId(); out << nodes_[i]->varId() << " -> " << childs[j]->varId();
out << " [style=bold]" << endl ; out << " [style=bold]" << std::endl;
} }
} }
out << "}" << endl; out << "}" << std::endl;
out.close(); out.close();
} }
} // namespace Horus

View File

@ -1,5 +1,5 @@
#ifndef HORUS_BAYESBALLGRAPH_H #ifndef YAP_PACKAGES_CLPBN_HORUS_BAYESBALLGRAPH_H_
#define HORUS_BAYESBALLGRAPH_H #define YAP_PACKAGES_CLPBN_HORUS_BAYESBALLGRAPH_H_
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
@ -7,54 +7,55 @@
#include "Var.h" #include "Var.h"
#include "Horus.h" #include "Horus.h"
using namespace std;
class BBNode : public Var namespace Horus {
{
class BBNode : public Var {
public: public:
BBNode (Var* v) : Var (v), visited_(false), BBNode (Var* v) : Var (v), visited_(false),
markedOnTop_(false), markedOnBottom_(false) { } markedAbove_(false), markedBelow_(false) { }
const vector<BBNode*>& childs (void) const { return childs_; } const std::vector<BBNode*>& childs() const { return childs_; }
vector<BBNode*>& childs (void) { return childs_; } std::vector<BBNode*>& childs() { return childs_; }
const vector<BBNode*>& parents (void) const { return parents_; } const std::vector<BBNode*>& parents() const { return parents_; }
vector<BBNode*>& parents (void) { return parents_; } std::vector<BBNode*>& parents() { return parents_; }
void addParent (BBNode* p) { parents_.push_back (p); } void addParent (BBNode* p) { parents_.push_back (p); }
void addChild (BBNode* c) { childs_.push_back (c); } void addChild (BBNode* c) { childs_.push_back (c); }
bool isVisited (void) const { return visited_; } bool isVisited() const { return visited_; }
void setAsVisited (void) { visited_ = true; } void setAsVisited() { visited_ = true; }
bool isMarkedOnTop (void) const { return markedOnTop_; } bool isMarkedAbove() const { return markedAbove_; }
void markOnTop (void) { markedOnTop_ = true; } void markAbove() { markedAbove_ = true; }
bool isMarkedOnBottom (void) const { return markedOnBottom_; } bool isMarkedBelow() const { return markedBelow_; }
void markOnBottom (void) { markedOnBottom_ = true; } void markBelow() { markedBelow_ = true; }
void clear (void) { visited_ = markedOnTop_ = markedOnBottom_ = false; } void clear() { visited_ = markedAbove_ = markedBelow_ = false; }
private: private:
bool visited_; bool visited_;
bool markedOnTop_; bool markedAbove_;
bool markedOnBottom_; bool markedBelow_;
vector<BBNode*> childs_; std::vector<BBNode*> childs_;
vector<BBNode*> parents_; std::vector<BBNode*> parents_;
}; };
class BayesBallGraph class BayesBallGraph {
{
public: public:
BayesBallGraph (void) { } BayesBallGraph() { }
bool empty() const { return nodes_.empty(); }
void addNode (BBNode* n); void addNode (BBNode* n);
@ -64,19 +65,18 @@ class BayesBallGraph
BBNode* getNode (VarId vid); BBNode* getNode (VarId vid);
bool empty (void) const { return nodes_.empty(); } void setIndexes();
void setIndexes (void); void clear();
void clear (void);
void exportToGraphViz (const char*); void exportToGraphViz (const char*);
private: private:
vector<BBNode*> nodes_; std::vector<BBNode*> nodes_;
std::unordered_map<VarId, BBNode*> varMap_;
unordered_map<VarId, BBNode*> varMap_;
}; };
#endif // HORUS_BAYESBALLGRAPH_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_BAYESBALLGRAPH_H_

View File

@ -1,37 +1,39 @@
#include <cassert> #include <cassert>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <iomanip>
#include <sstream>
#include "BeliefProp.h" #include "BeliefProp.h"
#include "Indexer.h" #include "Indexer.h"
#include "Horus.h" #include "Horus.h"
namespace Horus {
double BeliefProp::accuracy_ = 0.0001; double BeliefProp::accuracy_ = 0.0001;
unsigned BeliefProp::maxIter_ = 1000; unsigned BeliefProp::maxIter_ = 1000;
MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED;
BeliefProp::MsgSchedule BeliefProp::schedule_ =
MsgSchedule::seqFixedSch;
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
BeliefProp::BeliefProp (const FactorGraph& fg)
: GroundSolver (fg), nIters_(0), runned_(false)
{ {
runned_ = false;
} }
BeliefProp::~BeliefProp (void) BeliefProp::~BeliefProp()
{ {
for (size_t i = 0; i < varsI_.size(); i++) {
delete varsI_[i];
}
for (size_t i = 0; i < facsI_.size(); i++) {
delete facsI_[i];
}
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
delete links_[i]; delete links_[i];
} }
links_.clear();
} }
@ -48,22 +50,22 @@ BeliefProp::solveQuery (VarIds queryVids)
void void
BeliefProp::printSolverFlags (void) const BeliefProp::printSolverFlags() const
{ {
stringstream ss; std::stringstream ss;
ss << "belief propagation [" ; ss << "belief propagation [" ;
ss << "bp_msg_schedule=" ; ss << "bp_msg_schedule=" ;
switch (schedule_) { switch (schedule_) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break; case MsgSchedule::parallelSch: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
} }
ss << ",bp_max_iter=" << Util::toString (maxIter_); ss << ",bp_max_iter=" << Util::toString (maxIter_);
ss << ",bp_accuracy=" << Util::toString (accuracy_); ss << ",bp_accuracy=" << Util::toString (accuracy_);
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; std::cout << ss.str() << std::endl;
} }
@ -82,7 +84,7 @@ BeliefProp::getPosterioriOf (VarId vid)
probs[var->getEvidence()] = LogAware::withEvidence(); probs[var->getEvidence()] = LogAware::withEvidence();
} else { } else {
probs.resize (var->range(), LogAware::multIdenty()); probs.resize (var->range(), LogAware::multIdenty());
const BpLinks& links = ninf(var)->getLinks(); const BpLinks& links = getLinks (var);
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
probs += links[i]->message(); probs += links[i]->message();
@ -133,7 +135,7 @@ BeliefProp::getFactorJoint (
runSolver(); runSolver();
} }
Factor res (fn->factor()); Factor res (fn->factor());
const BpLinks& links = ninf(fn)->getLinks(); const BpLinks& links = getLinks( fn);
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
Factor msg ({links[i]->varNode()->varId()}, Factor msg ({links[i]->varNode()->varId()},
{links[i]->varNode()->range()}, {links[i]->varNode()->range()},
@ -152,26 +154,119 @@ BeliefProp::getFactorJoint (
BeliefProp::BpLink::BpLink (FacNode* fn, VarNode* vn)
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
residual_ = 0.0;
}
void void
BeliefProp::runSolver (void) BeliefProp::BpLink::clearResidual()
{
residual_ = 0.0;
}
void
BeliefProp::BpLink::updateResidual()
{
residual_ = LogAware::getMaxNorm (v1_, v2_);
}
void
BeliefProp::BpLink::updateMessage()
{
swap (currMsg_, nextMsg_);
}
std::string
BeliefProp::BpLink::toString() const
{
std::stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
void
BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
{
if (Globals::verbosity > 2) {
std::cout << "calculating & updating " << link->toString();
std::cout << std::endl;
}
calcFactorToVarMsg (link);
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void
BeliefProp::calculateMessage (BpLink* link, bool calcResidual)
{
if (Globals::verbosity > 2) {
std::cout << "calculating " << link->toString();
std::cout << std::endl;
}
calcFactorToVarMsg (link);
if (calcResidual) {
link->updateResidual();
}
}
void
BeliefProp::updateMessage (BpLink* link)
{
link->updateMessage();
if (Globals::verbosity > 2) {
std::cout << "updating " << link->toString();
std::cout << std::endl;
}
}
void
BeliefProp::runSolver()
{ {
initializeSolver(); initializeSolver();
nIters_ = 0; nIters_ = 0;
while (!converged() && nIters_ < maxIter_) { while (!converged() && nIters_ < maxIter_) {
nIters_ ++; nIters_ ++;
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
Util::printHeader (string ("Iteration ") + Util::toString (nIters_)); Util::printHeader (std::string ("Iteration ")
+ Util::toString (nIters_));
} }
switch (schedule_) { switch (schedule_) {
case MsgSchedule::SEQ_RANDOM: case MsgSchedule::seqRandomSch:
std::random_shuffle (links_.begin(), links_.end()); std::random_shuffle (links_.begin(), links_.end());
// no break // no break
case MsgSchedule::SEQ_FIXED: case MsgSchedule::seqFixedSch:
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]); calculateAndUpdateMessage (links_[i]);
} }
break; break;
case MsgSchedule::PARALLEL: case MsgSchedule::parallelSch:
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]); calculateMessage (links_[i]);
} }
@ -179,20 +274,21 @@ BeliefProp::runSolver (void)
updateMessage(links_[i]); updateMessage(links_[i]);
} }
break; break;
case MsgSchedule::MAX_RESIDUAL: case MsgSchedule::maxResidualSch:
maxResidualSchedule(); maxResidualSchedule();
break; break;
} }
} }
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
if (nIters_ < maxIter_) { if (nIters_ < maxIter_) {
cout << "Belief propagation converged in " ; std::cout << "Belief propagation converged in " ;
cout << nIters_ << " iterations" << endl; std::cout << nIters_ << " iterations" << std::endl;
} else { } else {
cout << "The maximum number of iterations was hit, terminating..." ; std::cout << "The maximum number of iterations was hit," ;
cout << endl; std::cout << " terminating..." ;
std::cout << std::endl;
} }
cout << endl; std::cout << std::endl;
} }
runned_ = true; runned_ = true;
} }
@ -200,7 +296,7 @@ BeliefProp::runSolver (void)
void void
BeliefProp::createLinks (void) BeliefProp::createLinks()
{ {
const FacNodes& facNodes = fg.facNodes(); const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
@ -214,7 +310,7 @@ BeliefProp::createLinks (void)
void void
BeliefProp::maxResidualSchedule (void) BeliefProp::maxResidualSchedule()
{ {
if (nIters_ == 1) { if (nIters_ == 1) {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
@ -227,11 +323,13 @@ BeliefProp::maxResidualSchedule (void)
for (size_t c = 0; c < links_.size(); c++) { for (size_t c = 0; c < links_.size(); c++) {
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << "current residuals:" << endl; std::cout << "current residuals:" << std::endl;
for (SortedOrder::iterator it = sortedOrder_.begin(); for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); ++it) { it != sortedOrder_.end(); ++it) {
cout << " " << setw (30) << left << (*it)->toString(); std::cout << " " << std::setw (30) << std::left;
cout << "residual = " << (*it)->residual() << endl; std::cout << (*it)->toString();
std::cout << "residual = " << (*it)->residual();
std::cout << std::endl;
} }
} }
@ -249,7 +347,7 @@ BeliefProp::maxResidualSchedule (void)
const FacNodes& factorNeighbors = link->varNode()->neighbors(); const FacNodes& factorNeighbors = link->varNode()->neighbors();
for (size_t i = 0; i < factorNeighbors.size(); i++) { for (size_t i = 0; i < factorNeighbors.size(); i++) {
if (factorNeighbors[i] != link->facNode()) { if (factorNeighbors[i] != link->facNode()) {
const BpLinks& links = ninf(factorNeighbors[i])->getLinks(); const BpLinks& links = getLinks (factorNeighbors[i]);
for (size_t j = 0; j < links.size(); j++) { for (size_t j = 0; j < links.size(); j++) {
if (links[j]->varNode() != link->varNode()) { if (links[j]->varNode() != link->varNode()) {
calculateMessage (links[j]); calculateMessage (links[j]);
@ -273,7 +371,7 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
{ {
FacNode* src = link->facNode(); FacNode* src = link->facNode();
const VarNode* dst = link->varNode(); const VarNode* dst = link->varNode();
const BpLinks& links = ninf(src)->getLinks(); const BpLinks& links = getLinks (src);
// calculate the product of messages that were sent // calculate the product of messages that were sent
// to factor `src', except from var `dst' // to factor `src', except from var `dst'
unsigned reps = 1; unsigned reps = 1;
@ -282,14 +380,14 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
if (links[i]->varNode() != dst) { if (links[i]->varNode() != dst) {
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " message from " << links[i]->varNode()->label(); std::cout << " message from " << links[i]->varNode()->label();
cout << ": " ; std::cout << ": " ;
} }
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]), Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::plus<double>()); reps, std::plus<double>());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << endl; std::cout << std::endl;
} }
} }
reps *= links[i]->varNode()->range(); reps *= links[i]->varNode()->range();
@ -297,14 +395,14 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
} else { } else {
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
if (links[i]->varNode() != dst) { if (links[i]->varNode() != dst) {
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " message from " << links[i]->varNode()->label(); std::cout << " message from " << links[i]->varNode()->label();
cout << ": " ; std::cout << ": " ;
} }
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]), Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::multiplies<double>()); reps, std::multiplies<double>());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << endl; std::cout << std::endl;
} }
} }
reps *= links[i]->varNode()->range(); reps *= links[i]->varNode()->range();
@ -313,27 +411,28 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
Factor result (src->factor().arguments(), Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct); src->factor().ranges(), msgProduct);
result.multiply (src->factor()); result.multiply (src->factor());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " message product: " << msgProduct << endl; std::cout << " message product: " << msgProduct << std::endl;
cout << " original factor: " << src->factor().params() << endl; std::cout << " original factor: " << src->factor().params();
cout << " factor product: " << result.params() << endl; std::cout << std::endl;
std::cout << " factor product: " << result.params() << std::endl;
} }
result.sumOutAllExcept (dst->varId()); result.sumOutAllExcept (dst->varId());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " marginalized: " << result.params() << endl; std::cout << " marginalized: " << result.params() << std::endl;
} }
link->nextMessage() = result.params(); link->nextMessage() = result.params();
LogAware::normalize (link->nextMessage()); LogAware::normalize (link->nextMessage());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " curr msg: " << link->message() << endl; std::cout << " curr msg: " << link->message() << std::endl;
cout << " next msg: " << link->nextMessage() << endl; std::cout << " next msg: " << link->nextMessage() << std::endl;
} }
} }
Params Params
BeliefProp::getVarToFactorMsg (const BpLink* link) const BeliefProp::getVarToFactorMsg (const BpLink* link)
{ {
const VarNode* src = link->varNode(); const VarNode* src = link->varNode();
Params msg; Params msg;
@ -343,18 +442,18 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
} else { } else {
msg.resize (src->range(), LogAware::one()); msg.resize (src->range(), LogAware::one());
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << msg; std::cout << msg;
} }
BpLinks::const_iterator it; BpLinks::const_iterator it;
const BpLinks& links = ninf (src)->getLinks(); const BpLinks& links = getLinks (src);
if (Globals::logDomain) { if (Globals::logDomain) {
for (it = links.begin(); it != links.end(); ++it) { for (it = links.begin(); it != links.end(); ++it) {
if (*it != link) { if (*it != link) {
msg += (*it)->message(); msg += (*it)->message();
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " x " << (*it)->message(); std::cout << " x " << (*it)->message();
} }
} }
} else { } else {
@ -362,13 +461,13 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
if (*it != link) { if (*it != link) {
msg *= (*it)->message(); msg *= (*it)->message();
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " x " << (*it)->message(); std::cout << " x " << (*it)->message();
} }
} }
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " = " << msg; std::cout << " = " << msg;
} }
return msg; return msg;
} }
@ -379,37 +478,37 @@ Params
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
{ {
return GroundSolver::getJointByConditioning ( return GroundSolver::getJointByConditioning (
GroundSolverType::BP, fg, jointVarIds); GroundSolverType::bpSolver, fg, jointVarIds);
} }
void void
BeliefProp::initializeSolver (void) BeliefProp::initializeSolver()
{ {
const VarNodes& varNodes = fg.varNodes(); const VarNodes& varNodes = fg.varNodes();
varsI_.reserve (varNodes.size()); varsLinks_.reserve (varNodes.size());
for (size_t i = 0; i < varNodes.size(); i++) { for (size_t i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo()); varsLinks_.push_back (BpLinks());
} }
const FacNodes& facNodes = fg.facNodes(); const FacNodes& facNodes = fg.facNodes();
facsI_.reserve (facNodes.size()); facsLinks_.reserve (facNodes.size());
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo()); facsLinks_.push_back (BpLinks());
} }
createLinks(); createLinks();
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
FacNode* src = links_[i]->facNode(); FacNode* src = links_[i]->facNode();
VarNode* dst = links_[i]->varNode(); VarNode* dst = links_[i]->varNode();
ninf (dst)->addBpLink (links_[i]); getLinks (dst).push_back (links_[i]);
ninf (src)->addBpLink (links_[i]); getLinks (src).push_back (links_[i]);
} }
} }
bool bool
BeliefProp::converged (void) BeliefProp::converged()
{ {
if (links_.empty()) { if (links_.empty()) {
return true; return true;
@ -418,16 +517,16 @@ BeliefProp::converged (void)
return false; return false;
} }
if (Globals::verbosity > 2) { if (Globals::verbosity > 2) {
cout << endl; std::cout << std::endl;
} }
if (nIters_ == 1) { if (nIters_ == 1) {
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << "no residuals" << endl << endl; std::cout << "no residuals" << std::endl << std::endl;
} }
return false; return false;
} }
bool converged = true; bool converged = true;
if (schedule_ == MsgSchedule::MAX_RESIDUAL) { if (schedule_ == MsgSchedule::maxResidualSch) {
double maxResidual = (*(sortedOrder_.begin()))->residual(); double maxResidual = (*(sortedOrder_.begin()))->residual();
if (maxResidual > accuracy_) { if (maxResidual > accuracy_) {
converged = false; converged = false;
@ -438,7 +537,8 @@ BeliefProp::converged (void)
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
double residual = links_[i]->residual(); double residual = links_[i]->residual();
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << links_[i]->toString() + " residual = " << residual << endl; std::cout << links_[i]->toString() + " residual = " << residual;
std::cout << std::endl;
} }
if (residual > accuracy_) { if (residual > accuracy_) {
converged = false; converged = false;
@ -448,7 +548,7 @@ BeliefProp::converged (void)
} }
} }
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << endl; std::cout << std::endl;
} }
} }
return converged; return converged;
@ -457,8 +557,10 @@ BeliefProp::converged (void)
void void
BeliefProp::printLinkInformation (void) const BeliefProp::printLinkInformation() const
{ {
using std::cout;
using std::endl;
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
BpLink* l = links_[i]; BpLink* l = links_[i];
cout << l->toString() << ":" << endl; cout << l->toString() << ":" << endl;
@ -470,3 +572,5 @@ BeliefProp::printLinkInformation (void) const
} }
} }
} // namespace Horus

View File

@ -1,72 +1,78 @@
#ifndef HORUS_BELIEFPROP_H #ifndef YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_
#define HORUS_BELIEFPROP_H #define YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_
#include <set>
#include <vector> #include <vector>
#include <set>
#include <sstream> #include <string>
#include "GroundSolver.h" #include "GroundSolver.h"
#include "FactorGraph.h" #include "FactorGraph.h"
using namespace std; namespace Horus {
class BeliefProp : public GroundSolver {
private:
class SPNodeInfo;
enum MsgSchedule { public:
SEQ_FIXED, enum class MsgSchedule {
SEQ_RANDOM, seqFixedSch,
PARALLEL, seqRandomSch,
MAX_RESIDUAL parallelSch,
maxResidualSch
}; };
BeliefProp (const FactorGraph&);
class BpLink virtual ~BeliefProp();
{
Params solveQuery (VarIds);
virtual void printSolverFlags() const;
virtual Params getPosterioriOf (VarId);
virtual Params getJointDistributionOf (const VarIds&);
Params getFactorJoint (FacNode* fn, const VarIds&);
static double accuracy() { return accuracy_; }
static void setAccuracy (double acc) { accuracy_ = acc; }
static unsigned maxIterations() { return maxIter_; }
static void setMaxIterations (unsigned mi) { maxIter_ = mi; }
static MsgSchedule msgSchedule() { return schedule_; }
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
protected:
class BpLink {
public: public:
BpLink (FacNode* fn, VarNode* vn) BpLink (FacNode* fn, VarNode* vn);
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
residual_ = 0.0;
}
virtual ~BpLink (void) { }; virtual ~BpLink() { };
FacNode* facNode (void) const { return fac_; } FacNode* facNode() const { return fac_; }
VarNode* varNode (void) const { return var_; } VarNode* varNode() const { return var_; }
const Params& message (void) const { return *currMsg_; } const Params& message() const { return *currMsg_; }
Params& nextMessage (void) { return *nextMsg_; } Params& nextMessage() { return *nextMsg_; }
double residual (void) const { return residual_; } double residual() const { return residual_; }
void clearResidual (void) { residual_ = 0.0; } void clearResidual();
void updateResidual (void) void updateResidual();
{
residual_ = LogAware::getMaxNorm (v1_, v2_);
}
virtual void updateMessage (void) virtual void updateMessage();
{
swap (currMsg_, nextMsg_);
}
string toString (void) const std::string toString() const;
{
stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
protected: protected:
FacNode* fac_; FacNode* fac_;
@ -81,137 +87,78 @@ class BpLink
DISALLOW_COPY_AND_ASSIGN (BpLink); DISALLOW_COPY_AND_ASSIGN (BpLink);
}; };
typedef vector<BpLink*> BpLinks; struct CmpResidual {
bool operator() (const BpLink* l1, const BpLink* l2) {
return l1->residual() > l2->residual();
}};
typedef std::vector<BeliefProp::BpLink*> BpLinks;
typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
class SPNodeInfo BpLinks& getLinks (const VarNode* var);
{
public:
SPNodeInfo (void) { }
void addBpLink (BpLink* link) { links_.push_back (link); }
const BpLinks& getLinks (void) { return links_; }
private:
BpLinks links_;
DISALLOW_COPY_AND_ASSIGN (SPNodeInfo);
};
BpLinks& getLinks (const FacNode* fac);
class BeliefProp : public GroundSolver void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true);
{
public:
BeliefProp (const FactorGraph&);
virtual ~BeliefProp (void); void calculateMessage (BpLink* link, bool calcResidual = true);
Params solveQuery (VarIds); void updateMessage (BpLink* link);
virtual void printSolverFlags (void) const; void runSolver();
virtual Params getPosterioriOf (VarId); virtual void createLinks();
virtual Params getJointDistributionOf (const VarIds&); virtual void maxResidualSchedule();
Params getFactorJoint (FacNode* fn, const VarIds&);
static double accuracy (void) { return accuracy_; }
static void setAccuracy (double acc) { accuracy_ = acc; }
static unsigned maxIterations (void) { return maxIter_; }
static void setMaxIterations (unsigned mi) { maxIter_ = mi; }
static MsgSchedule msgSchedule (void) { return schedule_; }
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
protected:
SPNodeInfo* ninf (const VarNode* var) const
{
return varsI_[var->getIndex()];
}
SPNodeInfo* ninf (const FacNode* fac) const
{
return facsI_[fac->getIndex()];
}
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
{
if (Globals::verbosity > 2) {
cout << "calculating & updating " << link->toString() << endl;
}
calcFactorToVarMsg (link);
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void calculateMessage (BpLink* link, bool calcResidual = true)
{
if (Globals::verbosity > 2) {
cout << "calculating " << link->toString() << endl;
}
calcFactorToVarMsg (link);
if (calcResidual) {
link->updateResidual();
}
}
void updateMessage (BpLink* link)
{
link->updateMessage();
if (Globals::verbosity > 2) {
cout << "updating " << link->toString() << endl;
}
}
struct CompareResidual
{
inline bool operator() (const BpLink* link1, const BpLink* link2)
{
return link1->residual() > link2->residual();
}
};
void runSolver (void);
virtual void createLinks (void);
virtual void maxResidualSchedule (void);
virtual void calcFactorToVarMsg (BpLink*); virtual void calcFactorToVarMsg (BpLink*);
virtual Params getVarToFactorMsg (const BpLink*) const; virtual Params getVarToFactorMsg (const BpLink*);
virtual Params getJointByConditioning (const VarIds&) const; virtual Params getJointByConditioning (const VarIds&) const;
BpLinks links_; BpLinks links_;
unsigned nIters_; unsigned nIters_;
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_;
bool runned_; bool runned_;
typedef multiset<BpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_; SortedOrder sortedOrder_;
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
BpLinkMap linkMap_; BpLinkMap linkMap_;
static double accuracy_; static double accuracy_;
static unsigned maxIter_;
static MsgSchedule schedule_;
private: private:
void initializeSolver (void); void initializeSolver();
bool converged (void); bool converged();
virtual void printLinkInformation (void) const; virtual void printLinkInformation() const;
std::vector<BpLinks> varsLinks_;
std::vector<BpLinks> facsLinks_;
static unsigned maxIter_;
static MsgSchedule schedule_;
DISALLOW_COPY_AND_ASSIGN (BeliefProp); DISALLOW_COPY_AND_ASSIGN (BeliefProp);
}; };
#endif // HORUS_BELIEFPROP_H
inline BeliefProp::BpLinks&
BeliefProp::getLinks (const VarNode* var)
{
return varsLinks_[var->getIndex()];
}
inline BeliefProp::BpLinks&
BeliefProp::getLinks (const FacNode* fac)
{
return facsLinks_[fac->getIndex()];
}
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_

View File

@ -1,11 +1,88 @@
#include <queue> #include <queue>
#include <iostream>
#include <ostream>
#include <fstream> #include <fstream>
#include "ConstraintTree.h" #include "ConstraintTree.h"
#include "Util.h" #include "Util.h"
namespace Horus {
class CTNode {
public:
CTNode (const CTNode& n, const CTChilds& chs = CTChilds())
: symbol_(n.symbol()), childs_(chs), level_(n.level()) { }
CTNode (Symbol s, unsigned l, const CTChilds& chs = CTChilds())
: symbol_(s), childs_(chs), level_(l) { }
unsigned level() const { return level_; }
void setLevel (unsigned level) { level_ = level; }
Symbol symbol() const { return symbol_; }
void setSymbol (Symbol s) { symbol_ = s; }
CTChilds& childs() { return childs_; }
const CTChilds& childs() const { return childs_; }
size_t nrChilds() const { return childs_.size(); }
bool isRoot() const { return level_ == 0; }
bool isLeaf() const { return childs_.empty(); }
CTChilds::iterator findSymbol (Symbol symb);
void mergeSubtree (CTNode*, bool = true);
void removeChild (CTNode*);
void removeChilds();
void removeAndDeleteChild (CTNode*);
void removeAndDeleteAllChilds();
SymbolSet childSymbols() const;
static CTNode* copySubtree (const CTNode*);
static void deleteSubtree (CTNode*);
private:
void updateChildLevels (CTNode*, unsigned);
Symbol symbol_;
CTChilds childs_;
unsigned level_;
DISALLOW_ASSIGN (CTNode);
};
inline CTChilds::iterator
CTNode::findSymbol (Symbol symb)
{
CTNode tmp (symb, 0);
return childs_.find (&tmp);
}
inline bool
CmpSymbol::operator() (const CTNode* n1, const CTNode* n2) const
{
return n1->symbol() < n2->symbol();
}
void void
CTNode::mergeSubtree (CTNode* n, bool updateLevels) CTNode::mergeSubtree (CTNode* n, bool updateLevels)
{ {
@ -38,7 +115,7 @@ CTNode::removeChild (CTNode* child)
void void
CTNode::removeChilds (void) CTNode::removeChilds()
{ {
childs_.clear(); childs_.clear();
} }
@ -55,7 +132,7 @@ CTNode::removeAndDeleteChild (CTNode* child)
void void
CTNode::removeAndDeleteAllChilds (void) CTNode::removeAndDeleteAllChilds()
{ {
for (CTChilds::const_iterator chIt = childs_.begin(); for (CTChilds::const_iterator chIt = childs_.begin();
chIt != childs_.end(); ++ chIt) { chIt != childs_.end(); ++ chIt) {
@ -67,7 +144,7 @@ CTNode::removeAndDeleteAllChilds (void)
SymbolSet SymbolSet
CTNode::childSymbols (void) const CTNode::childSymbols() const
{ {
SymbolSet symbols; SymbolSet symbols;
for (CTChilds::const_iterator chIt = childs_.begin(); for (CTChilds::const_iterator chIt = childs_.begin();
@ -106,14 +183,14 @@ CTNode::copySubtree (const CTNode* root1)
return new CTNode (*root1); return new CTNode (*root1);
} }
CTNode* root2 = new CTNode (*root1); CTNode* root2 = new CTNode (*root1);
typedef pair<const CTNode*, CTNode*> StackPair; typedef std::pair<const CTNode*, CTNode*> StackPair;
vector<StackPair> stack = { StackPair (root1, root2) }; std::vector<StackPair> stack = { StackPair (root1, root2) };
while (stack.empty() == false) { while (stack.empty() == false) {
const CTNode* n1 = stack.back().first; const CTNode* n1 = stack.back().first;
CTNode* n2 = stack.back().second; CTNode* n2 = stack.back().second;
stack.pop_back(); stack.pop_back();
// cout << "n2 childs: " << n2->childs(); // std::cout << "n2 childs: " << n2->childs();
// cout << "n1 childs: " << n1->childs(); // std::cout << "n1 childs: " << n1->childs();
n2->childs().reserve (n1->nrChilds()); n2->childs().reserve (n1->nrChilds());
stack.reserve (n1->nrChilds()); stack.reserve (n1->nrChilds());
for (CTChilds::const_iterator chIt = n1->childs().begin(); for (CTChilds::const_iterator chIt = n1->childs().begin();
@ -144,7 +221,8 @@ CTNode::deleteSubtree (CTNode* n)
ostream& operator<< (ostream &out, const CTNode& n) std::ostream&
operator<< (std::ostream& out, const CTNode& n)
{ {
out << "(" << n.level() << ") " ; out << "(" << n.level() << ") " ;
out << n.symbol(); out << n.symbol();
@ -187,7 +265,8 @@ ConstraintTree::ConstraintTree (
ConstraintTree::ConstraintTree (vector<vector<string>> names) ConstraintTree::ConstraintTree (
std::vector<std::vector<std::string>> names)
{ {
assert (names.empty() == false); assert (names.empty() == false);
assert (names.front().empty() == false); assert (names.front().empty() == false);
@ -216,13 +295,33 @@ ConstraintTree::ConstraintTree (const ConstraintTree& ct)
ConstraintTree::~ConstraintTree (void) ConstraintTree::ConstraintTree (
const CTChilds& rootChilds,
const LogVars& logVars)
: root_(new CTNode (Symbol (0), unsigned (0), rootChilds)),
logVars_(logVars),
logVarSet_(logVars)
{
}
ConstraintTree::~ConstraintTree()
{ {
CTNode::deleteSubtree (root_); CTNode::deleteSubtree (root_);
} }
bool
ConstraintTree::empty() const
{
return root_->childs().empty();
}
void void
ConstraintTree::addTuple (const Tuple& tuple) ConstraintTree::addTuple (const Tuple& tuple)
{ {
@ -448,7 +547,7 @@ ConstraintTree::ConstraintTree::isSingleton (LogVar X)
LogVarSet LogVarSet
ConstraintTree::singletons (void) ConstraintTree::singletons()
{ {
LogVarSet singletons; LogVarSet singletons;
for (size_t i = 0; i < logVars_.size(); i++) { for (size_t i = 0; i < logVars_.size(); i++) {
@ -491,7 +590,7 @@ ConstraintTree::tupleSet (const LogVars& originalLvs)
getTuples (root_, Tuples(), stopLevel, tuples, CTNodes() = {}); getTuples (root_, Tuples(), stopLevel, tuples, CTNodes() = {});
if (originalLvs.size() != uniqueLvs.size()) { if (originalLvs.size() != uniqueLvs.size()) {
vector<size_t> indexes; std::vector<size_t> indexes;
indexes.reserve (originalLvs.size()); indexes.reserve (originalLvs.size());
for (size_t i = 0; i < originalLvs.size(); i++) { for (size_t i = 0; i < originalLvs.size(); i++) {
indexes.push_back (Util::indexOf (uniqueLvs, originalLvs[i])); indexes.push_back (Util::indexOf (uniqueLvs, originalLvs[i]));
@ -519,21 +618,22 @@ ConstraintTree::exportToGraphViz (
const char* fileName, const char* fileName,
bool showLogVars) const bool showLogVars) const
{ {
ofstream out (fileName); std::ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ; std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return; return;
} }
out << "digraph {" << endl; out << "digraph {" << std::endl;
ConstraintTree copy (*this); ConstraintTree copy (*this);
copy.moveToTop (copy.logVarSet_.elements()); copy.moveToTop (copy.logVarSet_.elements());
CTNodes nodes = getNodesBelow (copy.root_); CTNodes nodes = getNodesBelow (copy.root_);
out << "\"" << copy.root_ << "\"" << " [label=\"R\"]" << endl; out << "\"" << copy.root_ << "\"" << " [label=\"R\"]" << std::endl;
for (CTNodes::const_iterator it = ++ nodes.begin(); for (CTNodes::const_iterator it = ++ nodes.begin();
it != nodes.end(); ++ it) { it != nodes.end(); ++ it) {
out << "\"" << *it << "\""; out << "\"" << *it << "\"";
out << " [label=\"" << **it << "\"]" ; out << " [label=\"" << **it << "\"]" ;
out << endl; out << std::endl;
} }
for (CTNodes::const_iterator it = nodes.begin(); for (CTNodes::const_iterator it = nodes.begin();
it != nodes.end(); ++ it) { it != nodes.end(); ++ it) {
@ -542,24 +642,24 @@ ConstraintTree::exportToGraphViz (
chIt != childs.end(); ++ chIt) { chIt != childs.end(); ++ chIt) {
out << "\"" << *it << "\"" ; out << "\"" << *it << "\"" ;
out << " -> " ; out << " -> " ;
out << "\"" << *chIt << "\"" << endl ; out << "\"" << *chIt << "\"" << std::endl ;
} }
} }
if (showLogVars) { if (showLogVars) {
out << "Root [label=\"\", shape=plaintext]" << endl; out << "Root [label=\"\", shape=plaintext]" << std::endl;
for (size_t i = 0; i < copy.logVars_.size(); i++) { for (size_t i = 0; i < copy.logVars_.size(); i++) {
out << copy.logVars_[i] << " [label=" ; out << copy.logVars_[i] << " [label=" ;
out << copy.logVars_[i] << ", " ; out << copy.logVars_[i] << ", " ;
out << "shape=plaintext, fontsize=14]" << endl; out << "shape=plaintext, fontsize=14]" << std::endl;
} }
out << "Root -> " << copy.logVars_[0]; out << "Root -> " << copy.logVars_[0];
out << " [style=invis]" << endl; out << " [style=invis]" << std::endl;
for (size_t i = 0; i < copy.logVars_.size() - 1; i++) { for (size_t i = 0; i < copy.logVars_.size() - 1; i++) {
out << copy.logVars_[i] << " -> " << copy.logVars_[i + 1]; out << copy.logVars_[i] << " -> " << copy.logVars_[i + 1];
out << " [style=invis]" << endl; out << " [style=invis]" << std::endl;
} }
} }
out << "}" << endl; out << "}" <<std::endl;
out.close(); out.close();
} }
@ -690,9 +790,9 @@ ConstraintTree::split (
split (root_, ct->root(), commChilds, exclChilds, stopLevel); split (root_, ct->root(), commChilds, exclChilds, stopLevel);
ConstraintTree* commCt = new ConstraintTree (commChilds, logVars_); ConstraintTree* commCt = new ConstraintTree (commChilds, logVars_);
ConstraintTree* exclCt = new ConstraintTree (exclChilds, logVars_); ConstraintTree* exclCt = new ConstraintTree (exclChilds, logVars_);
// cout << commCt->tupleSet() << " + " ; // std::cout << commCt->tupleSet() << " + " ;
// cout << exclCt->tupleSet() << " = " ; // std::cout << exclCt->tupleSet() << " = " ;
// cout << tupleSet() << endl; // std::cout << tupleSet() << std::endl;
assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet()); assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty()); assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
return {commCt, exclCt}; return {commCt, exclCt};
@ -710,20 +810,20 @@ ConstraintTree::countNormalize (const LogVarSet& Ys)
} }
moveToTop (Zs.elements()); moveToTop (Zs.elements());
ConstraintTrees cts; ConstraintTrees cts;
unordered_map<unsigned, ConstraintTree*> countMap; std::unordered_map<unsigned, ConstraintTree*> countMap;
unsigned stopLevel = getLevel (Zs.back()); unsigned stopLevel = getLevel (Zs.back());
const CTChilds& childs = root_->childs(); const CTChilds& childs = root_->childs();
for (CTChilds::const_iterator chIt = childs.begin(); for (CTChilds::const_iterator chIt = childs.begin();
chIt != childs.end(); ++ chIt) { chIt != childs.end(); ++ chIt) {
const vector<pair<CTNode*, unsigned>>& res = const std::vector<std::pair<CTNode*, unsigned>>& res =
countNormalize (*chIt, stopLevel); countNormalize (*chIt, stopLevel);
for (size_t j = 0; j < res.size(); j++) { for (size_t j = 0; j < res.size(); j++) {
unordered_map<unsigned, ConstraintTree*>::iterator it std::unordered_map<unsigned, ConstraintTree*>::iterator it
= countMap.find (res[j].second); = countMap.find (res[j].second);
if (it == countMap.end()) { if (it == countMap.end()) {
ConstraintTree* newCt = new ConstraintTree (logVars_); ConstraintTree* newCt = new ConstraintTree (logVars_);
it = countMap.insert (make_pair (res[j].second, newCt)).first; it = countMap.insert (std::make_pair (res[j].second, newCt)).first;
cts.push_back (newCt); cts.push_back (newCt);
} }
it->second->root_->mergeSubtree (res[j].first); it->second->root_->mergeSubtree (res[j].first);
@ -743,31 +843,31 @@ ConstraintTree::jointCountNormalize (
LogVar X_new2) LogVar X_new2)
{ {
unsigned N = getConditionalCount (X); unsigned N = getConditionalCount (X);
// cout << "My tuples: " << tupleSet() << endl; // std::cout << "My tuples: " << tupleSet() << std::endl;
// cout << "CommCt tuples: " << commCt->tupleSet() << endl; // std::cout << "CommCt tuples: " << commCt->tupleSet() << std::endl;
// cout << "ExclCt tuples: " << exclCt->tupleSet() << endl; // std::cout << "ExclCt tuples: " << exclCt->tupleSet() << std::endl;
// cout << "Counted Lv: " << X << endl; // std::cout << "Counted Lv: " << X << std::endl;
// cout << "X_new1: " << X_new1 << endl; // std::cout << "X_new1: " << X_new1 << std::endl;
// cout << "X_new2: " << X_new2 << endl; // std::cout << "X_new2: " << X_new2 << std::endl;
// cout << "Original N: " << N << endl; // std::cout << "Original N: " << N << std::endl;
// cout << endl; // std::cout << endl;
ConstraintTrees normCts1 = commCt->countNormalize (X); ConstraintTrees normCts1 = commCt->countNormalize (X);
vector<unsigned> counts1 (normCts1.size()); std::vector<unsigned> counts1 (normCts1.size());
for (size_t i = 0; i < normCts1.size(); i++) { for (size_t i = 0; i < normCts1.size(); i++) {
counts1[i] = normCts1[i]->getConditionalCount (X); counts1[i] = normCts1[i]->getConditionalCount (X);
// cout << "normCts1[" << i << "] #" << counts1[i] ; // std::cout << "normCts1[" << i << "] #" << counts1[i] ;
// cout << " " << normCts1[i]->tupleSet() << endl; // std::cout << " " << normCts1[i]->tupleSet() << std::endl;
} }
ConstraintTrees normCts2 = exclCt->countNormalize (X); ConstraintTrees normCts2 = exclCt->countNormalize (X);
vector<unsigned> counts2 (normCts2.size()); std::vector<unsigned> counts2 (normCts2.size());
for (size_t i = 0; i < normCts2.size(); i++) { for (size_t i = 0; i < normCts2.size(); i++) {
counts2[i] = normCts2[i]->getConditionalCount (X); counts2[i] = normCts2[i]->getConditionalCount (X);
// cout << "normCts2[" << i << "] #" << counts2[i] ; // std::cout << "normCts2[" << i << "] #" << counts2[i] ;
// cout << " " << normCts2[i]->tupleSet() << endl; // std::cout << " " << normCts2[i]->tupleSet() << std::endl;
} }
// cout << endl; // std::cout << std::endl;
ConstraintTree* excl1 = 0; ConstraintTree* excl1 = 0;
for (size_t i = 0; i < normCts1.size(); i++) { for (size_t i = 0; i < normCts1.size(); i++) {
@ -775,7 +875,7 @@ ConstraintTree::jointCountNormalize (
excl1 = normCts1[i]; excl1 = normCts1[i];
normCts1.erase (normCts1.begin() + i); normCts1.erase (normCts1.begin() + i);
counts1.erase (counts1.begin() + i); counts1.erase (counts1.begin() + i);
// cout << "joint-count(" << N << ",0)" << endl; // std::cout << "joint-count(" << N << ",0)" << std::endl;
break; break;
} }
} }
@ -786,7 +886,7 @@ ConstraintTree::jointCountNormalize (
excl2 = normCts2[i]; excl2 = normCts2[i];
normCts2.erase (normCts2.begin() + i); normCts2.erase (normCts2.begin() + i);
counts2.erase (counts2.begin() + i); counts2.erase (counts2.begin() + i);
// cout << "joint-count(0," << N << ")" << endl; // std::cout << "joint-count(0," << N << ")" << std::endl;
break; break;
} }
} }
@ -794,8 +894,8 @@ ConstraintTree::jointCountNormalize (
for (size_t i = 0; i < normCts1.size(); i++) { for (size_t i = 0; i < normCts1.size(); i++) {
unsigned j; unsigned j;
for (j = 0; counts1[i] + counts2[j] != N; j++) ; for (j = 0; counts1[i] + counts2[j] != N; j++) ;
// cout << "joint-count(" << counts1[i] ; // std::cout << "joint-count(" << counts1[i] ;
// cout << "," << counts2[j] << ")" << endl; // std::cout << "," << counts2[j] << ")" << std::endl;
const CTChilds& childs = normCts2[j]->root_->childs(); const CTChilds& childs = normCts2[j]->root_->childs();
for (CTChilds::const_iterator chIt = childs.begin(); for (CTChilds::const_iterator chIt = childs.begin();
chIt != childs.end(); ++ chIt) { chIt != childs.end(); ++ chIt) {
@ -930,7 +1030,7 @@ CTNodes
ConstraintTree::getNodesBelow (CTNode* fromHere) const ConstraintTree::getNodesBelow (CTNode* fromHere) const
{ {
CTNodes nodes; CTNodes nodes;
queue<CTNode*> queue; std::queue<CTNode*> queue;
queue.push (fromHere); queue.push (fromHere);
while (queue.empty() == false) { while (queue.empty() == false) {
CTNode* node = queue.front(); CTNode* node = queue.front();
@ -1016,7 +1116,7 @@ ConstraintTree::swapLogVar (LogVar X)
{ {
size_t pos = Util::indexOf (logVars_, X); size_t pos = Util::indexOf (logVars_, X);
assert (pos != logVars_.size()); assert (pos != logVars_.size());
const CTNodes& nodes = getNodesAtLevel (pos); CTNodes nodes = getNodesAtLevel (pos);
for (CTNodes::const_iterator nodeIt = nodes.begin(); for (CTNodes::const_iterator nodeIt = nodes.begin();
nodeIt != nodes.end(); ++ nodeIt) { nodeIt != nodes.end(); ++ nodeIt) {
CTChilds childsCopy = (*nodeIt)->childs(); CTChilds childsCopy = (*nodeIt)->childs();
@ -1098,7 +1198,7 @@ ConstraintTree::getTuples (
unsigned unsigned
ConstraintTree::size (void) const ConstraintTree::size() const
{ {
return countTuples (root_); return countTuples (root_);
} }
@ -1114,26 +1214,26 @@ ConstraintTree::nrSymbols (LogVar X)
vector<pair<CTNode*, unsigned>> std::vector<std::pair<CTNode*, unsigned>>
ConstraintTree::countNormalize ( ConstraintTree::countNormalize (
const CTNode* n, const CTNode* n,
unsigned stopLevel) unsigned stopLevel)
{ {
if (n->level() == stopLevel) { if (n->level() == stopLevel) {
return vector<pair<CTNode*, unsigned>>() = { return std::vector<std::pair<CTNode*, unsigned>>() = {
make_pair (CTNode::copySubtree (n), countTuples (n)) std::make_pair (CTNode::copySubtree (n), countTuples (n))
}; };
} }
vector<pair<CTNode*, unsigned>> res; std::vector<std::pair<CTNode*, unsigned>> res;
const CTChilds& childs = n->childs(); const CTChilds& childs = n->childs();
for (CTChilds::const_iterator chIt = childs.begin(); for (CTChilds::const_iterator chIt = childs.begin();
chIt != childs.end(); ++ chIt) { chIt != childs.end(); ++ chIt) {
const vector<pair<CTNode*, unsigned>>& lowerRes = const std::vector<std::pair<CTNode*, unsigned>>& lowerRes =
countNormalize (*chIt, stopLevel); countNormalize (*chIt, stopLevel);
for (size_t j = 0; j < lowerRes.size(); j++) { for (size_t j = 0; j < lowerRes.size(); j++) {
CTNode* newNode = new CTNode (*n); CTNode* newNode = new CTNode (*n);
newNode->mergeSubtree (lowerRes[j].first); newNode->mergeSubtree (lowerRes[j].first);
res.push_back (make_pair (newNode, lowerRes[j].second)); res.push_back (std::make_pair (newNode, lowerRes[j].second));
} }
} }
return res; return res;
@ -1172,3 +1272,5 @@ ConstraintTree::split (
} }
} }
} // namespace Horus

View File

@ -1,104 +1,35 @@
#ifndef HORUS_CONSTRAINTTREE_H #ifndef YAP_PACKAGES_CLPBN_HORUS_CONSTRAINTTREE_H_
#define HORUS_CONSTRAINTTREE_H #define YAP_PACKAGES_CLPBN_HORUS_CONSTRAINTTREE_H_
#include <cassert> #include <cassert>
#include <algorithm>
#include <iostream> #include <vector>
#include <sstream> #include <algorithm>
#include <string>
#include "TinySet.h" #include "TinySet.h"
#include "LiftedUtils.h" #include "LiftedUtils.h"
using namespace std;
namespace Horus {
class CTNode; class CTNode;
typedef vector<CTNode*> CTNodes;
class ConstraintTree; class ConstraintTree;
typedef vector<ConstraintTree*> ConstraintTrees;
class CTNode typedef std::vector<CTNode*> CTNodes;
{ typedef std::vector<ConstraintTree*> ConstraintTrees;
public:
struct CompareSymbol
{ struct CmpSymbol {
bool operator() (const CTNode* n1, const CTNode* n2) const bool operator() (const CTNode* n1, const CTNode* n2) const;
{
return n1->symbol() < n2->symbol();
}
}; };
private:
typedef TinySet<CTNode*, CompareSymbol> CTChilds_;
public: typedef TinySet<CTNode*, CmpSymbol> CTChilds;
CTNode (const CTNode& n, const CTChilds_& chs = CTChilds_())
: symbol_(n.symbol()), childs_(chs), level_(n.level()) { }
CTNode (Symbol s, unsigned l, const CTChilds_& chs = CTChilds_())
: symbol_(s), childs_(chs), level_(l) { }
unsigned level (void) const { return level_; }
void setLevel (unsigned level) { level_ = level; }
Symbol symbol (void) const { return symbol_; }
void setSymbol (const Symbol s) { symbol_ = s; }
CTChilds_& childs (void) { return childs_; }
const CTChilds_& childs (void) const { return childs_; }
size_t nrChilds (void) const { return childs_.size(); }
bool isRoot (void) const { return level_ == 0; }
bool isLeaf (void) const { return childs_.empty(); }
CTChilds_::iterator findSymbol (Symbol symb)
{
CTNode tmp (symb, 0);
return childs_.find (&tmp);
}
void mergeSubtree (CTNode*, bool = true);
void removeChild (CTNode*);
void removeChilds (void);
void removeAndDeleteChild (CTNode*);
void removeAndDeleteAllChilds (void);
SymbolSet childSymbols (void) const;
static CTNode* copySubtree (const CTNode*);
static void deleteSubtree (CTNode*);
private:
void updateChildLevels (CTNode*, unsigned);
Symbol symbol_;
CTChilds_ childs_;
unsigned level_;
DISALLOW_ASSIGN (CTNode);
};
ostream& operator<< (ostream &out, const CTNode&);
typedef TinySet<CTNode*, CTNode::CompareSymbol> CTChilds; class ConstraintTree {
class ConstraintTree
{
public: public:
ConstraintTree (unsigned); ConstraintTree (unsigned);
@ -106,38 +37,23 @@ class ConstraintTree
ConstraintTree (const LogVars&, const Tuples&); ConstraintTree (const LogVars&, const Tuples&);
ConstraintTree (vector<vector<string>> names); ConstraintTree (std::vector<std::vector<std::string>> names);
ConstraintTree (const ConstraintTree&); ConstraintTree (const ConstraintTree&);
ConstraintTree (const CTChilds& rootChilds, const LogVars& logVars) ConstraintTree (const CTChilds& rootChilds, const LogVars& logVars);
: root_(new CTNode (0, 0, rootChilds)),
logVars_(logVars),
logVarSet_(logVars) { }
~ConstraintTree (void); ~ConstraintTree();
CTNode* root (void) const { return root_; } CTNode* root() const { return root_; }
bool empty (void) const { return root_->childs().empty(); } bool empty() const;
const LogVars& logVars (void) const const LogVars& logVars() const;
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVars_;
}
const LogVarSet& logVarSet (void) const const LogVarSet& logVarSet() const;
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVarSet_;
}
size_t nrLogVars (void) const size_t nrLogVars() const;
{
return logVars_.size();
assert (LogVarSet (logVars_) == logVarSet_);
}
void addTuple (const Tuple&); void addTuple (const Tuple&);
@ -163,13 +79,13 @@ class ConstraintTree
bool isSingleton (LogVar); bool isSingleton (LogVar);
LogVarSet singletons (void); LogVarSet singletons();
TupleSet tupleSet (unsigned = 0) const; TupleSet tupleSet (unsigned = 0) const;
TupleSet tupleSet (const LogVars&); TupleSet tupleSet (const LogVars&);
unsigned size (void) const; unsigned size() const;
unsigned nrSymbols (LogVar); unsigned nrSymbols (LogVar);
@ -218,11 +134,10 @@ class ConstraintTree
void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const; void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const;
vector<std::pair<CTNode*, unsigned>> countNormalize ( std::vector<std::pair<CTNode*, unsigned>> countNormalize (
const CTNode*, unsigned); const CTNode*, unsigned);
static void split ( static void split (CTNode*, CTNode*, CTChilds&, CTChilds&, unsigned);
CTNode*, CTNode*, CTChilds&, CTChilds&, unsigned);
CTNode* root_; CTNode* root_;
LogVars logVars_; LogVars logVars_;
@ -230,5 +145,33 @@ class ConstraintTree
}; };
#endif // HORUS_CONSTRAINTTREE_H
inline const LogVars&
ConstraintTree::logVars() const
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVars_;
}
inline const LogVarSet&
ConstraintTree::logVarSet() const
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVarSet_;
}
inline size_t
ConstraintTree::nrLogVars() const
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVars_.size();
}
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_CONSTRAINTTREE_H_

View File

@ -1,7 +1,62 @@
#include <cassert>
#include <iostream>
#include <sstream>
#include "CountingBp.h" #include "CountingBp.h"
#include "WeightedBp.h" #include "WeightedBp.h"
namespace Horus {
class VarCluster {
public:
VarCluster (const VarNodes& vs) : members_(vs) { }
const VarNode* first() const { return members_.front(); }
const VarNodes& members() const { return members_; }
VarNode* representative() const { return repr_; }
void setRepresentative (VarNode* vn) { repr_ = vn; }
private:
VarNodes members_;
VarNode* repr_;
DISALLOW_COPY_AND_ASSIGN (VarCluster);
};
class FacCluster {
private:
typedef std::vector<VarCluster*> VarClusters;
public:
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
: members_(fcs), varClusters_(vcs) { }
const FacNode* first() const { return members_.front(); }
const FacNodes& members() const { return members_; }
FacNode* representative() const { return repr_; }
void setRepresentative (FacNode* fn) { repr_ = fn; }
VarClusters& varClusters() { return varClusters_; }
FacNodes members_;
FacNode* repr_;
VarClusters varClusters_;
DISALLOW_COPY_AND_ASSIGN (FacCluster);
};
bool CountingBp::fif_ = true; bool CountingBp::fif_ = true;
@ -17,7 +72,7 @@ CountingBp::CountingBp (const FactorGraph& fg)
CountingBp::~CountingBp (void) CountingBp::~CountingBp()
{ {
delete solver_; delete solver_;
delete compressedFg_; delete compressedFg_;
@ -32,23 +87,24 @@ CountingBp::~CountingBp (void)
void void
CountingBp::printSolverFlags (void) const CountingBp::printSolverFlags() const
{ {
stringstream ss; std::stringstream ss;
ss << "counting bp [" ; ss << "counting bp [" ;
ss << "bp_msg_schedule=" ; ss << "bp_msg_schedule=" ;
typedef WeightedBp::MsgSchedule MsgSchedule;
switch (WeightedBp::msgSchedule()) { switch (WeightedBp::msgSchedule()) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break; case MsgSchedule::parallelSch: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
} }
ss << ",bp_max_iter=" << WeightedBp::maxIterations(); ss << ",bp_max_iter=" << WeightedBp::maxIterations();
ss << ",bp_accuracy=" << WeightedBp::accuracy(); ss << ",bp_accuracy=" << WeightedBp::accuracy();
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",fif=" << Util::toString (CountingBp::fif_); ss << ",fif=" << Util::toString (CountingBp::fif_);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; std::cout << ss.str() << std::endl;
} }
@ -69,11 +125,10 @@ CountingBp::solveQuery (VarIds queryVids)
idx = i; idx = i;
break; break;
} }
cout << endl;
} }
if (idx == facNodes.size()) { if (idx == facNodes.size()) {
res = GroundSolver::getJointByConditioning ( res = GroundSolver::getJointByConditioning (
GroundSolverType::CBP, fg, queryVids); GroundSolverType::CbpSolver, fg, queryVids);
} else { } else {
VarIds reprArgs; VarIds reprArgs;
for (size_t i = 0; i < queryVids.size(); i++) { for (size_t i = 0; i < queryVids.size(); i++) {
@ -124,7 +179,7 @@ CountingBp::findIdenticalFactors()
void void
CountingBp::setInitialColors (void) CountingBp::setInitialColors()
{ {
varColors_.resize (fg.nrVarNodes()); varColors_.resize (fg.nrVarNodes());
facColors_.resize (fg.nrFacNodes()); facColors_.resize (fg.nrFacNodes());
@ -135,7 +190,7 @@ CountingBp::setInitialColors (void)
unsigned range = varNodes[i]->range(); unsigned range = varNodes[i]->range();
VarColorMap::iterator it = colorMap.find (range); VarColorMap::iterator it = colorMap.find (range);
if (it == colorMap.end()) { if (it == colorMap.end()) {
it = colorMap.insert (make_pair ( it = colorMap.insert (std::make_pair (
range, Colors (range + 1, -1))).first; range, Colors (range + 1, -1))).first;
} }
unsigned idx = varNodes[i]->hasEvidence() unsigned idx = varNodes[i]->hasEvidence()
@ -154,7 +209,8 @@ CountingBp::setInitialColors (void)
unsigned distId = facNodes[i]->factor().distId(); unsigned distId = facNodes[i]->factor().distId();
DistColorMap::iterator it = distColors.find (distId); DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) { if (it == distColors.end()) {
it = distColors.insert (make_pair (distId, getNewColor())).first; it = distColors.insert (std::make_pair (
distId, getNewColor())).first;
} }
setColor (facNodes[i], it->second); setColor (facNodes[i], it->second);
} }
@ -163,7 +219,7 @@ CountingBp::setInitialColors (void)
void void
CountingBp::createGroups (void) CountingBp::createGroups()
{ {
VarSignMap varGroups; VarSignMap varGroups;
FacSignMap facGroups; FacSignMap facGroups;
@ -179,10 +235,11 @@ CountingBp::createGroups (void)
size_t prevVarGroupsSize = varGroups.size(); size_t prevVarGroupsSize = varGroups.size();
varGroups.clear(); varGroups.clear();
for (size_t i = 0; i < varNodes.size(); i++) { for (size_t i = 0; i < varNodes.size(); i++) {
const VarSignature& signature = getSignature (varNodes[i]); VarSignature signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature); VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) { if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signature, VarNodes())).first; it = varGroups.insert (std::make_pair (
signature, VarNodes())).first;
} }
it->second.push_back (varNodes[i]); it->second.push_back (varNodes[i]);
} }
@ -199,10 +256,11 @@ CountingBp::createGroups (void)
facGroups.clear(); facGroups.clear();
// set a new color to the factors with the same signature // set a new color to the factors with the same signature
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
const FacSignature& signature = getSignature (facNodes[i]); FacSignature signature = getSignature (facNodes[i]);
FacSignMap::iterator it = facGroups.find (signature); FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) { if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first; it = facGroups.insert (std::make_pair (
signature, FacNodes())).first;
} }
it->second.push_back (facNodes[i]); it->second.push_back (facNodes[i]);
} }
@ -235,7 +293,8 @@ CountingBp::createClusters (
const VarNodes& groupVars = it->second; const VarNodes& groupVars = it->second;
VarCluster* vc = new VarCluster (groupVars); VarCluster* vc = new VarCluster (groupVars);
for (size_t i = 0; i < groupVars.size(); i++) { for (size_t i = 0; i < groupVars.size(); i++) {
varClusterMap_.insert (make_pair (groupVars[i]->varId(), vc)); varClusterMap_.insert (std::make_pair (
groupVars[i]->varId(), vc));
} }
varClusters_.push_back (vc); varClusters_.push_back (vc);
} }
@ -257,29 +316,29 @@ CountingBp::createClusters (
VarSignature CountingBp::VarSignature
CountingBp::getSignature (const VarNode* varNode) CountingBp::getSignature (const VarNode* varNode)
{ {
const FacNodes& neighs = varNode->neighbors();
VarSignature sign; VarSignature sign;
const FacNodes& neighs = varNode->neighbors();
sign.reserve (neighs.size() + 1); sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) { for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (make_pair ( sign.push_back (std::make_pair (
getColor (neighs[i]), getColor (neighs[i]),
neighs[i]->factor().indexOf (varNode->varId()))); neighs[i]->factor().indexOf (varNode->varId())));
} }
std::sort (sign.begin(), sign.end()); std::sort (sign.begin(), sign.end());
sign.push_back (make_pair (getColor (varNode), 0)); sign.push_back (std::make_pair (getColor (varNode), 0));
return sign; return sign;
} }
FacSignature CountingBp::FacSignature
CountingBp::getSignature (const FacNode* facNode) CountingBp::getSignature (const FacNode* facNode)
{ {
const VarNodes& neighs = facNode->neighbors();
FacSignature sign; FacSignature sign;
const VarNodes& neighs = facNode->neighbors();
sign.reserve (neighs.size() + 1); sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) { for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (getColor (neighs[i])); sign.push_back (getColor (neighs[i]));
@ -314,7 +373,7 @@ CountingBp::getRepresentative (FacNode* fn)
FactorGraph* FactorGraph*
CountingBp::getCompressedFactorGraph (void) CountingBp::getCompressedFactorGraph()
{ {
FactorGraph* fg = new FactorGraph(); FactorGraph* fg = new FactorGraph();
for (size_t i = 0; i < varClusters_.size(); i++) { for (size_t i = 0; i < varClusters_.size(); i++) {
@ -342,10 +401,10 @@ CountingBp::getCompressedFactorGraph (void)
vector<vector<unsigned>> std::vector<std::vector<unsigned>>
CountingBp::getWeights (void) const CountingBp::getWeights() const
{ {
vector<vector<unsigned>> weights; std::vector<std::vector<unsigned>> weights;
weights.reserve (facClusters_.size()); weights.reserve (facClusters_.size());
for (size_t i = 0; i < facClusters_.size(); i++) { for (size_t i = 0; i < facClusters_.size(); i++) {
const VarClusters& neighs = facClusters_[i]->varClusters(); const VarClusters& neighs = facClusters_[i]->varClusters();
@ -390,32 +449,34 @@ CountingBp::printGroups (
const FacSignMap& facGroups) const const FacSignMap& facGroups) const
{ {
unsigned count = 1; unsigned count = 1;
cout << "variable groups:" << endl; std::cout << "variable groups:" << std::endl;
for (VarSignMap::const_iterator it = varGroups.begin(); for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) { it != varGroups.end(); ++it) {
const VarNodes& groupMembers = it->second; const VarNodes& groupMembers = it->second;
if (groupMembers.size() > 0) { if (groupMembers.size() > 0) {
cout << count << ": " ; std::cout << count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) { for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->label() << " " ; std::cout << groupMembers[i]->label() << " " ;
} }
count ++; count ++;
cout << endl; std::cout << std::endl;
} }
} }
count = 1; count = 1;
cout << endl << "factor groups:" << endl; std::cout << std::endl << "factor groups:" << std::endl;
for (FacSignMap::const_iterator it = facGroups.begin(); for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) { it != facGroups.end(); ++it) {
const FacNodes& groupMembers = it->second; const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) { if (groupMembers.size() > 0) {
cout << ++count << ": " ; std::cout << ++count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) { for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->getLabel() << " " ; std::cout << groupMembers[i]->getLabel() << " " ;
} }
count ++; count ++;
cout << endl; std::cout << std::endl;
} }
} }
} }
} // namespace Horus

View File

@ -1,155 +1,99 @@
#ifndef HORUS_COUNTINGBP_H #ifndef YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_
#define HORUS_COUNTINGBP_H #define YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_
#include <vector>
#include <unordered_map> #include <unordered_map>
#include "GroundSolver.h" #include "GroundSolver.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Horus.h" #include "Horus.h"
namespace Horus {
class VarCluster; class VarCluster;
class FacCluster; class FacCluster;
class WeightedBp; class WeightedBp;
typedef long Color;
typedef vector<Color> Colors;
typedef vector<std::pair<Color,unsigned>> VarSignature;
typedef vector<Color> FacSignature;
typedef unordered_map<unsigned, Color> DistColorMap; template <class T> inline size_t
typedef unordered_map<unsigned, Colors> VarColorMap; hash_combine (size_t seed, const T& v)
typedef unordered_map<VarSignature, VarNodes> VarSignMap;
typedef unordered_map<FacSignature, FacNodes> FacSignMap;
typedef unordered_map<VarId, VarCluster*> VarClusterMap;
typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters;
template <class T>
inline size_t hash_combine (size_t seed, const T& v)
{ {
return seed ^ (hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2)); return seed ^ (std::hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
} }
} // namespace Horus
namespace std { namespace std {
template <typename T1, typename T2> struct hash<std::pair<T1,T2>>
{ template <typename T1, typename T2> struct hash<std::pair<T1,T2>> {
size_t operator() (const std::pair<T1,T2>& p) const size_t operator() (const std::pair<T1,T2>& p) const {
{ return Horus::hash_combine (std::hash<T1>()(p.first), p.second);
return hash_combine (std::hash<T1>()(p.first), p.second); }};
}
};
template <typename T> struct hash<std::vector<T>> template <typename T> struct hash<std::vector<T>>
{ {
size_t operator() (const std::vector<T>& vec) const size_t operator() (const std::vector<T>& vec) const
{ {
size_t h = 0; size_t h = 0;
typename vector<T>::const_iterator first = vec.begin(); typename std::vector<T>::const_iterator first = vec.begin();
typename vector<T>::const_iterator last = vec.end(); typename std::vector<T>::const_iterator last = vec.end();
for (; first != last; ++first) { for (; first != last; ++first) {
h = hash_combine (h, *first); h = Horus::hash_combine (h, *first);
} }
return h; return h;
} }
}; };
}
} // namespace std
class VarCluster namespace Horus {
{
public:
VarCluster (const VarNodes& vs) : members_(vs) { }
const VarNode* first (void) const { return members_.front(); } class CountingBp : public GroundSolver {
const VarNodes& members (void) const { return members_; }
VarNode* representative (void) const { return repr_; }
void setRepresentative (VarNode* vn) { repr_ = vn; }
private:
VarNodes members_;
VarNode* repr_;
DISALLOW_COPY_AND_ASSIGN (VarCluster);
};
class FacCluster
{
public:
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
: members_(fcs), varClusters_(vcs) { }
const FacNode* first (void) const { return members_.front(); }
const FacNodes& members (void) const { return members_; }
FacNode* representative (void) const { return repr_; }
void setRepresentative (FacNode* fn) { repr_ = fn; }
VarClusters& varClusters (void) { return varClusters_; }
private:
FacNodes members_;
FacNode* repr_;
VarClusters varClusters_;
DISALLOW_COPY_AND_ASSIGN (FacCluster);
};
class CountingBp : public GroundSolver
{
public: public:
CountingBp (const FactorGraph& fg); CountingBp (const FactorGraph& fg);
~CountingBp (void); ~CountingBp();
void printSolverFlags (void) const; void printSolverFlags() const;
Params solveQuery (VarIds); Params solveQuery (VarIds);
static void setFindIdenticalFactorsFlag (bool fif) { fif_ = fif; } static void setFindIdenticalFactorsFlag (bool fif) { fif_ = fif; }
private: private:
Color getNewColor (void) typedef long Color;
{ typedef std::vector<Color> Colors;
++ freeColor_;
return freeColor_ - 1;
}
Color getColor (const VarNode* vn) const typedef std::vector<std::pair<Color,unsigned>> VarSignature;
{ typedef std::vector<Color> FacSignature;
return varColors_[vn->getIndex()];
}
Color getColor (const FacNode* fn) const typedef std::vector<VarCluster*> VarClusters;
{ typedef std::vector<FacCluster*> FacClusters;
return facColors_[fn->getIndex()];
}
void setColor (const VarNode* vn, Color c) typedef std::unordered_map<unsigned, Color> DistColorMap;
{ typedef std::unordered_map<unsigned, Colors> VarColorMap;
varColors_[vn->getIndex()] = c; typedef std::unordered_map<VarSignature, VarNodes> VarSignMap;
} typedef std::unordered_map<FacSignature, FacNodes> FacSignMap;
typedef std::unordered_map<VarId, VarCluster*> VarClusterMap;
void setColor (const FacNode* fn, Color c) Color getNewColor();
{
facColors_[fn->getIndex()] = c;
}
void findIdenticalFactors (void); Color getColor (const VarNode* vn) const;
void setInitialColors (void); Color getColor (const FacNode* fn) const;
void createGroups (void); void setColor (const VarNode* vn, Color c);
void setColor (const FacNode* fn, Color c);
void findIdenticalFactors();
void setInitialColors();
void createGroups();
void createClusters (const VarSignMap&, const FacSignMap&); void createClusters (const VarSignMap&, const FacSignMap&);
@ -163,12 +107,12 @@ class CountingBp : public GroundSolver
FacNode* getRepresentative (FacNode*); FacNode* getRepresentative (FacNode*);
FactorGraph* getCompressedFactorGraph (void); FactorGraph* getCompressedFactorGraph();
vector<vector<unsigned>> getWeights (void) const; std::vector<std::vector<unsigned>> getWeights() const;
unsigned getWeight (const FacCluster*, unsigned getWeight (const FacCluster*, const VarCluster*,
const VarCluster*, size_t index) const; size_t index) const;
Color freeColor_; Color freeColor_;
Colors varColors_; Colors varColors_;
@ -184,5 +128,48 @@ class CountingBp : public GroundSolver
DISALLOW_COPY_AND_ASSIGN (CountingBp); DISALLOW_COPY_AND_ASSIGN (CountingBp);
}; };
#endif // HORUS_COUNTINGBP_H
inline CountingBp::Color
CountingBp::getNewColor()
{
++ freeColor_;
return freeColor_ - 1;
}
inline CountingBp::Color
CountingBp::getColor (const VarNode* vn) const
{
return varColors_[vn->getIndex()];
}
inline CountingBp::Color
CountingBp::getColor (const FacNode* fn) const
{
return facColors_[fn->getIndex()];
}
inline void
CountingBp::setColor (const VarNode* vn, CountingBp::Color c)
{
varColors_[vn->getIndex()] = c;
}
inline void
CountingBp::setColor (const FacNode* fn, Color c)
{
facColors_[fn->getIndex()] = c;
}
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_

View File

@ -1,25 +1,30 @@
#include <iostream>
#include <fstream> #include <fstream>
#include "ElimGraph.h" #include "ElimGraph.h"
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
namespace Horus {
ElimGraph::ElimHeuristic ElimGraph::elimHeuristic_ =
ElimHeuristic::minNeighborsEh;
ElimGraph::ElimGraph (const vector<Factor*>& factors) ElimGraph::ElimGraph (const std::vector<Factor*>& factors)
{ {
for (size_t i = 0; i < factors.size(); i++) { for (size_t i = 0; i < factors.size(); i++) {
if (factors[i]) { if (factors[i]) {
const VarIds& args = factors[i]->arguments(); const VarIds& args = factors[i]->arguments();
for (size_t j = 0; j < args.size() - 1; j++) { for (size_t j = 0; j < args.size() - 1; j++) {
EgNode* n1 = getEgNode (args[j]); EGNode* n1 = getEGNode (args[j]);
if (!n1) { if (!n1) {
n1 = new EgNode (args[j], factors[i]->range (j)); n1 = new EGNode (args[j], factors[i]->range (j));
addNode (n1); addNode (n1);
} }
for (size_t k = j + 1; k < args.size(); k++) { for (size_t k = j + 1; k < args.size(); k++) {
EgNode* n2 = getEgNode (args[k]); EGNode* n2 = getEGNode (args[k]);
if (!n2) { if (!n2) {
n2 = new EgNode (args[k], factors[i]->range (k)); n2 = new EGNode (args[k], factors[i]->range (k));
addNode (n2); addNode (n2);
} }
if (!neighbors (n1, n2)) { if (!neighbors (n1, n2)) {
@ -27,8 +32,8 @@ ElimGraph::ElimGraph (const vector<Factor*>& factors)
} }
} }
} }
if (args.size() == 1 && !getEgNode (args[0])) { if (args.size() == 1 && !getEGNode (args[0])) {
addNode (new EgNode (args[0], factors[i]->range (0))); addNode (new EGNode (args[0], factors[i]->range (0)));
} }
} }
} }
@ -36,7 +41,7 @@ ElimGraph::ElimGraph (const vector<Factor*>& factors)
ElimGraph::~ElimGraph (void) ElimGraph::~ElimGraph()
{ {
for (size_t i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
delete nodes_[i]; delete nodes_[i];
@ -57,7 +62,7 @@ ElimGraph::getEliminatingOrder (const VarIds& excludedVids)
} }
size_t nrVarsToEliminate = nodes_.size() - excludedVids.size(); size_t nrVarsToEliminate = nodes_.size() - excludedVids.size();
for (size_t i = 0; i < nrVarsToEliminate; i++) { for (size_t i = 0; i < nrVarsToEliminate; i++) {
EgNode* node = getLowestCostNode(); EGNode* node = getLowestCostNode();
unmarked_.remove (node); unmarked_.remove (node);
const EGNeighs& neighs = node->neighbors(); const EGNeighs& neighs = node->neighbors();
for (size_t j = 0; j < neighs.size(); j++) { for (size_t j = 0; j < neighs.size(); j++) {
@ -72,15 +77,15 @@ ElimGraph::getEliminatingOrder (const VarIds& excludedVids)
void void
ElimGraph::print (void) const ElimGraph::print() const
{ {
for (size_t i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
cout << "node " << nodes_[i]->label() << " neighs:" ; std::cout << "node " << nodes_[i]->label() << " neighs:" ;
EGNeighs neighs = nodes_[i]->neighbors(); EGNeighs neighs = nodes_[i]->neighbors();
for (size_t j = 0; j < neighs.size(); j++) { for (size_t j = 0; j < neighs.size(); j++) {
cout << " " << neighs[j]->label(); std::cout << " " << neighs[j]->label();
} }
cout << endl; std::cout << std::endl;
} }
} }
@ -92,25 +97,27 @@ ElimGraph::exportToGraphViz (
bool showNeighborless, bool showNeighborless,
const VarIds& highlightVarIds) const const VarIds& highlightVarIds) const
{ {
ofstream out (fileName); std::ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ; std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return; return;
} }
out << "strict graph {" << endl; out << "strict graph {" << std::endl;
for (size_t i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->neighbors().empty() == false) { if (showNeighborless || nodes_[i]->neighbors().empty() == false) {
out << '"' << nodes_[i]->label() << '"' << endl; out << '"' << nodes_[i]->label() << '"' << std::endl;
} }
} }
for (size_t i = 0; i < highlightVarIds.size(); i++) { for (size_t i = 0; i < highlightVarIds.size(); i++) {
EgNode* node =getEgNode (highlightVarIds[i]); EGNode* node =getEGNode (highlightVarIds[i]);
if (node) { if (node) {
out << '"' << node->label() << '"' ; out << '"' << node->label() << '"' ;
out << " [shape=box3d]" << endl; out << " [shape=box3d]" << std::endl;
} else { } else {
cerr << "Error: invalid variable id: " << highlightVarIds[i] << "." ; std::cerr << "Error: invalid variable id: " ;
cerr << endl; std::cerr << highlightVarIds[i] << "." ;
std::cerr << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
} }
@ -118,10 +125,10 @@ ElimGraph::exportToGraphViz (
EGNeighs neighs = nodes_[i]->neighbors(); EGNeighs neighs = nodes_[i]->neighbors();
for (size_t j = 0; j < neighs.size(); j++) { for (size_t j = 0; j < neighs.size(); j++) {
out << '"' << nodes_[i]->label() << '"' << " -- " ; out << '"' << nodes_[i]->label() << '"' << " -- " ;
out << '"' << neighs[j]->label() << '"' << endl; out << '"' << neighs[j]->label() << '"' << std::endl;
} }
} }
out << "}" << endl; out << "}" << std::endl;
out.close(); out.close();
} }
@ -132,7 +139,7 @@ ElimGraph::getEliminationOrder (
const Factors& factors, const Factors& factors,
VarIds excludedVids) VarIds excludedVids)
{ {
if (elimHeuristic_ == ElimHeuristic::SEQUENTIAL) { if (elimHeuristic_ == ElimHeuristic::sequentialEh) {
VarIds allVids; VarIds allVids;
Factors::const_iterator first = factors.begin(); Factors::const_iterator first = factors.begin();
Factors::const_iterator end = factors.end(); Factors::const_iterator end = factors.end();
@ -150,33 +157,33 @@ ElimGraph::getEliminationOrder (
void void
ElimGraph::addNode (EgNode* n) ElimGraph::addNode (EGNode* n)
{ {
nodes_.push_back (n); nodes_.push_back (n);
n->setIndex (nodes_.size() - 1); n->setIndex (nodes_.size() - 1);
varMap_.insert (make_pair (n->varId(), n)); varMap_.insert (std::make_pair (n->varId(), n));
} }
EgNode* ElimGraph::EGNode*
ElimGraph::getEgNode (VarId vid) const ElimGraph::getEGNode (VarId vid) const
{ {
unordered_map<VarId, EgNode*>::const_iterator it; std::unordered_map<VarId, EGNode*>::const_iterator it;
it = varMap_.find (vid); it = varMap_.find (vid);
return (it != varMap_.end()) ? it->second : 0; return (it != varMap_.end()) ? it->second : 0;
} }
EgNode* ElimGraph::EGNode*
ElimGraph::getLowestCostNode (void) const ElimGraph::getLowestCostNode() const
{ {
EgNode* bestNode = 0; EGNode* bestNode = 0;
unsigned minCost = Util::maxUnsigned(); unsigned minCost = Util::maxUnsigned();
EGNeighs::const_iterator it; EGNeighs::const_iterator it;
switch (elimHeuristic_) { switch (elimHeuristic_) {
case MIN_NEIGHBORS: { case ElimHeuristic::minNeighborsEh: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) { for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getNeighborsCost (*it); unsigned cost = getNeighborsCost (*it);
if (cost < minCost) { if (cost < minCost) {
@ -185,7 +192,7 @@ ElimGraph::getLowestCostNode (void) const
} }
}} }}
break; break;
case MIN_WEIGHT: { case ElimHeuristic::minWeightEh: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) { for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getWeightCost (*it); unsigned cost = getWeightCost (*it);
if (cost < minCost) { if (cost < minCost) {
@ -194,7 +201,7 @@ ElimGraph::getLowestCostNode (void) const
} }
}} }}
break; break;
case MIN_FILL: { case ElimHeuristic::minFillEh: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) { for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getFillCost (*it); unsigned cost = getFillCost (*it);
if (cost < minCost) { if (cost < minCost) {
@ -203,7 +210,7 @@ ElimGraph::getLowestCostNode (void) const
} }
}} }}
break; break;
case WEIGHTED_MIN_FILL: { case ElimHeuristic::weightedMinFillEh: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) { for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getWeightedFillCost (*it); unsigned cost = getWeightedFillCost (*it);
if (cost < minCost) { if (cost < minCost) {
@ -222,7 +229,7 @@ ElimGraph::getLowestCostNode (void) const
void void
ElimGraph::connectAllNeighbors (const EgNode* n) ElimGraph::connectAllNeighbors (const EGNode* n)
{ {
const EGNeighs& neighs = n->neighbors(); const EGNeighs& neighs = n->neighbors();
if (neighs.size() > 0) { if (neighs.size() > 0) {
@ -236,3 +243,5 @@ ElimGraph::connectAllNeighbors (const EgNode* n)
} }
} }
} // namespace Horus

View File

@ -1,81 +1,121 @@
#ifndef HORUS_ELIMGRAPH_H #ifndef YAP_PACKAGES_CLPBN_HORUS_ELIMGRAPH_H_
#define HORUS_ELIMGRAPH_H #define YAP_PACKAGES_CLPBN_HORUS_ELIMGRAPH_H_
#include "unordered_map" #include <cassert>
#include <vector>
#include <unordered_map>
#include "FactorGraph.h" #include "FactorGraph.h"
#include "TinySet.h" #include "TinySet.h"
#include "Horus.h" #include "Horus.h"
using namespace std;
enum ElimHeuristic namespace Horus {
{
SEQUENTIAL, class ElimGraph {
MIN_NEIGHBORS, public:
MIN_WEIGHT, enum class ElimHeuristic {
MIN_FILL, sequentialEh,
WEIGHTED_MIN_FILL minNeighborsEh,
minWeightEh,
minFillEh,
weightedMinFillEh
}; };
class EgNode;
typedef TinySet<EgNode*> EGNeighs;
class EgNode : public Var
{
public:
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
void addNeighbor (EgNode* n) { neighs_.insert (n); }
void removeNeighbor (EgNode* n) { neighs_.remove (n); }
bool isNeighbor (EgNode* n) const { return neighs_.contains (n); }
const EGNeighs& neighbors (void) const { return neighs_; }
private:
EGNeighs neighs_;
};
class ElimGraph
{
public:
ElimGraph (const Factors&); ElimGraph (const Factors&);
~ElimGraph (void); ~ElimGraph();
VarIds getEliminatingOrder (const VarIds&); VarIds getEliminatingOrder (const VarIds&);
void print (void) const; void print() const;
void exportToGraphViz (const char*, bool = true, void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const; const VarIds& = VarIds()) const;
static VarIds getEliminationOrder (const Factors&, VarIds); static VarIds getEliminationOrder (const Factors&, VarIds);
static ElimHeuristic elimHeuristic (void) { return elimHeuristic_; } static ElimHeuristic elimHeuristic() { return elimHeuristic_; }
static void setElimHeuristic (ElimHeuristic eh) { elimHeuristic_ = eh; } static void setElimHeuristic (ElimHeuristic eh) { elimHeuristic_ = eh; }
private: private:
void addEdge (EgNode* n1, EgNode* n2) class EGNode;
typedef TinySet<EGNode*> EGNeighs;
class EGNode : public Var {
public:
EGNode (VarId vid, unsigned range) : Var (vid, range) { }
void addNeighbor (EGNode* n) { neighs_.insert (n); }
void removeNeighbor (EGNode* n) { neighs_.remove (n); }
bool isNeighbor (EGNode* n) const { return neighs_.contains (n); }
const EGNeighs& neighbors() const { return neighs_; }
private:
EGNeighs neighs_;
};
void addEdge (EGNode* n1, EGNode* n2);
unsigned getNeighborsCost (const EGNode* n) const;
unsigned getWeightCost (const EGNode* n) const;
unsigned getFillCost (const EGNode* n) const;
unsigned getWeightedFillCost (const EGNode* n) const;
bool neighbors (EGNode* n1, EGNode* n2) const;
void addNode (EGNode*);
EGNode* getEGNode (VarId) const;
EGNode* getLowestCostNode() const;
void connectAllNeighbors (const EGNode*);
std::vector<EGNode*> nodes_;
EGNeighs unmarked_;
std::unordered_map<VarId, EGNode*> varMap_;
static ElimHeuristic elimHeuristic_;
DISALLOW_COPY_AND_ASSIGN (ElimGraph);
};
/* Profiling shows that we should inline the following functions */
inline void
ElimGraph::addEdge (EGNode* n1, EGNode* n2)
{ {
assert (n1 != n2); assert (n1 != n2);
n1->addNeighbor (n2); n1->addNeighbor (n2);
n2->addNeighbor (n1); n2->addNeighbor (n1);
} }
unsigned getNeighborsCost (const EgNode* n) const
inline unsigned
ElimGraph::getNeighborsCost (const EGNode* n) const
{ {
return n->neighbors().size(); return n->neighbors().size();
} }
unsigned getWeightCost (const EgNode* n) const
inline unsigned
ElimGraph::getWeightCost (const EGNode* n) const
{ {
unsigned cost = 1; unsigned cost = 1;
const EGNeighs& neighs = n->neighbors(); const EGNeighs& neighs = n->neighbors();
@ -85,7 +125,10 @@ class ElimGraph
return cost; return cost;
} }
unsigned getFillCost (const EgNode* n) const
inline unsigned
ElimGraph::getFillCost (const EGNode* n) const
{ {
unsigned cost = 0; unsigned cost = 0;
const EGNeighs& neighs = n->neighbors(); const EGNeighs& neighs = n->neighbors();
@ -101,7 +144,10 @@ class ElimGraph
return cost; return cost;
} }
unsigned getWeightedFillCost (const EgNode* n) const
inline unsigned
ElimGraph::getWeightedFillCost (const EGNode* n) const
{ {
unsigned cost = 0; unsigned cost = 0;
const EGNeighs& neighs = n->neighbors(); const EGNeighs& neighs = n->neighbors();
@ -117,27 +163,15 @@ class ElimGraph
return cost; return cost;
} }
bool neighbors (EgNode* n1, EgNode* n2) const
inline bool
ElimGraph::neighbors (EGNode* n1, EGNode* n2) const
{ {
return n1->isNeighbor (n2); return n1->isNeighbor (n2);
} }
void addNode (EgNode*); } // namespace Horus
EgNode* getEgNode (VarId) const; #endif // YAP_PACKAGES_CLPBN_HORUS_ELIMGRAPH_H_
EgNode* getLowestCostNode (void) const;
void connectAllNeighbors (const EgNode*);
vector<EgNode*> nodes_;
TinySet<EgNode*> unmarked_;
unordered_map<VarId, EgNode*> varMap_;
static ElimHeuristic elimHeuristic_;
DISALLOW_COPY_AND_ASSIGN (ElimGraph);
};
#endif // HORUS_ELIMGRAPH_H

View File

@ -1,21 +1,15 @@
#include <cstdlib>
#include <cassert> #include <cassert>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "Factor.h" #include "Factor.h"
#include "Indexer.h"
#include "Var.h" #include "Var.h"
Factor::Factor (const Factor& g) namespace Horus {
{
clone (g);
}
Factor::Factor ( Factor::Factor (
const VarIds& vids, const VarIds& vids,
@ -77,7 +71,7 @@ Factor::sumOutAllExcept (VarId vid)
void void
Factor::sumOutAllExcept (const VarIds& vids) Factor::sumOutAllExcept (const VarIds& vids)
{ {
vector<bool> mask (args_.size(), false); std::vector<bool> mask (args_.size(), false);
for (unsigned i = 0; i < vids.size(); i++) { for (unsigned i = 0; i < vids.size(); i++) {
assert (indexOf (vids[i]) != args_.size()); assert (indexOf (vids[i]) != args_.size());
mask[indexOf (vids[i])] = true; mask[indexOf (vids[i])] = true;
@ -91,28 +85,30 @@ void
Factor::sumOutAllExceptIndex (size_t idx) Factor::sumOutAllExceptIndex (size_t idx)
{ {
assert (idx < args_.size()); assert (idx < args_.size());
vector<bool> mask (args_.size(), false); std::vector<bool> mask (args_.size(), false);
mask[idx] = true; mask[idx] = true;
sumOutArgs (mask); sumOutArgs (mask);
} }
void
Factor::multiply (Factor& g) Factor&
Factor::multiply (const Factor& g)
{ {
if (args_.empty()) { if (args_.empty()) {
clone (g); operator= (g);
} else { } else {
TFactor<VarId>::multiply (g); GenericFactor<VarId>::multiply (g);
} }
return *this;
} }
string std::string
Factor::getLabel (void) const Factor::getLabel() const
{ {
stringstream ss; std::stringstream ss;
ss << "f(" ; ss << "f(" ;
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ; if (i != 0) ss << "," ;
@ -125,19 +121,19 @@ Factor::getLabel (void) const
void void
Factor::print (void) const Factor::print() const
{ {
Vars vars; Vars vars;
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
vars.push_back (new Var (args_[i], ranges_[i])); vars.push_back (new Var (args_[i], ranges_[i]));
} }
vector<string> jointStrings = Util::getStateLines (vars); std::vector<std::string> jointStrings = Util::getStateLines (vars);
for (size_t i = 0; i < params_.size(); i++) { for (size_t i = 0; i < params_.size(); i++) {
// cout << "[" << distId_ << "] " ; // cout << "[" << distId_ << "] " ;
cout << "f(" << jointStrings[i] << ")" ; std::cout << "f(" << jointStrings[i] << ")" ;
cout << " = " << params_[i] << endl; std::cout << " = " << params_[i] << std::endl;
} }
cout << endl; std::cout << std::endl;
for (size_t i = 0; i < vars.size(); i++) { for (size_t i = 0; i < vars.size(); i++) {
delete vars[i]; delete vars[i];
} }
@ -146,8 +142,9 @@ Factor::print (void) const
void void
Factor::sumOutFirstVariable (void) Factor::sumOutFirstVariable()
{ {
assert (ranges_.front() == 2);
size_t sep = params_.size() / 2; size_t sep = params_.size() / 2;
if (Globals::logDomain) { if (Globals::logDomain) {
std::transform ( std::transform (
@ -169,19 +166,21 @@ Factor::sumOutFirstVariable (void)
void void
Factor::sumOutLastVariable (void) Factor::sumOutLastVariable()
{ {
assert (ranges_.back() == 2);
Params::iterator first1 = params_.begin(); Params::iterator first1 = params_.begin();
Params::iterator first2 = params_.begin(); Params::iterator first2 = params_.begin();
Params::iterator last = params_.end(); Params::iterator last = params_.end();
if (Globals::logDomain) { if (Globals::logDomain) {
while (first2 != last) { while (first2 != last) {
// the arguments can be swaped, but that is ok double tmp = *first2++;
*first1++ = Util::logSum (*first2++, *first2++); *first1++ = Util::logSum (tmp, *first2++);
} }
} else { } else {
while (first2 != last) { while (first2 != last) {
*first1++ = (*first2++) + (*first2++); *first1 = *first2++;
*first1++ += *first2++;
} }
} }
params_.resize (params_.size() / 2); params_.resize (params_.size() / 2);
@ -192,7 +191,7 @@ Factor::sumOutLastVariable (void)
void void
Factor::sumOutArgs (const vector<bool>& mask) Factor::sumOutArgs (const std::vector<bool>& mask)
{ {
assert (mask.size() == args_.size()); assert (mask.size() == args_.size());
size_t new_size = 1; size_t new_size = 1;
@ -224,14 +223,5 @@ Factor::sumOutArgs (const vector<bool>& mask)
params_ = newps; params_ = newps;
} }
} // namespace Horus
void
Factor::clone (const Factor& g)
{
args_ = g.arguments();
ranges_ = g.ranges();
params_ = g.params();
distId_ = g.distId();
}

View File

@ -1,262 +1,20 @@
#ifndef HORUS_FACTOR_H #ifndef YAP_PACKAGES_CLPBN_HORUS_FACTOR_H_
#define HORUS_FACTOR_H #define YAP_PACKAGES_CLPBN_HORUS_FACTOR_H_
#include <cassert>
#include <vector> #include <vector>
#include <string>
#include "Indexer.h" #include "GenericFactor.h"
#include "Util.h" #include "Util.h"
using namespace std; namespace Horus {
class Factor : public GenericFactor<VarId> {
template <typename T>
class TFactor
{
public: public:
const vector<T>& arguments (void) const { return args_; } Factor() { }
vector<T>& arguments (void) { return args_; }
const Ranges& ranges (void) const { return ranges_; }
const Params& params (void) const { return params_; }
Params& params (void) { return params_; }
size_t nrArguments (void) const { return args_.size(); }
size_t size (void) const { return params_.size(); }
unsigned distId (void) const { return distId_; }
void setDistId (unsigned id) { distId_ = id; }
void normalize (void) { LogAware::normalize (params_); }
void randomize (void)
{
for (size_t i = 0; i < params_.size(); ++i) {
params_[i] = (double) std::rand() / RAND_MAX;
}
}
void setParams (const Params& newParams)
{
params_ = newParams;
assert (params_.size() == Util::sizeExpected (ranges_));
}
size_t indexOf (const T& t) const
{
return Util::indexOf (args_, t);
}
const T& argument (size_t idx) const
{
assert (idx < args_.size());
return args_[idx];
}
T& argument (size_t idx)
{
assert (idx < args_.size());
return args_[idx];
}
unsigned range (size_t idx) const
{
assert (idx < ranges_.size());
return ranges_[idx];
}
void multiply (TFactor<T>& g)
{
if (args_ == g.arguments()) {
// optimization
Globals::logDomain
? params_ += g.params()
: params_ *= g.params();
return;
}
unsigned range_prod = 1;
bool share_arguments = false;
const vector<T>& g_args = g.arguments();
const Ranges& g_ranges = g.ranges();
const Params& g_params = g.params();
for (size_t i = 0; i < g_args.size(); i++) {
size_t idx = indexOf (g_args[i]);
if (idx == args_.size()) {
range_prod *= g_ranges[i];
args_.push_back (g_args[i]);
ranges_.push_back (g_ranges[i]);
} else {
share_arguments = true;
}
}
if (share_arguments == false) {
// optimization
cartesianProduct (g_params.begin(), g_params.end());
} else {
extend (range_prod);
Params::iterator it = params_.begin();
MapIndexer indexer (args_, ranges_, g_args, g_ranges);
if (Globals::logDomain) {
for (; indexer.valid(); ++it, ++indexer) {
*it += g_params[indexer];
}
} else {
for (; indexer.valid(); ++it, ++indexer) {
*it *= g_params[indexer];
}
}
}
}
void sumOutIndex (size_t idx)
{
assert (idx < args_.size());
assert (args_.size() > 1);
size_t new_size = params_.size() / ranges_[idx];
Params newps (new_size, LogAware::addIdenty());
Params::const_iterator first = params_.begin();
Params::const_iterator last = params_.end();
MapIndexer indexer (ranges_, idx);
if (Globals::logDomain) {
for (; first != last; ++indexer) {
newps[indexer] = Util::logSum (newps[indexer], *first++);
}
} else {
for (; first != last; ++indexer) {
newps[indexer] += *first++;
}
}
params_ = newps;
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
void absorveEvidence (const T& arg, unsigned obsIdx)
{
size_t idx = indexOf (arg);
assert (idx != args_.size());
assert (obsIdx < ranges_[idx]);
Params newps;
newps.reserve (params_.size() / ranges_[idx]);
Indexer indexer (ranges_);
for (unsigned i = 0; i < obsIdx; ++i) {
indexer.incrementDimension (idx);
}
while (indexer.valid()) {
newps.push_back (params_[indexer]);
indexer.incrementExceptDimension (idx);
}
params_ = newps;
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
void reorderArguments (const vector<T> new_args)
{
assert (new_args.size() == args_.size());
if (new_args == args_) {
return; // already on the desired order
}
Ranges new_ranges;
for (size_t i = 0; i < new_args.size(); i++) {
size_t idx = indexOf (new_args[i]);
assert (idx != args_.size());
new_ranges.push_back (ranges_[idx]);
}
Params newps;
newps.reserve (params_.size());
MapIndexer indexer (new_args, new_ranges, args_, ranges_);
for (; indexer.valid(); ++indexer) {
newps.push_back (params_[indexer]);
}
params_ = newps;
args_ = new_args;
ranges_ = new_ranges;
}
bool contains (const T& arg) const
{
return Util::contains (args_, arg);
}
bool contains (const vector<T>& args) const
{
for (size_t i = 0; i < args.size(); i++) {
if (contains (args[i]) == false) {
return false;
}
}
return true;
}
double& operator[] (size_t idx)
{
assert (idx < params_.size());
return params_[idx];
}
protected:
vector<T> args_;
Ranges ranges_;
Params params_;
unsigned distId_;
private:
void extend (unsigned range_prod)
{
Params backup = params_;
params_.clear();
params_.reserve (backup.size() * range_prod);
Params::const_iterator first = backup.begin();
Params::const_iterator last = backup.end();
for (; first != last; ++first) {
for (unsigned reps = 0; reps < range_prod; ++reps) {
params_.push_back (*first);
}
}
}
void cartesianProduct (
Params::const_iterator first2,
Params::const_iterator last2)
{
Params backup = params_;
params_.clear();
params_.reserve (params_.size() * (last2 - first2));
Params::const_iterator first1 = backup.begin();
Params::const_iterator last1 = backup.end();
Params::const_iterator tmp;
if (Globals::logDomain) {
for (; first1 != last1; ++first1) {
for (tmp = first2; tmp != last2; ++tmp) {
params_.push_back ((*first1) + (*tmp));
}
}
} else {
for (; first1 != last1; ++first1) {
for (tmp = first2; tmp != last2; ++tmp) {
params_.push_back ((*first1) * (*tmp));
}
}
}
}
};
class Factor : public TFactor<VarId>
{
public:
Factor (void) { }
Factor (const Factor&);
Factor (const VarIds&, const Ranges&, const Params&, Factor (const VarIds&, const Ranges&, const Params&,
unsigned = Util::maxUnsigned()); unsigned = Util::maxUnsigned());
@ -272,23 +30,21 @@ class Factor : public TFactor<VarId>
void sumOutAllExceptIndex (size_t idx); void sumOutAllExceptIndex (size_t idx);
void multiply (Factor&); Factor& multiply (const Factor&);
string getLabel (void) const; std::string getLabel() const;
void print (void) const; void print() const;
private: private:
void sumOutFirstVariable (void); void sumOutFirstVariable();
void sumOutLastVariable (void); void sumOutLastVariable();
void sumOutArgs (const vector<bool>& mask); void sumOutArgs (const std::vector<bool>& mask);
void clone (const Factor& f);
DISALLOW_ASSIGN (Factor);
}; };
#endif // HORUS_FACTOR_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_FACTOR_H_

View File

@ -1,17 +1,15 @@
#include <cassert>
#include <algorithm> #include <algorithm>
#include <set>
#include <vector>
#include <iostream> #include <iostream>
#include <sstream>
#include <fstream>
#include "FactorGraph.h" #include "FactorGraph.h"
#include "BayesBall.h" #include "BayesBall.h"
#include "Util.h" #include "Util.h"
namespace Horus {
bool FactorGraph::exportLd_ = false; bool FactorGraph::exportLd_ = false;
bool FactorGraph::exportUai_ = false; bool FactorGraph::exportUai_ = false;
bool FactorGraph::exportGv_ = false; bool FactorGraph::exportGv_ = false;
@ -20,25 +18,12 @@ bool FactorGraph::printFg_ = false;
FactorGraph::FactorGraph (const FactorGraph& fg) FactorGraph::FactorGraph (const FactorGraph& fg)
{ {
const VarNodes& varNodes = fg.varNodes(); clone (fg);
for (size_t i = 0; i < varNodes.size(); i++) {
addVarNode (new VarNode (varNodes[i]));
}
const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) {
FacNode* facNode = new FacNode (facNodes[i]->factor());
addFacNode (facNode);
const VarNodes& neighs = facNodes[i]->neighbors();
for (size_t j = 0; j < neighs.size(); j++) {
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
}
}
bayesFactors_ = fg.bayesianFactors();
} }
FactorGraph::~FactorGraph (void) FactorGraph::~FactorGraph()
{ {
for (size_t i = 0; i < varNodes_.size(); i++) { for (size_t i = 0; i < varNodes_.size(); i++) {
delete varNodes_[i]; delete varNodes_[i];
@ -50,152 +35,6 @@ FactorGraph::~FactorGraph (void)
void
FactorGraph::readFromUaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ;
exit (EXIT_FAILURE);
}
ignoreLines (is);
string line;
getline (is, line);
if (line == "BAYES") {
bayesFactors_ = true;
} else if (line == "MARKOV") {
bayesFactors_ = false;
} else {
cerr << "Error: the type of network is missing." << endl;
exit (EXIT_FAILURE);
}
// read the number of vars
ignoreLines (is);
unsigned nrVars;
is >> nrVars;
// read the range of each var
ignoreLines (is);
Ranges ranges (nrVars);
for (unsigned i = 0; i < nrVars; i++) {
is >> ranges[i];
}
unsigned nrFactors;
unsigned nrArgs;
unsigned vid;
is >> nrFactors;
vector<VarIds> allVarIds;
vector<Ranges> allRanges;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrArgs;
allVarIds.push_back ({ });
allRanges.push_back ({ });
for (unsigned j = 0; j < nrArgs; j++) {
is >> vid;
if (vid >= ranges.size()) {
cerr << "Error: invalid variable identifier `" << vid << "'. " ;
cerr << "Identifiers must be between 0 and " << ranges.size() - 1 ;
cerr << "." << endl;
exit (EXIT_FAILURE);
}
allVarIds.back().push_back (vid);
allRanges.back().push_back (ranges[vid]);
}
}
// read the parameters
unsigned nrParams;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrParams;
if (nrParams != Util::sizeExpected (allRanges[i])) {
cerr << "Error: invalid number of parameters for factor nº " << i ;
cerr << ", " << Util::sizeExpected (allRanges[i]);
cerr << " expected, " << nrParams << " given." << endl;
exit (EXIT_FAILURE);
}
Params params (nrParams);
for (unsigned j = 0; j < nrParams; j++) {
is >> params[j];
}
if (Globals::logDomain) {
Util::log (params);
}
Factor f (allVarIds[i], allRanges[i], params);
if (bayesFactors_ && allVarIds[i].size() > 1) {
// In this format the child is the last variable,
// move it to be the first
std::swap (allVarIds[i].front(), allVarIds[i].back());
f.reorderArguments (allVarIds[i]);
}
addFactor (f);
}
is.close();
}
void
FactorGraph::readFromLibDaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ;
exit (EXIT_FAILURE);
}
ignoreLines (is);
unsigned nrFactors;
unsigned nrArgs;
VarId vid;
is >> nrFactors;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
// read the factor arguments
is >> nrArgs;
VarIds vids;
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> vid;
vids.push_back (vid);
}
// read ranges
Ranges ranges (nrArgs);
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> ranges[j];
VarNode* var = getVarNode (vids[j]);
if (var && ranges[j] != var->range()) {
cerr << "Error: variable `" << vids[j] << "' appears in two or " ;
cerr << "more factors with a different range." << endl;
}
}
// read parameters
ignoreLines (is);
unsigned nNonzeros;
is >> nNonzeros;
Params params (Util::sizeExpected (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
unsigned index;
is >> index;
ignoreLines (is);
double val;
is >> val;
params[index] = val;
}
if (Globals::logDomain) {
Util::log (params);
}
std::reverse (vids.begin(), vids.end());
Factor f (vids, ranges, params);
std::reverse (vids.begin(), vids.end());
f.reorderArguments (vids);
addFactor (f);
}
is.close();
}
void void
FactorGraph::addFactor (const Factor& factor) FactorGraph::addFactor (const Factor& factor)
{ {
@ -221,7 +60,7 @@ FactorGraph::addVarNode (VarNode* vn)
{ {
varNodes_.push_back (vn); varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1); vn->setIndex (varNodes_.size() - 1);
varMap_.insert (make_pair (vn->varId(), vn)); varMap_.insert (std::make_pair (vn->varId(), vn));
} }
@ -245,7 +84,7 @@ FactorGraph::addEdge (VarNode* vn, FacNode* fn)
bool bool
FactorGraph::isTree (void) const FactorGraph::isTree() const
{ {
return !containsCycle(); return !containsCycle();
} }
@ -253,7 +92,7 @@ FactorGraph::isTree (void) const
BayesBallGraph& BayesBallGraph&
FactorGraph::getStructure (void) FactorGraph::getStructure()
{ {
assert (bayesFactors_); assert (bayesFactors_);
if (structure_.empty()) { if (structure_.empty()) {
@ -273,8 +112,10 @@ FactorGraph::getStructure (void)
void void
FactorGraph::print (void) const FactorGraph::print() const
{ {
using std::cout;
using std::endl;
for (size_t i = 0; i < varNodes_.size(); i++) { for (size_t i = 0; i < varNodes_.size(); i++) {
cout << "var id = " << varNodes_[i]->varId() << endl; cout << "var id = " << varNodes_[i]->varId() << endl;
cout << "label = " << varNodes_[i]->label() << endl; cout << "label = " << varNodes_[i]->label() << endl;
@ -296,28 +137,29 @@ FactorGraph::print (void) const
void void
FactorGraph::exportToLibDai (const char* fileName) const FactorGraph::exportToLibDai (const char* fileName) const
{ {
ofstream out (fileName); std::ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ; std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return; return;
} }
out << facNodes_.size() << endl << endl; out << facNodes_.size() << std::endl << std::endl;
for (size_t i = 0; i < facNodes_.size(); i++) { for (size_t i = 0; i < facNodes_.size(); i++) {
Factor f (facNodes_[i]->factor()); Factor f (facNodes_[i]->factor());
out << f.nrArguments() << endl; out << f.nrArguments() << std::endl;
out << Util::elementsToString (f.arguments()) << endl; out << Util::elementsToString (f.arguments()) << std::endl;
out << Util::elementsToString (f.ranges()) << endl; out << Util::elementsToString (f.ranges()) << std::endl;
VarIds args = f.arguments(); VarIds args = f.arguments();
std::reverse (args.begin(), args.end()); std::reverse (args.begin(), args.end());
f.reorderArguments (args); f.reorderArguments (args);
if (Globals::logDomain) { if (Globals::logDomain) {
Util::exp (f.params()); Util::exp (f.params());
} }
out << f.size() << endl; out << f.size() << std::endl;
for (size_t j = 0; j < f.size(); j++) { for (size_t j = 0; j < f.size(); j++) {
out << j << " " << f[j] << endl; out << j << " " << f[j] << std::endl;
} }
out << endl; out << std::endl;
} }
out.close(); out.close();
} }
@ -327,28 +169,30 @@ FactorGraph::exportToLibDai (const char* fileName) const
void void
FactorGraph::exportToUai (const char* fileName) const FactorGraph::exportToUai (const char* fileName) const
{ {
ofstream out (fileName); std::ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ; std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return; return;
} }
out << (bayesFactors_ ? "BAYES" : "MARKOV") ; out << (bayesFactors_ ? "BAYES" : "MARKOV") ;
out << endl << endl; out << std::endl << std::endl;
out << varNodes_.size() << endl; out << varNodes_.size() << std::endl;
VarNodes sortedVns = varNodes_; VarNodes sortedVns = varNodes_;
std::sort (sortedVns.begin(), sortedVns.end(), sortByVarId()); std::sort (sortedVns.begin(), sortedVns.end(), sortByVarId());
for (size_t i = 0; i < sortedVns.size(); i++) { for (size_t i = 0; i < sortedVns.size(); i++) {
out << ((i != 0) ? " " : "") << sortedVns[i]->range(); out << ((i != 0) ? " " : "") << sortedVns[i]->range();
} }
out << endl << facNodes_.size() << endl; out << std::endl << facNodes_.size() << std::endl;
for (size_t i = 0; i < facNodes_.size(); i++) { for (size_t i = 0; i < facNodes_.size(); i++) {
VarIds args = facNodes_[i]->factor().arguments(); VarIds args = facNodes_[i]->factor().arguments();
if (bayesFactors_) { if (bayesFactors_) {
std::swap (args.front(), args.back()); std::swap (args.front(), args.back());
} }
out << args.size() << " " << Util::elementsToString (args) << endl; out << args.size() << " " << Util::elementsToString (args);
out << std::endl;
} }
out << endl; out << std::endl;
for (size_t i = 0; i < facNodes_.size(); i++) { for (size_t i = 0; i < facNodes_.size(); i++) {
Factor f = facNodes_[i]->factor(); Factor f = facNodes_[i]->factor();
if (bayesFactors_) { if (bayesFactors_) {
@ -360,8 +204,9 @@ FactorGraph::exportToUai (const char* fileName) const
if (Globals::logDomain) { if (Globals::logDomain) {
Util::exp (params); Util::exp (params);
} }
out << params.size() << endl << " " ; out << params.size() << std::endl << " " ;
out << Util::elementsToString (params) << endl << endl; out << Util::elementsToString (params);
out << std::endl << std::endl;
} }
out.close(); out.close();
} }
@ -371,53 +216,239 @@ FactorGraph::exportToUai (const char* fileName) const
void void
FactorGraph::exportToGraphViz (const char* fileName) const FactorGraph::exportToGraphViz (const char* fileName) const
{ {
ofstream out (fileName); std::ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ; std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return; return;
} }
out << "graph \"" << fileName << "\" {" << endl; out << "graph \"" << fileName << "\" {" << std::endl;
for (size_t i = 0; i < varNodes_.size(); i++) { for (size_t i = 0; i < varNodes_.size(); i++) {
if (varNodes_[i]->hasEvidence()) { if (varNodes_[i]->hasEvidence()) {
out << '"' << varNodes_[i]->label() << '"' ; out << '"' << varNodes_[i]->label() << '"' ;
out << " [style=filled, fillcolor=yellow]" << endl; out << " [style=filled, fillcolor=yellow]" << std::endl;
} }
} }
for (size_t i = 0; i < facNodes_.size(); i++) { for (size_t i = 0; i < facNodes_.size(); i++) {
out << '"' << facNodes_[i]->getLabel() << '"' ; out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " [label=\"" << facNodes_[i]->getLabel(); out << " [label=\"" << facNodes_[i]->getLabel();
out << "\"" << ", shape=box]" << endl; out << "\"" << ", shape=box]" << std::endl;
} }
for (size_t i = 0; i < facNodes_.size(); i++) { for (size_t i = 0; i < facNodes_.size(); i++) {
const VarNodes& myVars = facNodes_[i]->neighbors(); const VarNodes& myVars = facNodes_[i]->neighbors();
for (size_t j = 0; j < myVars.size(); j++) { for (size_t j = 0; j < myVars.size(); j++) {
out << '"' << facNodes_[i]->getLabel() << '"' ; out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " -- " ; out << " -- " ;
out << '"' << myVars[j]->label() << '"' << endl; out << '"' << myVars[j]->label() << '"' << std::endl;
} }
} }
out << "}" << endl; out << "}" << std::endl;
out.close(); out.close();
} }
void FactorGraph&
FactorGraph::ignoreLines (std::ifstream& is) const FactorGraph::operator= (const FactorGraph& fg)
{ {
string ignoreStr; if (this != &fg) {
while (is.peek() == '#' || is.peek() == '\n') { for (size_t i = 0; i < varNodes_.size(); i++) {
getline (is, ignoreStr); delete varNodes_[i];
} }
varNodes_.clear();
for (size_t i = 0; i < facNodes_.size(); i++) {
delete facNodes_[i];
}
facNodes_.clear();
varMap_.clear();
clone (fg);
}
return *this;
}
FactorGraph
FactorGraph::readFromUaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
FactorGraph fg;
ignoreLines (is);
std::string line;
getline (is, line);
if (line == "BAYES") {
fg.bayesFactors_ = true;
} else if (line == "MARKOV") {
fg.bayesFactors_ = false;
} else {
std::cerr << "Error: the type of network is missing." << std::endl;
exit (EXIT_FAILURE);
}
// read the number of vars
ignoreLines (is);
unsigned nrVars;
is >> nrVars;
// read the range of each var
ignoreLines (is);
Ranges ranges (nrVars);
for (unsigned i = 0; i < nrVars; i++) {
is >> ranges[i];
}
unsigned nrFactors;
unsigned nrArgs;
unsigned vid;
is >> nrFactors;
std::vector<VarIds> allVarIds;
std::vector<Ranges> allRanges;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrArgs;
allVarIds.push_back ({ });
allRanges.push_back ({ });
for (unsigned j = 0; j < nrArgs; j++) {
is >> vid;
if (vid >= ranges.size()) {
std::cerr << "Error: invalid variable identifier `" << vid << "'" ;
std::cerr << ". Identifiers must be between 0 and " ;
std::cerr << ranges.size() - 1 << "." << std::endl;
exit (EXIT_FAILURE);
}
allVarIds.back().push_back (vid);
allRanges.back().push_back (ranges[vid]);
}
}
// read the parameters
unsigned nrParams;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrParams;
if (nrParams != Util::sizeExpected (allRanges[i])) {
std::cerr << "Error: invalid number of parameters for factor nº " ;
std::cerr << i << ", " << Util::sizeExpected (allRanges[i]);
std::cerr << " expected, " << nrParams << " given." << std::endl;
exit (EXIT_FAILURE);
}
Params params (nrParams);
for (unsigned j = 0; j < nrParams; j++) {
is >> params[j];
}
if (Globals::logDomain) {
Util::log (params);
}
Factor f (allVarIds[i], allRanges[i], params);
if (fg.bayesFactors_ && allVarIds[i].size() > 1) {
// In this format the child is the last variable,
// move it to be the first
std::swap (allVarIds[i].front(), allVarIds[i].back());
f.reorderArguments (allVarIds[i]);
}
fg.addFactor (f);
}
is.close();
return fg;
}
FactorGraph
FactorGraph::readFromLibDaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
FactorGraph fg;
ignoreLines (is);
unsigned nrFactors;
unsigned nrArgs;
VarId vid;
is >> nrFactors;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
// read the factor arguments
is >> nrArgs;
VarIds vids;
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> vid;
vids.push_back (vid);
}
// read ranges
Ranges ranges (nrArgs);
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> ranges[j];
VarNode* var = fg.getVarNode (vids[j]);
if (var && ranges[j] != var->range()) {
std::cerr << "Error: variable `" << vids[j] << "' appears" ;
std::cerr << " in two or more factors with a different range." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
}
// read parameters
ignoreLines (is);
unsigned nNonzeros;
is >> nNonzeros;
Params params (Util::sizeExpected (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
unsigned index;
is >> index;
ignoreLines (is);
double val;
is >> val;
params[index] = val;
}
if (Globals::logDomain) {
Util::log (params);
}
std::reverse (vids.begin(), vids.end());
std::reverse (ranges.begin(), ranges.end());
Factor f (vids, ranges, params);
std::reverse (vids.begin(), vids.end());
f.reorderArguments (vids);
fg.addFactor (f);
}
is.close();
return fg;
}
void
FactorGraph::clone (const FactorGraph& fg)
{
const VarNodes& varNodes = fg.varNodes();
for (size_t i = 0; i < varNodes.size(); i++) {
addVarNode (new VarNode (varNodes[i]));
}
const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) {
FacNode* facNode = new FacNode (facNodes[i]->factor());
addFacNode (facNode);
const VarNodes& neighs = facNodes[i]->neighbors();
for (size_t j = 0; j < neighs.size(); j++) {
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
}
}
bayesFactors_ = fg.bayesianFactors();
} }
bool bool
FactorGraph::containsCycle (void) const FactorGraph::containsCycle() const
{ {
vector<bool> visitedVars (varNodes_.size(), false); std::vector<bool> visitedVars (varNodes_.size(), false);
vector<bool> visitedFactors (facNodes_.size(), false); std::vector<bool> visitedFactors (facNodes_.size(), false);
for (size_t i = 0; i < varNodes_.size(); i++) { for (size_t i = 0; i < varNodes_.size(); i++) {
int v = varNodes_[i]->getIndex(); int v = varNodes_[i]->getIndex();
if (!visitedVars[v]) { if (!visitedVars[v]) {
@ -435,8 +466,8 @@ bool
FactorGraph::containsCycle ( FactorGraph::containsCycle (
const VarNode* v, const VarNode* v,
const FacNode* p, const FacNode* p,
vector<bool>& visitedVars, std::vector<bool>& visitedVars,
vector<bool>& visitedFactors) const std::vector<bool>& visitedFactors) const
{ {
visitedVars[v->getIndex()] = true; visitedVars[v->getIndex()] = true;
const FacNodes& adjacencies = v->neighbors(); const FacNodes& adjacencies = v->neighbors();
@ -460,8 +491,8 @@ bool
FactorGraph::containsCycle ( FactorGraph::containsCycle (
const FacNode* v, const FacNode* v,
const VarNode* p, const VarNode* p,
vector<bool>& visitedVars, std::vector<bool>& visitedVars,
vector<bool>& visitedFactors) const std::vector<bool>& visitedFactors) const
{ {
visitedFactors[v->getIndex()] = true; visitedFactors[v->getIndex()] = true;
const VarNodes& adjacencies = v->neighbors(); const VarNodes& adjacencies = v->neighbors();
@ -479,3 +510,16 @@ FactorGraph::containsCycle (
return false; // no cycle detected in this component return false; // no cycle detected in this component
} }
void
FactorGraph::ignoreLines (std::ifstream& is)
{
std::string ignoreStr;
while (is.peek() == '#' || is.peek() == '\n') {
getline (is, ignoreStr);
}
}
} // namespace Horus

View File

@ -1,28 +1,32 @@
#ifndef HORUS_FACTORGRAPH_H #ifndef YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_
#define HORUS_FACTORGRAPH_H #define YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_
#include <vector> #include <vector>
#include <unordered_map>
#include <string>
#include <fstream>
#include "Factor.h" #include "Factor.h"
#include "BayesBallGraph.h" #include "BayesBallGraph.h"
#include "Horus.h" #include "Horus.h"
using namespace std;
namespace Horus {
class FacNode; class FacNode;
class VarNode : public Var
{ class VarNode : public Var {
public: public:
VarNode (VarId varId, unsigned nrStates, VarNode (VarId varId, unsigned nrStates,
int evidence = Constants::NO_EVIDENCE) int evidence = Constants::unobserved)
: Var (varId, nrStates, evidence) { } : Var (varId, nrStates, evidence) { }
VarNode (const Var* v) : Var (v) { } VarNode (const Var* v) : Var (v) { }
void addNeighbor (FacNode* fn) { neighs_.push_back (fn); } void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
const FacNodes& neighbors (void) const { return neighs_; } const FacNodes& neighbors() const { return neighs_; }
private: private:
FacNodes neighs_; FacNodes neighs_;
@ -32,24 +36,23 @@ class VarNode : public Var
class FacNode class FacNode {
{
public: public:
FacNode (const Factor& f) : factor_(f), index_(-1) { } FacNode (const Factor& f) : factor_(f), index_(-1) { }
const Factor& factor (void) const { return factor_; } const Factor& factor() const { return factor_; }
Factor& factor (void) { return factor_; } Factor& factor() { return factor_; }
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); } void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
const VarNodes& neighbors (void) const { return neighs_; } const VarNodes& neighbors() const { return neighs_; }
size_t getIndex (void) const { return index_; } size_t getIndex() const { return index_; }
void setIndex (size_t index) { index_ = index; } void setIndex (size_t index) { index_ = index; }
string getLabel (void) { return factor_.getLabel(); } std::string getLabel() { return factor_.getLabel(); }
private: private:
VarNodes neighs_; VarNodes neighs_;
@ -61,36 +64,27 @@ class FacNode
class FactorGraph class FactorGraph {
{
public: public:
FactorGraph (void) : bayesFactors_(false) { } FactorGraph() : bayesFactors_(false) { }
FactorGraph (const FactorGraph&); FactorGraph (const FactorGraph&);
~FactorGraph (void); ~FactorGraph();
const VarNodes& varNodes (void) const { return varNodes_; } const VarNodes& varNodes() const { return varNodes_; }
const FacNodes& facNodes (void) const { return facNodes_; } const FacNodes& facNodes() const { return facNodes_; }
void setFactorsAsBayesian (void) { bayesFactors_ = true; } void setFactorsAsBayesian() { bayesFactors_ = true; }
bool bayesianFactors (void) const { return bayesFactors_; } bool bayesianFactors() const { return bayesFactors_; }
size_t nrVarNodes (void) const { return varNodes_.size(); } size_t nrVarNodes() const { return varNodes_.size(); }
size_t nrFacNodes (void) const { return facNodes_.size(); } size_t nrFacNodes() const { return facNodes_.size(); }
VarNode* getVarNode (VarId vid) const VarNode* getVarNode (VarId vid) const;
{
VarMap::const_iterator it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
void readFromUaiFormat (const char*);
void readFromLibDaiFormat (const char*);
void addFactor (const Factor& factor); void addFactor (const Factor& factor);
@ -100,11 +94,11 @@ class FactorGraph
void addEdge (VarNode*, FacNode*); void addEdge (VarNode*, FacNode*);
bool isTree (void) const; bool isTree() const;
BayesBallGraph& getStructure (void); BayesBallGraph& getStructure();
void print (void) const; void print() const;
void exportToLibDai (const char*) const; void exportToLibDai (const char*) const;
@ -112,67 +106,80 @@ class FactorGraph
void exportToGraphViz (const char*) const; void exportToGraphViz (const char*) const;
static bool exportToLibDai (void) { return exportLd_; } FactorGraph& operator= (const FactorGraph&);
static bool exportToUai (void) { return exportUai_; } static FactorGraph readFromUaiFormat (const char*);
static bool exportGraphViz (void) { return exportGv_; } static FactorGraph readFromLibDaiFormat (const char*);
static bool printFactorGraph (void) { return printFg_; } static bool exportToLibDai() { return exportLd_; }
static void enableExportToLibDai (void) { exportLd_ = true; } static bool exportToUai() { return exportUai_; }
static void disableExportToLibDai (void) { exportLd_ = false; } static bool exportGraphViz() { return exportGv_; }
static void enableExportToUai (void) { exportUai_ = true; } static bool printFactorGraph() { return printFg_; }
static void disableExportToUai (void) { exportUai_ = false; } static void enableExportToLibDai() { exportLd_ = true; }
static void enableExportToGraphViz (void) { exportGv_ = true; } static void disableExportToLibDai() { exportLd_ = false; }
static void disableExportToGraphViz (void) { exportGv_ = false; } static void enableExportToUai() { exportUai_ = true; }
static void enablePrintFactorGraph (void) { printFg_ = true; } static void disableExportToUai() { exportUai_ = false; }
static void disablePrintFactorGraph (void) { printFg_ = false; } static void enableExportToGraphViz() { exportGv_ = true; }
static void disableExportToGraphViz() { exportGv_ = false; }
static void enablePrintFactorGraph() { printFg_ = true; }
static void disablePrintFactorGraph() { printFg_ = false; }
private: private:
void ignoreLines (std::ifstream&) const; typedef std::unordered_map<unsigned, VarNode*> VarMap;
bool containsCycle (void) const; void clone (const FactorGraph& fg);
bool containsCycle() const;
bool containsCycle (const VarNode*, const FacNode*, bool containsCycle (const VarNode*, const FacNode*,
vector<bool>&, vector<bool>&) const; std::vector<bool>&, std::vector<bool>&) const;
bool containsCycle (const FacNode*, const VarNode*, bool containsCycle (const FacNode*, const VarNode*,
vector<bool>&, vector<bool>&) const; std::vector<bool>&, std::vector<bool>&) const;
static void ignoreLines (std::ifstream&);
VarNodes varNodes_; VarNodes varNodes_;
FacNodes facNodes_; FacNodes facNodes_;
VarMap varMap_;
BayesBallGraph structure_; BayesBallGraph structure_;
bool bayesFactors_; bool bayesFactors_;
typedef unordered_map<unsigned, VarNode*> VarMap;
VarMap varMap_;
static bool exportLd_; static bool exportLd_;
static bool exportUai_; static bool exportUai_;
static bool exportGv_; static bool exportGv_;
static bool printFg_; static bool printFg_;
DISALLOW_ASSIGN (FactorGraph);
}; };
struct sortByVarId inline VarNode*
FactorGraph::getVarNode (VarId vid) const
{ {
VarMap::const_iterator it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
struct sortByVarId {
bool operator()(VarNode* vn1, VarNode* vn2) { bool operator()(VarNode* vn1, VarNode* vn2) {
return vn1->varId() < vn2->varId(); return vn1->varId() < vn2->varId();
} }};
};
} // namespace Horus
#endif // HORUS_FACTORGRAPH_H #endif // YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_

View File

@ -0,0 +1,256 @@
#include <cassert>
#include "GenericFactor.h"
#include "ProbFormula.h"
#include "Indexer.h"
namespace Horus {
template <typename T> const T&
GenericFactor<T>::argument (size_t idx) const
{
assert (idx < args_.size());
return args_[idx];
}
template <typename T> T&
GenericFactor<T>::argument (size_t idx)
{
assert (idx < args_.size());
return args_[idx];
}
template <typename T> unsigned
GenericFactor<T>::range (size_t idx) const
{
assert (idx < ranges_.size());
return ranges_[idx];
}
template <typename T> bool
GenericFactor<T>::contains (const T& arg) const
{
return Util::contains (args_, arg);
}
template <typename T> bool
GenericFactor<T>::contains (const std::vector<T>& args) const
{
for (size_t i = 0; i < args.size(); i++) {
if (contains (args[i]) == false) {
return false;
}
}
return true;
}
template <typename T> void
GenericFactor<T>::setParams (const Params& newParams)
{
params_ = newParams;
assert (params_.size() == Util::sizeExpected (ranges_));
}
template <typename T> double
GenericFactor<T>::operator[] (size_t idx) const
{
assert (idx < params_.size());
return params_[idx];
}
template <typename T> double&
GenericFactor<T>::operator[] (size_t idx)
{
assert (idx < params_.size());
return params_[idx];
}
template <typename T> GenericFactor<T>&
GenericFactor<T>::multiply (const GenericFactor<T>& g)
{
if (args_ == g.arguments()) {
// optimization
Globals::logDomain
? params_ += g.params()
: params_ *= g.params();
return *this;
}
unsigned range_prod = 1;
bool share_arguments = false;
const std::vector<T>& g_args = g.arguments();
const Ranges& g_ranges = g.ranges();
const Params& g_params = g.params();
for (size_t i = 0; i < g_args.size(); i++) {
size_t idx = indexOf (g_args[i]);
if (idx == args_.size()) {
range_prod *= g_ranges[i];
args_.push_back (g_args[i]);
ranges_.push_back (g_ranges[i]);
} else {
share_arguments = true;
}
}
if (share_arguments == false) {
// optimization
cartesianProduct (g_params.begin(), g_params.end());
} else {
extend (range_prod);
Params::iterator it = params_.begin();
MapIndexer indexer (args_, ranges_, g_args, g_ranges);
if (Globals::logDomain) {
for (; indexer.valid(); ++it, ++indexer) {
*it += g_params[indexer];
}
} else {
for (; indexer.valid(); ++it, ++indexer) {
*it *= g_params[indexer];
}
}
}
return *this;
}
template <typename T> void
GenericFactor<T>::sumOutIndex (size_t idx)
{
assert (idx < args_.size());
assert (args_.size() > 1);
size_t new_size = params_.size() / ranges_[idx];
Params newps (new_size, LogAware::addIdenty());
Params::const_iterator first = params_.begin();
Params::const_iterator last = params_.end();
MapIndexer indexer (ranges_, idx);
if (Globals::logDomain) {
for (; first != last; ++indexer) {
newps[indexer] = Util::logSum (newps[indexer], *first++);
}
} else {
for (; first != last; ++indexer) {
newps[indexer] += *first++;
}
}
params_ = newps;
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
template <typename T> void
GenericFactor<T>::absorveEvidence (const T& arg, unsigned obsIdx)
{
size_t idx = indexOf (arg);
assert (idx != args_.size());
assert (obsIdx < ranges_[idx]);
Params newps;
newps.reserve (params_.size() / ranges_[idx]);
Indexer indexer (ranges_);
for (unsigned i = 0; i < obsIdx; ++i) {
indexer.incrementDimension (idx);
}
while (indexer.valid()) {
newps.push_back (params_[indexer]);
indexer.incrementExceptDimension (idx);
}
params_ = newps;
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
template <typename T> void
GenericFactor<T>::reorderArguments (const std::vector<T>& new_args)
{
assert (new_args.size() == args_.size());
if (new_args == args_) {
return; // already on the desired order
}
Ranges new_ranges;
for (size_t i = 0; i < new_args.size(); i++) {
size_t idx = indexOf (new_args[i]);
assert (idx != args_.size());
new_ranges.push_back (ranges_[idx]);
}
Params newps;
newps.reserve (params_.size());
MapIndexer indexer (new_args, new_ranges, args_, ranges_);
for (; indexer.valid(); ++indexer) {
newps.push_back (params_[indexer]);
}
params_ = newps;
args_ = new_args;
ranges_ = new_ranges;
}
template <typename T> void
GenericFactor<T>::extend (unsigned range_prod)
{
Params backup = params_;
params_.clear();
params_.reserve (backup.size() * range_prod);
Params::const_iterator first = backup.begin();
Params::const_iterator last = backup.end();
for (; first != last; ++first) {
for (unsigned reps = 0; reps < range_prod; ++reps) {
params_.push_back (*first);
}
}
}
template <typename T> void
GenericFactor<T>::cartesianProduct (
Params::const_iterator first2,
Params::const_iterator last2)
{
Params backup = params_;
params_.clear();
params_.reserve (params_.size() * (last2 - first2));
Params::const_iterator first1 = backup.begin();
Params::const_iterator last1 = backup.end();
Params::const_iterator tmp;
if (Globals::logDomain) {
for (; first1 != last1; ++first1) {
for (tmp = first2; tmp != last2; ++tmp) {
params_.push_back ((*first1) + (*tmp));
}
}
} else {
for (; first1 != last1; ++first1) {
for (tmp = first2; tmp != last2; ++tmp) {
params_.push_back ((*first1) * (*tmp));
}
}
}
}
template class GenericFactor<VarId>;
template class GenericFactor<ProbFormula>;
} // namespace Horus

View File

@ -0,0 +1,76 @@
#ifndef YAP_PACKAGES_CLPBN_HORUS_GENERICFACTOR_H_
#define YAP_PACKAGES_CLPBN_HORUS_GENERICFACTOR_H_
#include <vector>
#include "Util.h"
namespace Horus {
template <typename T>
class GenericFactor {
public:
const std::vector<T>& arguments() const { return args_; }
std::vector<T>& arguments() { return args_; }
const Ranges& ranges() const { return ranges_; }
const Params& params() const { return params_; }
Params& params() { return params_; }
size_t nrArguments() const { return args_.size(); }
size_t size() const { return params_.size(); }
unsigned distId() const { return distId_; }
void setDistId (unsigned id) { distId_ = id; }
void normalize() { LogAware::normalize (params_); }
size_t indexOf (const T& t) const { return Util::indexOf (args_, t); }
const T& argument (size_t idx) const;
T& argument (size_t idx);
unsigned range (size_t idx) const;
bool contains (const T& arg) const;
bool contains (const std::vector<T>& args) const;
void setParams (const Params& newParams);
double operator[] (size_t idx) const;
double& operator[] (size_t idx);
GenericFactor<T>& multiply (const GenericFactor<T>& g);
void sumOutIndex (size_t idx);
void absorveEvidence (const T& arg, unsigned obsIdx);
void reorderArguments (const std::vector<T>& new_args);
protected:
std::vector<T> args_;
Ranges ranges_;
Params params_;
unsigned distId_;
private:
void extend (unsigned range_prod);
void cartesianProduct (
Params::const_iterator first2, Params::const_iterator last2);
};
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_GENERICFACTOR_H_

View File

@ -1,10 +1,20 @@
#include <cassert>
#include <vector>
#include <string>
#include <iostream>
#include <iomanip>
#include "GroundSolver.h" #include "GroundSolver.h"
#include "VarElim.h" #include "VarElim.h"
#include "BeliefProp.h" #include "BeliefProp.h"
#include "CountingBp.h" #include "CountingBp.h"
#include "Indexer.h"
#include "Util.h" #include "Util.h"
namespace Horus {
void void
GroundSolver::printAnswer (const VarIds& vids) GroundSolver::printAnswer (const VarIds& vids)
{ {
@ -19,20 +29,21 @@ GroundSolver::printAnswer (const VarIds& vids)
} }
if (unobservedVids.empty() == false) { if (unobservedVids.empty() == false) {
Params res = solveQuery (unobservedVids); Params res = solveQuery (unobservedVids);
vector<string> stateLines = Util::getStateLines (unobservedVars); std::vector<std::string> stateLines =
Util::getStateLines (unobservedVars);
for (size_t i = 0; i < res.size(); i++) { for (size_t i = 0; i < res.size(); i++) {
cout << "P(" << stateLines[i] << ") = " ; std::cout << "P(" << stateLines[i] << ") = " ;
cout << std::setprecision (Constants::PRECISION) << res[i]; std::cout << std::setprecision (Constants::precision) << res[i];
cout << endl; std::cout << std::endl;
} }
cout << endl; std::cout << std::endl;
} }
} }
void void
GroundSolver::printAllPosterioris (void) GroundSolver::printAllPosterioris()
{ {
VarNodes vars = fg.varNodes(); VarNodes vars = fg.varNodes();
std::sort (vars.begin(), vars.end(), sortByVarId()); std::sort (vars.begin(), vars.end(), sortByVarId());
@ -57,9 +68,9 @@ GroundSolver::getJointByConditioning (
GroundSolver* solver = 0; GroundSolver* solver = 0;
switch (solverType) { switch (solverType) {
case GroundSolverType::BP: solver = new BeliefProp (fg); break; case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
case GroundSolverType::CBP: solver = new CountingBp (fg); break; case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
case GroundSolverType::VE: solver = new VarElim (fg); break; case GroundSolverType::veSolver: solver = new VarElim (fg); break;
} }
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]}); Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
VarIds observedVids = {jointVars[0]->varId()}; VarIds observedVids = {jointVars[0]->varId()};
@ -80,9 +91,9 @@ GroundSolver::getJointByConditioning (
} }
delete solver; delete solver;
switch (solverType) { switch (solverType) {
case GroundSolverType::BP: solver = new BeliefProp (fg); break; case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
case GroundSolverType::CBP: solver = new CountingBp (fg); break; case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
case GroundSolverType::VE: solver = new VarElim (fg); break; case GroundSolverType::veSolver: solver = new VarElim (fg); break;
} }
Params beliefs = solver->solveQuery ({jointVarIds[i]}); Params beliefs = solver->solveQuery ({jointVarIds[i]});
for (size_t k = 0; k < beliefs.size(); k++) { for (size_t k = 0; k < beliefs.size(); k++) {
@ -105,3 +116,5 @@ GroundSolver::getJointByConditioning (
return prevBeliefs; return prevBeliefs;
} }
} // namespace Horus

View File

@ -1,16 +1,13 @@
#ifndef HORUS_GROUNDSOLVER_H #ifndef YAP_PACKAGES_CLPBN_HORUS_GROUNDSOLVER_H_
#define HORUS_GROUNDSOLVER_H #define YAP_PACKAGES_CLPBN_HORUS_GROUNDSOLVER_H_
#include <iomanip>
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Horus.h" #include "Horus.h"
using namespace std; namespace Horus {
class GroundSolver class GroundSolver {
{
public: public:
GroundSolver (const FactorGraph& factorGraph) : fg(factorGraph) { } GroundSolver (const FactorGraph& factorGraph) : fg(factorGraph) { }
@ -18,11 +15,11 @@ class GroundSolver
virtual Params solveQuery (VarIds queryVids) = 0; virtual Params solveQuery (VarIds queryVids) = 0;
virtual void printSolverFlags (void) const = 0; virtual void printSolverFlags() const = 0;
void printAnswer (const VarIds& vids); void printAnswer (const VarIds& vids);
void printAllPosterioris (void); void printAllPosterioris();
static Params getJointByConditioning (GroundSolverType, static Params getJointByConditioning (GroundSolverType,
FactorGraph, const VarIds& jointVarIds); FactorGraph, const VarIds& jointVarIds);
@ -30,8 +27,11 @@ class GroundSolver
protected: protected:
const FactorGraph& fg; const FactorGraph& fg;
private:
DISALLOW_COPY_AND_ASSIGN (GroundSolver); DISALLOW_COPY_AND_ASSIGN (GroundSolver);
}; };
#endif // HORUS_GROUNDSOLVER_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_GROUNDSOLVER_H_

View File

@ -7,6 +7,8 @@
#include "Util.h" #include "Util.h"
namespace Horus {
HistogramSet::HistogramSet (unsigned size, unsigned range) HistogramSet::HistogramSet (unsigned size, unsigned range)
{ {
size_ = size; size_ = size;
@ -17,7 +19,7 @@ HistogramSet::HistogramSet (unsigned size, unsigned range)
void void
HistogramSet::nextHistogram (void) HistogramSet::nextHistogram()
{ {
for (size_t i = hist_.size() - 1; i-- > 0; ) { for (size_t i = hist_.size() - 1; i-- > 0; ) {
if (hist_[i] > 0) { if (hist_[i] > 0) {
@ -43,7 +45,7 @@ HistogramSet::operator[] (size_t idx) const
unsigned unsigned
HistogramSet::nrHistograms (void) const HistogramSet::nrHistograms() const
{ {
return HistogramSet::nrHistograms (size_, hist_.size()); return HistogramSet::nrHistograms (size_, hist_.size());
} }
@ -51,7 +53,7 @@ HistogramSet::nrHistograms (void) const
void void
HistogramSet::reset (void) HistogramSet::reset()
{ {
std::fill (hist_.begin() + 1, hist_.end(), 0); std::fill (hist_.begin() + 1, hist_.end(), 0);
hist_[0] = size_; hist_[0] = size_;
@ -59,12 +61,12 @@ HistogramSet::reset (void)
vector<Histogram> std::vector<Histogram>
HistogramSet::getHistograms (unsigned N, unsigned R) HistogramSet::getHistograms (unsigned N, unsigned R)
{ {
HistogramSet hs (N, R); HistogramSet hs (N, R);
unsigned H = hs.nrHistograms(); unsigned H = hs.nrHistograms();
vector<Histogram> histograms; std::vector<Histogram> histograms;
histograms.reserve (H); histograms.reserve (H);
for (unsigned i = 0; i < H; i++) { for (unsigned i = 0; i < H; i++) {
histograms.push_back (hs.hist_); histograms.push_back (hs.hist_);
@ -86,9 +88,9 @@ HistogramSet::nrHistograms (unsigned N, unsigned R)
size_t size_t
HistogramSet::findIndex ( HistogramSet::findIndex (
const Histogram& h, const Histogram& h,
const vector<Histogram>& hists) const std::vector<Histogram>& hists)
{ {
vector<Histogram>::const_iterator it = std::lower_bound ( std::vector<Histogram>::const_iterator it = std::lower_bound (
hists.begin(), hists.end(), h, std::greater<Histogram>()); hists.begin(), hists.end(), h, std::greater<Histogram>());
assert (it != hists.end() && *it == h); assert (it != hists.end() && *it == h);
return std::distance (hists.begin(), it); return std::distance (hists.begin(), it);
@ -96,13 +98,13 @@ HistogramSet::findIndex (
vector<double> std::vector<double>
HistogramSet::getNumAssigns (unsigned N, unsigned R) HistogramSet::getNumAssigns (unsigned N, unsigned R)
{ {
HistogramSet hs (N, R); HistogramSet hs (N, R);
double N_fac = Util::logFactorial (N); double N_fac = Util::logFactorial (N);
unsigned H = hs.nrHistograms(); unsigned H = hs.nrHistograms();
vector<double> numAssigns; std::vector<double> numAssigns;
numAssigns.reserve (H); numAssigns.reserve (H);
for (unsigned h = 0; h < H; h++) { for (unsigned h = 0; h < H; h++) {
double prod = 0.0; double prod = 0.0;
@ -118,14 +120,6 @@ HistogramSet::getNumAssigns (unsigned N, unsigned R)
ostream& operator<< (ostream &os, const HistogramSet& hs)
{
os << "#" << hs.hist_;
return os;
}
unsigned unsigned
HistogramSet::maxCount (size_t idx) const HistogramSet::maxCount (size_t idx) const
{ {
@ -144,3 +138,14 @@ HistogramSet::clearAfter (size_t idx)
std::fill (hist_.begin() + idx + 1, hist_.end(), 0); std::fill (hist_.begin() + idx + 1, hist_.end(), 0);
} }
std::ostream&
operator<< (std::ostream& os, const HistogramSet& hs)
{
os << "#" << hs.hist_;
return os;
}
} // namespace Horus

View File

@ -1,50 +1,51 @@
#ifndef HORUS_HISTOGRAM_H #ifndef YAP_PACKAGES_CLPBN_HORUS_HISTOGRAM_H_
#define HORUS_HISTOGRAM_H #define YAP_PACKAGES_CLPBN_HORUS_HISTOGRAM_H_
#include <vector> #include <vector>
#include <ostream> #include <ostream>
#include "Horus.h" #include "Horus.h"
using namespace std; typedef std::vector<unsigned> Histogram;
typedef vector<unsigned> Histogram;
class HistogramSet namespace Horus {
{
class HistogramSet {
public: public:
HistogramSet (unsigned, unsigned); HistogramSet (unsigned, unsigned);
void nextHistogram (void); void nextHistogram();
unsigned operator[] (size_t idx) const; unsigned operator[] (size_t idx) const;
unsigned nrHistograms (void) const; unsigned nrHistograms() const;
void reset (void); void reset();
static vector<Histogram> getHistograms (unsigned, unsigned); static std::vector<Histogram> getHistograms (unsigned, unsigned);
static unsigned nrHistograms (unsigned, unsigned); static unsigned nrHistograms (unsigned, unsigned);
static size_t findIndex ( static size_t findIndex (
const Histogram&, const vector<Histogram>&); const Histogram&, const std::vector<Histogram>&);
static vector<double> getNumAssigns (unsigned, unsigned); static std::vector<double> getNumAssigns (unsigned, unsigned);
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
private: private:
unsigned maxCount (size_t) const; unsigned maxCount (size_t) const;
void clearAfter (size_t); void clearAfter (size_t);
friend std::ostream& operator<< (std::ostream&, const HistogramSet&);
unsigned size_; unsigned size_;
Histogram hist_; Histogram hist_;
DISALLOW_COPY_AND_ASSIGN (HistogramSet); DISALLOW_COPY_AND_ASSIGN (HistogramSet);
}; };
#endif // HORUS_HISTOGRAM_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_HISTOGRAM_H_

View File

@ -1,5 +1,5 @@
#ifndef HORUS_HORUS_H #ifndef YAP_PACKAGES_CLPBN_HORUS_HORUS_H_
#define HORUS_HORUS_H #define YAP_PACKAGES_CLPBN_HORUS_HORUS_H_
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \ #define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&); \ TypeName(const TypeName&); \
@ -14,6 +14,9 @@
#include <vector> #include <vector>
#include <string> #include <string>
namespace Horus {
class Var; class Var;
class Factor; class Factor;
class VarNode; class VarNode;
@ -31,19 +34,17 @@ typedef std::vector<unsigned> Ranges;
typedef unsigned long long ullong; typedef unsigned long long ullong;
enum LiftedSolverType enum class LiftedSolverType {
{ lveSolver, // generalized counting first-order variable elimination
LVE, // generalized counting first-order variable elimination (GC-FOVE) lbpSolver, // lifted first-order belief propagation
LBP, // lifted first-order belief propagation lkcSolver // lifted first-order knowledge compilation
LKC // lifted first-order knowledge compilation
}; };
enum GroundSolverType enum class GroundSolverType {
{ veSolver, // variable elimination
VE, // variable elimination bpSolver, // belief propagation
BP, // belief propagation CbpSolver // counting belief propagation
CBP // counting belief propagation
}; };
@ -57,20 +58,22 @@ extern unsigned verbosity;
extern LiftedSolverType liftedSolver; extern LiftedSolverType liftedSolver;
extern GroundSolverType groundSolver; extern GroundSolverType groundSolver;
}; }
namespace Constants { namespace Constants {
// show message calculation for belief propagation // show message calculation for belief propagation
const bool SHOW_BP_CALCS = false; const bool showBpCalcs = false;
const int NO_EVIDENCE = -1; const int unobserved = -1;
// number of digits to show when printing a parameter // number of digits to show when printing a parameter
const unsigned PRECISION = 6; const unsigned precision = 8;
}; }
#endif // HORUS_HORUS_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_HORUS_H_

View File

@ -1,53 +1,61 @@
#include <cstdlib> #include <cassert>
#include <string>
#include <iostream> #include <iostream>
#include <sstream>
#include "FactorGraph.h" #include "FactorGraph.h"
#include "VarElim.h" #include "VarElim.h"
#include "BeliefProp.h" #include "BeliefProp.h"
#include "CountingBp.h" #include "CountingBp.h"
using namespace std;
namespace {
int readHorusFlags (int, const char* []); int readHorusFlags (int, const char* []);
void readFactorGraph (FactorGraph&, const char*);
VarIds readQueryAndEvidence (FactorGraph&, int, const char* [], int);
void runSolver (const FactorGraph&, const VarIds&); void readFactorGraph (Horus::FactorGraph&, const char*);
const string USAGE = "usage: ./hcli [solver=hve|bp|cbp] \ Horus::VarIds readQueryAndEvidence (
Horus::FactorGraph&, int, const char* [], int);
void runSolver (const Horus::FactorGraph&, const Horus::VarIds&);
const std::string usage = "usage: ./hcli [solver=hve|bp|cbp] \
[<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ; [<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ;
}
int int
main (int argc, const char* argv[]) main (int argc, const char* argv[])
{ {
if (argc <= 1) { if (argc <= 1) {
cerr << "Error: no probabilistic graphical model was given." << endl; std::cerr << "Error: no probabilistic graphical model was given." ;
cerr << USAGE << endl; std::cerr << std::endl << usage << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
int idx = readHorusFlags (argc, argv); int idx = readHorusFlags (argc, argv);
FactorGraph fg; Horus::FactorGraph fg;
readFactorGraph (fg, argv[idx]); readFactorGraph (fg, argv[idx]);
VarIds queryIds = readQueryAndEvidence (fg, argc, argv, idx + 1); Horus::VarIds queryIds
if (FactorGraph::exportToLibDai()) { = readQueryAndEvidence (fg, argc, argv, idx + 1);
if (Horus::FactorGraph::exportToLibDai()) {
fg.exportToLibDai ("model.fg"); fg.exportToLibDai ("model.fg");
} }
if (FactorGraph::exportToUai()) { if (Horus::FactorGraph::exportToUai()) {
fg.exportToUai ("model.uai"); fg.exportToUai ("model.uai");
} }
if (FactorGraph::exportGraphViz()) { if (Horus::FactorGraph::exportGraphViz()) {
fg.exportToGraphViz ("model.dot"); fg.exportToGraphViz ("model.dot");
} }
if (FactorGraph::printFactorGraph()) { if (Horus::FactorGraph::printFactorGraph()) {
fg.print(); fg.print();
} }
if (Globals::verbosity > 0) { if (Horus::Globals::verbosity > 0) {
cout << "factor graph contains " ; std::cout << "factor graph contains " ;
cout << fg.nrVarNodes() << " variables and " ; std::cout << fg.nrVarNodes() << " variables and " ;
cout << fg.nrFacNodes() << " factors " << endl; std::cout << fg.nrFacNodes() << " factors " << std::endl;
} }
runSolver (fg, queryIds); runSolver (fg, queryIds);
return 0; return 0;
@ -55,29 +63,31 @@ main (int argc, const char* argv[])
namespace {
int int
readHorusFlags (int argc, const char* argv[]) readHorusFlags (int argc, const char* argv[])
{ {
int i = 1; int i = 1;
for (; i < argc; i++) { for (; i < argc; i++) {
const string& arg = argv[i]; const std::string& arg = argv[i];
size_t pos = arg.find ('='); size_t pos = arg.find ('=');
if (pos == std::string::npos) { if (pos == std::string::npos) {
return i; return i;
} }
string leftArg = arg.substr (0, pos); std::string leftArg = arg.substr (0, pos);
string rightArg = arg.substr (pos + 1); std::string rightArg = arg.substr (pos + 1);
if (leftArg.empty()) { if (leftArg.empty()) {
cerr << "Error: missing left argument." << endl; std::cerr << "Error: missing left argument." << std::endl;
cerr << USAGE << endl; std::cerr << usage << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
if (rightArg.empty()) { if (rightArg.empty()) {
cerr << "Error: missing right argument." << endl; std::cerr << "Error: missing right argument." << std::endl;
cerr << USAGE << endl; std::cerr << usage << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
Util::setHorusFlag (leftArg, rightArg); Horus::Util::setHorusFlag (leftArg, rightArg);
} }
return i + 1; return i + 1;
} }
@ -85,84 +95,84 @@ readHorusFlags (int argc, const char* argv[])
void void
readFactorGraph (FactorGraph& fg, const char* s) readFactorGraph (Horus::FactorGraph& fg, const char* s)
{ {
string fileName (s); std::string fileName (s);
string extension = fileName.substr (fileName.find_last_of ('.') + 1); std::string extension = fileName.substr (fileName.find_last_of ('.') + 1);
if (extension == "uai") { if (extension == "uai") {
fg.readFromUaiFormat (fileName.c_str()); fg = Horus::FactorGraph::readFromUaiFormat (fileName.c_str());
} else if (extension == "fg") { } else if (extension == "fg") {
fg.readFromLibDaiFormat (fileName.c_str()); fg = Horus::FactorGraph::readFromLibDaiFormat (fileName.c_str());
} else { } else {
cerr << "Error: the probabilistic graphical model must be " ; std::cerr << "Error: the probabilistic graphical model must be " ;
cerr << "defined either in a UAI or libDAI file." << endl; std::cerr << "defined either in a UAI or libDAI file." << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
} }
VarIds Horus::VarIds
readQueryAndEvidence ( readQueryAndEvidence (
FactorGraph& fg, Horus::FactorGraph& fg,
int argc, int argc,
const char* argv[], const char* argv[],
int start) int start)
{ {
VarIds queryIds; Horus::VarIds queryIds;
for (int i = start; i < argc; i++) { for (int i = start; i < argc; i++) {
const string& arg = argv[i]; const std::string& arg = argv[i];
if (arg.find ('=') == std::string::npos) { if (arg.find ('=') == std::string::npos) {
if (Util::isInteger (arg) == false) { if (Horus::Util::isInteger (arg) == false) {
cerr << "Error: `" << arg << "' " ; std::cerr << "Error: `" << arg << "' " ;
cerr << "is not a variable id." ; std::cerr << "is not a variable id." ;
cerr << endl; std::cerr << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
VarId vid = Util::stringToUnsigned (arg); Horus::VarId vid = Horus::Util::stringToUnsigned (arg);
VarNode* queryVar = fg.getVarNode (vid); Horus::VarNode* queryVar = fg.getVarNode (vid);
if (queryVar == false) { if (queryVar == false) {
cerr << "Error: unknow variable with id " ; std::cerr << "Error: unknow variable with id " ;
cerr << "`" << vid << "'." << endl; std::cerr << "`" << vid << "'." << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
queryIds.push_back (vid); queryIds.push_back (vid);
} else { } else {
size_t pos = arg.find ('='); size_t pos = arg.find ('=');
string leftArg = arg.substr (0, pos); std::string leftArg = arg.substr (0, pos);
string rightArg = arg.substr (pos + 1); std::string rightArg = arg.substr (pos + 1);
if (leftArg.empty()) { if (leftArg.empty()) {
cerr << "Error: missing left argument." << endl; std::cerr << "Error: missing left argument." << std::endl;
cerr << USAGE << endl; std::cerr << usage << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
if (Util::isInteger (leftArg) == false) { if (Horus::Util::isInteger (leftArg) == false) {
cerr << "Error: `" << leftArg << "' " ; std::cerr << "Error: `" << leftArg << "' " ;
cerr << "is not a variable id." << endl ; std::cerr << "is not a variable id." << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
VarId vid = Util::stringToUnsigned (leftArg); Horus::VarId vid = Horus::Util::stringToUnsigned (leftArg);
VarNode* observedVar = fg.getVarNode (vid); Horus::VarNode* observedVar = fg.getVarNode (vid);
if (observedVar == false) { if (observedVar == false) {
cerr << "Error: unknow variable with id " ; std::cerr << "Error: unknow variable with id " ;
cerr << "`" << vid << "'." << endl; std::cerr << "`" << vid << "'." << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
if (rightArg.empty()) { if (rightArg.empty()) {
cerr << "Error: missing right argument." << endl; std::cerr << "Error: missing right argument." << std::endl;
cerr << USAGE << endl; std::cerr << usage << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
if (Util::isInteger (rightArg) == false) { if (Horus::Util::isInteger (rightArg) == false) {
cerr << "Error: `" << rightArg << "' " ; std::cerr << "Error: `" << rightArg << "' " ;
cerr << "is not a state index." << endl ; std::cerr << "is not a state index." << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
unsigned stateIdx = Util::stringToUnsigned (rightArg); unsigned stateIdx = Horus::Util::stringToUnsigned (rightArg);
if (observedVar->isValidState (stateIdx) == false) { if (observedVar->isValidState (stateIdx) == false) {
cerr << "Error: `" << stateIdx << "' " ; std::cerr << "Error: `" << stateIdx << "' " ;
cerr << "is not a valid state index for variable with id " ; std::cerr << "is not a valid state index for variable with id " ;
cerr << "`" << vid << "'." << endl; std::cerr << "`" << vid << "'." << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
observedVar->setEvidence (stateIdx); observedVar->setEvidence (stateIdx);
@ -174,25 +184,27 @@ readQueryAndEvidence (
void void
runSolver (const FactorGraph& fg, const VarIds& queryIds) runSolver (
const Horus::FactorGraph& fg,
const Horus::VarIds& queryIds)
{ {
GroundSolver* solver = 0; Horus::GroundSolver* solver = 0;
switch (Globals::groundSolver) { switch (Horus::Globals::groundSolver) {
case GroundSolverType::VE: case Horus::GroundSolverType::veSolver:
solver = new VarElim (fg); solver = new Horus::VarElim (fg);
break; break;
case GroundSolverType::BP: case Horus::GroundSolverType::bpSolver:
solver = new BeliefProp (fg); solver = new Horus::BeliefProp (fg);
break; break;
case GroundSolverType::CBP: case Horus::GroundSolverType::CbpSolver:
solver = new CountingBp (fg); solver = new Horus::CountingBp (fg);
break; break;
default: default:
assert (false); assert (false);
} }
if (Globals::verbosity > 0) { if (Horus::Globals::verbosity > 0) {
solver->printSolverFlags(); solver->printSolverFlags();
cout << endl; std::cout << std::endl;
} }
if (queryIds.empty()) { if (queryIds.empty()) {
solver->printAllPosterioris(); solver->printAllPosterioris();
@ -202,3 +214,5 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
delete solver; delete solver;
} }
}

View File

@ -1,7 +1,8 @@
#include <cstdlib> #include <cassert>
#include <vector> #include <vector>
#include <unordered_map>
#include <string>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
@ -20,24 +21,28 @@
#include "BayesBall.h" #include "BayesBall.h"
using namespace std; namespace Horus {
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork; namespace {
Parfactor* readParfactor (YAP_Term); Parfactor* readParfactor (YAP_Term);
void readLiftedEvidence (YAP_Term, ObservedFormulas&); ObservedFormulas* readLiftedEvidence (YAP_Term);
vector<unsigned> readUnsignedList (YAP_Term list); std::vector<unsigned> readUnsignedList (YAP_Term);
Params readParameters (YAP_Term); Params readParameters (YAP_Term);
YAP_Term fillAnswersPrologList (vector<Params>& results); YAP_Term fillSolutionList (const std::vector<Params>&);
}
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
int int
createLiftedNetwork (void) createLiftedNetwork()
{ {
Parfactors parfactors; Parfactors parfactors;
YAP_Term parfactorList = YAP_ARG1; YAP_Term parfactorList = YAP_ARG1;
@ -52,7 +57,7 @@ createLiftedNetwork (void)
Util::printHeader ("INITIAL PARFACTORS"); Util::printHeader ("INITIAL PARFACTORS");
for (size_t i = 0; i < parfactors.size(); i++) { for (size_t i = 0; i < parfactors.size(); i++) {
parfactors[i]->print(); parfactors[i]->print();
cout << endl; std::cout << std::endl;
} }
} }
@ -64,21 +69,20 @@ createLiftedNetwork (void)
} }
// read evidence // read evidence
ObservedFormulas* obsFormulas = new ObservedFormulas(); ObservedFormulas* obsFormulas = readLiftedEvidence (YAP_ARG2);
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas); LiftedNetwork* network = new LiftedNetwork (pfList, obsFormulas);
YAP_Int p = (YAP_Int) (net); YAP_Int p = (YAP_Int) (network);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3); return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
} }
int int
createGroundNetwork (void) createGroundNetwork()
{ {
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1))); std::string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
FactorGraph* fg = new FactorGraph(); FactorGraph* fg = new FactorGraph();
if (factorsType == "bayes") { if (factorsType == "bayes") {
fg->setFactorsAsBayesian(); fg->setFactorsAsBayesian();
@ -121,9 +125,9 @@ createGroundNetwork (void)
fg->print(); fg->print();
} }
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
cout << "factor graph contains " ; std::cout << "factor graph contains " ;
cout << fg->nrVarNodes() << " variables and " ; std::cout << fg->nrVarNodes() << " variables and " ;
cout << fg->nrFacNodes() << " factors " << endl; std::cout << fg->nrFacNodes() << " factors " << std::endl;
} }
YAP_Int p = (YAP_Int) (fg); YAP_Int p = (YAP_Int) (fg);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4); return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4);
@ -132,45 +136,46 @@ createGroundNetwork (void)
int int
runLiftedSolver (void) runLiftedSolver()
{ {
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1); LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
ParfactorList pfListCopy (*network->first); ParfactorList copy (*network->first);
LiftedOperations::absorveEvidence (pfListCopy, *network->second); LiftedOperations::absorveEvidence (copy, *network->second);
LiftedSolver* solver = 0; LiftedSolver* solver = 0;
switch (Globals::liftedSolver) { switch (Globals::liftedSolver) {
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break; case LiftedSolverType::lveSolver: solver = new LiftedVe (copy); break;
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break; case LiftedSolverType::lbpSolver: solver = new LiftedBp (copy); break;
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break; case LiftedSolverType::lkcSolver: solver = new LiftedKc (copy); break;
} }
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
solver->printSolverFlags(); solver->printSolverFlags();
cout << endl; std::cout << std::endl;
} }
YAP_Term taskList = YAP_ARG2; YAP_Term taskList = YAP_ARG2;
vector<Params> results; std::vector<Params> results;
while (taskList != YAP_TermNil()) { while (taskList != YAP_TermNil()) {
Grounds queryVars; Grounds queryVars;
YAP_Term jointList = YAP_HeadOfTerm (taskList); YAP_Term jointList = YAP_HeadOfTerm (taskList);
while (jointList != YAP_TermNil()) { while (jointList != YAP_TermNil()) {
YAP_Term ground = YAP_HeadOfTerm (jointList); YAP_Term ground = YAP_HeadOfTerm (jointList);
if (YAP_IsAtomTerm (ground)) { if (YAP_IsAtomTerm (ground)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground))); std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
queryVars.push_back (Ground (LiftedUtils::getSymbol (name))); queryVars.push_back (Ground (LiftedUtils::getSymbol (name)));
} else { } else {
assert (YAP_IsApplTerm (ground)); assert (YAP_IsApplTerm (ground));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground); YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor)))); std::string name ((char*) (YAP_AtomName (
YAP_NameOfFunctor (yapFunctor))));
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor); unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
Symbol functor = LiftedUtils::getSymbol (name); Symbol functor = LiftedUtils::getSymbol (name);
Symbols args; Symbols args;
for (unsigned i = 1; i <= arity; i++) { for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, ground); YAP_Term ti = YAP_ArgOfTerm (i, ground);
assert (YAP_IsAtomTerm (ti)); assert (YAP_IsAtomTerm (ti));
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti))); std::string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
args.push_back (LiftedUtils::getSymbol (arg)); args.push_back (LiftedUtils::getSymbol (arg));
} }
queryVars.push_back (Ground (functor, args)); queryVars.push_back (Ground (functor, args));
@ -183,17 +188,17 @@ runLiftedSolver (void)
delete solver; delete solver;
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3); return YAP_Unify (fillSolutionList (results), YAP_ARG3);
} }
int int
runGroundSolver (void) runGroundSolver()
{ {
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1); FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
vector<VarIds> tasks; std::vector<VarIds> tasks;
YAP_Term taskList = YAP_ARG2; YAP_Term taskList = YAP_ARG2;
while (taskList != YAP_TermNil()) { while (taskList != YAP_TermNil()) {
tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList))); tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList)));
@ -213,17 +218,17 @@ runGroundSolver (void)
GroundSolver* solver = 0; GroundSolver* solver = 0;
CountingBp::setFindIdenticalFactorsFlag (false); CountingBp::setFindIdenticalFactorsFlag (false);
switch (Globals::groundSolver) { switch (Globals::groundSolver) {
case GroundSolverType::VE: solver = new VarElim (*mfg); break; case GroundSolverType::veSolver: solver = new VarElim (*mfg); break;
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break; case GroundSolverType::bpSolver: solver = new BeliefProp (*mfg); break;
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break; case GroundSolverType::CbpSolver: solver = new CountingBp (*mfg); break;
} }
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
solver->printSolverFlags(); solver->printSolverFlags();
cout << endl; std::cout << std::endl;
} }
vector<Params> results; std::vector<Params> results;
results.reserve (tasks.size()); results.reserve (tasks.size());
for (size_t i = 0; i < tasks.size(); i++) { for (size_t i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i])); results.push_back (solver->solveQuery (tasks[i]));
@ -234,19 +239,19 @@ runGroundSolver (void)
delete mfg; delete mfg;
} }
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3); return YAP_Unify (fillSolutionList (results), YAP_ARG3);
} }
int int
setParfactorsParams (void) setParfactorsParams()
{ {
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1); LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
ParfactorList* pfList = network->first; ParfactorList* pfList = network->first;
YAP_Term distIdsList = YAP_ARG2; YAP_Term distIdsList = YAP_ARG2;
YAP_Term paramsList = YAP_ARG3; YAP_Term paramsList = YAP_ARG3;
unordered_map<unsigned, Params> paramsMap; std::unordered_map<unsigned, Params> paramsMap;
while (distIdsList != YAP_TermNil()) { while (distIdsList != YAP_TermNil()) {
unsigned distId = (unsigned) YAP_IntOfTerm ( unsigned distId = (unsigned) YAP_IntOfTerm (
YAP_HeadOfTerm (distIdsList)); YAP_HeadOfTerm (distIdsList));
@ -267,12 +272,12 @@ setParfactorsParams (void)
int int
setFactorsParams (void) setFactorsParams()
{ {
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1); FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distIdsList = YAP_ARG2; YAP_Term distIdsList = YAP_ARG2;
YAP_Term paramsList = YAP_ARG3; YAP_Term paramsList = YAP_ARG3;
unordered_map<unsigned, Params> paramsMap; std::unordered_map<unsigned, Params> paramsMap;
while (distIdsList != YAP_TermNil()) { while (distIdsList != YAP_TermNil()) {
unsigned distId = (unsigned) YAP_IntOfTerm ( unsigned distId = (unsigned) YAP_IntOfTerm (
YAP_HeadOfTerm (distIdsList)); YAP_HeadOfTerm (distIdsList));
@ -293,10 +298,10 @@ setFactorsParams (void)
int int
setVarsInformation (void) setVarsInformation()
{ {
Var::clearVarsInfo(); Var::clearVarsInfo();
vector<string> labels; std::vector<std::string> labels;
YAP_Term labelsL = YAP_ARG1; YAP_Term labelsL = YAP_ARG1;
while (labelsL != YAP_TermNil()) { while (labelsL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL)); YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
@ -323,20 +328,20 @@ setVarsInformation (void)
int int
setHorusFlag (void) setHorusFlag()
{ {
string option ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1))); std::string option ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
string value; std::string value;
if (option == "verbosity") { if (option == "verbosity") {
stringstream ss; std::stringstream ss;
ss << (int) YAP_IntOfTerm (YAP_ARG2); ss << (int) YAP_IntOfTerm (YAP_ARG2);
ss >> value; ss >> value;
} else if (option == "bp_accuracy") { } else if (option == "bp_accuracy") {
stringstream ss; std::stringstream ss;
ss << (float) YAP_FloatOfTerm (YAP_ARG2); ss << (float) YAP_FloatOfTerm (YAP_ARG2);
ss >> value; ss >> value;
} else if (option == "bp_max_iter") { } else if (option == "bp_max_iter") {
stringstream ss; std::stringstream ss;
ss << (int) YAP_IntOfTerm (YAP_ARG2); ss << (int) YAP_IntOfTerm (YAP_ARG2);
ss >> value; ss >> value;
} else { } else {
@ -348,7 +353,7 @@ setHorusFlag (void)
int int
freeGroundNetwork (void) freeGroundNetwork()
{ {
delete (FactorGraph*) YAP_IntOfTerm (YAP_ARG1); delete (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
return TRUE; return TRUE;
@ -357,7 +362,7 @@ freeGroundNetwork (void)
int int
freeLiftedNetwork (void) freeLiftedNetwork()
{ {
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1); LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
delete network->first; delete network->first;
@ -368,6 +373,8 @@ freeLiftedNetwork (void)
namespace {
Parfactor* Parfactor*
readParfactor (YAP_Term pfTerm) readParfactor (YAP_Term pfTerm)
{ {
@ -386,23 +393,24 @@ readParfactor (YAP_Term pfTerm)
// read parametric random vars // read parametric random vars
ProbFormulas formulas; ProbFormulas formulas;
unsigned count = 0; unsigned count = 0;
unordered_map<YAP_Term, LogVar> lvMap; std::unordered_map<YAP_Term, LogVar> lvMap;
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm); YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
while (pvList != YAP_TermNil()) { while (pvList != YAP_TermNil()) {
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList); YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
if (YAP_IsAtomTerm (formulaTerm)) { if (YAP_IsAtomTerm (formulaTerm)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm))); std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
Symbol functor = LiftedUtils::getSymbol (name); Symbol functor = LiftedUtils::getSymbol (name);
formulas.push_back (ProbFormula (functor, ranges[count])); formulas.push_back (ProbFormula (functor, ranges[count]));
} else { } else {
LogVars logVars; LogVars logVars;
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm); YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor))); std::string name ((char*) YAP_AtomName (
YAP_NameOfFunctor (yapFunctor)));
Symbol functor = LiftedUtils::getSymbol (name); Symbol functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor); unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) { for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm); YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti); std::unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
if (it != lvMap.end()) { if (it != lvMap.end()) {
logVars.push_back (it->second); logVars.push_back (it->second);
} else { } else {
@ -418,7 +426,7 @@ readParfactor (YAP_Term pfTerm)
} }
// read the parameters // read the parameters
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm)); Params params = readParameters (YAP_ArgOfTerm (4, pfTerm));
// read the constraint // read the constraint
Tuples tuples; Tuples tuples;
@ -434,10 +442,11 @@ readParfactor (YAP_Term pfTerm)
for (unsigned i = 1; i <= arity; i++) { for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, term); YAP_Term ti = YAP_ArgOfTerm (i, term);
if (YAP_IsAtomTerm (ti) == false) { if (YAP_IsAtomTerm (ti) == false) {
cerr << "Error: the constraint contains free variables." << endl; std::cerr << "Error: the constraint contains free variables." ;
std::cerr << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti))); std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
tuple[i - 1] = LiftedUtils::getSymbol (name); tuple[i - 1] = LiftedUtils::getSymbol (name);
} }
tuples.push_back (tuple); tuples.push_back (tuple);
@ -449,55 +458,56 @@ readParfactor (YAP_Term pfTerm)
void ObservedFormulas*
readLiftedEvidence ( readLiftedEvidence (YAP_Term observedList)
YAP_Term observedList,
ObservedFormulas& obsFormulas)
{ {
ObservedFormulas* obsFormulas = new ObservedFormulas();
while (observedList != YAP_TermNil()) { while (observedList != YAP_TermNil()) {
YAP_Term pair = YAP_HeadOfTerm (observedList); YAP_Term pair = YAP_HeadOfTerm (observedList);
YAP_Term ground = YAP_ArgOfTerm (1, pair); YAP_Term ground = YAP_ArgOfTerm (1, pair);
Symbol functor; Symbol functor;
Symbols args; Symbols args;
if (YAP_IsAtomTerm (ground)) { if (YAP_IsAtomTerm (ground)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground))); std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
functor = LiftedUtils::getSymbol (name); functor = LiftedUtils::getSymbol (name);
} else { } else {
assert (YAP_IsApplTerm (ground)); assert (YAP_IsApplTerm (ground));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground); YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor)))); std::string name ((char*) (YAP_AtomName (
YAP_NameOfFunctor (yapFunctor))));
functor = LiftedUtils::getSymbol (name); functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor); unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) { for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, ground); YAP_Term ti = YAP_ArgOfTerm (i, ground);
assert (YAP_IsAtomTerm (ti)); assert (YAP_IsAtomTerm (ti));
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti))); std::string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
args.push_back (LiftedUtils::getSymbol (arg)); args.push_back (LiftedUtils::getSymbol (arg));
} }
} }
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair)); unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
bool found = false; bool found = false;
for (size_t i = 0; i < obsFormulas.size(); i++) { for (size_t i = 0; i < obsFormulas->size(); i++) {
if (obsFormulas[i].functor() == functor && if ((*obsFormulas)[i].functor() == functor &&
obsFormulas[i].arity() == args.size() && (*obsFormulas)[i].arity() == args.size() &&
obsFormulas[i].evidence() == evidence) { (*obsFormulas)[i].evidence() == evidence) {
obsFormulas[i].addTuple (args); (*obsFormulas)[i].addTuple (args);
found = true; found = true;
} }
} }
if (found == false) { if (found == false) {
obsFormulas.push_back (ObservedFormula (functor, evidence, args)); obsFormulas->push_back (ObservedFormula (functor, evidence, args));
} }
observedList = YAP_TailOfTerm (observedList); observedList = YAP_TailOfTerm (observedList);
} }
return obsFormulas;
} }
vector<unsigned> std::vector<unsigned>
readUnsignedList (YAP_Term list) readUnsignedList (YAP_Term list)
{ {
vector<unsigned> vec; std::vector<unsigned> vec;
while (list != YAP_TermNil()) { while (list != YAP_TermNil()) {
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list))); vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
list = YAP_TailOfTerm (list); list = YAP_TailOfTerm (list);
@ -514,10 +524,11 @@ readParameters (YAP_Term paramL)
assert (YAP_IsPairTerm (paramL)); assert (YAP_IsPairTerm (paramL));
while (paramL != YAP_TermNil()) { while (paramL != YAP_TermNil()) {
YAP_Term hd = YAP_HeadOfTerm (paramL); YAP_Term hd = YAP_HeadOfTerm (paramL);
if (YAP_IsFloatTerm(hd)) if (YAP_IsFloatTerm (hd)) {
params.push_back ((double) YAP_FloatOfTerm (hd)); params.push_back ((double) YAP_FloatOfTerm (hd));
else } else {
params.push_back ((double) YAP_IntOfTerm (hd)); params.push_back ((double) YAP_IntOfTerm (hd));
}
paramL = YAP_TailOfTerm (paramL); paramL = YAP_TailOfTerm (paramL);
} }
if (Globals::logDomain) { if (Globals::logDomain) {
@ -529,17 +540,17 @@ readParameters (YAP_Term paramL)
YAP_Term YAP_Term
fillAnswersPrologList (vector<Params>& results) fillSolutionList (const std::vector<Params>& results)
{ {
YAP_Term list = YAP_TermNil(); YAP_Term list = YAP_TermNil();
for (size_t i = results.size(); i-- > 0; ) { for (size_t i = results.size(); i-- > 0; ) {
const Params& beliefs = results[i]; const Params& beliefs = results[i];
YAP_Term queryBeliefsL = YAP_TermNil(); YAP_Term queryBeliefsL = YAP_TermNil();
for (size_t j = beliefs.size(); j-- > 0; ) { for (size_t j = beliefs.size(); j-- > 0; ) {
YAP_Int sl1 = YAP_InitSlot (list); YAP_Int sl = YAP_InitSlot (list);
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]); YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL); queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
list = YAP_GetFromSlot (sl1); list = YAP_GetFromSlot (sl);
YAP_RecoverSlots (1); YAP_RecoverSlots (1);
} }
list = YAP_MkPairTerm (queryBeliefsL, list); list = YAP_MkPairTerm (queryBeliefsL, list);
@ -547,10 +558,12 @@ fillAnswersPrologList (vector<Params>& results)
return list; return list;
} }
}
extern "C" void extern "C" void
init_predicates (void) init_predicates()
{ {
YAP_UserCPredicate ("cpp_create_lifted_network", YAP_UserCPredicate ("cpp_create_lifted_network",
createLiftedNetwork, 3); createLiftedNetwork, 3);
@ -583,3 +596,5 @@ init_predicates (void)
freeGroundNetwork, 1); freeGroundNetwork, 1);
} }
} // namespace Horus

View File

@ -0,0 +1,32 @@
#include <sstream>
#include <iomanip>
#include "Indexer.h"
namespace Horus {
std::ostream&
operator<< (std::ostream& os, const Indexer& indexer)
{
os << "(" ;
os << std::setw (2) << std::setfill('0') << indexer.index_;
os << ") " ;
os << indexer.indices_;
return os;
}
std::ostream&
operator<< (std::ostream &os, const MapIndexer& indexer)
{
os << "(" ;
os << std::setw (2) << std::setfill('0') << indexer.index_;
os << ") " ;
os << indexer.indices_;
return os;
}
} // namespace Horus

View File

@ -1,19 +1,57 @@
#ifndef HORUS_INDEXER_H #ifndef YAP_PACKAGES_CLPBN_HORUS_INDEXER_H_
#define HORUS_INDEXER_H #define YAP_PACKAGES_CLPBN_HORUS_INDEXER_H_
#include <vector>
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include <sstream>
#include <iomanip>
#include "Util.h" #include "Util.h"
class Indexer namespace Horus {
{
class Indexer {
public: public:
Indexer (const Ranges& ranges, bool calcOffsets = true) Indexer (const Ranges& ranges, bool calcOffsets = true);
void increment();
void incrementDimension (size_t dim);
void incrementExceptDimension (size_t dim);
Indexer& operator++();
operator size_t() const;
unsigned operator[] (size_t dim) const;
bool valid() const;
void reset();
void resetDimension (size_t dim);
size_t size() const;
private:
void calculateOffsets();
friend std::ostream& operator<< (std::ostream&, const Indexer&);
size_t index_;
Ranges indices_;
const Ranges& ranges_;
size_t size_;
std::vector<size_t> offsets_;
DISALLOW_COPY_AND_ASSIGN (Indexer);
};
inline
Indexer::Indexer (const Ranges& ranges, bool calcOffsets)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges), : index_(0), indices_(ranges.size(), 0), ranges_(ranges),
size_(Util::sizeExpected (ranges)) size_(Util::sizeExpected (ranges))
{ {
@ -22,7 +60,10 @@ class Indexer
} }
} }
void increment (void)
inline void
Indexer::increment()
{ {
for (size_t i = ranges_.size(); i-- > 0; ) { for (size_t i = ranges_.size(); i-- > 0; ) {
indices_[i] ++; indices_[i] ++;
@ -35,7 +76,10 @@ class Indexer
index_ ++; index_ ++;
} }
void incrementDimension (size_t dim)
inline void
Indexer::incrementDimension (size_t dim)
{ {
assert (dim < ranges_.size()); assert (dim < ranges_.size());
assert (ranges_.size() == offsets_.size()); assert (ranges_.size() == offsets_.size());
@ -44,7 +88,10 @@ class Indexer
index_ += offsets_[dim]; index_ += offsets_[dim];
} }
void incrementExceptDimension (size_t dim)
inline void
Indexer::incrementExceptDimension (size_t dim)
{ {
assert (ranges_.size() == offsets_.size()); assert (ranges_.size() == offsets_.size());
for (size_t i = ranges_.size(); i-- > 0; ) { for (size_t i = ranges_.size(); i-- > 0; ) {
@ -62,50 +109,71 @@ class Indexer
index_ = size_; index_ = size_;
} }
Indexer& operator++ (void)
inline Indexer&
Indexer::operator++()
{ {
increment(); increment();
return *this; return *this;
} }
operator size_t (void) const
inline
Indexer::operator size_t() const
{ {
return index_; return index_;
} }
unsigned operator[] (size_t dim) const
inline unsigned
Indexer::operator[] (size_t dim) const
{ {
assert (valid()); assert (valid());
assert (dim < ranges_.size()); assert (dim < ranges_.size());
return indices_[dim]; return indices_[dim];
} }
bool valid (void) const
inline bool
Indexer::valid() const
{ {
return index_ < size_; return index_ < size_;
} }
void reset (void)
inline void
Indexer::reset()
{ {
std::fill (indices_.begin(), indices_.end(), 0);
index_ = 0; index_ = 0;
std::fill (indices_.begin(), indices_.end(), 0);
} }
void resetDimension (size_t dim)
inline void
Indexer::resetDimension (size_t dim)
{ {
indices_[dim] = 0; indices_[dim] = 0;
index_ -= offsets_[dim] * ranges_[dim]; index_ -= offsets_[dim] * ranges_[dim];
} }
size_t size (void) const
inline size_t
Indexer::size() const
{ {
return size_ ; return size_ ;
} }
friend std::ostream& operator<< (std::ostream&, const Indexer&);
private:
void calculateOffsets (void) inline void
Indexer::calculateOffsets()
{ {
size_t prod = 1; size_t prod = 1;
offsets_.resize (ranges_.size()); offsets_.resize (ranges_.size());
@ -115,33 +183,49 @@ class Indexer
} }
} }
class MapIndexer {
public:
MapIndexer (const Ranges& ranges, const std::vector<bool>& mask);
MapIndexer (const Ranges& ranges, size_t dim);
template <typename T>
MapIndexer (
const std::vector<T>& allArgs,
const Ranges& allRanges,
const std::vector<T>& wantedArgs,
const Ranges& wantedRanges);
MapIndexer& operator++();
operator size_t() const;
unsigned operator[] (size_t dim) const;
bool valid() const;
void reset();
private:
friend std::ostream& operator<< (std::ostream&, const MapIndexer&);
size_t index_; size_t index_;
Ranges indices_; Ranges indices_;
const Ranges& ranges_; const Ranges& ranges_;
size_t size_; bool valid_;
vector<size_t> offsets_; std::vector<size_t> offsets_;
DISALLOW_COPY_AND_ASSIGN (Indexer); DISALLOW_COPY_AND_ASSIGN (MapIndexer);
}; };
inline std::ostream& inline
operator<< (std::ostream& os, const Indexer& indexer) MapIndexer::MapIndexer (
{ const Ranges& ranges,
os << "(" ; const std::vector<bool>& mask)
os << std::setw (2) << std::setfill('0') << indexer.index_;
os << ") " ;
os << indexer.indices_;
return os;
}
class MapIndexer
{
public:
MapIndexer (const Ranges& ranges, const vector<bool>& mask)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges), : index_(0), indices_(ranges.size(), 0), ranges_(ranges),
valid_(true) valid_(true)
{ {
@ -156,7 +240,10 @@ class MapIndexer
assert (ranges.size() == mask.size()); assert (ranges.size() == mask.size());
} }
MapIndexer (const Ranges& ranges, size_t dim)
inline
MapIndexer::MapIndexer (const Ranges& ranges, size_t dim)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges), : index_(0), indices_(ranges.size(), 0), ranges_(ranges),
valid_(true) valid_(true)
{ {
@ -170,17 +257,19 @@ class MapIndexer
} }
} }
template <typename T>
MapIndexer (
const vector<T>& allArgs, template <typename T> inline
MapIndexer::MapIndexer (
const std::vector<T>& allArgs,
const Ranges& allRanges, const Ranges& allRanges,
const vector<T>& wantedArgs, const std::vector<T>& wantedArgs,
const Ranges& wantedRanges) const Ranges& wantedRanges)
: index_(0), indices_(allArgs.size(), 0), ranges_(allRanges), : index_(0), indices_(allArgs.size(), 0), ranges_(allRanges),
valid_(true) valid_(true)
{ {
size_t prod = 1; size_t prod = 1;
vector<size_t> offsets (wantedRanges.size()); std::vector<size_t> offsets (wantedRanges.size());
for (size_t i = wantedRanges.size(); i-- > 0; ) { for (size_t i = wantedRanges.size(); i-- > 0; ) {
offsets[i] = prod; offsets[i] = prod;
prod *= wantedRanges[i]; prod *= wantedRanges[i];
@ -192,7 +281,10 @@ class MapIndexer
} }
} }
MapIndexer& operator++ (void)
inline MapIndexer&
MapIndexer::operator++()
{ {
assert (valid_); assert (valid_);
for (size_t i = ranges_.size(); i-- > 0; ) { for (size_t i = ranges_.size(); i-- > 0; ) {
@ -209,54 +301,43 @@ class MapIndexer
return *this; return *this;
} }
operator size_t (void) const
inline
MapIndexer::operator size_t() const
{ {
assert (valid()); assert (valid());
return index_; return index_;
} }
unsigned operator[] (size_t dim) const
inline unsigned
MapIndexer::operator[] (size_t dim) const
{ {
assert (valid()); assert (valid());
assert (dim < ranges_.size()); assert (dim < ranges_.size());
return indices_[dim]; return indices_[dim];
} }
bool valid (void) const
inline bool
MapIndexer::valid() const
{ {
return valid_; return valid_;
} }
void reset (void)
inline void
MapIndexer::reset()
{ {
std::fill (indices_.begin(), indices_.end(), 0);
index_ = 0; index_ = 0;
std::fill (indices_.begin(), indices_.end(), 0);
} }
friend std::ostream& operator<< (std::ostream&, const MapIndexer&); } // namespace Horus
private: #endif // YAP_PACKAGES_CLPBN_HORUS_INDEXER_H_
size_t index_;
Ranges indices_;
const Ranges& ranges_;
bool valid_;
vector<size_t> offsets_;
DISALLOW_COPY_AND_ASSIGN (MapIndexer);
};
inline std::ostream&
operator<< (std::ostream &os, const MapIndexer& indexer)
{
os << "(" ;
os << std::setw (2) << std::setfill('0') << indexer.index_;
os << ") " ;
os << indexer.indices_;
return os;
}
#endif // HORUS_INDEXER_H

View File

@ -1,9 +1,15 @@
#include <cassert>
#include <sstream>
#include "LiftedBp.h" #include "LiftedBp.h"
#include "LiftedOperations.h" #include "LiftedOperations.h"
#include "WeightedBp.h" #include "WeightedBp.h"
#include "FactorGraph.h" #include "FactorGraph.h"
namespace Horus {
LiftedBp::LiftedBp (const ParfactorList& parfactorList) LiftedBp::LiftedBp (const ParfactorList& parfactorList)
: LiftedSolver (parfactorList) : LiftedSolver (parfactorList)
{ {
@ -14,7 +20,7 @@ LiftedBp::LiftedBp (const ParfactorList& parfactorList)
LiftedBp::~LiftedBp (void) LiftedBp::~LiftedBp()
{ {
delete solver_; delete solver_;
delete fg_; delete fg_;
@ -27,7 +33,7 @@ LiftedBp::solveQuery (const Grounds& query)
{ {
assert (query.empty() == false); assert (query.empty() == false);
Params res; Params res;
vector<PrvGroup> groups = getQueryGroups (query); std::vector<PrvGroup> groups = getQueryGroups (query);
if (query.size() == 1) { if (query.size() == 1) {
res = solver_->getPosterioriOf (groups[0]); res = solver_->getPosterioriOf (groups[0]);
} else { } else {
@ -58,28 +64,29 @@ LiftedBp::solveQuery (const Grounds& query)
void void
LiftedBp::printSolverFlags (void) const LiftedBp::printSolverFlags() const
{ {
stringstream ss; std::stringstream ss;
ss << "lifted bp [" ; ss << "lifted bp [" ;
ss << "bp_msg_schedule=" ; ss << "bp_msg_schedule=" ;
typedef WeightedBp::MsgSchedule MsgSchedule;
switch (WeightedBp::msgSchedule()) { switch (WeightedBp::msgSchedule()) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break; case MsgSchedule::parallelSch: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
} }
ss << ",bp_max_iter=" << WeightedBp::maxIterations(); ss << ",bp_max_iter=" << WeightedBp::maxIterations();
ss << ",bp_accuracy=" << WeightedBp::accuracy(); ss << ",bp_accuracy=" << WeightedBp::accuracy();
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; std::cout << ss.str() << std::endl;
} }
void void
LiftedBp::refineParfactors (void) LiftedBp::refineParfactors()
{ {
pfList_ = parfactorList; pfList_ = parfactorList;
while (iterate() == false); while (iterate() == false);
@ -93,7 +100,7 @@ LiftedBp::refineParfactors (void)
bool bool
LiftedBp::iterate (void) LiftedBp::iterate()
{ {
ParfactorList::iterator it = pfList_.begin(); ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) { while (it != pfList_.end()) {
@ -114,10 +121,10 @@ LiftedBp::iterate (void)
vector<PrvGroup> std::vector<PrvGroup>
LiftedBp::getQueryGroups (const Grounds& query) LiftedBp::getQueryGroups (const Grounds& query)
{ {
vector<PrvGroup> queryGroups; std::vector<PrvGroup> queryGroups;
for (unsigned i = 0; i < query.size(); i++) { for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin(); ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) { for (; it != pfList_.end(); ++it) {
@ -134,12 +141,12 @@ LiftedBp::getQueryGroups (const Grounds& query)
void void
LiftedBp::createFactorGraph (void) LiftedBp::createFactorGraph()
{ {
fg_ = new FactorGraph(); fg_ = new FactorGraph();
ParfactorList::const_iterator it = pfList_.begin(); ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) { for (; it != pfList_.end(); ++it) {
vector<PrvGroup> groups = (*it)->getAllGroups(); std::vector<PrvGroup> groups = (*it)->getAllGroups();
VarIds varIds; VarIds varIds;
for (size_t i = 0; i < groups.size(); i++) { for (size_t i = 0; i < groups.size(); i++) {
varIds.push_back (groups[i]); varIds.push_back (groups[i]);
@ -150,10 +157,10 @@ LiftedBp::createFactorGraph (void)
vector<vector<unsigned>> std::vector<std::vector<unsigned>>
LiftedBp::getWeights (void) const LiftedBp::getWeights() const
{ {
vector<vector<unsigned>> weights; std::vector<std::vector<unsigned>> weights;
weights.reserve (pfList_.size()); weights.reserve (pfList_.size());
ParfactorList::const_iterator it = pfList_.begin(); ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) { for (; it != pfList_.end(); ++it) {
@ -196,7 +203,7 @@ LiftedBp::getJointByConditioning (
Grounds obsGrounds = {query[0]}; Grounds obsGrounds = {query[0]};
for (size_t i = 1; i < query.size(); i++) { for (size_t i = 1; i < query.size(); i++) {
Params newBeliefs; Params newBeliefs;
vector<ObservedFormula> obsFs; std::vector<ObservedFormula> obsFs;
Ranges obsRanges; Ranges obsRanges;
for (size_t j = 0; j < obsGrounds.size(); j++) { for (size_t j = 0; j < obsGrounds.size(); j++) {
obsFs.push_back (ObservedFormula ( obsFs.push_back (ObservedFormula (
@ -231,3 +238,5 @@ LiftedBp::getJointByConditioning (
return prevBeliefs; return prevBeliefs;
} }
} // namespace Horus

View File

@ -1,33 +1,38 @@
#ifndef HORUS_LIFTEDBP_H #ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDBP_H_
#define HORUS_LIFTEDBP_H #define YAP_PACKAGES_CLPBN_HORUS_LIFTEDBP_H_
#include <vector>
#include "LiftedSolver.h" #include "LiftedSolver.h"
#include "ParfactorList.h" #include "ParfactorList.h"
#include "Indexer.h"
namespace Horus {
class FactorGraph; class FactorGraph;
class WeightedBp; class WeightedBp;
class LiftedBp : public LiftedSolver class LiftedBp : public LiftedSolver{
{
public: public:
LiftedBp (const ParfactorList& pfList); LiftedBp (const ParfactorList& pfList);
~LiftedBp (void); ~LiftedBp();
Params solveQuery (const Grounds&); Params solveQuery (const Grounds&);
void printSolverFlags (void) const; void printSolverFlags() const;
private: private:
void refineParfactors (void); void refineParfactors();
bool iterate (void); bool iterate();
vector<PrvGroup> getQueryGroups (const Grounds&); std::vector<PrvGroup> getQueryGroups (const Grounds&);
void createFactorGraph (void); void createFactorGraph();
vector<vector<unsigned>> getWeights (void) const; std::vector<std::vector<unsigned>> getWeights() const;
unsigned rangeOfGround (const Ground&); unsigned rangeOfGround (const Ground&);
@ -38,8 +43,9 @@ class LiftedBp : public LiftedSolver
FactorGraph* fg_; FactorGraph* fg_;
DISALLOW_COPY_AND_ASSIGN (LiftedBp); DISALLOW_COPY_AND_ASSIGN (LiftedBp);
}; };
#endif // HORUS_LIFTEDBP_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDBP_H_

View File

@ -1,11 +1,283 @@
#include <cassert>
#include <vector>
#include <unordered_map>
#include <string>
#include <fstream> #include <fstream>
#include <iostream>
#include "LiftedKc.h" #include "LiftedKc.h"
#include "LiftedWCNF.h"
#include "LiftedOperations.h" #include "LiftedOperations.h"
#include "Indexer.h" #include "Indexer.h"
OrNode::~OrNode (void) namespace Horus {
enum class CircuitNodeType {
orCnt,
andCnt,
setOrCnt,
setAndCnt,
incExcCnt,
leafCnt,
smoothCnt,
trueCnt,
compilationFailedCnt
};
class CircuitNode {
public:
CircuitNode() { }
virtual ~CircuitNode() { }
virtual double weight() const = 0;
};
class OrNode : public CircuitNode {
public:
OrNode() : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
~OrNode();
CircuitNode** leftBranch () { return &leftBranch_; }
CircuitNode** rightBranch() { return &rightBranch_; }
double weight() const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class AndNode : public CircuitNode {
public:
AndNode() : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
AndNode (CircuitNode* leftBranch, CircuitNode* rightBranch)
: CircuitNode(), leftBranch_(leftBranch),
rightBranch_(rightBranch) { }
~AndNode();
CircuitNode** leftBranch () { return &leftBranch_; }
CircuitNode** rightBranch() { return &rightBranch_; }
double weight() const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class SetOrNode : public CircuitNode {
public:
SetOrNode (unsigned nrGroundings)
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
~SetOrNode();
CircuitNode** follow() { return &follow_; }
static unsigned nrPositives() { return nrPos_; }
static unsigned nrNegatives() { return nrNeg_; }
static bool isSet() { return nrPos_ >= 0; }
double weight() const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
static int nrPos_;
static int nrNeg_;
};
class SetAndNode : public CircuitNode {
public:
SetAndNode (unsigned nrGroundings)
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
~SetAndNode();
CircuitNode** follow() { return &follow_; }
double weight() const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
};
class IncExcNode : public CircuitNode {
public:
IncExcNode()
: CircuitNode(), plus1Branch_(0), plus2Branch_(0), minusBranch_(0) { }
~IncExcNode();
CircuitNode** plus1Branch() { return &plus1Branch_; }
CircuitNode** plus2Branch() { return &plus2Branch_; }
CircuitNode** minusBranch() { return &minusBranch_; }
double weight() const;
private:
CircuitNode* plus1Branch_;
CircuitNode* plus2Branch_;
CircuitNode* minusBranch_;
};
class LeafNode : public CircuitNode {
public:
LeafNode (Clause* clause, const LiftedWCNF& lwcnf)
: CircuitNode(), clause_(clause), lwcnf_(lwcnf) { }
~LeafNode();
const Clause* clause() const { return clause_; }
Clause* clause() { return clause_; }
double weight() const;
private:
Clause* clause_;
const LiftedWCNF& lwcnf_;
};
class SmoothNode : public CircuitNode {
public:
SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf)
: CircuitNode(), clauses_(clauses), lwcnf_(lwcnf) { }
~SmoothNode();
const Clauses& clauses() const { return clauses_; }
Clauses clauses() { return clauses_; }
double weight() const;
private:
Clauses clauses_;
const LiftedWCNF& lwcnf_;
};
class TrueNode : public CircuitNode {
public:
TrueNode() : CircuitNode() { }
double weight() const;
};
class CompilationFailedNode : public CircuitNode {
public:
CompilationFailedNode() : CircuitNode() { }
double weight() const;
};
class LiftedCircuit {
public:
LiftedCircuit (const LiftedWCNF* lwcnf);
~LiftedCircuit();
bool isCompilationSucceeded() const;
double getWeightedModelCount() const;
void exportToGraphViz (const char*);
private:
void compile (CircuitNode** follow, Clauses& clauses);
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
LogVars& rootLogVars);
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
void shatterCountedLogVars (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses,
size_t idx1, size_t idx2);
bool independentClause (Clause& clause, Clauses& otherClauses) const;
bool independentLiteral (const Literal& lit,
const Literals& otherLits) const;
LitLvTypesSet smoothCircuit (CircuitNode* node);
void createSmoothNode (const LitLvTypesSet& lids,
CircuitNode** prev);
std::vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
bool containsTypes (const LogVarTypes& typesA,
const LogVarTypes& typesB) const;
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
void exportToGraphViz (CircuitNode* node, std::ofstream&);
void printClauses (CircuitNode* node, std::ofstream&,
std::string extraOptions = "");
std::string escapeNode (const CircuitNode* node) const;
std::string getExplanationString (CircuitNode* node);
CircuitNode* root_;
const LiftedWCNF* lwcnf_;
bool compilationSucceeded_;
Clauses backupClauses_;
std::unordered_map<CircuitNode*, Clauses> originClausesMap_;
std::unordered_map<CircuitNode*, std::string> explanationMap_;
DISALLOW_COPY_AND_ASSIGN (LiftedCircuit);
};
OrNode::~OrNode()
{ {
delete leftBranch_; delete leftBranch_;
delete rightBranch_; delete rightBranch_;
@ -14,7 +286,7 @@ OrNode::~OrNode (void)
double double
OrNode::weight (void) const OrNode::weight() const
{ {
double lw = leftBranch_->weight(); double lw = leftBranch_->weight();
double rw = rightBranch_->weight(); double rw = rightBranch_->weight();
@ -23,7 +295,7 @@ OrNode::weight (void) const
AndNode::~AndNode (void) AndNode::~AndNode()
{ {
delete leftBranch_; delete leftBranch_;
delete rightBranch_; delete rightBranch_;
@ -32,7 +304,7 @@ AndNode::~AndNode (void)
double double
AndNode::weight (void) const AndNode::weight() const
{ {
double lw = leftBranch_->weight(); double lw = leftBranch_->weight();
double rw = rightBranch_->weight(); double rw = rightBranch_->weight();
@ -46,7 +318,7 @@ int SetOrNode::nrNeg_ = -1;
SetOrNode::~SetOrNode (void) SetOrNode::~SetOrNode()
{ {
delete follow_; delete follow_;
} }
@ -54,7 +326,7 @@ SetOrNode::~SetOrNode (void)
double double
SetOrNode::weight (void) const SetOrNode::weight() const
{ {
double weightSum = LogAware::addIdenty(); double weightSum = LogAware::addIdenty();
for (unsigned i = 0; i < nrGroundings_ + 1; i++) { for (unsigned i = 0; i < nrGroundings_ + 1; i++) {
@ -76,7 +348,7 @@ SetOrNode::weight (void) const
SetAndNode::~SetAndNode (void) SetAndNode::~SetAndNode()
{ {
delete follow_; delete follow_;
} }
@ -84,14 +356,14 @@ SetAndNode::~SetAndNode (void)
double double
SetAndNode::weight (void) const SetAndNode::weight() const
{ {
return LogAware::pow (follow_->weight(), nrGroundings_); return LogAware::pow (follow_->weight(), nrGroundings_);
} }
IncExcNode::~IncExcNode (void) IncExcNode::~IncExcNode()
{ {
delete plus1Branch_; delete plus1Branch_;
delete plus2Branch_; delete plus2Branch_;
@ -101,7 +373,7 @@ IncExcNode::~IncExcNode (void)
double double
IncExcNode::weight (void) const IncExcNode::weight() const
{ {
double w = 0.0; double w = 0.0;
if (Globals::logDomain) { if (Globals::logDomain) {
@ -116,7 +388,7 @@ IncExcNode::weight (void) const
LeafNode::~LeafNode (void) LeafNode::~LeafNode()
{ {
delete clause_; delete clause_;
} }
@ -124,7 +396,7 @@ LeafNode::~LeafNode (void)
double double
LeafNode::weight (void) const LeafNode::weight() const
{ {
assert (clause_->isUnit()); assert (clause_->isUnit());
if (clause_->posCountedLogVars().empty() == false if (clause_->posCountedLogVars().empty() == false
@ -161,7 +433,7 @@ LeafNode::weight (void) const
SmoothNode::~SmoothNode (void) SmoothNode::~SmoothNode()
{ {
Clause::deleteClauses (clauses_); Clause::deleteClauses (clauses_);
} }
@ -169,7 +441,7 @@ SmoothNode::~SmoothNode (void)
double double
SmoothNode::weight (void) const SmoothNode::weight() const
{ {
Clauses cs = clauses(); Clauses cs = clauses();
double totalWeight = LogAware::multIdenty(); double totalWeight = LogAware::multIdenty();
@ -204,7 +476,7 @@ SmoothNode::weight (void) const
double double
TrueNode::weight (void) const TrueNode::weight() const
{ {
return LogAware::multIdenty(); return LogAware::multIdenty();
} }
@ -212,7 +484,7 @@ TrueNode::weight (void) const
double double
CompilationFailedNode::weight (void) const CompilationFailedNode::weight() const
{ {
// weighted model counting in compilation // weighted model counting in compilation
// failed nodes should give NaN // failed nodes should give NaN
@ -234,21 +506,22 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
if (compilationSucceeded_) { if (compilationSucceeded_) {
double wmc = LogAware::exp (getWeightedModelCount()); double wmc = LogAware::exp (getWeightedModelCount());
cout << "Weighted model count = " << wmc << endl << endl; std::cout << "Weighted model count = " << wmc;
std::cout << std::endl << std::endl;
} }
cout << "Exporting circuit to graphviz (circuit.dot)..." ; std::cout << "Exporting circuit to graphviz (circuit.dot)..." ;
cout << endl << endl; std::cout << std::endl << std::endl;
exportToGraphViz ("circuit.dot"); exportToGraphViz ("circuit.dot");
} }
} }
LiftedCircuit::~LiftedCircuit (void) LiftedCircuit::~LiftedCircuit()
{ {
delete root_; delete root_;
unordered_map<CircuitNode*, Clauses>::iterator it; std::unordered_map<CircuitNode*, Clauses>::iterator it
it = originClausesMap_.begin(); = originClausesMap_.begin();
while (it != originClausesMap_.end()) { while (it != originClausesMap_.end()) {
Clause::deleteClauses (it->second); Clause::deleteClauses (it->second);
++ it; ++ it;
@ -258,7 +531,7 @@ LiftedCircuit::~LiftedCircuit (void)
bool bool
LiftedCircuit::isCompilationSucceeded (void) const LiftedCircuit::isCompilationSucceeded() const
{ {
return compilationSucceeded_; return compilationSucceeded_;
} }
@ -266,7 +539,7 @@ LiftedCircuit::isCompilationSucceeded (void) const
double double
LiftedCircuit::getWeightedModelCount (void) const LiftedCircuit::getWeightedModelCount() const
{ {
assert (compilationSucceeded_); assert (compilationSucceeded_);
return root_->weight(); return root_->weight();
@ -277,15 +550,16 @@ LiftedCircuit::getWeightedModelCount (void) const
void void
LiftedCircuit::exportToGraphViz (const char* fileName) LiftedCircuit::exportToGraphViz (const char* fileName)
{ {
ofstream out (fileName); std::ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ; std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return; return;
} }
out << "digraph {" << endl; out << "digraph {" << std::endl;
out << "ranksep=1" << endl; out << "ranksep=1" << std::endl;
exportToGraphViz (root_, out); exportToGraphViz (root_, out);
out << "}" << endl; out << "}" << std::endl;
out.close(); out.close();
} }
@ -389,7 +663,7 @@ LiftedCircuit::tryUnitPropagation (
AndNode* andNode = new AndNode(); AndNode* andNode = new AndNode();
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
originClausesMap_[andNode] = backupClauses_; originClausesMap_[andNode] = backupClauses_;
stringstream explanation; std::stringstream explanation;
explanation << " UP on " << clauses[i]->literals()[0]; explanation << " UP on " << clauses[i]->literals()[0];
explanationMap_[andNode] = explanation.str(); explanationMap_[andNode] = explanation.str();
} }
@ -478,7 +752,7 @@ LiftedCircuit::tryShannonDecomp (
OrNode* orNode = new OrNode(); OrNode* orNode = new OrNode();
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
originClausesMap_[orNode] = backupClauses_; originClausesMap_[orNode] = backupClauses_;
stringstream explanation; std::stringstream explanation;
explanation << " SD on " << literals[j]; explanation << " SD on " << literals[j];
explanationMap_[orNode] = explanation.str(); explanationMap_[orNode] = explanation.str();
} }
@ -558,7 +832,7 @@ LiftedCircuit::tryInclusionExclusion (
IncExcNode* ieNode = new IncExcNode(); IncExcNode* ieNode = new IncExcNode();
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
originClausesMap_[ieNode] = backupClauses_; originClausesMap_[ieNode] = backupClauses_;
stringstream explanation; std::stringstream explanation;
explanation << " IncExc on clause nº " << i + 1; explanation << " IncExc on clause nº " << i + 1;
explanationMap_[ieNode] = explanation.str(); explanationMap_[ieNode] = explanation.str();
} }
@ -635,13 +909,13 @@ LiftedCircuit::tryIndepPartialGroundingAux (
} }
} }
// verifies if the IPG logical vars appear in the same positions // verifies if the IPG logical vars appear in the same positions
unordered_map<LiteralId, size_t> positions; std::unordered_map<LiteralId, size_t> positions;
for (size_t i = 0; i < clauses.size(); i++) { for (size_t i = 0; i < clauses.size(); i++) {
const Literals& literals = clauses[i]->literals(); const Literals& literals = clauses[i]->literals();
for (size_t j = 0; j < literals.size(); j++) { for (size_t j = 0; j < literals.size(); j++) {
size_t idx = literals[j].indexOfLogVar (rootLogVars[i]); size_t idx = literals[j].indexOfLogVar (rootLogVars[i]);
assert (idx != literals[j].nrLogVars()); assert (idx != literals[j].nrLogVars());
unordered_map<LiteralId, size_t>::iterator it; std::unordered_map<LiteralId, size_t>::iterator it;
it = positions.find (literals[j].lid()); it = positions.find (literals[j].lid());
if (it != positions.end()) { if (it != positions.end()) {
if (it->second != idx) { if (it->second != idx) {
@ -810,7 +1084,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
switch (getCircuitNodeType (node)) { switch (getCircuitNodeType (node)) {
case CircuitNodeType::OR_NODE: { case CircuitNodeType::orCnt: {
OrNode* casted = dynamic_cast<OrNode*>(node); OrNode* casted = dynamic_cast<OrNode*>(node);
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch()); LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch()); LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
@ -823,7 +1097,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
break; break;
} }
case CircuitNodeType::AND_NODE: { case CircuitNodeType::andCnt: {
AndNode* casted = dynamic_cast<AndNode*>(node); AndNode* casted = dynamic_cast<AndNode*>(node);
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch()); LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch()); LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
@ -832,17 +1106,18 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
break; break;
} }
case CircuitNodeType::SET_OR_NODE: { case CircuitNodeType::setOrCnt: {
SetOrNode* casted = dynamic_cast<SetOrNode*>(node); SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
propagLits = smoothCircuit (*casted->follow()); propagLits = smoothCircuit (*casted->follow());
TinySet<pair<LiteralId,unsigned>> litSet; TinySet<std::pair<LiteralId,unsigned>> litSet;
for (size_t i = 0; i < propagLits.size(); i++) { for (size_t i = 0; i < propagLits.size(); i++) {
litSet.insert (make_pair (propagLits[i].lid(), litSet.insert (std::make_pair (propagLits[i].lid(),
propagLits[i].logVarTypes().size())); propagLits[i].logVarTypes().size()));
} }
LitLvTypesSet missingLids; LitLvTypesSet missingLids;
for (size_t i = 0; i < litSet.size(); i++) { for (size_t i = 0; i < litSet.size(); i++) {
vector<LogVarTypes> allTypes = getAllPossibleTypes (litSet[i].second); std::vector<LogVarTypes> allTypes
= getAllPossibleTypes (litSet[i].second);
for (size_t j = 0; j < allTypes.size(); j++) { for (size_t j = 0; j < allTypes.size(); j++) {
bool typeFound = false; bool typeFound = false;
for (size_t k = 0; k < propagLits.size(); k++) { for (size_t k = 0; k < propagLits.size(); k++) {
@ -869,13 +1144,13 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
break; break;
} }
case CircuitNodeType::SET_AND_NODE: { case CircuitNodeType::setAndCnt: {
SetAndNode* casted = dynamic_cast<SetAndNode*>(node); SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
propagLits = smoothCircuit (*casted->follow()); propagLits = smoothCircuit (*casted->follow());
break; break;
} }
case CircuitNodeType::INC_EXC_NODE: { case CircuitNodeType::incExcCnt: {
IncExcNode* casted = dynamic_cast<IncExcNode*>(node); IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch()); LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch());
LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch()); LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch());
@ -888,7 +1163,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
break; break;
} }
case CircuitNodeType::LEAF_NODE: { case CircuitNodeType::leafCnt: {
LeafNode* casted = dynamic_cast<LeafNode*>(node); LeafNode* casted = dynamic_cast<LeafNode*>(node);
propagLits.insert (LitLvTypes ( propagLits.insert (LitLvTypes (
casted->clause()->literals()[0].lid(), casted->clause()->literals()[0].lid(),
@ -911,8 +1186,8 @@ LiftedCircuit::createSmoothNode (
{ {
if (missingLits.empty() == false) { if (missingLits.empty() == false) {
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
unordered_map<CircuitNode*, Clauses>::iterator it; std::unordered_map<CircuitNode*, Clauses>::iterator it
it = originClausesMap_.find (*prev); = originClausesMap_.find (*prev);
if (it != originClausesMap_.end()) { if (it != originClausesMap_.end()) {
backupClauses_ = it->second; backupClauses_ = it->second;
} else { } else {
@ -927,9 +1202,9 @@ LiftedCircuit::createSmoothNode (
Clause* c = lwcnf_->createClause (lid); Clause* c = lwcnf_->createClause (lid);
for (size_t j = 0; j < types.size(); j++) { for (size_t j = 0; j < types.size(); j++) {
LogVar X = c->literals().front().logVars()[j]; LogVar X = c->literals().front().logVars()[j];
if (types[j] == LogVarType::POS_LV) { if (types[j] == LogVarType::posLvt) {
c->addPosCountedLogVar (X); c->addPosCountedLogVar (X);
} else if (types[j] == LogVarType::NEG_LV) { } else if (types[j] == LogVarType::negLvt) {
c->addNegCountedLogVar (X); c->addNegCountedLogVar (X);
} }
} }
@ -947,15 +1222,15 @@ LiftedCircuit::createSmoothNode (
vector<LogVarTypes> std::vector<LogVarTypes>
LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
{ {
vector<LogVarTypes> res; std::vector<LogVarTypes> res;
if (nrLogVars == 0) { if (nrLogVars == 0) {
// do nothing // do nothing
} else if (nrLogVars == 1) { } else if (nrLogVars == 1) {
res.push_back ({ LogVarType::POS_LV }); res.push_back ({ LogVarType::posLvt });
res.push_back ({ LogVarType::NEG_LV }); res.push_back ({ LogVarType::negLvt });
} else { } else {
Ranges ranges (nrLogVars, 2); Ranges ranges (nrLogVars, 2);
Indexer indexer (ranges); Indexer indexer (ranges);
@ -963,9 +1238,9 @@ LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
LogVarTypes types; LogVarTypes types;
for (size_t i = 0; i < nrLogVars; i++) { for (size_t i = 0; i < nrLogVars; i++) {
if (indexer[i] == 0) { if (indexer[i] == 0) {
types.push_back (LogVarType::POS_LV); types.push_back (LogVarType::posLvt);
} else { } else {
types.push_back (LogVarType::NEG_LV); types.push_back (LogVarType::negLvt);
} }
} }
res.push_back (types); res.push_back (types);
@ -983,13 +1258,13 @@ LiftedCircuit::containsTypes (
const LogVarTypes& typesB) const const LogVarTypes& typesB) const
{ {
for (size_t i = 0; i < typesA.size(); i++) { for (size_t i = 0; i < typesA.size(); i++) {
if (typesA[i] == LogVarType::FULL_LV) { if (typesA[i] == LogVarType::fullLvt) {
} else if (typesA[i] == LogVarType::POS_LV } else if (typesA[i] == LogVarType::posLvt
&& typesB[i] == LogVarType::POS_LV) { && typesB[i] == LogVarType::posLvt) {
} else if (typesA[i] == LogVarType::NEG_LV } else if (typesA[i] == LogVarType::negLvt
&& typesB[i] == LogVarType::NEG_LV) { && typesB[i] == LogVarType::negLvt) {
} else { } else {
return false; return false;
@ -1003,25 +1278,25 @@ LiftedCircuit::containsTypes (
CircuitNodeType CircuitNodeType
LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const
{ {
CircuitNodeType type = CircuitNodeType::OR_NODE; CircuitNodeType type = CircuitNodeType::orCnt;
if (dynamic_cast<const OrNode*>(node)) { if (dynamic_cast<const OrNode*>(node)) {
type = CircuitNodeType::OR_NODE; type = CircuitNodeType::orCnt;
} else if (dynamic_cast<const AndNode*>(node)) { } else if (dynamic_cast<const AndNode*>(node)) {
type = CircuitNodeType::AND_NODE; type = CircuitNodeType::andCnt;
} else if (dynamic_cast<const SetOrNode*>(node)) { } else if (dynamic_cast<const SetOrNode*>(node)) {
type = CircuitNodeType::SET_OR_NODE; type = CircuitNodeType::setOrCnt;
} else if (dynamic_cast<const SetAndNode*>(node)) { } else if (dynamic_cast<const SetAndNode*>(node)) {
type = CircuitNodeType::SET_AND_NODE; type = CircuitNodeType::setAndCnt;
} else if (dynamic_cast<const IncExcNode*>(node)) { } else if (dynamic_cast<const IncExcNode*>(node)) {
type = CircuitNodeType::INC_EXC_NODE; type = CircuitNodeType::incExcCnt;
} else if (dynamic_cast<const LeafNode*>(node)) { } else if (dynamic_cast<const LeafNode*>(node)) {
type = CircuitNodeType::LEAF_NODE; type = CircuitNodeType::leafCnt;
} else if (dynamic_cast<const SmoothNode*>(node)) { } else if (dynamic_cast<const SmoothNode*>(node)) {
type = CircuitNodeType::SMOOTH_NODE; type = CircuitNodeType::smoothCnt;
} else if (dynamic_cast<const TrueNode*>(node)) { } else if (dynamic_cast<const TrueNode*>(node)) {
type = CircuitNodeType::TRUE_NODE; type = CircuitNodeType::trueCnt;
} else if (dynamic_cast<const CompilationFailedNode*>(node)) { } else if (dynamic_cast<const CompilationFailedNode*>(node)) {
type = CircuitNodeType::COMPILATION_FAILED_NODE; type = CircuitNodeType::compilationFailedCnt;
} else { } else {
assert (false); assert (false);
} }
@ -1031,127 +1306,131 @@ LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const
void void
LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os) LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
{ {
assert (node); assert (node);
static unsigned nrAuxNodes = 0; static unsigned nrAuxNodes = 0;
stringstream ss; std::stringstream ss;
ss << "n" << nrAuxNodes; ss << "n" << nrAuxNodes;
string auxNode = ss.str(); std::string auxNode = ss.str();
nrAuxNodes ++; nrAuxNodes ++;
string opStyle = "shape=circle,width=0.7,margin=\"0.0,0.0\"," ; std::string opStyle = "shape=circle,width=0.7,margin=\"0.0,0.0\"," ;
switch (getCircuitNodeType (node)) { switch (getCircuitNodeType (node)) {
case OR_NODE: { case CircuitNodeType::orCnt: {
OrNode* casted = dynamic_cast<OrNode*>(node); OrNode* casted = dynamic_cast<OrNode*>(node);
printClauses (casted, os); printClauses (casted, os);
os << auxNode << " [" << opStyle << "label=\"\"]" << endl; os << auxNode << " [" << opStyle << "label=\"\"]" ;
os << std::endl;
os << escapeNode (node) << " -> " << auxNode; os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ; os << " [label=\"" << getExplanationString (node) << "\"]" ;
os << endl; os << std::endl;
os << auxNode << " -> " ; os << auxNode << " -> " ;
os << escapeNode (*casted->leftBranch()); os << escapeNode (*casted->leftBranch());
os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ; os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ;
os << endl; os << std::endl;
os << auxNode << " -> " ; os << auxNode << " -> " ;
os << escapeNode (*casted->rightBranch()); os << escapeNode (*casted->rightBranch());
os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ; os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ;
os << endl; os << std::endl;
exportToGraphViz (*casted->leftBranch(), os); exportToGraphViz (*casted->leftBranch(), os);
exportToGraphViz (*casted->rightBranch(), os); exportToGraphViz (*casted->rightBranch(), os);
break; break;
} }
case AND_NODE: { case CircuitNodeType::andCnt: {
AndNode* casted = dynamic_cast<AndNode*>(node); AndNode* casted = dynamic_cast<AndNode*>(node);
printClauses (casted, os); printClauses (casted, os);
os << auxNode << " [" << opStyle << "label=\"\"]" << endl; os << auxNode << " [" << opStyle << "label=\"\"]" ;
os << std::endl;
os << escapeNode (node) << " -> " << auxNode; os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ; os << " [label=\"" << getExplanationString (node) << "\"]" ;
os << endl; os << std::endl;
os << auxNode << " -> " ; os << auxNode << " -> " ;
os << escapeNode (*casted->leftBranch()); os << escapeNode (*casted->leftBranch());
os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ; os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ;
os << endl; os << std::endl;
os << auxNode << " -> " ; os << auxNode << " -> " ;
os << escapeNode (*casted->rightBranch()) << endl; os << escapeNode (*casted->rightBranch());
os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ; os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ;
os << endl; os << std::endl;
exportToGraphViz (*casted->leftBranch(), os); exportToGraphViz (*casted->leftBranch(), os);
exportToGraphViz (*casted->rightBranch(), os); exportToGraphViz (*casted->rightBranch(), os);
break; break;
} }
case SET_OR_NODE: { case CircuitNodeType::setOrCnt: {
SetOrNode* casted = dynamic_cast<SetOrNode*>(node); SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
printClauses (casted, os); printClauses (casted, os);
os << auxNode << " [" << opStyle << "label=\"(X)\"]" << endl; os << auxNode << " [" << opStyle << "label=\"(X)\"]" ;
os << std::endl;
os << escapeNode (node) << " -> " << auxNode; os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ; os << " [label=\"" << getExplanationString (node) << "\"]" ;
os << endl; os << std::endl;
os << auxNode << " -> " ; os << auxNode << " -> " ;
os << escapeNode (*casted->follow()); os << escapeNode (*casted->follow());
os << " [label=\" " << (*casted->follow())->weight() << "\"]" ; os << " [label=\" " << (*casted->follow())->weight() << "\"]" ;
os << endl; os << std::endl;
exportToGraphViz (*casted->follow(), os); exportToGraphViz (*casted->follow(), os);
break; break;
} }
case SET_AND_NODE: { case CircuitNodeType::setAndCnt: {
SetAndNode* casted = dynamic_cast<SetAndNode*>(node); SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
printClauses (casted, os); printClauses (casted, os);
os << auxNode << " [" << opStyle << "label=\"∧(X)\"]" << endl; os << auxNode << " [" << opStyle << "label=\"∧(X)\"]" ;
os << std::endl;
os << escapeNode (node) << " -> " << auxNode; os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ; os << " [label=\"" << getExplanationString (node) << "\"]" ;
os << endl; os << std::endl;
os << auxNode << " -> " ; os << auxNode << " -> " ;
os << escapeNode (*casted->follow()); os << escapeNode (*casted->follow());
os << " [label=\" " << (*casted->follow())->weight() << "\"]" ; os << " [label=\" " << (*casted->follow())->weight() << "\"]" ;
os << endl; os << std::endl;
exportToGraphViz (*casted->follow(), os); exportToGraphViz (*casted->follow(), os);
break; break;
} }
case INC_EXC_NODE: { case CircuitNodeType::incExcCnt: {
IncExcNode* casted = dynamic_cast<IncExcNode*>(node); IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
printClauses (casted, os); printClauses (casted, os);
os << auxNode << " [" << opStyle << "label=\"+ - +\"]" ; os << auxNode << " [" << opStyle << "label=\"+ - +\"]" ;
os << endl; os << std::endl;
os << escapeNode (node) << " -> " << auxNode; os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ; os << " [label=\"" << getExplanationString (node) << "\"]" ;
os << endl; os << std::endl;
os << auxNode << " -> " ; os << auxNode << " -> " ;
os << escapeNode (*casted->plus1Branch()); os << escapeNode (*casted->plus1Branch());
os << " [label=\" " << (*casted->plus1Branch())->weight() << "\"]" ; os << " [label=\" " << (*casted->plus1Branch())->weight() << "\"]" ;
os << endl; os << std::endl;
os << auxNode << " -> " ; os << auxNode << " -> " ;
os << escapeNode (*casted->minusBranch()) << endl; os << escapeNode (*casted->minusBranch()) << std::endl;
os << " [label=\" " << (*casted->minusBranch())->weight() << "\"]" ; os << " [label=\" " << (*casted->minusBranch())->weight() << "\"]" ;
os << endl; os << std::endl;
os << auxNode << " -> " ; os << auxNode << " -> " ;
os << escapeNode (*casted->plus2Branch()); os << escapeNode (*casted->plus2Branch());
os << " [label=\" " << (*casted->plus2Branch())->weight() << "\"]" ; os << " [label=\" " << (*casted->plus2Branch())->weight() << "\"]" ;
os << endl; os << std::endl;
exportToGraphViz (*casted->plus1Branch(), os); exportToGraphViz (*casted->plus1Branch(), os);
exportToGraphViz (*casted->plus2Branch(), os); exportToGraphViz (*casted->plus2Branch(), os);
@ -1159,24 +1438,24 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
break; break;
} }
case LEAF_NODE: { case CircuitNodeType::leafCnt: {
printClauses (node, os, "style=filled,fillcolor=palegreen,"); printClauses (node, os, "style=filled,fillcolor=palegreen,");
break; break;
} }
case SMOOTH_NODE: { case CircuitNodeType::smoothCnt: {
printClauses (node, os, "style=filled,fillcolor=lightblue,"); printClauses (node, os, "style=filled,fillcolor=lightblue,");
break; break;
} }
case TRUE_NODE: { case CircuitNodeType::trueCnt: {
os << escapeNode (node); os << escapeNode (node);
os << " [shape=box,label=\"\"]" ; os << " [shape=box,label=\"\"]" ;
os << endl; os << std::endl;
break; break;
} }
case COMPILATION_FAILED_NODE: { case CircuitNodeType::compilationFailedCnt: {
printClauses (node, os, "style=filled,fillcolor=salmon,"); printClauses (node, os, "style=filled,fillcolor=salmon,");
break; break;
} }
@ -1188,17 +1467,17 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
string std::string
LiftedCircuit::escapeNode (const CircuitNode* node) const LiftedCircuit::escapeNode (const CircuitNode* node) const
{ {
stringstream ss; std::stringstream ss;
ss << "\"" << node << "\"" ; ss << "\"" << node << "\"" ;
return ss.str(); return ss.str();
} }
string std::string
LiftedCircuit::getExplanationString (CircuitNode* node) LiftedCircuit::getExplanationString (CircuitNode* node)
{ {
return Util::contains (explanationMap_, node) return Util::contains (explanationMap_, node)
@ -1211,15 +1490,15 @@ LiftedCircuit::getExplanationString (CircuitNode* node)
void void
LiftedCircuit::printClauses ( LiftedCircuit::printClauses (
CircuitNode* node, CircuitNode* node,
ofstream& os, std::ofstream& os,
string extraOptions) std::string extraOptions)
{ {
Clauses clauses; Clauses clauses;
if (Util::contains (originClausesMap_, node)) { if (Util::contains (originClausesMap_, node)) {
clauses = originClausesMap_[node]; clauses = originClausesMap_[node];
} else if (getCircuitNodeType (node) == CircuitNodeType::LEAF_NODE) { } else if (getCircuitNodeType (node) == CircuitNodeType::leafCnt) {
clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ; clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ;
} else if (getCircuitNodeType (node) == CircuitNodeType::SMOOTH_NODE) { } else if (getCircuitNodeType (node) == CircuitNodeType::smoothCnt) {
clauses = (dynamic_cast<SmoothNode*>(node))->clauses(); clauses = (dynamic_cast<SmoothNode*>(node))->clauses();
} }
assert (clauses.empty() == false); assert (clauses.empty() == false);
@ -1230,15 +1509,7 @@ LiftedCircuit::printClauses (
os << *clauses[i]; os << *clauses[i];
} }
os << "\"]" ; os << "\"]" ;
os << endl; os << std::endl;
}
LiftedKc::~LiftedKc (void)
{
delete lwcnf_;
delete circuit_;
} }
@ -1246,20 +1517,21 @@ LiftedKc::~LiftedKc (void)
Params Params
LiftedKc::solveQuery (const Grounds& query) LiftedKc::solveQuery (const Grounds& query)
{ {
pfList_ = parfactorList; ParfactorList pfList (parfactorList);
LiftedOperations::shatterAgainstQuery (pfList_, query); LiftedOperations::shatterAgainstQuery (pfList, query);
LiftedOperations::runWeakBayesBall (pfList_, query); LiftedOperations::runWeakBayesBall (pfList, query);
lwcnf_ = new LiftedWCNF (pfList_); LiftedWCNF lwcnf (pfList);
circuit_ = new LiftedCircuit (lwcnf_); LiftedCircuit circuit (&lwcnf);
if (circuit_->isCompilationSucceeded() == false) { if (circuit.isCompilationSucceeded() == false) {
cerr << "Error: the circuit compilation has failed." << endl; std::cerr << "Error: the circuit compilation has failed." ;
std::cerr << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
vector<PrvGroup> groups; std::vector<PrvGroup> groups;
Ranges ranges; Ranges ranges;
for (size_t i = 0; i < query.size(); i++) { for (size_t i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin(); ParfactorList::const_iterator it = pfList.begin();
while (it != pfList_.end()) { while (it != pfList.end()) {
size_t idx = (*it)->indexOfGround (query[i]); size_t idx = (*it)->indexOfGround (query[i]);
if (idx != (*it)->nrArguments()) { if (idx != (*it)->nrArguments()) {
groups.push_back ((*it)->argument (idx).group()); groups.push_back ((*it)->argument (idx).group());
@ -1274,18 +1546,18 @@ LiftedKc::solveQuery (const Grounds& query)
Indexer indexer (ranges); Indexer indexer (ranges);
while (indexer.valid()) { while (indexer.valid()) {
for (size_t i = 0; i < groups.size(); i++) { for (size_t i = 0; i < groups.size(); i++) {
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]); std::vector<LiteralId> litIds = lwcnf.prvGroupLiterals (groups[i]);
for (size_t j = 0; j < litIds.size(); j++) { for (size_t j = 0; j < litIds.size(); j++) {
if (indexer[i] == j) { if (indexer[i] == j) {
lwcnf_->addWeight (litIds[j], LogAware::one(), lwcnf.addWeight (litIds[j], LogAware::one(),
LogAware::one()); LogAware::one());
} else { } else {
lwcnf_->addWeight (litIds[j], LogAware::zero(), lwcnf.addWeight (litIds[j], LogAware::zero(),
LogAware::one()); LogAware::one());
} }
} }
} }
params.push_back (circuit_->getWeightedModelCount()); params.push_back (circuit.getWeightedModelCount());
++ indexer; ++ indexer;
} }
LogAware::normalize (params); LogAware::normalize (params);
@ -1298,12 +1570,14 @@ LiftedKc::solveQuery (const Grounds& query)
void void
LiftedKc::printSolverFlags (void) const LiftedKc::printSolverFlags() const
{ {
stringstream ss; std::stringstream ss;
ss << "lifted kc [" ; ss << "lifted kc [" ;
ss << "log_domain=" << Util::toString (Globals::logDomain); ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; std::cout << ss.str() << std::endl;
} }
} // namespace Horus

View File

@ -1,302 +1,26 @@
#ifndef HORUS_LIFTEDKC_H #ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDKC_H_
#define HORUS_LIFTEDKC_H #define YAP_PACKAGES_CLPBN_HORUS_LIFTEDKC_H_
#include "LiftedSolver.h" #include "LiftedSolver.h"
#include "LiftedWCNF.h"
#include "ParfactorList.h" #include "ParfactorList.h"
enum CircuitNodeType { namespace Horus {
OR_NODE,
AND_NODE,
SET_OR_NODE,
SET_AND_NODE,
INC_EXC_NODE,
LEAF_NODE,
SMOOTH_NODE,
TRUE_NODE,
COMPILATION_FAILED_NODE
};
class LiftedKc : public LiftedSolver {
class CircuitNode
{
public:
CircuitNode (void) { }
virtual ~CircuitNode (void) { }
virtual double weight (void) const = 0;
};
class OrNode : public CircuitNode
{
public:
OrNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
~OrNode (void);
CircuitNode** leftBranch (void) { return &leftBranch_; }
CircuitNode** rightBranch (void) { return &rightBranch_; }
double weight (void) const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class AndNode : public CircuitNode
{
public:
AndNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
AndNode (CircuitNode* leftBranch, CircuitNode* rightBranch)
: CircuitNode(), leftBranch_(leftBranch), rightBranch_(rightBranch) { }
~AndNode (void);
CircuitNode** leftBranch (void) { return &leftBranch_; }
CircuitNode** rightBranch (void) { return &rightBranch_; }
double weight (void) const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class SetOrNode : public CircuitNode
{
public:
SetOrNode (unsigned nrGroundings)
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
~SetOrNode (void);
CircuitNode** follow (void) { return &follow_; }
static unsigned nrPositives (void) { return nrPos_; }
static unsigned nrNegatives (void) { return nrNeg_; }
static bool isSet (void) { return nrPos_ >= 0; }
double weight (void) const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
static int nrPos_;
static int nrNeg_;
};
class SetAndNode : public CircuitNode
{
public:
SetAndNode (unsigned nrGroundings)
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
~SetAndNode (void);
CircuitNode** follow (void) { return &follow_; }
double weight (void) const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
};
class IncExcNode : public CircuitNode
{
public:
IncExcNode (void)
: CircuitNode(), plus1Branch_(0), plus2Branch_(0), minusBranch_(0) { }
~IncExcNode (void);
CircuitNode** plus1Branch (void) { return &plus1Branch_; }
CircuitNode** plus2Branch (void) { return &plus2Branch_; }
CircuitNode** minusBranch (void) { return &minusBranch_; }
double weight (void) const;
private:
CircuitNode* plus1Branch_;
CircuitNode* plus2Branch_;
CircuitNode* minusBranch_;
};
class LeafNode : public CircuitNode
{
public:
LeafNode (Clause* clause, const LiftedWCNF& lwcnf)
: CircuitNode(), clause_(clause), lwcnf_(lwcnf) { }
~LeafNode (void);
const Clause* clause (void) const { return clause_; }
Clause* clause (void) { return clause_; }
double weight (void) const;
private:
Clause* clause_;
const LiftedWCNF& lwcnf_;
};
class SmoothNode : public CircuitNode
{
public:
SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf)
: CircuitNode(), clauses_(clauses), lwcnf_(lwcnf) { }
~SmoothNode (void);
const Clauses& clauses (void) const { return clauses_; }
Clauses clauses (void) { return clauses_; }
double weight (void) const;
private:
Clauses clauses_;
const LiftedWCNF& lwcnf_;
};
class TrueNode : public CircuitNode
{
public:
TrueNode (void) : CircuitNode() { }
double weight (void) const;
};
class CompilationFailedNode : public CircuitNode
{
public:
CompilationFailedNode (void) : CircuitNode() { }
double weight (void) const;
};
class LiftedCircuit
{
public:
LiftedCircuit (const LiftedWCNF* lwcnf);
~LiftedCircuit (void);
bool isCompilationSucceeded (void) const;
double getWeightedModelCount (void) const;
void exportToGraphViz (const char*);
private:
void compile (CircuitNode** follow, Clauses& clauses);
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
LogVars& rootLogVars);
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
void shatterCountedLogVars (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2);
bool independentClause (Clause& clause, Clauses& otherClauses) const;
bool independentLiteral (const Literal& lit,
const Literals& otherLits) const;
LitLvTypesSet smoothCircuit (CircuitNode* node);
void createSmoothNode (const LitLvTypesSet& lids,
CircuitNode** prev);
vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
bool containsTypes (const LogVarTypes& typesA,
const LogVarTypes& typesB) const;
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
void exportToGraphViz (CircuitNode* node, ofstream&);
void printClauses (CircuitNode* node, ofstream&,
string extraOptions = "");
string escapeNode (const CircuitNode* node) const;
string getExplanationString (CircuitNode* node);
CircuitNode* root_;
const LiftedWCNF* lwcnf_;
bool compilationSucceeded_;
Clauses backupClauses_;
unordered_map<CircuitNode*, Clauses> originClausesMap_;
unordered_map<CircuitNode*, string> explanationMap_;
DISALLOW_COPY_AND_ASSIGN (LiftedCircuit);
};
class LiftedKc : public LiftedSolver
{
public: public:
LiftedKc (const ParfactorList& pfList) LiftedKc (const ParfactorList& pfList)
: LiftedSolver(pfList) { } : LiftedSolver(pfList) { }
~LiftedKc (void);
Params solveQuery (const Grounds&); Params solveQuery (const Grounds&);
void printSolverFlags (void) const; void printSolverFlags() const;
private: private:
LiftedWCNF* lwcnf_;
LiftedCircuit* circuit_;
ParfactorList pfList_;
DISALLOW_COPY_AND_ASSIGN (LiftedKc); DISALLOW_COPY_AND_ASSIGN (LiftedKc);
}; };
#endif // HORUS_LIFTEDKC_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDKC_H_

View File

@ -1,10 +1,22 @@
#include <vector>
#include <queue>
#include <iostream>
#include "LiftedOperations.h" #include "LiftedOperations.h"
namespace Horus {
namespace LiftedOperations {
namespace {
Parfactors absorve (ObservedFormula& obsFormula, Parfactor* g);
}
void void
LiftedOperations::shatterAgainstQuery ( shatterAgainstQuery (ParfactorList& pfList, const Grounds& query)
ParfactorList& pfList,
const Grounds& query)
{ {
for (size_t i = 0; i < query.size(); i++) { for (size_t i = 0; i < query.size(); i++) {
if (query[i].isAtom()) { if (query[i].isAtom()) {
@ -35,17 +47,17 @@ LiftedOperations::shatterAgainstQuery (
} }
} }
if (found == false) { if (found == false) {
cerr << "Error: could not find a parfactor with ground " ; std::cerr << "Error: could not find a parfactor with ground " ;
cerr << "`" << query[i] << "'." << endl; std::cerr << "`" << query[i] << "'." << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
pfList.add (newPfs); pfList.add (newPfs);
} }
if (Globals::verbosity > 2) { if (Globals::verbosity > 2) {
Util::printAsteriskLine(); Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl; std::cout << "SHATTERED AGAINST THE QUERY" << std::endl;
for (size_t i = 0; i < query.size(); i++) { for (size_t i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl; std::cout << " -> " << query[i] << std::endl;
} }
Util::printAsteriskLine(); Util::printAsteriskLine();
pfList.print(); pfList.print();
@ -55,12 +67,10 @@ LiftedOperations::shatterAgainstQuery (
void void
LiftedOperations::runWeakBayesBall ( runWeakBayesBall (ParfactorList& pfList, const Grounds& query)
ParfactorList& pfList,
const Grounds& query)
{ {
queue<PrvGroup> todo; // groups to process std::queue<PrvGroup> todo; // groups to process
set<PrvGroup> done; // processed or in queue std::set<PrvGroup> done; // processed or in queue
for (size_t i = 0; i < query.size(); i++) { for (size_t i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList.begin(); ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) { while (it != pfList.end()) {
@ -74,14 +84,14 @@ LiftedOperations::runWeakBayesBall (
} }
} }
set<Parfactor*> requiredPfs; std::set<Parfactor*> requiredPfs;
while (todo.empty() == false) { while (todo.empty() == false) {
PrvGroup group = todo.front(); PrvGroup group = todo.front();
ParfactorList::iterator it = pfList.begin(); ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) { while (it != pfList.end()) {
if (Util::contains (requiredPfs, *it) == false && if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) { (*it)->containsGroup (group)) {
vector<PrvGroup> groups = (*it)->getAllGroups(); std::vector<PrvGroup> groups = (*it)->getAllGroups();
for (size_t i = 0; i < groups.size(); i++) { for (size_t i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) { if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]); todo.push (groups[i]);
@ -116,9 +126,7 @@ LiftedOperations::runWeakBayesBall (
void void
LiftedOperations::absorveEvidence ( absorveEvidence (ParfactorList& pfList, ObservedFormulas& obsFormulas)
ParfactorList& pfList,
ObservedFormulas& obsFormulas)
{ {
for (size_t i = 0; i < obsFormulas.size(); i++) { for (size_t i = 0; i < obsFormulas.size(); i++) {
Parfactors newPfs; Parfactors newPfs;
@ -143,9 +151,9 @@ LiftedOperations::absorveEvidence (
} }
if (Globals::verbosity > 2 && obsFormulas.empty() == false) { if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine(); Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl; std::cout << "AFTER EVIDENCE ABSORVED" << std::endl;
for (size_t i = 0; i < obsFormulas.size(); i++) { for (size_t i = 0; i < obsFormulas.size(); i++) {
cout << " -> " << obsFormulas[i] << endl; std::cout << " -> " << obsFormulas[i] << std::endl;
} }
Util::printAsteriskLine(); Util::printAsteriskLine();
pfList.print(); pfList.print();
@ -155,9 +163,7 @@ LiftedOperations::absorveEvidence (
Parfactors Parfactors
LiftedOperations::countNormalize ( countNormalize (Parfactor* g, const LogVarSet& set)
Parfactor* g,
const LogVarSet& set)
{ {
Parfactors normPfs; Parfactors normPfs;
if (set.empty()) { if (set.empty()) {
@ -174,7 +180,7 @@ LiftedOperations::countNormalize (
Parfactor Parfactor
LiftedOperations::calcGroundMultiplication (Parfactor pf) calcGroundMultiplication (Parfactor pf)
{ {
LogVarSet lvs = pf.constr()->logVarSet(); LogVarSet lvs = pf.constr()->logVarSet();
lvs -= pf.constr()->singletons(); lvs -= pf.constr()->singletons();
@ -206,10 +212,10 @@ LiftedOperations::calcGroundMultiplication (Parfactor pf)
namespace {
Parfactors Parfactors
LiftedOperations::absorve ( absorve (ObservedFormula& obsFormula, Parfactor* g)
ObservedFormula& obsFormula,
Parfactor* g)
{ {
Parfactors absorvedPfs; Parfactors absorvedPfs;
const ProbFormulas& formulas = g->arguments(); const ProbFormulas& formulas = g->arguments();
@ -269,3 +275,9 @@ LiftedOperations::absorve (
return absorvedPfs; return absorvedPfs;
} }
}
} // namespace LiftedOperations
} // namespace Horus

View File

@ -1,29 +1,26 @@
#ifndef HORUS_LIFTEDOPERATIONS_H #ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDOPERATIONS_H_
#define HORUS_LIFTEDOPERATIONS_H #define YAP_PACKAGES_CLPBN_HORUS_LIFTEDOPERATIONS_H_
#include "ParfactorList.h" #include "ParfactorList.h"
class LiftedOperations
{
public:
static void shatterAgainstQuery (
ParfactorList& pfList, const Grounds& query);
static void runWeakBayesBall ( namespace Horus {
ParfactorList& pfList, const Grounds&);
static void absorveEvidence ( namespace LiftedOperations {
ParfactorList& pfList, ObservedFormulas& obsFormulas);
static Parfactors countNormalize (Parfactor*, const LogVarSet&); void shatterAgainstQuery (ParfactorList& pfList, const Grounds& query);
static Parfactor calcGroundMultiplication (Parfactor pf); void runWeakBayesBall (ParfactorList& pfList, const Grounds& query);
private: void absorveEvidence (ParfactorList& pfList, ObservedFormulas&);
static Parfactors absorve (ObservedFormula&, Parfactor*);
DISALLOW_COPY_AND_ASSIGN (LiftedOperations); Parfactors countNormalize (Parfactor*, const LogVarSet&);
};
#endif // HORUS_LIFTEDOPERATIONS_H Parfactor calcGroundMultiplication (Parfactor pf);
} // namespace LiftedOperations
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDOPERATIONS_H_

View File

@ -1,14 +1,12 @@
#ifndef HORUS_LIFTEDSOLVER_H #ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDSOLVER_H_
#define HORUS_LIFTEDSOLVER_H #define YAP_PACKAGES_CLPBN_HORUS_LIFTEDSOLVER_H_
#include "ParfactorList.h" #include "ParfactorList.h"
#include "Horus.h"
using namespace std; namespace Horus {
class LiftedSolver class LiftedSolver {
{
public: public:
LiftedSolver (const ParfactorList& pfList) LiftedSolver (const ParfactorList& pfList)
: parfactorList(pfList) { } : parfactorList(pfList) { }
@ -17,7 +15,7 @@ class LiftedSolver
virtual Params solveQuery (const Grounds& query) = 0; virtual Params solveQuery (const Grounds& query) = 0;
virtual void printSolverFlags (void) const = 0; virtual void printSolverFlags() const = 0;
protected: protected:
const ParfactorList& parfactorList; const ParfactorList& parfactorList;
@ -26,5 +24,7 @@ class LiftedSolver
DISALLOW_COPY_AND_ASSIGN (LiftedSolver); DISALLOW_COPY_AND_ASSIGN (LiftedSolver);
}; };
#endif // HORUS_LIFTEDSOLVER_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDSOLVER_H_

View File

@ -1,22 +1,21 @@
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#include <sstream>
#include "LiftedUtils.h" #include "LiftedUtils.h"
#include "ConstraintTree.h"
namespace Horus {
namespace LiftedUtils { namespace LiftedUtils {
std::unordered_map<std::string, unsigned> symbolDict;
unordered_map<string, unsigned> symbolDict;
Symbol Symbol
getSymbol (const string& symbolName) getSymbol (const std::string& symbolName)
{ {
unordered_map<string, unsigned>::iterator it std::unordered_map<std::string, unsigned>::iterator it
= symbolDict.find (symbolName); = symbolDict.find (symbolName);
if (it != symbolDict.end()) { if (it != symbolDict.end()) {
return it->second; return it->second;
@ -29,12 +28,12 @@ getSymbol (const string& symbolName)
void void
printSymbolDictionary (void) printSymbolDictionary()
{ {
unordered_map<string, unsigned>::const_iterator it std::unordered_map<std::string, unsigned>::const_iterator it
= symbolDict.begin(); = symbolDict.begin();
while (it != symbolDict.end()) { while (it != symbolDict.end()) {
cout << it->first << " -> " << it->second << endl; std::cout << it->first << " -> " << it->second << std::endl;
++ it; ++ it;
} }
} }
@ -43,9 +42,10 @@ printSymbolDictionary (void)
ostream& operator<< (ostream &os, const Symbol& s) std::ostream&
operator<< (std::ostream& os, const Symbol& s)
{ {
unordered_map<string, unsigned>::const_iterator it std::unordered_map<std::string, unsigned>::const_iterator it
= LiftedUtils::symbolDict.begin(); = LiftedUtils::symbolDict.begin();
while (it != LiftedUtils::symbolDict.end() && it->second != s) { while (it != LiftedUtils::symbolDict.end() && it->second != s) {
++ it; ++ it;
@ -57,9 +57,10 @@ ostream& operator<< (ostream &os, const Symbol& s)
ostream& operator<< (ostream &os, const LogVar& X) std::ostream&
operator<< (std::ostream& os, const LogVar& X)
{ {
const string labels[] = { const std::string labels[] = {
"A", "B", "C", "D", "E", "F", "A", "B", "C", "D", "E", "F",
"G", "H", "I", "J", "K", "M" }; "G", "H", "I", "J", "K", "M" };
(X >= 12) ? os << "X_" << X.id_ : os << labels[X]; (X >= 12) ? os << "X_" << X.id_ : os << labels[X];
@ -68,7 +69,8 @@ ostream& operator<< (ostream &os, const LogVar& X)
ostream& operator<< (ostream &os, const Tuple& t) std::ostream&
operator<< (std::ostream& os, const Tuple& t)
{ {
os << "(" ; os << "(" ;
for (size_t i = 0; i < t.size(); i++) { for (size_t i = 0; i < t.size(); i++) {
@ -80,7 +82,8 @@ ostream& operator<< (ostream &os, const Tuple& t)
ostream& operator<< (ostream &os, const Ground& gr) std::ostream&
operator<< (std::ostream& os, const Ground& gr)
{ {
os << gr.functor(); os << gr.functor();
os << "(" ; os << "(" ;
@ -95,12 +98,12 @@ ostream& operator<< (ostream &os, const Ground& gr)
LogVars LogVars
Substitution::getDiscardedLogVars (void) const Substitution::getDiscardedLogVars() const
{ {
LogVars discardedLvs; LogVars discardedLvs;
set<LogVar> doneLvs; std::set<LogVar> doneLvs;
unordered_map<LogVar, LogVar>::const_iterator it; std::unordered_map<LogVar, LogVar>::const_iterator it
it = subs_.begin(); = subs_.begin();
while (it != subs_.end()) { while (it != subs_.end()) {
if (Util::contains (doneLvs, it->second)) { if (Util::contains (doneLvs, it->second)) {
discardedLvs.push_back (it->first); discardedLvs.push_back (it->first);
@ -114,9 +117,10 @@ Substitution::getDiscardedLogVars (void) const
ostream& operator<< (ostream &os, const Substitution& theta) std::ostream&
operator<< (std::ostream& os, const Substitution& theta)
{ {
unordered_map<LogVar, LogVar>::const_iterator it; std::unordered_map<LogVar, LogVar>::const_iterator it;
os << "[" ; os << "[" ;
it = theta.subs_.begin(); it = theta.subs_.begin();
while (it != theta.subs_.end()) { while (it != theta.subs_.end()) {
@ -128,3 +132,5 @@ ostream& operator<< (ostream &os, const Substitution& theta)
return os; return os;
} }
} // namespace Horus

View File

@ -1,142 +1,186 @@
#ifndef HORUS_LIFTEDUTILS_H #ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDUTILS_H_
#define HORUS_LIFTEDUTILS_H #define YAP_PACKAGES_CLPBN_HORUS_LIFTEDUTILS_H_
#include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <string>
#include <ostream>
#include "TinySet.h" #include "TinySet.h"
#include "Util.h" #include "Util.h"
using namespace std; namespace Horus {
class Symbol {
class Symbol
{
public: public:
Symbol (void) : id_(Util::maxUnsigned()) { } Symbol() : id_(Util::maxUnsigned()) { }
Symbol (unsigned id) : id_(id) { } Symbol (unsigned id) : id_(id) { }
operator unsigned (void) const { return id_; } operator unsigned() const { return id_; }
bool valid (void) const { return id_ != Util::maxUnsigned(); } bool valid() const { return id_ != Util::maxUnsigned(); }
static Symbol invalid (void) { return Symbol(); } static Symbol invalid() { return Symbol(); }
friend ostream& operator<< (ostream &os, const Symbol& s);
private: private:
friend std::ostream& operator<< (std::ostream&, const Symbol&);
unsigned id_; unsigned id_;
}; };
class LogVar class LogVar {
{
public: public:
LogVar (void) : id_(Util::maxUnsigned()) { } LogVar() : id_(Util::maxUnsigned()) { }
LogVar (unsigned id) : id_(id) { } LogVar (unsigned id) : id_(id) { }
operator unsigned (void) const { return id_; } operator unsigned() const { return id_; }
LogVar& operator++ (void) LogVar& operator++();
bool valid() const;
private:
friend std::ostream& operator<< (std::ostream&, const LogVar&);
unsigned id_;
};
inline LogVar&
LogVar::operator++()
{ {
assert (valid()); assert (valid());
id_ ++; id_ ++;
return *this; return *this;
} }
bool valid (void) const
inline bool
LogVar::valid() const
{ {
return id_ != Util::maxUnsigned(); return id_ != Util::maxUnsigned();
} }
friend ostream& operator<< (ostream &os, const LogVar& X); } // namespace Horus
private:
unsigned id_;
};
namespace std { namespace std {
template <> struct hash<Symbol> {
size_t operator() (const Symbol& s) const { template <> struct hash<Horus::Symbol> {
size_t operator() (const Horus::Symbol& s) const {
return std::hash<unsigned>() (s); return std::hash<unsigned>() (s);
}}; }};
template <> struct hash<LogVar> { template <> struct hash<Horus::LogVar> {
size_t operator() (const LogVar& X) const { size_t operator() (const Horus::LogVar& X) const {
return std::hash<unsigned>() (X); return std::hash<unsigned>() (X);
}}; }};
};
} // namespace std
typedef vector<Symbol> Symbols; namespace Horus {
typedef vector<Symbol> Tuple;
typedef vector<Tuple> Tuples; typedef std::vector<Symbol> Symbols;
typedef vector<LogVar> LogVars; typedef std::vector<Symbol> Tuple;
typedef std::vector<Tuple> Tuples;
typedef std::vector<LogVar> LogVars;
typedef TinySet<Symbol> SymbolSet; typedef TinySet<Symbol> SymbolSet;
typedef TinySet<LogVar> LogVarSet; typedef TinySet<LogVar> LogVarSet;
typedef TinySet<Tuple> TupleSet; typedef TinySet<Tuple> TupleSet;
ostream& operator<< (ostream &os, const Tuple& t); std::ostream& operator<< (std::ostream&, const Tuple&);
namespace LiftedUtils { namespace LiftedUtils {
Symbol getSymbol (const string&);
void printSymbolDictionary (void); Symbol getSymbol (const std::string&);
void printSymbolDictionary();
} }
class Ground class Ground {
{
public: public:
Ground (Symbol f) : functor_(f) { } Ground (Symbol f) : functor_(f) { }
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { } Ground (Symbol f, const Symbols& args)
: functor_(f), args_(args) { }
Symbol functor (void) const { return functor_; } Symbol functor() const { return functor_; }
Symbols args (void) const { return args_; } Symbols args() const { return args_; }
size_t arity (void) const { return args_.size(); } size_t arity() const { return args_.size(); }
bool isAtom (void) const { return args_.empty(); } bool isAtom() const { return args_.empty(); }
friend ostream& operator<< (ostream &os, const Ground& gr);
private: private:
friend std::ostream& operator<< (std::ostream&, const Ground&);
Symbol functor_; Symbol functor_;
Symbols args_; Symbols args_;
}; };
typedef vector<Ground> Grounds; typedef std::vector<Ground> Grounds;
class Substitution class Substitution {
{
public: public:
void add (LogVar X_old, LogVar X_new) void add (LogVar X_old, LogVar X_new);
void rename (LogVar X_old, LogVar X_new);
LogVar newNameFor (LogVar X) const;
bool containsReplacementFor (LogVar X) const;
size_t nrReplacements() const;
LogVars getDiscardedLogVars() const;
private:
friend std::ostream& operator<< (
std::ostream&, const Substitution&);
std::unordered_map<LogVar, LogVar> subs_;
};
inline void
Substitution::add (LogVar X_old, LogVar X_new)
{ {
assert (Util::contains (subs_, X_old) == false); assert (Util::contains (subs_, X_old) == false);
subs_.insert (make_pair (X_old, X_new)); subs_.insert (std::make_pair (X_old, X_new));
} }
void rename (LogVar X_old, LogVar X_new)
inline void
Substitution::rename (LogVar X_old, LogVar X_new)
{ {
assert (Util::contains (subs_, X_old)); assert (Util::contains (subs_, X_old));
subs_.find (X_old)->second = X_new; subs_.find (X_old)->second = X_new;
} }
LogVar newNameFor (LogVar X) const
inline LogVar
Substitution::newNameFor (LogVar X) const
{ {
unordered_map<LogVar, LogVar>::const_iterator it; std::unordered_map<LogVar, LogVar>::const_iterator it;
it = subs_.find (X); it = subs_.find (X);
if (it != subs_.end()) { if (it != subs_.end()) {
return subs_.find (X)->second; return subs_.find (X)->second;
@ -144,21 +188,23 @@ class Substitution
return X; return X;
} }
bool containsReplacementFor (LogVar X) const
inline bool
Substitution::containsReplacementFor (LogVar X) const
{ {
return Util::contains (subs_, X); return Util::contains (subs_, X);
} }
size_t nrReplacements (void) const { return subs_.size(); }
LogVars getDiscardedLogVars (void) const;
friend ostream& operator<< (ostream &os, const Substitution& theta); inline size_t
Substitution::nrReplacements() const
{
return subs_.size();
}
private: } // namespace Horus
unordered_map<LogVar, LogVar> subs_;
}; #endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDUTILS_H_
#endif // HORUS_LIFTEDUTILS_H

View File

@ -1,6 +1,12 @@
#include <algorithm> #include <cassert>
#include <vector>
#include <set> #include <set>
#include <queue>
#include <algorithm>
#include <string>
#include <iostream>
#include <sstream>
#include "LiftedVe.h" #include "LiftedVe.h"
#include "LiftedOperations.h" #include "LiftedOperations.h"
@ -8,21 +14,158 @@
#include "Util.h" #include "Util.h"
vector<LiftedOperator*> namespace Horus {
class LiftedOperator {
public:
virtual ~LiftedOperator() { }
virtual double getLogCost() = 0;
virtual void apply() = 0;
virtual std::string toString() = 0;
static std::vector<LiftedOperator*> getValidOps (
ParfactorList&, const Grounds&);
static void printValidOps (ParfactorList&, const Grounds&);
static std::vector<ParfactorList::iterator> getParfactorsWithGroup (
ParfactorList&, PrvGroup group);
private:
DISALLOW_ASSIGN (LiftedOperator);
};
class ProductOperator : public LiftedOperator {
public:
ProductOperator (
ParfactorList::iterator g1,
ParfactorList::iterator g2,
ParfactorList& pfList)
: g1_(g1), g2_(g2), pfList_(pfList) { }
double getLogCost();
void apply();
static std::vector<ProductOperator*> getValidOps (ParfactorList&);
std::string toString();
private:
static bool validOp (Parfactor*, Parfactor*);
ParfactorList::iterator g1_;
ParfactorList::iterator g2_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (ProductOperator);
};
class SumOutOperator : public LiftedOperator {
public:
SumOutOperator (PrvGroup group, ParfactorList& pfList)
: group_(group), pfList_(pfList) { }
double getLogCost();
void apply();
static std::vector<SumOutOperator*> getValidOps (
ParfactorList&, const Grounds&);
std::string toString();
private:
static bool validOp (PrvGroup, ParfactorList&, const Grounds&);
static bool isToEliminate (Parfactor*, PrvGroup, const Grounds&);
PrvGroup group_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (SumOutOperator);
};
class CountingOperator : public LiftedOperator {
public:
CountingOperator (
ParfactorList::iterator pfIter,
LogVar X,
ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
double getLogCost();
void apply();
static std::vector<CountingOperator*> getValidOps (ParfactorList&);
std::string toString();
private:
static bool validOp (Parfactor*, LogVar);
ParfactorList::iterator pfIter_;
LogVar X_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (CountingOperator);
};
class GroundOperator : public LiftedOperator {
public:
GroundOperator (
PrvGroup group,
unsigned lvIndex,
ParfactorList& pfList)
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
double getLogCost();
void apply();
static std::vector<GroundOperator*> getValidOps (ParfactorList&);
std::string toString();
private:
std::vector<std::pair<PrvGroup, unsigned>> getAffectedFormulas();
PrvGroup group_;
unsigned lvIndex_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (GroundOperator);
};
std::vector<LiftedOperator*>
LiftedOperator::getValidOps ( LiftedOperator::getValidOps (
ParfactorList& pfList, ParfactorList& pfList,
const Grounds& query) const Grounds& query)
{ {
vector<LiftedOperator*> validOps; std::vector<LiftedOperator*> validOps;
vector<ProductOperator*> multOps; std::vector<ProductOperator*> multOps;
multOps = ProductOperator::getValidOps (pfList); multOps = ProductOperator::getValidOps (pfList);
validOps.insert (validOps.end(), multOps.begin(), multOps.end()); validOps.insert (validOps.end(), multOps.begin(), multOps.end());
if (Globals::verbosity > 1 || multOps.empty()) { if (Globals::verbosity > 1 || multOps.empty()) {
vector<SumOutOperator*> sumOutOps; std::vector<SumOutOperator*> sumOutOps;
vector<CountingOperator*> countOps; std::vector<CountingOperator*> countOps;
vector<GroundOperator*> groundOps; std::vector<GroundOperator*> groundOps;
sumOutOps = SumOutOperator::getValidOps (pfList, query); sumOutOps = SumOutOperator::getValidOps (pfList, query);
countOps = CountingOperator::getValidOps (pfList); countOps = CountingOperator::getValidOps (pfList);
groundOps = GroundOperator::getValidOps (pfList); groundOps = GroundOperator::getValidOps (pfList);
@ -41,21 +184,21 @@ LiftedOperator::printValidOps (
ParfactorList& pfList, ParfactorList& pfList,
const Grounds& query) const Grounds& query)
{ {
vector<LiftedOperator*> validOps; std::vector<LiftedOperator*> validOps;
validOps = LiftedOperator::getValidOps (pfList, query); validOps = LiftedOperator::getValidOps (pfList, query);
for (size_t i = 0; i < validOps.size(); i++) { for (size_t i = 0; i < validOps.size(); i++) {
cout << "-> " << validOps[i]->toString(); std::cout << "-> " << validOps[i]->toString();
delete validOps[i]; delete validOps[i];
} }
} }
vector<ParfactorList::iterator> std::vector<ParfactorList::iterator>
LiftedOperator::getParfactorsWithGroup ( LiftedOperator::getParfactorsWithGroup (
ParfactorList& pfList, PrvGroup group) ParfactorList& pfList, PrvGroup group)
{ {
vector<ParfactorList::iterator> iters; std::vector<ParfactorList::iterator> iters;
ParfactorList::iterator pflIt = pfList.begin(); ParfactorList::iterator pflIt = pfList.begin();
while (pflIt != pfList.end()) { while (pflIt != pfList.end()) {
if ((*pflIt)->containsGroup (group)) { if ((*pflIt)->containsGroup (group)) {
@ -69,7 +212,7 @@ LiftedOperator::getParfactorsWithGroup (
double double
ProductOperator::getLogCost (void) ProductOperator::getLogCost()
{ {
return std::log (0.0); return std::log (0.0);
} }
@ -77,7 +220,7 @@ ProductOperator::getLogCost (void)
void void
ProductOperator::apply (void) ProductOperator::apply()
{ {
Parfactor* g1 = *g1_; Parfactor* g1 = *g1_;
Parfactor* g2 = *g2_; Parfactor* g2 = *g2_;
@ -89,13 +232,13 @@ ProductOperator::apply (void)
vector<ProductOperator*> std::vector<ProductOperator*>
ProductOperator::getValidOps (ParfactorList& pfList) ProductOperator::getValidOps (ParfactorList& pfList)
{ {
vector<ProductOperator*> validOps; std::vector<ProductOperator*> validOps;
ParfactorList::iterator it1 = pfList.begin(); ParfactorList::iterator it1 = pfList.begin();
ParfactorList::iterator penultimate = -- pfList.end(); ParfactorList::iterator penultimate = -- pfList.end();
set<Parfactor*> pfs; std::set<Parfactor*> pfs;
while (it1 != penultimate) { while (it1 != penultimate) {
if (Util::contains (pfs, *it1)) { if (Util::contains (pfs, *it1)) {
++ it1; ++ it1;
@ -128,15 +271,15 @@ ProductOperator::getValidOps (ParfactorList& pfList)
string std::string
ProductOperator::toString (void) ProductOperator::toString()
{ {
stringstream ss; std::stringstream ss;
ss << "just multiplicate " ; ss << "just multiplicate " ;
ss << (*g1_)->getAllGroups(); ss << (*g1_)->getAllGroups();
ss << " x " ; ss << " x " ;
ss << (*g2_)->getAllGroups(); ss << (*g2_)->getAllGroups();
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl; ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
return ss.str(); return ss.str();
} }
@ -168,14 +311,14 @@ ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
double double
SumOutOperator::getLogCost (void) SumOutOperator::getLogCost()
{ {
TinySet<PrvGroup> groupSet; TinySet<PrvGroup> groupSet;
ParfactorList::const_iterator pfIter = pfList_.begin(); ParfactorList::const_iterator pfIter = pfList_.begin();
unsigned nrProdFactors = 0; unsigned nrProdFactors = 0;
while (pfIter != pfList_.end()) { while (pfIter != pfList_.end()) {
if ((*pfIter)->containsGroup (group_)) { if ((*pfIter)->containsGroup (group_)) {
vector<PrvGroup> groups = (*pfIter)->getAllGroups(); std::vector<PrvGroup> groups = (*pfIter)->getAllGroups();
groupSet |= TinySet<PrvGroup> (groups); groupSet |= TinySet<PrvGroup> (groups);
++ nrProdFactors; ++ nrProdFactors;
} }
@ -203,9 +346,9 @@ SumOutOperator::getLogCost (void)
void void
SumOutOperator::apply (void) SumOutOperator::apply()
{ {
vector<ParfactorList::iterator> iters; std::vector<ParfactorList::iterator> iters;
iters = getParfactorsWithGroup (pfList_, group_); iters = getParfactorsWithGroup (pfList_, group_);
Parfactor* product = *(iters[0]); Parfactor* product = *(iters[0]);
pfList_.remove (iters[0]); pfList_.remove (iters[0]);
@ -234,13 +377,13 @@ SumOutOperator::apply (void)
vector<SumOutOperator*> std::vector<SumOutOperator*>
SumOutOperator::getValidOps ( SumOutOperator::getValidOps (
ParfactorList& pfList, ParfactorList& pfList,
const Grounds& query) const Grounds& query)
{ {
vector<SumOutOperator*> validOps; std::vector<SumOutOperator*> validOps;
set<PrvGroup> allGroups; std::set<PrvGroup> allGroups;
ParfactorList::const_iterator it = pfList.begin(); ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) { while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments(); const ProbFormulas& formulas = (*it)->arguments();
@ -249,7 +392,7 @@ SumOutOperator::getValidOps (
} }
++ it; ++ it;
} }
set<PrvGroup>::const_iterator groupIt = allGroups.begin(); std::set<PrvGroup>::const_iterator groupIt = allGroups.begin();
while (groupIt != allGroups.end()) { while (groupIt != allGroups.end()) {
if (validOp (*groupIt, pfList, query)) { if (validOp (*groupIt, pfList, query)) {
validOps.push_back (new SumOutOperator (*groupIt, pfList)); validOps.push_back (new SumOutOperator (*groupIt, pfList));
@ -261,18 +404,18 @@ SumOutOperator::getValidOps (
string std::string
SumOutOperator::toString (void) SumOutOperator::toString()
{ {
stringstream ss; std::stringstream ss;
vector<ParfactorList::iterator> pfIters; std::vector<ParfactorList::iterator> pfIters;
pfIters = getParfactorsWithGroup (pfList_, group_); pfIters = getParfactorsWithGroup (pfList_, group_);
size_t idx = (*pfIters[0])->indexOfGroup (group_); size_t idx = (*pfIters[0])->indexOfGroup (group_);
ProbFormula f = (*pfIters[0])->argument (idx); ProbFormula f = (*pfIters[0])->argument (idx);
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars()); TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
ss << "sum out " << f.functor() << "/" << f.arity(); ss << "sum out " << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")"; ss << "|" << tupleSet << " (group " << group_ << ")";
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl; ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
return ss.str(); return ss.str();
} }
@ -284,7 +427,7 @@ SumOutOperator::validOp (
ParfactorList& pfList, ParfactorList& pfList,
const Grounds& query) const Grounds& query)
{ {
vector<ParfactorList::iterator> pfIters; std::vector<ParfactorList::iterator> pfIters;
pfIters = getParfactorsWithGroup (pfList, group); pfIters = getParfactorsWithGroup (pfList, group);
if (isToEliminate (*pfIters[0], group, query) == false) { if (isToEliminate (*pfIters[0], group, query) == false) {
return false; return false;
@ -335,7 +478,7 @@ SumOutOperator::isToEliminate (
double double
CountingOperator::getLogCost (void) CountingOperator::getLogCost()
{ {
double cost = 0.0; double cost = 0.0;
size_t fIdx = (*pfIter_)->indexOfLogVar (X_); size_t fIdx = (*pfIter_)->indexOfLogVar (X_);
@ -370,7 +513,7 @@ CountingOperator::getLogCost (void)
void void
CountingOperator::apply (void) CountingOperator::apply()
{ {
if ((*pfIter_)->constr()->isCountNormalized (X_)) { if ((*pfIter_)->constr()->isCountNormalized (X_)) {
(*pfIter_)->countConvert (X_); (*pfIter_)->countConvert (X_);
@ -393,10 +536,10 @@ CountingOperator::apply (void)
vector<CountingOperator*> std::vector<CountingOperator*>
CountingOperator::getValidOps (ParfactorList& pfList) CountingOperator::getValidOps (ParfactorList& pfList)
{ {
vector<CountingOperator*> validOps; std::vector<CountingOperator*> validOps;
ParfactorList::iterator it = pfList.begin(); ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) { while (it != pfList.end()) {
LogVarSet candidates = (*it)->uncountedLogVars(); LogVarSet candidates = (*it)->uncountedLogVars();
@ -414,17 +557,17 @@ CountingOperator::getValidOps (ParfactorList& pfList)
string std::string
CountingOperator::toString (void) CountingOperator::toString()
{ {
stringstream ss; std::stringstream ss;
ss << "count convert " << X_ << " in " ; ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel(); ss << (*pfIter_)->getLabel();
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl; ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_); Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) { if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
ss << " º " << pfs[i]->getLabel() << endl; ss << " º " << pfs[i]->getLabel() << std::endl;
} }
} }
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
@ -455,16 +598,16 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
double double
GroundOperator::getLogCost (void) GroundOperator::getLogCost()
{ {
vector<pair<PrvGroup, unsigned>> affectedFormulas; std::vector<std::pair<PrvGroup, unsigned>> affectedFormulas;
affectedFormulas = getAffectedFormulas(); affectedFormulas = getAffectedFormulas();
// cout << "affected formulas: " ; // std::cout << "affected formulas: " ;
// for (size_t i = 0; i < affectedFormulas.size(); i++) { // for (size_t i = 0; i < affectedFormulas.size(); i++) {
// cout << affectedFormulas[i].first << ":" ; // std::cout << affectedFormulas[i].first << ":" ;
// cout << affectedFormulas[i].second << " " ; // std::cout << affectedFormulas[i].second << " " ;
// } // }
// cout << "cost =" ; // std::cout << "cost =" ;
double totalCost = std::log (0.0); double totalCost = std::log (0.0);
ParfactorList::iterator pflIt = pfList_.begin(); ParfactorList::iterator pflIt = pfList_.begin();
while (pflIt != pfList_.end()) { while (pflIt != pfList_.end()) {
@ -495,20 +638,20 @@ GroundOperator::getLogCost (void)
} }
} }
if (willBeAffected) { if (willBeAffected) {
// cout << " + " << std::exp (reps) << "x" << std::exp (pfSize); // std::cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
double pfCost = reps + pfSize; double pfCost = reps + pfSize;
totalCost = Util::logSum (totalCost, pfCost); totalCost = Util::logSum (totalCost, pfCost);
} }
++ pflIt; ++ pflIt;
} }
// cout << endl; // std::cout << std::endl;
return totalCost + 3; return totalCost + 3;
} }
void void
GroundOperator::apply (void) GroundOperator::apply()
{ {
ParfactorList::iterator pfIter; ParfactorList::iterator pfIter;
pfIter = getParfactorsWithGroup (pfList_, group_).front(); pfIter = getParfactorsWithGroup (pfList_, group_).front();
@ -537,11 +680,11 @@ GroundOperator::apply (void)
vector<GroundOperator*> std::vector<GroundOperator*>
GroundOperator::getValidOps (ParfactorList& pfList) GroundOperator::getValidOps (ParfactorList& pfList)
{ {
vector<GroundOperator*> validOps; std::vector<GroundOperator*> validOps;
set<PrvGroup> allGroups; std::set<PrvGroup> allGroups;
ParfactorList::const_iterator it = pfList.begin(); ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) { while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments(); const ProbFormulas& formulas = (*it)->arguments();
@ -564,18 +707,18 @@ GroundOperator::getValidOps (ParfactorList& pfList)
string std::string
GroundOperator::toString (void) GroundOperator::toString()
{ {
stringstream ss; std::stringstream ss;
vector<ParfactorList::iterator> pfIters; std::vector<ParfactorList::iterator> pfIters;
pfIters = getParfactorsWithGroup (pfList_, group_); pfIters = getParfactorsWithGroup (pfList_, group_);
Parfactor* pf = *(getParfactorsWithGroup (pfList_, group_).front()); Parfactor* pf = *(getParfactorsWithGroup (pfList_, group_).front());
size_t idx = pf->indexOfGroup (group_); size_t idx = pf->indexOfGroup (group_);
ProbFormula f = pf->argument (idx); ProbFormula f = pf->argument (idx);
LogVar lv = f.logVars()[lvIndex_]; LogVar lv = f.logVars()[lvIndex_];
TupleSet tupleSet = pf->constr()->tupleSet ({lv}); TupleSet tupleSet = pf->constr()->tupleSet ({lv});
string pos = "th"; std::string pos = "th";
if (lvIndex_ == 0) { if (lvIndex_ == 0) {
pos = "st" ; pos = "st" ;
} else if (lvIndex_ == 1) { } else if (lvIndex_ == 1) {
@ -586,21 +729,21 @@ GroundOperator::toString (void)
ss << "grounding " << lvIndex_ + 1 << pos << " log var in " ; ss << "grounding " << lvIndex_ + 1 << pos << " log var in " ;
ss << f.functor() << "/" << f.arity(); ss << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")"; ss << "|" << tupleSet << " (group " << group_ << ")";
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl; ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
return ss.str(); return ss.str();
} }
vector<pair<PrvGroup, unsigned>> std::vector<std::pair<PrvGroup, unsigned>>
GroundOperator::getAffectedFormulas (void) GroundOperator::getAffectedFormulas()
{ {
vector<pair<PrvGroup, unsigned>> affectedFormulas; std::vector<std::pair<PrvGroup, unsigned>> affectedFormulas;
affectedFormulas.push_back (make_pair (group_, lvIndex_)); affectedFormulas.push_back (std::make_pair (group_, lvIndex_));
queue<pair<PrvGroup, unsigned>> q; std::queue<std::pair<PrvGroup, unsigned>> q;
q.push (make_pair (group_, lvIndex_)); q.push (std::make_pair (group_, lvIndex_));
while (q.empty() == false) { while (q.empty() == false) {
pair<PrvGroup, unsigned> front = q.front(); std::pair<PrvGroup, unsigned> front = q.front();
ParfactorList::iterator pflIt = pfList_.begin(); ParfactorList::iterator pflIt = pfList_.begin();
while (pflIt != pfList_.end()) { while (pflIt != pfList_.end()) {
size_t idx = (*pflIt)->indexOfGroup (front.first); size_t idx = (*pflIt)->indexOfGroup (front.first);
@ -610,7 +753,7 @@ GroundOperator::getAffectedFormulas (void)
const ProbFormulas& fs = (*pflIt)->arguments(); const ProbFormulas& fs = (*pflIt)->arguments();
for (size_t i = 0; i < fs.size(); i++) { for (size_t i = 0; i < fs.size(); i++) {
if (i != idx && fs[i].contains (X)) { if (i != idx && fs[i].contains (X)) {
pair<PrvGroup, unsigned> pair = make_pair ( std::pair<PrvGroup, unsigned> pair = std::make_pair (
fs[i].group(), fs[i].indexOf (X)); fs[i].group(), fs[i].indexOf (X));
if (Util::contains (affectedFormulas, pair) == false) { if (Util::contains (affectedFormulas, pair) == false) {
q.push (pair); q.push (pair);
@ -645,13 +788,13 @@ LiftedVe::solveQuery (const Grounds& query)
void void
LiftedVe::printSolverFlags (void) const LiftedVe::printSolverFlags() const
{ {
stringstream ss; std::stringstream ss;
ss << "lve [" ; ss << "lve [" ;
ss << "log_domain=" << Util::toString (Globals::logDomain); ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; std::cout << ss.str() << std::endl;
} }
@ -675,9 +818,9 @@ LiftedVe::runSolver (const Grounds& query)
break; break;
} }
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << "best operation: " << op->toString(); std::cout << "best operation: " << op->toString();
if (Globals::verbosity > 2) { if (Globals::verbosity > 2) {
cout << endl; std::cout << std::endl;
} }
} }
op->apply(); op->apply();
@ -693,8 +836,9 @@ LiftedVe::runSolver (const Grounds& query)
} }
} }
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
cout << "largest cost = " << std::exp (largestCost_) << endl; std::cout << "largest cost = " << std::exp (largestCost_);
cout << endl; std::cout << std::endl;
std::cout << std::endl;
} }
(*pfList_.begin())->simplifyGrounds(); (*pfList_.begin())->simplifyGrounds();
(*pfList_.begin())->reorderAccordingGrounds (query); (*pfList_.begin())->reorderAccordingGrounds (query);
@ -707,7 +851,7 @@ LiftedVe::getBestOperation (const Grounds& query)
{ {
double bestCost = 0.0; double bestCost = 0.0;
LiftedOperator* bestOp = 0; LiftedOperator* bestOp = 0;
vector<LiftedOperator*> validOps; std::vector<LiftedOperator*> validOps;
validOps = LiftedOperator::getValidOps (pfList_, query); validOps = LiftedOperator::getValidOps (pfList_, query);
for (size_t i = 0; i < validOps.size(); i++) { for (size_t i = 0; i < validOps.size(); i++) {
double cost = validOps[i]->getLogCost(); double cost = validOps[i]->getLogCost();
@ -727,3 +871,5 @@ LiftedVe::getBestOperation (const Grounds& query)
return bestOp; return bestOp;
} }
} // namespace Horus

View File

@ -1,157 +1,23 @@
#ifndef HORUS_LIFTEDVE_H #ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDVE_H_
#define HORUS_LIFTEDVE_H #define YAP_PACKAGES_CLPBN_HORUS_LIFTEDVE_H_
#include "LiftedSolver.h" #include "LiftedSolver.h"
#include "ParfactorList.h" #include "ParfactorList.h"
class LiftedOperator namespace Horus {
{
public:
virtual ~LiftedOperator (void) { }
virtual double getLogCost (void) = 0; class LiftedOperator;
virtual void apply (void) = 0;
virtual string toString (void) = 0;
static vector<LiftedOperator*> getValidOps (
ParfactorList&, const Grounds&);
static void printValidOps (ParfactorList&, const Grounds&);
static vector<ParfactorList::iterator> getParfactorsWithGroup (
ParfactorList&, PrvGroup group);
private:
DISALLOW_ASSIGN (LiftedOperator);
};
class LiftedVe : public LiftedSolver {
class ProductOperator : public LiftedOperator
{
public:
ProductOperator (
ParfactorList::iterator g1, ParfactorList::iterator g2,
ParfactorList& pfList) : g1_(g1), g2_(g2), pfList_(pfList) { }
double getLogCost (void);
void apply (void);
static vector<ProductOperator*> getValidOps (ParfactorList&);
string toString (void);
private:
static bool validOp (Parfactor*, Parfactor*);
ParfactorList::iterator g1_;
ParfactorList::iterator g2_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (ProductOperator);
};
class SumOutOperator : public LiftedOperator
{
public:
SumOutOperator (PrvGroup group, ParfactorList& pfList)
: group_(group), pfList_(pfList) { }
double getLogCost (void);
void apply (void);
static vector<SumOutOperator*> getValidOps (
ParfactorList&, const Grounds&);
string toString (void);
private:
static bool validOp (PrvGroup, ParfactorList&, const Grounds&);
static bool isToEliminate (Parfactor*, PrvGroup, const Grounds&);
PrvGroup group_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (SumOutOperator);
};
class CountingOperator : public LiftedOperator
{
public:
CountingOperator (
ParfactorList::iterator pfIter,
LogVar X,
ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
double getLogCost (void);
void apply (void);
static vector<CountingOperator*> getValidOps (ParfactorList&);
string toString (void);
private:
static bool validOp (Parfactor*, LogVar);
ParfactorList::iterator pfIter_;
LogVar X_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (CountingOperator);
};
class GroundOperator : public LiftedOperator
{
public:
GroundOperator (
PrvGroup group,
unsigned lvIndex,
ParfactorList& pfList)
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
double getLogCost (void);
void apply (void);
static vector<GroundOperator*> getValidOps (ParfactorList&);
string toString (void);
private:
vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void);
PrvGroup group_;
unsigned lvIndex_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (GroundOperator);
};
class LiftedVe : public LiftedSolver
{
public: public:
LiftedVe (const ParfactorList& pfList) LiftedVe (const ParfactorList& pfList)
: LiftedSolver(pfList) { } : LiftedSolver(pfList) { }
Params solveQuery (const Grounds&); Params solveQuery (const Grounds&);
void printSolverFlags (void) const; void printSolverFlags() const;
private: private:
void runSolver (const Grounds&); void runSolver (const Grounds&);
@ -164,5 +30,7 @@ class LiftedVe : public LiftedSolver
DISALLOW_COPY_AND_ASSIGN (LiftedVe); DISALLOW_COPY_AND_ASSIGN (LiftedVe);
}; };
#endif // HORUS_LIFTEDVE_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDVE_H_

View File

@ -1,10 +1,20 @@
#include <cassert>
#include <iostream>
#include <sstream>
#include "LiftedWCNF.h" #include "LiftedWCNF.h"
#include "ParfactorList.h"
#include "ConstraintTree.h" #include "ConstraintTree.h"
#include "Indexer.h" #include "Indexer.h"
namespace Horus {
bool bool
Literal::isGround (ConstraintTree constr, LogVarSet ipgLogVars) const Literal::isGround (
ConstraintTree constr,
const LogVarSet& ipgLogVars) const
{ {
if (logVars_.empty()) { if (logVars_.empty()) {
return true; return true;
@ -24,13 +34,13 @@ Literal::indexOfLogVar (LogVar X) const
string std::string
Literal::toString ( Literal::toString (
LogVarSet ipgLogVars, LogVarSet ipgLogVars,
LogVarSet posCountedLvs, LogVarSet posCountedLvs,
LogVarSet negCountedLvs) const LogVarSet negCountedLvs) const
{ {
stringstream ss; std::stringstream ss;
negated_ ? ss << "¬" : ss << "" ; negated_ ? ss << "¬" : ss << "" ;
ss << "λ" ; ss << "λ" ;
ss << lid_ ; ss << lid_ ;
@ -44,7 +54,7 @@ Literal::toString (
ss << "-" << logVars_[i]; ss << "-" << logVars_[i];
} else if (ipgLogVars.contains (logVars_[i])) { } else if (ipgLogVars.contains (logVars_[i])) {
LogVar X = logVars_[i]; LogVar X = logVars_[i];
const string labels[] = { const std::string labels[] = {
"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f",
"g", "h", "i", "j", "k", "m" }; "g", "h", "i", "j", "k", "m" };
(X >= 12) ? ss << "x_" << X : ss << labels[X]; (X >= 12) ? ss << "x_" << X : ss << labels[X];
@ -60,7 +70,7 @@ Literal::toString (
std::ostream& std::ostream&
operator<< (ostream &os, const Literal& lit) operator<< (std::ostream& os, const Literal& lit)
{ {
os << lit.toString(); os << lit.toString();
return os; return os;
@ -216,7 +226,7 @@ Clause::isIpgLogVar (LogVar X) const
TinySet<LiteralId> TinySet<LiteralId>
Clause::lidSet (void) const Clause::lidSet() const
{ {
TinySet<LiteralId> lidSet; TinySet<LiteralId> lidSet;
for (size_t i = 0; i < literals_.size(); i++) { for (size_t i = 0; i < literals_.size(); i++) {
@ -228,7 +238,7 @@ Clause::lidSet (void) const
LogVarSet LogVarSet
Clause::ipgCandidates (void) const Clause::ipgCandidates() const
{ {
LogVarSet candidates; LogVarSet candidates;
LogVarSet allLvs = constr_.logVarSet(); LogVarSet allLvs = constr_.logVarSet();
@ -259,11 +269,11 @@ Clause::logVarTypes (size_t litIdx) const
const LogVars& lvs = literals_[litIdx].logVars(); const LogVars& lvs = literals_[litIdx].logVars();
for (size_t i = 0; i < lvs.size(); i++) { for (size_t i = 0; i < lvs.size(); i++) {
if (posCountedLvs_.contains (lvs[i])) { if (posCountedLvs_.contains (lvs[i])) {
types.push_back (LogVarType::POS_LV); types.push_back (LogVarType::posLvt);
} else if (negCountedLvs_.contains (lvs[i])) { } else if (negCountedLvs_.contains (lvs[i])) {
types.push_back (LogVarType::NEG_LV); types.push_back (LogVarType::negLvt);
} else { } else {
types.push_back (LogVarType::FULL_LV); types.push_back (LogVarType::fullLvt);
} }
} }
return types; return types;
@ -320,7 +330,7 @@ void
Clause::printClauses (const Clauses& clauses) Clause::printClauses (const Clauses& clauses)
{ {
for (size_t i = 0; i < clauses.size(); i++) { for (size_t i = 0; i < clauses.size(); i++) {
cout << *clauses[i] << endl; std::cout << *clauses[i] << std::endl;
} }
} }
@ -337,7 +347,7 @@ Clause::deleteClauses (Clauses& clauses)
std::ostream& std::ostream&
operator<< (ostream &os, const Clause& clause) operator<< (std::ostream& os, const Clause& clause)
{ {
for (unsigned i = 0; i < clause.literals_.size(); i++) { for (unsigned i = 0; i < clause.literals_.size(); i++) {
if (i != 0) os << " v " ; if (i != 0) os << " v " ;
@ -374,9 +384,9 @@ operator<< (std::ostream &os, const LitLvTypes& lit)
os << lit.lid_ << "<" ; os << lit.lid_ << "<" ;
for (size_t i = 0; i < lit.lvTypes_.size(); i++) { for (size_t i = 0; i < lit.lvTypes_.size(); i++) {
switch (lit.lvTypes_[i]) { switch (lit.lvTypes_[i]) {
case LogVarType::FULL_LV: os << "F" ; break; case LogVarType::fullLvt: os << "F" ; break;
case LogVarType::POS_LV: os << "P" ; break; case LogVarType::posLvt: os << "P" ; break;
case LogVarType::NEG_LV: os << "N" ; break; case LogVarType::negLvt: os << "N" ; break;
} }
} }
os << ">" ; os << ">" ;
@ -385,6 +395,14 @@ operator<< (std::ostream &os, const LitLvTypes& lit)
void
LitLvTypes::setAllFullLogVars()
{
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::fullLvt);
}
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList) LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
: freeLiteralId_(0), pfList_(pfList) : freeLiteralId_(0), pfList_(pfList)
{ {
@ -394,7 +412,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
/* /*
// INCLUSION-EXCLUSION TEST // INCLUSION-EXCLUSION TEST
clauses_.clear(); clauses_.clear();
vector<vector<string>> names = { std::vector<std::vector<string>> names = {
{"a1","b1"},{"a2","b2"} {"a1","b1"},{"a2","b2"}
}; };
Clause* c1 = new Clause (names); Clause* c1 = new Clause (names);
@ -406,7 +424,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
/* /*
// INDEPENDENT PARTIAL GROUND TEST // INDEPENDENT PARTIAL GROUND TEST
clauses_.clear(); clauses_.clear();
vector<vector<string>> names = { std::vector<std::vector<string>> names = {
{"a1","b1"},{"a2","b2"} {"a1","b1"},{"a2","b2"}
}; };
Clause* c1 = new Clause (names); Clause* c1 = new Clause (names);
@ -422,7 +440,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
/* /*
// ATOM-COUNTING TEST // ATOM-COUNTING TEST
clauses_.clear(); clauses_.clear();
vector<vector<string>> names = { std::vector<std::vector<string>> names = {
{"p1","p1"},{"p1","p2"},{"p1","p3"}, {"p1","p1"},{"p1","p2"},{"p1","p3"},
{"p2","p1"},{"p2","p2"},{"p2","p3"}, {"p2","p1"},{"p2","p2"},{"p2","p3"},
{"p3","p1"},{"p3","p2"},{"p3","p3"} {"p3","p1"},{"p3","p2"},{"p3","p3"}
@ -438,21 +456,21 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
*/ */
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << "FORMULA INDICATORS:" << endl; std::cout << "FORMULA INDICATORS:" << std::endl;
printFormulaIndicators(); printFormulaIndicators();
cout << endl; std::cout << std::endl;
cout << "WEIGHTED INDICATORS:" << endl; std::cout << "WEIGHTED INDICATORS:" << std::endl;
printWeights(); printWeights();
cout << endl; std::cout << std::endl;
cout << "CLAUSES:" << endl; std::cout << "CLAUSES:" << std::endl;
printClauses(); printClauses();
cout << endl; std::cout << std::endl;
} }
} }
LiftedWCNF::~LiftedWCNF (void) LiftedWCNF::~LiftedWCNF()
{ {
Clause::deleteClauses (clauses_); Clause::deleteClauses (clauses_);
} }
@ -462,7 +480,7 @@ LiftedWCNF::~LiftedWCNF (void)
void void
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW) LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
{ {
weights_[lid] = make_pair (posW, negW); weights_[lid] = std::make_pair (posW, negW);
} }
@ -470,8 +488,8 @@ LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
double double
LiftedWCNF::posWeight (LiteralId lid) const LiftedWCNF::posWeight (LiteralId lid) const
{ {
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it; std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
it = weights_.find (lid); = weights_.find (lid);
return it != weights_.end() ? it->second.first : LogAware::one(); return it != weights_.end() ? it->second.first : LogAware::one();
} }
@ -480,14 +498,14 @@ LiftedWCNF::posWeight (LiteralId lid) const
double double
LiftedWCNF::negWeight (LiteralId lid) const LiftedWCNF::negWeight (LiteralId lid) const
{ {
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it; std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
it = weights_.find (lid); = weights_.find (lid);
return it != weights_.end() ? it->second.second : LogAware::one(); return it != weights_.end() ? it->second.second : LogAware::one();
} }
vector<LiteralId> std::vector<LiteralId>
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup) LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
{ {
assert (Util::contains (map_, prvGroup)); assert (Util::contains (map_, prvGroup));
@ -536,9 +554,10 @@ LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
ConstraintTree tempConstr = (*it)->constr()->projectedCopy( ConstraintTree tempConstr = (*it)->constr()->projectedCopy(
formulas[i].logVars()); formulas[i].logVars());
Clause* clause = new Clause (tempConstr); Clause* clause = new Clause (tempConstr);
vector<LiteralId> lids; std::vector<LiteralId> lids;
for (size_t j = 0; j < formulas[i].range(); j++) { for (size_t j = 0; j < formulas[i].range(); j++) {
clause->addLiteral (Literal (freeLiteralId_, formulas[i].logVars())); clause->addLiteral (Literal (
freeLiteralId_, formulas[i].logVars()));
lids.push_back (freeLiteralId_); lids.push_back (freeLiteralId_);
freeLiteralId_ ++; freeLiteralId_ ++;
} }
@ -568,7 +587,7 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
ParfactorList::const_iterator it = pfList.begin(); ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) { while (it != pfList.end()) {
Indexer indexer ((*it)->ranges()); Indexer indexer ((*it)->ranges());
vector<PrvGroup> groups = (*it)->getAllGroups(); std::vector<PrvGroup> groups = (*it)->getAllGroups();
while (indexer.valid()) { while (indexer.valid()) {
LiteralId paramVarLid = freeLiteralId_; LiteralId paramVarLid = freeLiteralId_;
// λu1 ∧ ... ∧ λun ∧ λxi <=> θxi|u1,...,un // λu1 ∧ ... ∧ λun ∧ λxi <=> θxi|u1,...,un
@ -606,26 +625,26 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
void void
LiftedWCNF::printFormulaIndicators (void) const LiftedWCNF::printFormulaIndicators() const
{ {
if (map_.empty()) { if (map_.empty()) {
return; return;
} }
set<PrvGroup> allGroups; std::set<PrvGroup> allGroups;
ParfactorList::const_iterator it = pfList_.begin(); ParfactorList::const_iterator it = pfList_.begin();
while (it != pfList_.end()) { while (it != pfList_.end()) {
const ProbFormulas& formulas = (*it)->arguments(); const ProbFormulas& formulas = (*it)->arguments();
for (size_t i = 0; i < formulas.size(); i++) { for (size_t i = 0; i < formulas.size(); i++) {
if (Util::contains (allGroups, formulas[i].group()) == false) { if (Util::contains (allGroups, formulas[i].group()) == false) {
allGroups.insert (formulas[i].group()); allGroups.insert (formulas[i].group());
cout << formulas[i] << " | " ; std::cout << formulas[i] << " | " ;
ConstraintTree tempCt = (*it)->constr()->projectedCopy ( ConstraintTree tempCt = (*it)->constr()->projectedCopy (
formulas[i].logVars()); formulas[i].logVars());
cout << tempCt.tupleSet(); std::cout << tempCt.tupleSet();
cout << " indicators => " ; std::cout << " indicators => " ;
vector<LiteralId> indicators = std::vector<LiteralId> indicators =
(map_.find (formulas[i].group()))->second; (map_.find (formulas[i].group()))->second;
cout << indicators << endl; std::cout << indicators << std::endl;
} }
} }
++ it; ++ it;
@ -635,14 +654,14 @@ LiftedWCNF::printFormulaIndicators (void) const
void void
LiftedWCNF::printWeights (void) const LiftedWCNF::printWeights() const
{ {
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it; std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
it = weights_.begin(); = weights_.begin();
while (it != weights_.end()) { while (it != weights_.end()) {
cout << "λ" << it->first << " weights: " ; std::cout << "λ" << it->first << " weights: " ;
cout << it->second.first << " " << it->second.second; std::cout << it->second.first << " " << it->second.second;
cout << endl; std::cout << std::endl;
++ it; ++ it;
} }
} }
@ -650,8 +669,10 @@ LiftedWCNF::printWeights (void) const
void void
LiftedWCNF::printClauses (void) const LiftedWCNF::printClauses() const
{ {
Clause::printClauses (clauses_); Clause::printClauses (clauses_);
} }
}

View File

@ -1,90 +1,95 @@
#ifndef HORUS_LIFTEDWCNF_H #ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDWCNF_H_
#define HORUS_LIFTEDWCNF_H #define YAP_PACKAGES_CLPBN_HORUS_LIFTEDWCNF_H_
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <string>
#include <ostream>
#include "ParfactorList.h" #include "ConstraintTree.h"
#include "ProbFormula.h"
#include "LiftedUtils.h"
using namespace std;
class ConstraintTree; namespace Horus {
enum LogVarType class ParfactorList;
{
FULL_LV, enum class LogVarType {
POS_LV, fullLvt,
NEG_LV posLvt,
negLvt
}; };
typedef long LiteralId; typedef long LiteralId;
typedef vector<LogVarType> LogVarTypes; typedef std::vector<LogVarType> LogVarTypes;
class Literal class Literal {
{
public: public:
Literal (LiteralId lid, const LogVars& lvs) : Literal (LiteralId lid, const LogVars& lvs)
lid_(lid), logVars_(lvs), negated_(false) { } : lid_(lid), logVars_(lvs), negated_(false) { }
Literal (const Literal& lit, bool negated) : Literal (const Literal& lit, bool negated)
lid_(lit.lid_), logVars_(lit.logVars_), negated_(negated) { } : lid_(lit.lid_), logVars_(lit.logVars_), negated_(negated) { }
LiteralId lid (void) const { return lid_; } LiteralId lid() const { return lid_; }
LogVars logVars (void) const { return logVars_; } LogVars logVars() const { return logVars_; }
size_t nrLogVars (void) const { return logVars_.size(); } size_t nrLogVars() const { return logVars_.size(); }
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); } LogVarSet logVarSet() const { return LogVarSet (logVars_); }
void complement (void) { negated_ = !negated_; } void complement() { negated_ = !negated_; }
bool isPositive (void) const { return negated_ == false; } bool isPositive() const { return negated_ == false; }
bool isNegative (void) const { return negated_; } bool isNegative() const { return negated_; }
bool isGround (ConstraintTree constr, LogVarSet ipgLogVars) const; bool isGround (ConstraintTree constr, const LogVarSet& ipgLogVars) const;
size_t indexOfLogVar (LogVar X) const; size_t indexOfLogVar (LogVar X) const;
string toString (LogVarSet ipgLogVars = LogVarSet(), std::string toString (
LogVarSet ipgLogVars = LogVarSet(),
LogVarSet posCountedLvs = LogVarSet(), LogVarSet posCountedLvs = LogVarSet(),
LogVarSet negCountedLvs = LogVarSet()) const; LogVarSet negCountedLvs = LogVarSet()) const;
friend std::ostream& operator<< (std::ostream &os, const Literal& lit);
private: private:
friend std::ostream& operator<< (std::ostream&, const Literal&);
LiteralId lid_; LiteralId lid_;
LogVars logVars_; LogVars logVars_;
bool negated_; bool negated_;
}; };
typedef vector<Literal> Literals; typedef std::vector<Literal> Literals;
class Clause class Clause {
{
public: public:
Clause (const ConstraintTree& ct = ConstraintTree({})) : constr_(ct) { } Clause (const ConstraintTree& ct = ConstraintTree({})) : constr_(ct) { }
Clause (vector<vector<string>> names) : constr_(ConstraintTree (names)) { } Clause (std::vector<std::vector<std::string>> names) :
constr_(ConstraintTree (names)) { }
void addLiteral (const Literal& l) { literals_.push_back (l); } void addLiteral (const Literal& l) { literals_.push_back (l); }
const Literals& literals (void) const { return literals_; } const Literals& literals() const { return literals_; }
Literals& literals (void) { return literals_; } Literals& literals() { return literals_; }
size_t nrLiterals (void) const { return literals_.size(); } size_t nrLiterals() const { return literals_.size(); }
const ConstraintTree& constr (void) const { return constr_; } const ConstraintTree& constr() const { return constr_; }
ConstraintTree constr (void) { return constr_; } ConstraintTree constr() { return constr_; }
bool isUnit (void) const { return literals_.size() == 1; } bool isUnit() const { return literals_.size() == 1; }
LogVarSet ipgLogVars (void) const { return ipgLvs_; } LogVarSet ipgLogVars() const { return ipgLvs_; }
void addIpgLogVar (LogVar X) { ipgLvs_.insert (X); } void addIpgLogVar (LogVar X) { ipgLvs_.insert (X); }
@ -92,13 +97,13 @@ class Clause
void addNegCountedLogVar (LogVar X) { negCountedLvs_.insert (X); } void addNegCountedLogVar (LogVar X) { negCountedLvs_.insert (X); }
LogVarSet posCountedLogVars (void) const { return posCountedLvs_; } LogVarSet posCountedLogVars() const { return posCountedLvs_; }
LogVarSet negCountedLogVars (void) const { return negCountedLvs_; } LogVarSet negCountedLogVars() const { return negCountedLvs_; }
unsigned nrPosCountedLogVars (void) const { return posCountedLvs_.size(); } unsigned nrPosCountedLogVars() const { return posCountedLvs_.size(); }
unsigned nrNegCountedLogVars (void) const { return negCountedLvs_.size(); } unsigned nrNegCountedLogVars() const { return negCountedLvs_.size(); }
void addLiteralComplemented (const Literal& lit); void addLiteralComplemented (const Literal& lit);
@ -122,9 +127,9 @@ class Clause
bool isIpgLogVar (LogVar X) const; bool isIpgLogVar (LogVar X) const;
TinySet<LiteralId> lidSet (void) const; TinySet<LiteralId> lidSet() const;
LogVarSet ipgCandidates (void) const; LogVarSet ipgCandidates() const;
LogVarTypes logVarTypes (size_t litIdx) const; LogVarTypes logVarTypes (size_t litIdx) const;
@ -132,17 +137,18 @@ class Clause
static bool independentClauses (Clause& c1, Clause& c2); static bool independentClauses (Clause& c1, Clause& c2);
static vector<Clause*> copyClauses (const vector<Clause*>& clauses); static std::vector<Clause*> copyClauses (
const std::vector<Clause*>& clauses);
static void printClauses (const vector<Clause*>& clauses); static void printClauses (const std::vector<Clause*>& clauses);
static void deleteClauses (vector<Clause*>& clauses); static void deleteClauses (std::vector<Clause*>& clauses);
friend std::ostream& operator<< (ostream &os, const Clause& clause);
private: private:
LogVarSet getLogVarSetExcluding (size_t idx) const; LogVarSet getLogVarSetExcluding (size_t idx) const;
friend std::ostream& operator<< (std::ostream&, const Clause&);
Literals literals_; Literals literals_;
LogVarSet ipgLvs_; LogVarSet ipgLvs_;
LogVarSet posCountedLvs_; LogVarSet posCountedLvs_;
@ -152,58 +158,57 @@ class Clause
DISALLOW_ASSIGN (Clause); DISALLOW_ASSIGN (Clause);
}; };
typedef vector<Clause*> Clauses; typedef std::vector<Clause*> Clauses;
class LitLvTypes class LitLvTypes {
{
public: public:
struct CompareLitLvTypes LitLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
lid_(lid), lvTypes_(lvTypes) { }
LiteralId lid() const { return lid_; }
const LogVarTypes& logVarTypes() const { return lvTypes_; }
void setAllFullLogVars();
private:
friend std::ostream& operator<< (std::ostream&, const LitLvTypes&);
LiteralId lid_;
LogVarTypes lvTypes_;
};
struct CmpLitLvTypes
{ {
bool operator() ( bool operator() (
const LitLvTypes& types1, const LitLvTypes& types1,
const LitLvTypes& types2) const const LitLvTypes& types2) const
{ {
if (types1.lid_ < types2.lid_) { if (types1.lid() < types2.lid()) {
return true; return true;
} }
if (types1.lid_ == types2.lid_) { if (types1.lid() == types2.lid()){
return types1.lvTypes_ < types2.lvTypes_; return types1.logVarTypes() < types2.logVarTypes();
} }
return false; return false;
} }
}; };
LitLvTypes (LiteralId lid, const LogVarTypes& lvTypes) : typedef TinySet<LitLvTypes, CmpLitLvTypes> LitLvTypesSet;
lid_(lid), lvTypes_(lvTypes) { }
LiteralId lid (void) const { return lid_; }
const LogVarTypes& logVarTypes (void) const { return lvTypes_; }
void setAllFullLogVars (void) {
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV); }
friend std::ostream& operator<< (std::ostream &os, const LitLvTypes& lit);
private:
LiteralId lid_;
LogVarTypes lvTypes_;
};
typedef TinySet<LitLvTypes,LitLvTypes::CompareLitLvTypes> LitLvTypesSet;
class LiftedWCNF class LiftedWCNF {
{
public: public:
LiftedWCNF (const ParfactorList& pfList); LiftedWCNF (const ParfactorList& pfList);
~LiftedWCNF (void); ~LiftedWCNF();
const Clauses& clauses (void) const { return clauses_; } const Clauses& clauses() const { return clauses_; }
void addWeight (LiteralId lid, double posW, double negW); void addWeight (LiteralId lid, double posW, double negW);
@ -211,15 +216,15 @@ class LiftedWCNF
double negWeight (LiteralId lid) const; double negWeight (LiteralId lid) const;
vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup); std::vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
Clause* createClause (LiteralId lid) const; Clause* createClause (LiteralId lid) const;
void printFormulaIndicators (void) const; void printFormulaIndicators() const;
void printWeights (void) const; void printWeights() const;
void printClauses (void) const; void printClauses() const;
private: private:
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range); LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
@ -231,11 +236,13 @@ class LiftedWCNF
Clauses clauses_; Clauses clauses_;
LiteralId freeLiteralId_; LiteralId freeLiteralId_;
const ParfactorList& pfList_; const ParfactorList& pfList_;
unordered_map<PrvGroup, vector<LiteralId>> map_; std::unordered_map<PrvGroup, std::vector<LiteralId>> map_;
unordered_map<LiteralId, std::pair<double,double>> weights_; std::unordered_map<LiteralId, std::pair<double,double>> weights_;
DISALLOW_COPY_AND_ASSIGN (LiftedWCNF); DISALLOW_COPY_AND_ASSIGN (LiftedWCNF);
}; };
#endif // HORUS_LIFTEDWCNF_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDWCNF_H_

View File

@ -43,9 +43,9 @@ SO=@SO@
#4.1VPATH=@srcdir@:@srcdir@/OPTYap #4.1VPATH=@srcdir@:@srcdir@/OPTYap
CWD=$(PWD) CWD=$(PWD)
HCLI = $(srcdir)/hcli utestsdir=@srcdir@/unit_tests
HEADERS = \ MAIN_HEADERS = \
$(srcdir)/BayesBall.h \ $(srcdir)/BayesBall.h \
$(srcdir)/BayesBallGraph.h \ $(srcdir)/BayesBallGraph.h \
$(srcdir)/BeliefProp.h \ $(srcdir)/BeliefProp.h \
@ -54,6 +54,8 @@ HEADERS = \
$(srcdir)/ElimGraph.h \ $(srcdir)/ElimGraph.h \
$(srcdir)/Factor.h \ $(srcdir)/Factor.h \
$(srcdir)/FactorGraph.h \ $(srcdir)/FactorGraph.h \
$(srcdir)/GenericFactor.h \
$(srcdir)/GroundSolver.h \
$(srcdir)/Histogram.h \ $(srcdir)/Histogram.h \
$(srcdir)/Horus.h \ $(srcdir)/Horus.h \
$(srcdir)/Indexer.h \ $(srcdir)/Indexer.h \
@ -67,14 +69,20 @@ HEADERS = \
$(srcdir)/Parfactor.h \ $(srcdir)/Parfactor.h \
$(srcdir)/ParfactorList.h \ $(srcdir)/ParfactorList.h \
$(srcdir)/ProbFormula.h \ $(srcdir)/ProbFormula.h \
$(srcdir)/GroundSolver.h \
$(srcdir)/TinySet.h \ $(srcdir)/TinySet.h \
$(srcdir)/Util.h \ $(srcdir)/Util.h \
$(srcdir)/Var.h \ $(srcdir)/Var.h \
$(srcdir)/VarElim.h \ $(srcdir)/VarElim.h \
$(srcdir)/WeightedBp.h $(srcdir)/WeightedBp.h
CPP_SOURCES = \ UTESTS_HEADERS = \
$(utestsdir)/Common.h
HEADERS = \
$(MAIN_HEADERS) \
$(UTESTS_HEADERS)
MAIN_SOURCES = \
$(srcdir)/BayesBall.cpp \ $(srcdir)/BayesBall.cpp \
$(srcdir)/BayesBallGraph.cpp \ $(srcdir)/BayesBallGraph.cpp \
$(srcdir)/BeliefProp.cpp \ $(srcdir)/BeliefProp.cpp \
@ -83,9 +91,12 @@ CPP_SOURCES = \
$(srcdir)/ElimGraph.cpp \ $(srcdir)/ElimGraph.cpp \
$(srcdir)/Factor.cpp \ $(srcdir)/Factor.cpp \
$(srcdir)/FactorGraph.cpp \ $(srcdir)/FactorGraph.cpp \
$(srcdir)/GenericFactor.cpp \
$(srcdir)/GroundSolver.cpp \
$(srcdir)/Histogram.cpp \ $(srcdir)/Histogram.cpp \
$(srcdir)/HorusCli.cpp \ $(srcdir)/HorusCli.cpp \
$(srcdir)/HorusYap.cpp \ $(srcdir)/HorusYap.cpp \
$(srcdir)/Indexer.cpp \
$(srcdir)/LiftedBp.cpp \ $(srcdir)/LiftedBp.cpp \
$(srcdir)/LiftedKc.cpp \ $(srcdir)/LiftedKc.cpp \
$(srcdir)/LiftedOperations.cpp \ $(srcdir)/LiftedOperations.cpp \
@ -95,12 +106,23 @@ CPP_SOURCES = \
$(srcdir)/Parfactor.cpp \ $(srcdir)/Parfactor.cpp \
$(srcdir)/ParfactorList.cpp \ $(srcdir)/ParfactorList.cpp \
$(srcdir)/ProbFormula.cpp \ $(srcdir)/ProbFormula.cpp \
$(srcdir)/GroundSolver.cpp \
$(srcdir)/Util.cpp \ $(srcdir)/Util.cpp \
$(srcdir)/Var.cpp \ $(srcdir)/Var.cpp \
$(srcdir)/VarElim.cpp \ $(srcdir)/VarElim.cpp \
$(srcdir)/WeightedBp.cpp $(srcdir)/WeightedBp.cpp
UTESTS_SOURCES = \
$(utestsdir)/BeliefPropTest.cpp \
$(utestsdir)/Common.cpp \
$(utestsdir)/CountingBpTest.cpp \
$(utestsdir)/FactorTest.cpp \
$(utestsdir)/VarElimTest.cpp \
$(utestsdir)/UnitTesting.cpp
SOURCES = \
$(MAIN_SOURCES) \
$(UTESTS_SOURCES)
OBJS = \ OBJS = \
BayesBall.o \ BayesBall.o \
BayesBallGraph.o \ BayesBallGraph.o \
@ -110,8 +132,10 @@ OBJS = \
ElimGraph.o \ ElimGraph.o \
Factor.o \ Factor.o \
FactorGraph.o \ FactorGraph.o \
GenericFactor.o \
GroundSolver.o \
Histogram.o \ Histogram.o \
HorusYap.o \ Indexer.o \
LiftedBp.o \ LiftedBp.o \
LiftedKc.o \ LiftedKc.o \
LiftedOperations.o \ LiftedOperations.o \
@ -121,12 +145,15 @@ OBJS = \
ProbFormula.o \ ProbFormula.o \
Parfactor.o \ Parfactor.o \
ParfactorList.o \ ParfactorList.o \
GroundSolver.o \
Util.o \ Util.o \
Var.o \ Var.o \
VarElim.o \ VarElim.o \
WeightedBp.o WeightedBp.o
LIB_OBJS = \
$(OBJS) \
HorusYap.o
HCLI_OBJS = \ HCLI_OBJS = \
BayesBall.o \ BayesBall.o \
BayesBallGraph.o \ BayesBallGraph.o \
@ -135,51 +162,82 @@ HCLI_OBJS = \
ElimGraph.o \ ElimGraph.o \
Factor.o \ Factor.o \
FactorGraph.o \ FactorGraph.o \
HorusCli.o \ GenericFactor.o \
GroundSolver.o \ GroundSolver.o \
HorusCli.o \
Indexer.o \
Util.o \ Util.o \
Var.o \ Var.o \
VarElim.o \ VarElim.o \
WeightedBp.o WeightedBp.o
SOBJS=horus.@SO@ UTESTS_OBJS = \
$(OBJS) \
$(utestsdir)/BeliefPropTest.o \
$(utestsdir)/Common.o \
$(utestsdir)/CountingBpTest.o \
$(utestsdir)/FactorTest.o \
$(utestsdir)/VarElimTest.o \
$(utestsdir)/UnitTesting.o
all: $(SOBJS) hcli LIB = $(srcdir)/horus.@SO@
HCLI = $(srcdir)/hcli
UTESTING = $(srcdir)/run_tests
all: $(LIB) $(HCLI)
# Don't require $(UTESTING) by default as we
# don't want a hard dependency on CppUnit
with_tests: $(LIB) $(HCLI) $(UTESTING)
@DO_SECOND_LD@$(LIB): $(LIB_OBJS)
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o $@ $(LIB_OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
$(HCLI): $(HCLI_OBJS)
$(CXX) -o $@ $(HCLI_OBJS)
$(UTESTING): $(UTESTS_OBJS)
$(CXX) -o $@ $(UTESTS_OBJS) -lcppunit
# default rule # default rule
%.o : $(srcdir)/%.cpp %.o : $(srcdir)/%.cpp
$(CXX) -c $(CXXFLAGS) $< -o $@ $(CXX) -o $@ -c $(CXXFLAGS) $<
@DO_SECOND_LD@horus.@SO@: $(OBJS)
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
hcli: $(HCLI_OBJS)
$(CXX) -o $(HCLI) $(HCLI_OBJS)
install: all install: all
$(INSTALL_PROGRAM) $(SOBJS) $(DESTDIR)$(YAPLIBDIR) $(INSTALL_PROGRAM) $(LIB) $(DESTDIR)$(YAPLIBDIR)
$(INSTALL_PROGRAM) $(HCLI) $(DESTDIR)$(BINDIR) $(INSTALL_PROGRAM) $(HCLI) $(DESTDIR)$(BINDIR)
clean: clean:
rm -f *.o *~ $(OBJS) $(SOBJS) $(HCLI) *.BAK rm -f $(LIB) $(HCLI) $(UTESTING) *.o *~ $(utestsdir)/*.o $(utestsdir)/*~
erase_dots: remove_dots:
rm -f *.dot *.png rm -f *.dot *.png *.svg
depend: $(HEADERS) $(CPP_SOURCES) depend: $(SOURCES) $(HEADERS)
-@if test "$(GCC)" = yes; then\ -@if test "$(GCC)" = yes; then\
$(CC) -std=c++0x -MM -MG $(CFLAGS) -I$(srcdir) -I$(srcdir)/../../../../include -I$(srcdir)/../../../../H $(CPP_SOURCES) >> Makefile;\ for F in $(SOURCES); do \
D=`dirname $$F`; \
B=`basename $$F .cpp`; \
$(CXX) $(CXXFLAGS) -MM -MG -MT "$$D/$$B.o" -I$(srcdir)/../../../../H -I$(srcdir)/../../../../include $$F >> Makefile; \
done; \
else\ else\
makedepend -f - -- $(CFLAGS) -I$(srcdir)/../../../../H -I$(srcdir)/../../../../include -- $(CPP_SOURCES) |\ makedepend -- $(CXXFLAGS) -- -I$(srcdir)/../../../../H -I$(srcdir)/../../../../include $(SOURCES); \
sed 's|.*/\([^:]*\):|\1:|' >> Makefile ;\
fi fi
.PHONY: default all install clean remove_dots depend
# DO NOT DELETE THIS LINE -- make depend depends on it. # DO NOT DELETE THIS LINE -- make depend depends on it.

View File

@ -1,3 +1,8 @@
#include <cassert>
#include <iostream>
#include <sstream>
#include "Parfactor.h" #include "Parfactor.h"
#include "Histogram.h" #include "Histogram.h"
#include "Indexer.h" #include "Indexer.h"
@ -5,6 +10,8 @@
#include "Horus.h" #include "Horus.h"
namespace Horus {
Parfactor::Parfactor ( Parfactor::Parfactor (
const ProbFormulas& formulas, const ProbFormulas& formulas,
const Params& params, const Params& params,
@ -84,7 +91,7 @@ Parfactor::Parfactor (const Parfactor& g)
Parfactor::~Parfactor (void) Parfactor::~Parfactor()
{ {
delete constr_; delete constr_;
} }
@ -92,7 +99,7 @@ Parfactor::~Parfactor (void)
LogVarSet LogVarSet
Parfactor::countedLogVars (void) const Parfactor::countedLogVars() const
{ {
LogVarSet set; LogVarSet set;
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
@ -106,7 +113,7 @@ Parfactor::countedLogVars (void) const
LogVarSet LogVarSet
Parfactor::uncountedLogVars (void) const Parfactor::uncountedLogVars() const
{ {
return constr_->logVarSet() - countedLogVars(); return constr_->logVarSet() - countedLogVars();
} }
@ -114,7 +121,7 @@ Parfactor::uncountedLogVars (void) const
LogVarSet LogVarSet
Parfactor::elimLogVars (void) const Parfactor::elimLogVars() const
{ {
LogVarSet requiredToElim = constr_->logVarSet(); LogVarSet requiredToElim = constr_->logVarSet();
requiredToElim -= constr_->singletons(); requiredToElim -= constr_->singletons();
@ -149,7 +156,7 @@ Parfactor::sumOutIndex (size_t fIdx)
unsigned N = constr_->getConditionalCount ( unsigned N = constr_->getConditionalCount (
args_[fIdx].countedLogVar()); args_[fIdx].countedLogVar());
unsigned R = args_[fIdx].range(); unsigned R = args_[fIdx].range();
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R); std::vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
Indexer indexer (ranges_, fIdx); Indexer indexer (ranges_, fIdx);
while (indexer.valid()) { while (indexer.valid()) {
if (Globals::logDomain) { if (Globals::logDomain) {
@ -171,7 +178,7 @@ Parfactor::sumOutIndex (size_t fIdx)
} }
constr_->remove (excl); constr_->remove (excl);
TFactor<ProbFormula>::sumOutIndex (fIdx); GenericFactor<ProbFormula>::sumOutIndex (fIdx);
LogAware::pow (params_, exp); LogAware::pow (params_, exp);
} }
@ -181,7 +188,7 @@ void
Parfactor::multiply (Parfactor& g) Parfactor::multiply (Parfactor& g)
{ {
alignAndExponentiate (this, &g); alignAndExponentiate (this, &g);
TFactor<ProbFormula>::multiply (g); GenericFactor<ProbFormula>::multiply (g);
constr_->join (g.constr(), true); constr_->join (g.constr(), true);
simplifyGrounds(); simplifyGrounds();
assert (constr_->isCartesianProduct (countedLogVars())); assert (constr_->isCartesianProduct (countedLogVars()));
@ -224,10 +231,10 @@ Parfactor::countConvert (LogVar X)
unsigned N = constr_->getConditionalCount (X); unsigned N = constr_->getConditionalCount (X);
unsigned R = ranges_[fIdx]; unsigned R = ranges_[fIdx];
unsigned H = HistogramSet::nrHistograms (N, R); unsigned H = HistogramSet::nrHistograms (N, R);
vector<Histogram> histograms = HistogramSet::getHistograms (N, R); std::vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
Indexer indexer (ranges_); Indexer indexer (ranges_);
vector<Params> sumout (params_.size() / R); std::vector<Params> sumout (params_.size() / R);
unsigned count = 0; unsigned count = 0;
while (indexer.valid()) { while (indexer.valid()) {
sumout[count].reserve (R); sumout[count].reserve (R);
@ -279,11 +286,11 @@ Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
unsigned H1 = HistogramSet::nrHistograms (N1, R); unsigned H1 = HistogramSet::nrHistograms (N1, R);
unsigned H2 = HistogramSet::nrHistograms (N2, R); unsigned H2 = HistogramSet::nrHistograms (N2, R);
vector<Histogram> histograms = HistogramSet::getHistograms (N, R); std::vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R); std::vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
vector<Histogram> histograms2 = HistogramSet::getHistograms (N2, R); std::vector<Histogram> histograms2 = HistogramSet::getHistograms (N2, R);
vector<unsigned> sumIndexes; std::vector<unsigned> sumIndexes;
sumIndexes.reserve (H1 * H2); sumIndexes.reserve (H1 * H2);
for (unsigned i = 0; i < H1; i++) { for (unsigned i = 0; i < H1; i++) {
for (unsigned j = 0; j < H2; j++) { for (unsigned j = 0; j < H2; j++) {
@ -319,16 +326,16 @@ Parfactor::fullExpand (LogVar X)
unsigned N = constr_->getConditionalCount (X); unsigned N = constr_->getConditionalCount (X);
unsigned R = args_[fIdx].range(); unsigned R = args_[fIdx].range();
vector<Histogram> originHists = HistogramSet::getHistograms (N, R); std::vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R); std::vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
assert (ranges_[fIdx] == originHists.size()); assert (ranges_[fIdx] == originHists.size());
vector<unsigned> sumIndexes; std::vector<unsigned> sumIndexes;
sumIndexes.reserve (N * R); sumIndexes.reserve (N * R);
Ranges expandRanges (N, R); Ranges expandRanges (N, R);
Indexer indexer (expandRanges); Indexer indexer (expandRanges);
while (indexer.valid()) { while (indexer.valid()) {
vector<unsigned> hist (R, 0); std::vector<unsigned> hist (R, 0);
for (unsigned n = 0; n < N; n++) { for (unsigned n = 0; n < N; n++) {
hist += expandHists[indexer[n]]; hist += expandHists[indexer[n]];
} }
@ -384,14 +391,14 @@ Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
assert (args_[fIdx].isCounting() == false); assert (args_[fIdx].isCounting() == false);
assert (constr_->isCountNormalized (excl)); assert (constr_->isCountNormalized (excl));
LogAware::pow (params_, constr_->getConditionalCount (excl)); LogAware::pow (params_, constr_->getConditionalCount (excl));
TFactor<ProbFormula>::absorveEvidence (formula, evidence); GenericFactor<ProbFormula>::absorveEvidence (formula, evidence);
constr_->remove (excl); constr_->remove (excl);
} }
void void
Parfactor::setNewGroups (void) Parfactor::setNewGroups()
{ {
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
args_[i].setGroup (ProbFormula::getNewGroup()); args_[i].setGroup (ProbFormula::getNewGroup());
@ -494,7 +501,7 @@ Parfactor::containsGroup (PrvGroup group) const
bool bool
Parfactor::containsGroups (vector<PrvGroup> groups) const Parfactor::containsGroups (std::vector<PrvGroup> groups) const
{ {
for (size_t i = 0; i < groups.size(); i++) { for (size_t i = 0; i < groups.size(); i++) {
if (containsGroup (groups[i]) == false) { if (containsGroup (groups[i]) == false) {
@ -565,10 +572,10 @@ Parfactor::nrFormulasWithGroup (PrvGroup group) const
vector<PrvGroup> std::vector<PrvGroup>
Parfactor::getAllGroups (void) const Parfactor::getAllGroups() const
{ {
vector<PrvGroup> groups (args_.size()); std::vector<PrvGroup> groups (args_.size());
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
groups[i] = args_[i].group(); groups[i] = args_[i].group();
} }
@ -577,10 +584,10 @@ Parfactor::getAllGroups (void) const
string std::string
Parfactor::getLabel (void) const Parfactor::getLabel() const
{ {
stringstream ss; std::stringstream ss;
ss << "phi(" ; ss << "phi(" ;
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ; if (i != 0) ss << "," ;
@ -598,6 +605,8 @@ Parfactor::getLabel (void) const
void void
Parfactor::print (bool printParams) const Parfactor::print (bool printParams) const
{ {
using std::cout;
using std::endl;
cout << "Formulas: " ; cout << "Formulas: " ;
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
if (i != 0) cout << ", " ; if (i != 0) cout << ", " ;
@ -605,9 +614,10 @@ Parfactor::print (bool printParams) const
} }
cout << endl; cout << endl;
if (args_[0].group() != Util::maxUnsigned()) { if (args_[0].group() != Util::maxUnsigned()) {
vector<string> groups; std::vector<std::string> groups;
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
groups.push_back (string ("g") + Util::toString (args_[i].group())); groups.push_back (std::string ("g")
+ Util::toString (args_[i].group()));
} }
cout << "Groups: " << groups << endl; cout << "Groups: " << groups << endl;
} }
@ -633,12 +643,12 @@ Parfactor::print (bool printParams) const
void void
Parfactor::printParameters (void) const Parfactor::printParameters() const
{ {
vector<string> jointStrings; std::vector<std::string> jointStrings;
Indexer indexer (ranges_); Indexer indexer (ranges_);
while (indexer.valid()) { while (indexer.valid()) {
stringstream ss; std::stringstream ss;
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
if (i != 0) ss << ", " ; if (i != 0) ss << ", " ;
if (args_[i].isCounting()) { if (args_[i].isCounting()) {
@ -659,22 +669,22 @@ Parfactor::printParameters (void) const
++ indexer; ++ indexer;
} }
for (size_t i = 0; i < params_.size(); i++) { for (size_t i = 0; i < params_.size(); i++) {
cout << "f(" << jointStrings[i] << ")" ; std::cout << "f(" << jointStrings[i] << ")" ;
cout << " = " << params_[i] << endl; std::cout << " = " << params_[i] << std::endl;
} }
} }
void void
Parfactor::printProjections (void) const Parfactor::printProjections() const
{ {
ConstraintTree copy (*constr_); ConstraintTree copy (*constr_);
LogVarSet Xs = copy.logVarSet(); LogVarSet Xs = copy.logVarSet();
for (size_t i = 0; i < Xs.size(); i++) { for (size_t i = 0; i < Xs.size(); i++) {
cout << "-> projection of " << Xs[i] << ": " ; std::cout << "-> projection of " << Xs[i] << ": " ;
cout << copy.tupleSet ({Xs[i]}) << endl; std::cout << copy.tupleSet ({Xs[i]}) << std::endl;
} }
} }
@ -684,12 +694,12 @@ void
Parfactor::expandPotential ( Parfactor::expandPotential (
size_t fIdx, size_t fIdx,
unsigned newRange, unsigned newRange,
const vector<unsigned>& sumIndexes) const std::vector<unsigned>& sumIndexes)
{ {
ullong newSize = (params_.size() / ranges_[fIdx]) * newRange; ullong newSize = (params_.size() / ranges_[fIdx]) * newRange;
if (newSize > params_.max_size()) { if (newSize > params_.max_size()) {
cerr << "Error: an overflow occurred when performing expansion." ; std::cerr << "Error: an overflow occurred when performing expansion." ;
cerr << endl; std::cerr << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
@ -698,7 +708,7 @@ Parfactor::expandPotential (
params_.reserve (newSize); params_.reserve (newSize);
size_t prod = 1; size_t prod = 1;
vector<size_t> offsets (ranges_.size()); std::vector<size_t> offsets (ranges_.size());
for (size_t i = ranges_.size(); i-- > 0; ) { for (size_t i = ranges_.size(); i-- > 0; ) {
offsets[i] = prod; offsets[i] = prod;
prod *= ranges_[i]; prod *= ranges_[i];
@ -706,7 +716,7 @@ Parfactor::expandPotential (
size_t index = 0; size_t index = 0;
ranges_[fIdx] = newRange; ranges_[fIdx] = newRange;
vector<unsigned> indices (ranges_.size(), 0); std::vector<unsigned> indices (ranges_.size(), 0);
for (size_t k = 0; k < newSize; k++) { for (size_t k = 0; k < newSize; k++) {
assert (index < backup.size()); assert (index < backup.size());
params_.push_back (backup[index]); params_.push_back (backup[index]);
@ -759,7 +769,7 @@ Parfactor::simplifyCountingFormulas (size_t fIdx)
void void
Parfactor::simplifyGrounds (void) Parfactor::simplifyGrounds()
{ {
if (args_.size() == 1) { if (args_.size() == 1) {
return; return;
@ -872,12 +882,12 @@ Parfactor::alignLogicalVars (Parfactor* g1, Parfactor* g2)
std::pair<LogVars, LogVars> res = getAlignLogVars (g1, g2); std::pair<LogVars, LogVars> res = getAlignLogVars (g1, g2);
const LogVars& alignLvs1 = res.first; const LogVars& alignLvs1 = res.first;
const LogVars& alignLvs2 = res.second; const LogVars& alignLvs2 = res.second;
// cout << "ALIGNING :::::::::::::::::" << endl; // std::cout << "ALIGNING :::::::::::::::::" << std::endl;
// g1->print(); // g1->print();
// cout << "AND" << endl; // cout << "AND" << endl;
// g2->print(); // g2->print();
// cout << "-> align lvs1 = " << alignLvs1 << endl; // std::cout << "-> align lvs1 = " << alignLvs1 << std::endl;
// cout << "-> align lvs2 = " << alignLvs2 << endl; // std::cout << "-> align lvs2 = " << alignLvs2 << std::endl;
LogVar freeLogVar (0); LogVar freeLogVar (0);
Substitution theta1, theta2; Substitution theta1, theta2;
for (size_t i = 0; i < alignLvs1.size(); i++) { for (size_t i = 0; i < alignLvs1.size(); i++) {
@ -933,9 +943,11 @@ Parfactor::alignLogicalVars (Parfactor* g1, Parfactor* g2)
} }
} }
// cout << "theta1: " << theta1 << endl; // std::cout << "theta1: " << theta1 << std::endl;
// cout << "theta2: " << theta2 << endl; // std::cout << "theta2: " << theta2 << std::endl;
g1->applySubstitution (theta1); g1->applySubstitution (theta1);
g2->applySubstitution (theta2); g2->applySubstitution (theta2);
} }
} // namespace Horus

View File

@ -1,20 +1,21 @@
#ifndef HORUS_PARFACTOR_H #ifndef YAP_PACKAGES_CLPBN_HORUS_PARFACTOR_H_
#define HORUS_PARFACTOR_H #define YAP_PACKAGES_CLPBN_HORUS_PARFACTOR_H_
#include "Factor.h" #include <vector>
#include <string>
#include "GenericFactor.h"
#include "ProbFormula.h" #include "ProbFormula.h"
#include "ConstraintTree.h" #include "ConstraintTree.h"
#include "LiftedUtils.h" #include "LiftedUtils.h"
#include "Horus.h" #include "Horus.h"
class Parfactor : public TFactor<ProbFormula> namespace Horus {
{
class Parfactor : public GenericFactor<ProbFormula> {
public: public:
Parfactor ( Parfactor (const ProbFormulas&, const Params&, const Tuples&,
const ProbFormulas&,
const Params&,
const Tuples&,
unsigned distId); unsigned distId);
Parfactor (const Parfactor*, const Tuple&); Parfactor (const Parfactor*, const Tuple&);
@ -23,21 +24,21 @@ class Parfactor : public TFactor<ProbFormula>
Parfactor (const Parfactor&); Parfactor (const Parfactor&);
~Parfactor (void); ~Parfactor();
ConstraintTree* constr (void) { return constr_; } ConstraintTree* constr() { return constr_; }
const ConstraintTree* constr (void) const { return constr_; } const ConstraintTree* constr() const { return constr_; }
const LogVars& logVars (void) const { return constr_->logVars(); } const LogVars& logVars() const { return constr_->logVars(); }
const LogVarSet& logVarSet (void) const { return constr_->logVarSet(); } const LogVarSet& logVarSet() const { return constr_->logVarSet(); }
LogVarSet countedLogVars (void) const; LogVarSet countedLogVars() const;
LogVarSet uncountedLogVars (void) const; LogVarSet uncountedLogVars() const;
LogVarSet elimLogVars (void) const; LogVarSet elimLogVars() const;
LogVarSet exclusiveLogVars (size_t fIdx) const; LogVarSet exclusiveLogVars (size_t fIdx) const;
@ -57,7 +58,7 @@ class Parfactor : public TFactor<ProbFormula>
void absorveEvidence (const ProbFormula&, unsigned); void absorveEvidence (const ProbFormula&, unsigned);
void setNewGroups (void); void setNewGroups();
void applySubstitution (const Substitution&); void applySubstitution (const Substitution&);
@ -71,7 +72,7 @@ class Parfactor : public TFactor<ProbFormula>
bool containsGroup (PrvGroup) const; bool containsGroup (PrvGroup) const;
bool containsGroups (vector<PrvGroup>) const; bool containsGroups (std::vector<PrvGroup>) const;
unsigned nrFormulas (LogVar) const; unsigned nrFormulas (LogVar) const;
@ -81,17 +82,17 @@ class Parfactor : public TFactor<ProbFormula>
unsigned nrFormulasWithGroup (PrvGroup) const; unsigned nrFormulasWithGroup (PrvGroup) const;
vector<PrvGroup> getAllGroups (void) const; std::vector<PrvGroup> getAllGroups() const;
void print (bool = false) const; void print (bool = false) const;
void printParameters (void) const; void printParameters() const;
void printProjections (void) const; void printProjections() const;
string getLabel (void) const; std::string getLabel() const;
void simplifyGrounds (void); void simplifyGrounds();
static bool canMultiply (Parfactor*, Parfactor*); static bool canMultiply (Parfactor*, Parfactor*);
@ -104,7 +105,7 @@ class Parfactor : public TFactor<ProbFormula>
Parfactor* g1, Parfactor* g2); Parfactor* g1, Parfactor* g2);
void expandPotential (size_t fIdx, unsigned newRange, void expandPotential (size_t fIdx, unsigned newRange,
const vector<unsigned>& sumIndexes); const std::vector<unsigned>& sumIndexes);
static void alignAndExponentiate (Parfactor*, Parfactor*); static void alignAndExponentiate (Parfactor*, Parfactor*);
@ -115,7 +116,9 @@ class Parfactor : public TFactor<ProbFormula>
DISALLOW_ASSIGN (Parfactor); DISALLOW_ASSIGN (Parfactor);
}; };
typedef vector<Parfactor*> Parfactors; typedef std::vector<Parfactor*> Parfactors;
#endif // HORUS_PARFACTOR_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_PARFACTOR_H_

View File

@ -1,10 +1,14 @@
#include <cassert> #include <cassert>
#include <queue> #include <queue>
#include <iostream>
#include <sstream>
#include "ParfactorList.h" #include "ParfactorList.h"
namespace Horus {
ParfactorList::ParfactorList (const ParfactorList& pfList) ParfactorList::ParfactorList (const ParfactorList& pfList)
{ {
ParfactorList::const_iterator it = pfList.begin(); ParfactorList::const_iterator it = pfList.begin();
@ -23,7 +27,7 @@ ParfactorList::ParfactorList (const Parfactors& pfs)
ParfactorList::~ParfactorList (void) ParfactorList::~ParfactorList()
{ {
ParfactorList::const_iterator it = pfList_.begin(); ParfactorList::const_iterator it = pfList_.begin();
while (it != pfList_.end()) { while (it != pfList_.end()) {
@ -64,27 +68,27 @@ ParfactorList::addShattered (Parfactor* pf)
list<Parfactor*>::iterator std::list<Parfactor*>::iterator
ParfactorList::insertShattered ( ParfactorList::insertShattered (
list<Parfactor*>::iterator it, std::list<Parfactor*>::iterator it,
Parfactor* pf) Parfactor* pf)
{ {
return pfList_.insert (it, pf);
assert (isAllShattered()); assert (isAllShattered());
return pfList_.insert (it, pf);
} }
list<Parfactor*>::iterator std::list<Parfactor*>::iterator
ParfactorList::remove (list<Parfactor*>::iterator it) ParfactorList::remove (std::list<Parfactor*>::iterator it)
{ {
return pfList_.erase (it); return pfList_.erase (it);
} }
list<Parfactor*>::iterator std::list<Parfactor*>::iterator
ParfactorList::removeAndDelete (list<Parfactor*>::iterator it) ParfactorList::removeAndDelete (std::list<Parfactor*>::iterator it)
{ {
delete *it; delete *it;
return pfList_.erase (it); return pfList_.erase (it);
@ -93,12 +97,12 @@ ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
bool bool
ParfactorList::isAllShattered (void) const ParfactorList::isAllShattered() const
{ {
if (pfList_.size() <= 1) { if (pfList_.size() <= 1) {
return true; return true;
} }
vector<Parfactor*> pfs (pfList_.begin(), pfList_.end()); Parfactors pfs (pfList_.begin(), pfList_.end());
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
assert (isShattered (pfs[i])); assert (isShattered (pfs[i]));
} }
@ -115,13 +119,25 @@ ParfactorList::isAllShattered (void) const
void void
ParfactorList::print (void) const ParfactorList::print() const
{ {
struct sortByParams {
bool operator() (const Parfactor* pf1, const Parfactor* pf2)
{
if (pf1->params().size() < pf2->params().size()) {
return true;
} else if (pf1->params().size() == pf2->params().size() &&
pf1->params() < pf2->params()) {
return true;
}
return false;
}
};
Parfactors pfVec (pfList_.begin(), pfList_.end()); Parfactors pfVec (pfList_.begin(), pfList_.end());
std::sort (pfVec.begin(), pfVec.end(), sortByParams()); std::sort (pfVec.begin(), pfVec.end(), sortByParams());
for (size_t i = 0; i < pfVec.size(); i++) { for (size_t i = 0; i < pfVec.size(); i++) {
pfVec[i]->print(); pfVec[i]->print();
cout << endl; std::cout << std::endl;
} }
} }
@ -163,8 +179,8 @@ ParfactorList::isShattered (const Parfactor* g) const
formulas[i], *(g->constr()), formulas[i], *(g->constr()),
formulas[j], *(g->constr())) == false) { formulas[j], *(g->constr())) == false) {
g->print(); g->print();
cout << "-> not identical on positions " ; std::cout << "-> not identical on positions " ;
cout << i << " and " << j << endl; std::cout << i << " and " << j << std::endl;
return false; return false;
} }
} else { } else {
@ -172,8 +188,8 @@ ParfactorList::isShattered (const Parfactor* g) const
formulas[i], *(g->constr()), formulas[i], *(g->constr()),
formulas[j], *(g->constr())) == false) { formulas[j], *(g->constr())) == false) {
g->print(); g->print();
cout << "-> not disjoint on positions " ; std::cout << "-> not disjoint on positions " ;
cout << i << " and " << j << endl; std::cout << i << " and " << j << std::endl;
return false; return false;
} }
} }
@ -200,9 +216,10 @@ ParfactorList::isShattered (
fms1[i], *(g1->constr()), fms1[i], *(g1->constr()),
fms2[j], *(g2->constr())) == false) { fms2[j], *(g2->constr())) == false) {
g1->print(); g1->print();
cout << "^" << endl; std::cout << "^" << std::endl;
g2->print(); g2->print();
cout << "-> not identical on group " << fms1[i].group() << endl; std::cout << "-> not identical on group " ;
std::cout << fms1[i].group() << std::endl;
return false; return false;
} }
} else { } else {
@ -210,10 +227,10 @@ ParfactorList::isShattered (
fms1[i], *(g1->constr()), fms1[i], *(g1->constr()),
fms2[j], *(g2->constr())) == false) { fms2[j], *(g2->constr())) == false) {
g1->print(); g1->print();
cout << "^" << endl; std::cout << "^" << std::endl;
g2->print(); g2->print();
cout << "-> not disjoint on groups " << fms1[i].group(); std::cout << "-> not disjoint on groups " << fms1[i].group();
cout << " and " << fms2[j].group() << endl; std::cout << " and " << fms2[j].group() << std::endl;
return false; return false;
} }
} }
@ -227,12 +244,12 @@ ParfactorList::isShattered (
void void
ParfactorList::addToShatteredList (Parfactor* g) ParfactorList::addToShatteredList (Parfactor* g)
{ {
queue<Parfactor*> residuals; std::queue<Parfactor*> residuals;
residuals.push (g); residuals.push (g);
while (residuals.empty() == false) { while (residuals.empty() == false) {
Parfactor* pf = residuals.front(); Parfactor* pf = residuals.front();
bool pfSplitted = false; bool pfSplitted = false;
list<Parfactor*>::iterator pfIter; std::list<Parfactor*>::iterator pfIter;
pfIter = pfList_.begin(); pfIter = pfList_.begin();
while (pfIter != pfList_.end()) { while (pfIter != pfList_.end()) {
std::pair<Parfactors, Parfactors> shattRes; std::pair<Parfactors, Parfactors> shattRes;
@ -269,7 +286,7 @@ Parfactors
ParfactorList::shatterAgainstMySelf (Parfactor* g) ParfactorList::shatterAgainstMySelf (Parfactor* g)
{ {
Parfactors pfs; Parfactors pfs;
queue<Parfactor*> residuals; std::queue<Parfactor*> residuals;
residuals.push (g); residuals.push (g);
bool shattered = true; bool shattered = true;
while (residuals.empty() == false) { while (residuals.empty() == false) {
@ -325,19 +342,22 @@ ParfactorList::shatterAgainstMySelf (
{ {
/* /*
Util::printDashedLine(); Util::printDashedLine();
cout << "-> SHATTERING" << endl; std::cout << "-> SHATTERING" << std::endl;
g->print(); g->print();
cout << "-> ON: " << g->argument (fIdx1) << "|" ; std::cout << "-> ON: " << g->argument (fIdx1) << "|" ;
cout << g->constr()->tupleSet (g->argument (fIdx1).logVars()) << endl; std::cout << g->constr()->tupleSet (g->argument (fIdx1).logVars());
cout << "-> ON: " << g->argument (fIdx2) << "|" ; std::cout << std::endl;
cout << g->constr()->tupleSet (g->argument (fIdx2).logVars()) << endl; std::cout << "-> ON: " << g->argument (fIdx2) << "|" ;
std::cout << g->constr()->tupleSet (g->argument (fIdx2).logVars())
std::cout << std::endl;
Util::printDashedLine(); Util::printDashedLine();
*/ */
ProbFormula& f1 = g->argument (fIdx1); ProbFormula& f1 = g->argument (fIdx1);
ProbFormula& f2 = g->argument (fIdx2); ProbFormula& f2 = g->argument (fIdx2);
if (f1.isAtom()) { if (f1.isAtom()) {
cerr << "Error: a ground occurs twice in the same parfactor." << endl; std::cerr << "Error: a ground occurs twice in the same parfactor." ;
cerr << endl; std::cerr << std::endl;
std::cerr << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
assert (g->constr()->empty() == false); assert (g->constr()->empty() == false);
@ -441,14 +461,14 @@ ParfactorList::shatter (
ProbFormula& f2 = g2->argument (fIdx2); ProbFormula& f2 = g2->argument (fIdx2);
/* /*
Util::printDashedLine(); Util::printDashedLine();
cout << "-> SHATTERING" << endl; std::cout << "-> SHATTERING" << std::endl;
g1->print(); g1->print();
cout << "-> WITH" << endl; std::cout << "-> WITH" << std::endl;
g2->print(); g2->print();
cout << "-> ON: " << f1 << "|" ; std::cout << "-> ON: " << f1 << "|" ;
cout << g1->constr()->tupleSet (f1.logVars()) << endl; std::cout << g1->constr()->tupleSet (f1.logVars()) << std::endl;
cout << "-> ON: " << f2 << "|" ; std::cout << "-> ON: " << f2 << "|" ;
cout << g2->constr()->tupleSet (f2.logVars()) << endl; std::cout << g2->constr()->tupleSet (f2.logVars()) << std::endl;
Util::printDashedLine(); Util::printDashedLine();
*/ */
if (f1.isAtom()) { if (f1.isAtom()) {
@ -486,12 +506,12 @@ ParfactorList::shatter (
assert (commCt1->tupleSet (f1.logVars()) == assert (commCt1->tupleSet (f1.logVars()) ==
commCt2->tupleSet (f2.logVars())); commCt2->tupleSet (f2.logVars()));
// stringstream ss1; ss1 << "" << count << "_A.dot" ; // std::stringstream ss1; ss1 << "" << count << "_A.dot" ;
// stringstream ss2; ss2 << "" << count << "_B.dot" ; // std::stringstream ss2; ss2 << "" << count << "_B.dot" ;
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ; // std::stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ; // std::stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ; // std::stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ; // std::stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
// g1->constr()->exportToGraphViz (ss1.str().c_str(), true); // g1->constr()->exportToGraphViz (ss1.str().c_str(), true);
// g2->constr()->exportToGraphViz (ss2.str().c_str(), true); // g2->constr()->exportToGraphViz (ss2.str().c_str(), true);
// commCt1->exportToGraphViz (ss3.str().c_str(), true); // commCt1->exportToGraphViz (ss3.str().c_str(), true);
@ -638,3 +658,5 @@ ParfactorList::disjoint (
return (ts1 & ts2).empty(); return (ts1 & ts2).empty();
} }
} // namespace Horus

View File

@ -1,5 +1,5 @@
#ifndef HORUS_PARFACTORLIST_H #ifndef YAP_PACKAGES_CLPBN_HORUS_PARFACTORLIST_H_
#define HORUS_PARFACTORLIST_H #define YAP_PACKAGES_CLPBN_HORUS_PARFACTORLIST_H_
#include <list> #include <list>
@ -7,39 +7,38 @@
#include "ProbFormula.h" #include "ProbFormula.h"
using namespace std; namespace Horus {
class Parfactor; class Parfactor;
class ParfactorList
{ class ParfactorList {
public: public:
ParfactorList (void) { } ParfactorList() { }
ParfactorList (const ParfactorList&); ParfactorList (const ParfactorList&);
ParfactorList (const Parfactors&); ParfactorList (const Parfactors&);
~ParfactorList (void); ~ParfactorList();
const list<Parfactor*>& parfactors (void) const { return pfList_; } const std::list<Parfactor*>& parfactors() const { return pfList_; }
void clear (void) { pfList_.clear(); } void clear() { pfList_.clear(); }
size_t size (void) const { return pfList_.size(); } size_t size() const { return pfList_.size(); }
typedef std::list<Parfactor*>::iterator iterator; typedef std::list<Parfactor*>::iterator iterator;
iterator begin (void) { return pfList_.begin(); } iterator begin() { return pfList_.begin(); }
iterator end (void) { return pfList_.end(); } iterator end() { return pfList_.end(); }
typedef std::list<Parfactor*>::const_iterator const_iterator; typedef std::list<Parfactor*>::const_iterator const_iterator;
const_iterator begin (void) const { return pfList_.begin(); } const_iterator begin() const { return pfList_.begin(); }
const_iterator end (void) const { return pfList_.end(); } const_iterator end() const { return pfList_.end(); }
void add (Parfactor* pf); void add (Parfactor* pf);
@ -47,16 +46,18 @@ class ParfactorList
void addShattered (Parfactor* pf); void addShattered (Parfactor* pf);
list<Parfactor*>::iterator insertShattered ( std::list<Parfactor*>::iterator insertShattered (
list<Parfactor*>::iterator, Parfactor*); std::list<Parfactor*>::iterator, Parfactor*);
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator); std::list<Parfactor*>::iterator remove (
std::list<Parfactor*>::iterator);
list<Parfactor*>::iterator removeAndDelete (list<Parfactor*>::iterator); std::list<Parfactor*>::iterator removeAndDelete (
std::list<Parfactor*>::iterator);
bool isAllShattered (void) const; bool isAllShattered() const;
void print (void) const; void print() const;
ParfactorList& operator= (const ParfactorList& pfList); ParfactorList& operator= (const ParfactorList& pfList);
@ -101,22 +102,10 @@ class ParfactorList
const ProbFormula&, ConstraintTree, const ProbFormula&, ConstraintTree,
const ProbFormula&, ConstraintTree) const; const ProbFormula&, ConstraintTree) const;
struct sortByParams std::list<Parfactor*> pfList_;
{
inline bool operator() (const Parfactor* pf1, const Parfactor* pf2)
{
if (pf1->params().size() < pf2->params().size()) {
return true;
} else if (pf1->params().size() == pf2->params().size() &&
pf1->params() < pf2->params()) {
return true;
}
return false;
}
}; };
list<Parfactor*> pfList_; } // namespace Horus
};
#endif // HORUS_PARFACTORLIST_H #endif // YAP_PACKAGES_CLPBN_HORUS_PARFACTORLIST_H_

View File

@ -1,6 +1,13 @@
#include <cassert>
#include <iostream>
#include "ProbFormula.h" #include "ProbFormula.h"
namespace Horus {
PrvGroup ProbFormula::freeGroup_ = 0; PrvGroup ProbFormula::freeGroup_ = 0;
@ -38,7 +45,7 @@ ProbFormula::indexOf (LogVar X) const
bool bool
ProbFormula::isAtom (void) const ProbFormula::isAtom() const
{ {
return logVars_.empty(); return logVars_.empty();
} }
@ -46,7 +53,7 @@ ProbFormula::isAtom (void) const
bool bool
ProbFormula::isCounting (void) const ProbFormula::isCounting() const
{ {
return countedLogVar_.valid(); return countedLogVar_.valid();
} }
@ -54,7 +61,7 @@ ProbFormula::isCounting (void) const
LogVar LogVar
ProbFormula::countedLogVar (void) const ProbFormula::countedLogVar() const
{ {
assert (isCounting()); assert (isCounting());
return countedLogVar_; return countedLogVar_;
@ -71,7 +78,7 @@ ProbFormula::setCountedLogVar (LogVar lv)
void void
ProbFormula::clearCountedLogVar (void) ProbFormula::clearCountedLogVar()
{ {
countedLogVar_ = LogVar(); countedLogVar_ = LogVar();
} }
@ -93,15 +100,8 @@ ProbFormula::rename (LogVar oldName, LogVar newName)
bool operator== (const ProbFormula& f1, const ProbFormula& f2) std::ostream&
{ operator<< (std::ostream& os, const ProbFormula& f)
return f1.group_ == f2.group_ &&
f1.logVars_ == f2.logVars_;
}
std::ostream& operator<< (ostream &os, const ProbFormula& f)
{ {
os << f.functor_; os << f.functor_;
if (f.isAtom() == false) { if (f.isAtom() == false) {
@ -122,7 +122,7 @@ std::ostream& operator<< (ostream &os, const ProbFormula& f)
PrvGroup PrvGroup
ProbFormula::getNewGroup (void) ProbFormula::getNewGroup()
{ {
freeGroup_ ++; freeGroup_ ++;
assert (freeGroup_ != std::numeric_limits<PrvGroup>::max()); assert (freeGroup_ != std::numeric_limits<PrvGroup>::max());
@ -131,7 +131,24 @@ ProbFormula::getNewGroup (void)
ostream& operator<< (ostream &os, const ObservedFormula& of) ObservedFormula::ObservedFormula (Symbol f, unsigned a, unsigned ev)
: functor_(f), arity_(a), evidence_(ev), constr_(a)
{
}
ObservedFormula::ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
{
constr_.addTuple (tuple);
}
std::ostream&
operator<< (std::ostream& os, const ObservedFormula& of)
{ {
os << of.functor_ << "/" << of.arity_; os << of.functor_ << "/" << of.arity_;
os << "|" << of.constr_.tupleSet(); os << "|" << of.constr_.tupleSet();
@ -139,3 +156,5 @@ ostream& operator<< (ostream &os, const ObservedFormula& of)
return os; return os;
} }
} // namespace Horus

View File

@ -1,16 +1,20 @@
#ifndef HORUS_PROBFORMULA_H #ifndef YAP_PACKAGES_CLPBN_HORUS_PROBFORMULA_H_
#define HORUS_PROBFORMULA_H #define YAP_PACKAGES_CLPBN_HORUS_PROBFORMULA_H_
#include <vector>
#include <ostream>
#include <limits> #include <limits>
#include "ConstraintTree.h" #include "ConstraintTree.h"
#include "LiftedUtils.h" #include "LiftedUtils.h"
#include "Horus.h" #include "Horus.h"
namespace Horus {
typedef unsigned long PrvGroup; typedef unsigned long PrvGroup;
class ProbFormula class ProbFormula {
{
public: public:
ProbFormula (Symbol f, const LogVars& lvs, unsigned range) ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
: functor_(f), logVars_(lvs), range_(range), : functor_(f), logVars_(lvs), range_(range),
@ -20,19 +24,19 @@ class ProbFormula
: functor_(f), range_(r), : functor_(f), range_(r),
group_(std::numeric_limits<PrvGroup>::max()) { } group_(std::numeric_limits<PrvGroup>::max()) { }
Symbol functor (void) const { return functor_; } Symbol functor() const { return functor_; }
unsigned arity (void) const { return logVars_.size(); } unsigned arity() const { return logVars_.size(); }
unsigned range (void) const { return range_; } unsigned range() const { return range_; }
LogVars& logVars (void) { return logVars_; } LogVars& logVars() { return logVars_; }
const LogVars& logVars (void) const { return logVars_; } const LogVars& logVars() const { return logVars_; }
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); } LogVarSet logVarSet() const { return LogVarSet (logVars_); }
PrvGroup group (void) const { return group_; } PrvGroup group() const { return group_; }
void setGroup (PrvGroup g) { group_ = g; } void setGroup (PrvGroup g) { group_ = g; }
@ -44,25 +48,28 @@ class ProbFormula
size_t indexOf (LogVar) const; size_t indexOf (LogVar) const;
bool isAtom (void) const; bool isAtom() const;
bool isCounting (void) const; bool isCounting() const;
LogVar countedLogVar (void) const; LogVar countedLogVar() const;
void setCountedLogVar (LogVar); void setCountedLogVar (LogVar);
void clearCountedLogVar (void); void clearCountedLogVar();
void rename (LogVar, LogVar); void rename (LogVar, LogVar);
static PrvGroup getNewGroup (void); static PrvGroup getNewGroup();
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
private: private:
friend bool operator== (
const ProbFormula& f1, const ProbFormula& f2);
friend std::ostream& operator<< (
std::ostream&, const ProbFormula&);
Symbol functor_; Symbol functor_;
LogVars logVars_; LogVars logVars_;
unsigned range_; unsigned range_;
@ -71,45 +78,50 @@ class ProbFormula
static PrvGroup freeGroup_; static PrvGroup freeGroup_;
}; };
typedef vector<ProbFormula> ProbFormulas; typedef std::vector<ProbFormula> ProbFormulas;
class ObservedFormula inline bool
operator== (const ProbFormula& f1, const ProbFormula& f2)
{ {
public: return f1.group_ == f2.group_ && f1.logVars_ == f2.logVars_;
ObservedFormula (Symbol f, unsigned a, unsigned ev)
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
{
constr_.addTuple (tuple);
} }
Symbol functor (void) const { return functor_; }
unsigned arity (void) const { return arity_; }
unsigned evidence (void) const { return evidence_; } class ObservedFormula {
public:
ObservedFormula (Symbol f, unsigned a, unsigned ev);
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple);
Symbol functor() const { return functor_; }
unsigned arity() const { return arity_; }
unsigned evidence() const { return evidence_; }
void setEvidence (unsigned ev) { evidence_ = ev; } void setEvidence (unsigned ev) { evidence_ = ev; }
ConstraintTree& constr (void) { return constr_; } ConstraintTree& constr() { return constr_; }
bool isAtom (void) const { return arity_ == 0; } bool isAtom() const { return arity_ == 0; }
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); } void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
private: private:
friend std::ostream& operator<< (
std::ostream&, const ObservedFormula&);
Symbol functor_; Symbol functor_;
unsigned arity_; unsigned arity_;
unsigned evidence_; unsigned evidence_;
ConstraintTree constr_; ConstraintTree constr_;
}; };
typedef vector<ObservedFormula> ObservedFormulas; typedef std::vector<ObservedFormula> ObservedFormulas;
#endif // HORUS_PROBFORMULA_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_PROBFORMULA_H_

View File

@ -1,20 +1,18 @@
#ifndef HORUS_TINYSET_H #ifndef YAP_PACKAGES_CLPBN_HORUS_TINYSET_H_
#define HORUS_TINYSET_H #define YAP_PACKAGES_CLPBN_HORUS_TINYSET_H_
#include <algorithm>
#include <vector> #include <vector>
#include <algorithm>
#include <ostream>
using namespace std;
namespace Horus {
template <typename T, typename Compare = std::less<T>> template <typename T, typename Compare = std::less<T>>
class TinySet class TinySet {
{
public: public:
typedef typename std::vector<T>::iterator iterator;
typedef typename vector<T>::iterator iterator; typedef typename std::vector<T>::const_iterator const_iterator;
typedef typename vector<T>::const_iterator const_iterator;
TinySet (const TinySet& s) TinySet (const TinySet& s)
: vec_(s.vec_), cmp_(s.cmp_) { } : vec_(s.vec_), cmp_(s.cmp_) { }
@ -25,190 +23,72 @@ class TinySet
TinySet (const T& t, const Compare& cmp = Compare()) TinySet (const T& t, const Compare& cmp = Compare())
: vec_(1, t), cmp_(cmp) { } : vec_(1, t), cmp_(cmp) { }
TinySet (const vector<T>& elements, const Compare& cmp = Compare()) TinySet (const std::vector<T>& elements, const Compare& cmp = Compare());
: vec_(elements), cmp_(cmp)
{
std::sort (begin(), end(), cmp_);
iterator it = unique_cmp (begin(), end());
vec_.resize (it - begin());
}
iterator insert (const T& t) iterator insert (const T& t);
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
if (it == end() || cmp_(t, *it)) {
vec_.insert (it, t);
}
return it;
}
void insert_sorted (const T& t) void insert_sorted (const T& t);
{
vec_.push_back (t);
assert (consistent());
}
void remove (const T& t) void remove (const T& t);
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
if (it != end()) {
vec_.erase (it);
}
}
const_iterator find (const T& t) const const_iterator find (const T& t) const;
{
const_iterator it = std::lower_bound (begin(), end(), t, cmp_);
return it == end() || cmp_(t, *it) ? end() : it;
}
iterator find (const T& t) iterator find (const T& t);
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
return it == end() || cmp_(t, *it) ? end() : it;
}
/* set union */ /* set union */
TinySet operator| (const TinySet& s) const TinySet operator| (const TinySet& s) const;
{
TinySet res;
std::set_union (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
/* set intersection */ /* set intersection */
TinySet operator& (const TinySet& s) const TinySet operator& (const TinySet& s) const;
{
TinySet res;
std::set_intersection (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
/* set difference */ /* set difference */
TinySet operator- (const TinySet& s) const TinySet operator- (const TinySet& s) const;
{
TinySet res;
std::set_difference (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
TinySet& operator|= (const TinySet& s) TinySet& operator|= (const TinySet& s);
{
return *this = (*this | s);
}
TinySet& operator&= (const TinySet& s) TinySet& operator&= (const TinySet& s);
{
return *this = (*this & s);
}
TinySet& operator-= (const TinySet& s) TinySet& operator-= (const TinySet& s);
{
return *this = (*this - s);
}
bool contains (const T& t) const bool contains (const T& t) const;
{
return std::binary_search (
vec_.begin(), vec_.end(), t, cmp_);
}
bool contains (const TinySet& s) const bool contains (const TinySet& s) const;
{
return std::includes (
vec_.begin(),
vec_.end(),
s.vec_.begin(),
s.vec_.end(),
cmp_);
}
bool in (const TinySet& s) const bool in (const TinySet& s) const;
{
return std::includes (
s.vec_.begin(),
s.vec_.end(),
vec_.begin(),
vec_.end(),
cmp_);
}
bool intersects (const TinySet& s) const bool intersects (const TinySet& s) const;
{
return (*this & s).size() > 0;
}
const T& operator[] (typename vector<T>::size_type i) const const T& operator[] (typename std::vector<T>::size_type i) const;
{
return vec_[i];
}
T& operator[] (typename vector<T>::size_type i) T& operator[] (typename std::vector<T>::size_type i);
{
return vec_[i];
}
T front (void) const T front() const;
{
return vec_.front();
}
T& front (void) T& front();
{
return vec_.front();
}
T back (void) const T back() const;
{
return vec_.back();
}
T& back (void) T& back();
{
return vec_.back();
}
const vector<T>& elements (void) const const std::vector<T>& elements() const;
{
return vec_;
}
bool empty (void) const bool empty() const;
{
return vec_.empty();
}
typename vector<T>::size_type size (void) const typename std::vector<T>::size_type size() const;
{
return vec_.size();
}
void clear (void) void clear();
{
vec_.clear();
}
void reserve (typename vector<T>::size_type size) void reserve (typename std::vector<T>::size_type size);
{
vec_.reserve (size);
}
iterator begin (void) { return vec_.begin(); } iterator begin() { return vec_.begin(); }
iterator end (void) { return vec_.end(); } iterator end () { return vec_.end(); }
const_iterator begin (void) const { return vec_.begin(); } const_iterator begin() const { return vec_.begin(); }
const_iterator end (void) const { return vec_.end(); } const_iterator end () const { return vec_.end(); }
private:
iterator unique_cmp (iterator first, iterator last);
bool consistent() const;
friend bool operator== (const TinySet& s1, const TinySet& s2) friend bool operator== (const TinySet& s1, const TinySet& s2)
{ {
@ -223,7 +103,7 @@ class TinySet
friend std::ostream& operator<< (std::ostream& out, const TinySet& s) friend std::ostream& operator<< (std::ostream& out, const TinySet& s)
{ {
out << "{" ; out << "{" ;
typename vector<T>::size_type i; typename std::vector<T>::size_type i;
for (i = 0; i < s.size(); i++) { for (i = 0; i < s.size(); i++) {
out << ((i != 0) ? "," : "") << s.vec_[i]; out << ((i != 0) ? "," : "") << s.vec_[i];
} }
@ -231,8 +111,271 @@ class TinySet
return out; return out;
} }
private: std::vector<T> vec_;
iterator unique_cmp (iterator first, iterator last) Compare cmp_;
};
template <typename T, typename C> inline
TinySet<T,C>::TinySet (const std::vector<T>& elements, const C& cmp)
: vec_(elements), cmp_(cmp)
{
std::sort (begin(), end(), cmp_);
iterator it = unique_cmp (begin(), end());
vec_.resize (it - begin());
}
template <typename T, typename C> inline typename TinySet<T,C>::iterator
TinySet<T,C>::insert (const T& t)
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
if (it == end() || cmp_(t, *it)) {
vec_.insert (it, t);
}
return it;
}
template <typename T, typename C> inline void
TinySet<T,C>::insert_sorted (const T& t)
{
vec_.push_back (t);
assert (consistent());
}
template <typename T, typename C> inline void
TinySet<T,C>::remove (const T& t)
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
if (it != end()) {
vec_.erase (it);
}
}
template <typename T, typename C> inline typename TinySet<T,C>::const_iterator
TinySet<T,C>::find (const T& t) const
{
const_iterator it = std::lower_bound (begin(), end(), t, cmp_);
return it == end() || cmp_(t, *it) ? end() : it;
}
template <typename T, typename C> inline typename TinySet<T,C>::iterator
TinySet<T,C>::find (const T& t)
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
return it == end() || cmp_(t, *it) ? end() : it;
}
/* set union */
template <typename T, typename C> inline TinySet<T,C>
TinySet<T,C>::operator| (const TinySet& s) const
{
TinySet res;
std::set_union (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
/* set intersection */
template <typename T, typename C> inline TinySet<T,C>
TinySet<T,C>::operator& (const TinySet& s) const
{
TinySet res;
std::set_intersection (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
/* set difference */
template <typename T, typename C> inline TinySet<T,C>
TinySet<T,C>::operator- (const TinySet& s) const
{
TinySet res;
std::set_difference (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
template <typename T, typename C> inline TinySet<T,C>&
TinySet<T,C>::operator|= (const TinySet& s)
{
return *this = (*this | s);
}
template <typename T, typename C> inline TinySet<T,C>&
TinySet<T,C>::operator&= (const TinySet& s)
{
return *this = (*this & s);
}
template <typename T, typename C> inline TinySet<T,C>&
TinySet<T,C>::operator-= (const TinySet& s)
{
return *this = (*this - s);
}
template <typename T, typename C> inline bool
TinySet<T,C>::contains (const T& t) const
{
return std::binary_search (
vec_.begin(), vec_.end(), t, cmp_);
}
template <typename T, typename C> inline bool
TinySet<T,C>::contains (const TinySet& s) const
{
return std::includes (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
cmp_);
}
template <typename T, typename C> inline bool
TinySet<T,C>::in (const TinySet& s) const
{
return std::includes (
s.vec_.begin(), s.vec_.end(),
vec_.begin(), vec_.end(),
cmp_);
}
template <typename T, typename C> inline bool
TinySet<T,C>::intersects (const TinySet& s) const
{
return (*this & s).size() > 0;
}
template <typename T, typename C> inline const T&
TinySet<T,C>::operator[] (typename std::vector<T>::size_type i) const
{
return vec_[i];
}
template <typename T, typename C> inline T&
TinySet<T,C>::operator[] (typename std::vector<T>::size_type i)
{
return vec_[i];
}
template <typename T, typename C> inline T
TinySet<T,C>::front() const
{
return vec_.front();
}
template <typename T, typename C> inline T&
TinySet<T,C>::front()
{
return vec_.front();
}
template <typename T, typename C> inline T
TinySet<T,C>::back() const
{
return vec_.back();
}
template <typename T, typename C> inline T&
TinySet<T,C>::back()
{
return vec_.back();
}
template <typename T, typename C> inline const std::vector<T>&
TinySet<T,C>::elements() const
{
return vec_;
}
template <typename T, typename C> inline bool
TinySet<T,C>::empty() const
{
return vec_.empty();
}
template <typename T, typename C> inline typename std::vector<T>::size_type
TinySet<T,C>::size() const
{
return vec_.size();
}
template <typename T, typename C> inline void
TinySet<T,C>::clear()
{
vec_.clear();
}
template <typename T, typename C> inline void
TinySet<T,C>::reserve (typename std::vector<T>::size_type size)
{
vec_.reserve (size);
}
template <typename T, typename C> typename TinySet<T,C>::iterator
TinySet<T,C>::unique_cmp (iterator first, iterator last)
{ {
if (first == last) { if (first == last) {
return last; return last;
@ -246,9 +389,12 @@ class TinySet
return ++result; return ++result;
} }
bool consistent (void) const
template <typename T, typename C> inline bool
TinySet<T,C>::consistent() const
{ {
typename vector<T>::size_type i; typename std::vector<T>::size_type i;
for (i = 0; i < vec_.size() - 1; i++) { for (i = 0; i < vec_.size() - 1; i++) {
if ( ! cmp_(vec_[i], vec_[i + 1])) { if ( ! cmp_(vec_[i], vec_[i + 1])) {
return false; return false;
@ -257,9 +403,7 @@ class TinySet
return true; return true;
} }
vector<T> vec_; } // namespace Horus
Compare cmp_;
};
#endif // HORUS_TINYSET_H #endif // YAP_PACKAGES_CLPBN_HORUS_TINYSET_H_

View File

@ -1,27 +1,27 @@
#include <fstream>
#include "Util.h" #include "Util.h"
#include "Indexer.h" #include "Indexer.h"
#include "ElimGraph.h" #include "ElimGraph.h"
#include "BeliefProp.h" #include "BeliefProp.h"
namespace Horus {
namespace Globals { namespace Globals {
bool logDomain = false; bool logDomain = false;
unsigned verbosity = 0; unsigned verbosity = 0;
LiftedSolverType liftedSolver = LiftedSolverType::LVE; LiftedSolverType liftedSolver = LiftedSolverType::lveSolver;
GroundSolverType groundSolver = GroundSolverType::VE; GroundSolverType groundSolver = GroundSolverType::veSolver;
}; }
namespace Util { namespace Util {
template <> std::string template <> std::string
toString (const bool& b) toString (const bool& b)
{ {
@ -33,14 +33,14 @@ toString (const bool& b)
unsigned unsigned
stringToUnsigned (string str) stringToUnsigned (std::string str)
{ {
int val; int val;
stringstream ss; std::stringstream ss;
ss << str; ss << str;
ss >> val; ss >> val;
if (val < 0) { if (val < 0) {
cerr << "Error: the number readed is negative." << endl; std::cerr << "Error: the number readed is negative." << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
return static_cast<unsigned> (val); return static_cast<unsigned> (val);
@ -49,10 +49,10 @@ stringToUnsigned (string str)
double double
stringToDouble (string str) stringToDouble (std::string str)
{ {
double val; double val;
stringstream ss; std::stringstream ss;
ss << str; ss << str;
ss >> val; ss >> val;
return val; return val;
@ -117,7 +117,7 @@ size_t
sizeExpected (const Ranges& ranges) sizeExpected (const Ranges& ranges)
{ {
return std::accumulate (ranges.begin(), return std::accumulate (ranges.begin(),
ranges.end(), 1, multiplies<unsigned>()); ranges.end(), 1, std::multiplies<unsigned>());
} }
@ -136,10 +136,10 @@ nrDigits (int num)
bool bool
isInteger (const string& s) isInteger (const std::string& s)
{ {
stringstream ss1 (s); std::stringstream ss1 (s);
stringstream ss2; std::stringstream ss2;
int integer; int integer;
ss1 >> integer; ss1 >> integer;
ss2 << integer; ss2 << integer;
@ -148,10 +148,10 @@ isInteger (const string& s)
string std::string
parametersToString (const Params& v, unsigned precision) parametersToString (const Params& v, unsigned precision)
{ {
stringstream ss; std::stringstream ss;
ss.precision (precision); ss.precision (precision);
ss << "[" ; ss << "[" ;
for (size_t i = 0; i < v.size(); i++) { for (size_t i = 0; i < v.size(); i++) {
@ -164,7 +164,7 @@ parametersToString (const Params& v, unsigned precision)
vector<string> std::vector<std::string>
getStateLines (const Vars& vars) getStateLines (const Vars& vars)
{ {
Ranges ranges; Ranges ranges;
@ -172,9 +172,9 @@ getStateLines (const Vars& vars)
ranges.push_back (vars[i]->range()); ranges.push_back (vars[i]->range());
} }
Indexer indexer (ranges); Indexer indexer (ranges);
vector<string> jointStrings; std::vector<std::string> jointStrings;
while (indexer.valid()) { while (indexer.valid()) {
stringstream ss; std::stringstream ss;
for (size_t i = 0; i < vars.size(); i++) { for (size_t i = 0; i < vars.size(); i++) {
if (i != 0) ss << ", " ; if (i != 0) ss << ", " ;
ss << vars[i]->label() << "=" ; ss << vars[i]->label() << "=" ;
@ -188,34 +188,42 @@ getStateLines (const Vars& vars)
bool invalidValue (string option, string value) bool invalidValue (std::string option, std::string value)
{ {
cerr << "Warning: invalid value `" << value << "' " ; std::cerr << "Warning: invalid value `" << value << "' " ;
cerr << "for `" << option << "'." ; std::cerr << "for `" << option << "'." ;
cerr << endl; std::cerr << std::endl;
return false; return false;
} }
bool bool
setHorusFlag (string option, string value) setHorusFlag (std::string option, std::string value)
{ {
bool returnVal = true; bool returnVal = true;
if (option == "lifted_solver") { if (option == "lifted_solver") {
if (value == "lve") Globals::liftedSolver = LiftedSolverType::LVE; if (value == "lve")
else if (value == "lbp") Globals::liftedSolver = LiftedSolverType::LBP; Globals::liftedSolver = LiftedSolverType::lveSolver;
else if (value == "lkc") Globals::liftedSolver = LiftedSolverType::LKC; else if (value == "lbp")
else returnVal = invalidValue (option, value); Globals::liftedSolver = LiftedSolverType::lbpSolver;
else if (value == "lkc")
Globals::liftedSolver = LiftedSolverType::lkcSolver;
else
returnVal = invalidValue (option, value);
} else if (option == "ground_solver" || option == "solver") { } else if (option == "ground_solver" || option == "solver") {
if (value == "hve") Globals::groundSolver = GroundSolverType::VE; if (value == "hve")
else if (value == "bp") Globals::groundSolver = GroundSolverType::BP; Globals::groundSolver = GroundSolverType::veSolver;
else if (value == "cbp") Globals::groundSolver = GroundSolverType::CBP; else if (value == "bp")
else returnVal = invalidValue (option, value); Globals::groundSolver = GroundSolverType::bpSolver;
else if (value == "cbp")
Globals::groundSolver = GroundSolverType::CbpSolver;
else
returnVal = invalidValue (option, value);
} else if (option == "verbosity") { } else if (option == "verbosity") {
stringstream ss; std::stringstream ss;
ss << value; ss << value;
ss >> Globals::verbosity; ss >> Globals::verbosity;
@ -225,40 +233,42 @@ setHorusFlag (string option, string value)
else returnVal = invalidValue (option, value); else returnVal = invalidValue (option, value);
} else if (option == "hve_elim_heuristic") { } else if (option == "hve_elim_heuristic") {
typedef ElimGraph::ElimHeuristic ElimHeuristic;
if (value == "sequential") if (value == "sequential")
ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL); ElimGraph::setElimHeuristic (ElimHeuristic::sequentialEh);
else if (value == "min_neighbors") else if (value == "min_neighbors")
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_NEIGHBORS); ElimGraph::setElimHeuristic (ElimHeuristic::minNeighborsEh);
else if (value == "min_weight") else if (value == "min_weight")
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_WEIGHT); ElimGraph::setElimHeuristic (ElimHeuristic::minWeightEh);
else if (value == "min_fill") else if (value == "min_fill")
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_FILL); ElimGraph::setElimHeuristic (ElimHeuristic::minFillEh);
else if (value == "weighted_min_fill") else if (value == "weighted_min_fill")
ElimGraph::setElimHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL); ElimGraph::setElimHeuristic (ElimHeuristic::weightedMinFillEh);
else else
returnVal = invalidValue (option, value); returnVal = invalidValue (option, value);
} else if (option == "bp_msg_schedule") { } else if (option == "bp_msg_schedule") {
typedef BeliefProp::MsgSchedule MsgSchedule;
if (value == "seq_fixed") if (value == "seq_fixed")
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED); BeliefProp::setMsgSchedule (MsgSchedule::seqFixedSch);
else if (value == "seq_random") else if (value == "seq_random")
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_RANDOM); BeliefProp::setMsgSchedule (MsgSchedule::seqRandomSch);
else if (value == "parallel") else if (value == "parallel")
BeliefProp::setMsgSchedule (MsgSchedule::PARALLEL); BeliefProp::setMsgSchedule (MsgSchedule::parallelSch);
else if (value == "max_residual") else if (value == "max_residual")
BeliefProp::setMsgSchedule (MsgSchedule::MAX_RESIDUAL); BeliefProp::setMsgSchedule (MsgSchedule::maxResidualSch);
else else
returnVal = invalidValue (option, value); returnVal = invalidValue (option, value);
} else if (option == "bp_accuracy") { } else if (option == "bp_accuracy") {
stringstream ss; std::stringstream ss;
double acc; double acc;
ss << value; ss << value;
ss >> acc; ss >> acc;
BeliefProp::setAccuracy (acc); BeliefProp::setAccuracy (acc);
} else if (option == "bp_max_iter") { } else if (option == "bp_max_iter") {
stringstream ss; std::stringstream ss;
unsigned mi; unsigned mi;
ss << value; ss << value;
ss >> mi; ss >> mi;
@ -285,7 +295,7 @@ setHorusFlag (string option, string value)
else returnVal = invalidValue (option, value); else returnVal = invalidValue (option, value);
} else { } else {
cerr << "Warning: invalid option `" << option << "'" << endl; std::cerr << "Warning: invalid option `" << option << "'" << std::endl;
returnVal = false; returnVal = false;
} }
return returnVal; return returnVal;
@ -294,20 +304,20 @@ setHorusFlag (string option, string value)
void void
printHeader (string header, std::ostream& os) printHeader (std::string header, std::ostream& os)
{ {
printAsteriskLine (os); printAsteriskLine (os);
os << header << endl; os << header << std::endl;
printAsteriskLine (os); printAsteriskLine (os);
} }
void void
printSubHeader (string header, std::ostream& os) printSubHeader (std::string header, std::ostream& os)
{ {
printDashedLine (os); printDashedLine (os);
os << header << endl; os << header << std::endl;
printDashedLine (os); printDashedLine (os);
} }
@ -318,7 +328,7 @@ printAsteriskLine (std::ostream& os)
{ {
os << "********************************" ; os << "********************************" ;
os << "********************************" ; os << "********************************" ;
os << endl; os << std::endl;
} }
@ -328,11 +338,10 @@ printDashedLine (std::ostream& os)
{ {
os << "--------------------------------" ; os << "--------------------------------" ;
os << "--------------------------------" ; os << "--------------------------------" ;
os << endl; os << std::endl;
} }
} // namespace Util
}
@ -362,10 +371,10 @@ getL1Distance (const Params& v1, const Params& v2)
double dist = 0.0; double dist = 0.0;
if (Globals::logDomain) { if (Globals::logDomain) {
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0, dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
std::plus<double>(), FuncObject::abs_diff_exp<double>()); std::plus<double>(), FuncObj::abs_diff_exp<double>());
} else { } else {
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0, dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
std::plus<double>(), FuncObject::abs_diff<double>()); std::plus<double>(), FuncObj::abs_diff<double>());
} }
return dist; return dist;
} }
@ -379,10 +388,10 @@ getMaxNorm (const Params& v1, const Params& v2)
double max = 0.0; double max = 0.0;
if (Globals::logDomain) { if (Globals::logDomain) {
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0, max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
FuncObject::max<double>(), FuncObject::abs_diff_exp<double>()); FuncObj::max<double>(), FuncObj::abs_diff_exp<double>());
} else { } else {
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0, max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
FuncObject::max<double>(), FuncObject::abs_diff<double>()); FuncObj::max<double>(), FuncObj::abs_diff<double>());
} }
return max; return max;
} }
@ -428,5 +437,8 @@ pow (Params& v, double exp)
Globals::logDomain ? v *= exp : v ^= exp; Globals::logDomain ? v *= exp : v ^= exp;
} }
} } // namespace LogAware
} // namespace Horus

View File

@ -1,69 +1,78 @@
#ifndef HORUS_UTIL_H #ifndef YAP_PACKAGES_CLPBN_HORUS_UTIL_H_
#define HORUS_UTIL_H #define YAP_PACKAGES_CLPBN_HORUS_UTIL_H_
#include <cmath> #include <cmath>
#include <cassert> #include <cassert>
#include <algorithm>
#include <limits>
#include <vector> #include <vector>
#include <queue> #include <queue>
#include <set> #include <set>
#include <unordered_map> #include <unordered_map>
#include <algorithm>
#include <limits>
#include <string>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "Horus.h" #include "Horus.h"
using namespace std; namespace Horus {
namespace { namespace {
const double NEG_INF = -std::numeric_limits<double>::infinity(); const double NEG_INF = -std::numeric_limits<double>::infinity();
};
}
namespace Util { namespace Util {
template <typename T> void addToVector (vector<T>&, const vector<T>&); template <typename T> void
addToVector (std::vector<T>&, const std::vector<T>&);
template <typename T> void addToSet (set<T>&, const vector<T>&); template <typename T> void
addToSet (std::set<T>&, const std::vector<T>&);
template <typename T> void addToQueue (queue<T>&, const vector<T>&); template <typename T> void
addToQueue (std::queue<T>&, const std::vector<T>&);
template <typename T> bool contains (const vector<T>&, const T&); template <typename T> bool
contains (const std::vector<T>&, const T&);
template <typename T> bool contains (const set<T>&, const T&); template <typename T> bool contains
(const std::set<T>&, const T&);
template <typename K, typename V> bool contains ( template <typename K, typename V> bool
const unordered_map<K, V>&, const K&); contains (const std::unordered_map<K, V>&, const K&);
template <typename T> size_t indexOf (const vector<T>&, const T&); template <typename T> size_t
indexOf (const std::vector<T>&, const T&);
template <class Operation> template <class Operation> void
void apply_n_times (Params& v1, const Params& v2, apply_n_times (Params& v1, const Params& v2, unsigned reps, Operation);
unsigned repetitions, Operation);
template <typename T> void log (vector<T>&); template <typename T> void
log (std::vector<T>&);
template <typename T> void exp (vector<T>&); template <typename T> void
exp (std::vector<T>&);
template <typename T> string elementsToString ( template <typename T> std::string
const vector<T>& v, string sep = " "); elementsToString (const std::vector<T>& v, std::string sep = " ");
template <typename T> std::string toString (const T&); template <typename T> std::string
toString (const T&);
template <> std::string toString (const bool&); template <> std::string
toString (const bool&);
double logSum (double, double); double logSum (double, double);
unsigned maxUnsigned (void); unsigned maxUnsigned();
unsigned stringToUnsigned (string); unsigned stringToUnsigned (std::string);
double stringToDouble (string); double stringToDouble (std::string);
double factorial (unsigned); double factorial (unsigned);
@ -75,28 +84,29 @@ size_t sizeExpected (const Ranges&);
unsigned nrDigits (int); unsigned nrDigits (int);
bool isInteger (const string&); bool isInteger (const std::string&);
string parametersToString (const Params&, unsigned = Constants::PRECISION); std::string parametersToString (
const Params&, unsigned = Constants::precision);
vector<string> getStateLines (const Vars&); std::vector<std::string> getStateLines (const Vars&);
bool setHorusFlag (string option, string value); bool setHorusFlag (std::string option, std::string value);
void printHeader (string, std::ostream& os = std::cout); void printHeader (std::string, std::ostream& os = std::cout);
void printSubHeader (string, std::ostream& os = std::cout); void printSubHeader (std::string, std::ostream& os = std::cout);
void printAsteriskLine (std::ostream& os = std::cout); void printAsteriskLine (std::ostream& os = std::cout);
void printDashedLine (std::ostream& os = std::cout); void printDashedLine (std::ostream& os = std::cout);
}; } // namespace Util
template <typename T> void template <typename T> void
Util::addToVector (vector<T>& v, const vector<T>& elements) Util::addToVector (std::vector<T>& v, const std::vector<T>& elements)
{ {
v.insert (v.end(), elements.begin(), elements.end()); v.insert (v.end(), elements.begin(), elements.end());
} }
@ -104,7 +114,7 @@ Util::addToVector (vector<T>& v, const vector<T>& elements)
template <typename T> void template <typename T> void
Util::addToSet (set<T>& s, const vector<T>& elements) Util::addToSet (std::set<T>& s, const std::vector<T>& elements)
{ {
s.insert (elements.begin(), elements.end()); s.insert (elements.begin(), elements.end());
} }
@ -112,7 +122,7 @@ Util::addToSet (set<T>& s, const vector<T>& elements)
template <typename T> void template <typename T> void
Util::addToQueue (queue<T>& q, const vector<T>& elements) Util::addToQueue (std::queue<T>& q, const std::vector<T>& elements)
{ {
for (size_t i = 0; i < elements.size(); i++) { for (size_t i = 0; i < elements.size(); i++) {
q.push (elements[i]); q.push (elements[i]);
@ -122,7 +132,7 @@ Util::addToQueue (queue<T>& q, const vector<T>& elements)
template <typename T> bool template <typename T> bool
Util::contains (const vector<T>& v, const T& e) Util::contains (const std::vector<T>& v, const T& e)
{ {
return std::find (v.begin(), v.end(), e) != v.end(); return std::find (v.begin(), v.end(), e) != v.end();
} }
@ -130,7 +140,7 @@ Util::contains (const vector<T>& v, const T& e)
template <typename T> bool template <typename T> bool
Util::contains (const set<T>& s, const T& e) Util::contains (const std::set<T>& s, const T& e)
{ {
return s.find (e) != s.end(); return s.find (e) != s.end();
} }
@ -138,7 +148,7 @@ Util::contains (const set<T>& s, const T& e)
template <typename K, typename V> bool template <typename K, typename V> bool
Util::contains (const unordered_map<K, V>& m, const K& k) Util::contains (const std::unordered_map<K, V>& m, const K& k)
{ {
return m.find (k) != m.end(); return m.find (k) != m.end();
} }
@ -146,7 +156,7 @@ Util::contains (const unordered_map<K, V>& m, const K& k)
template <typename T> size_t template <typename T> size_t
Util::indexOf (const vector<T>& v, const T& e) Util::indexOf (const std::vector<T>& v, const T& e)
{ {
return std::distance (v.begin(), return std::distance (v.begin(),
std::find (v.begin(), v.end(), e)); std::find (v.begin(), v.end(), e));
@ -155,7 +165,10 @@ Util::indexOf (const vector<T>& v, const T& e)
template <class Operation> void template <class Operation> void
Util::apply_n_times (Params& v1, const Params& v2, unsigned repetitions, Util::apply_n_times (
Params& v1,
const Params& v2,
unsigned repetitions,
Operation unary_op) Operation unary_op)
{ {
Params::iterator first = v1.begin(); Params::iterator first = v1.begin();
@ -174,7 +187,7 @@ Util::apply_n_times (Params& v1, const Params& v2, unsigned repetitions,
template <typename T> void template <typename T> void
Util::log (vector<T>& v) Util::log (std::vector<T>& v)
{ {
std::transform (v.begin(), v.end(), v.begin(), ::log); std::transform (v.begin(), v.end(), v.begin(), ::log);
} }
@ -182,17 +195,17 @@ Util::log (vector<T>& v)
template <typename T> void template <typename T> void
Util::exp (vector<T>& v) Util::exp (std::vector<T>& v)
{ {
std::transform (v.begin(), v.end(), v.begin(), ::exp); std::transform (v.begin(), v.end(), v.begin(), ::exp);
} }
template <typename T> string template <typename T> std::string
Util::elementsToString (const vector<T>& v, string sep) Util::elementsToString (const std::vector<T>& v, std::string sep)
{ {
stringstream ss; std::stringstream ss;
for (size_t i = 0; i < v.size(); i++) { for (size_t i = 0; i < v.size(); i++) {
ss << ((i != 0) ? sep : "") << v[i]; ss << ((i != 0) ? sep : "") << v[i];
} }
@ -245,7 +258,7 @@ Util::logSum (double x, double y)
inline unsigned inline unsigned
Util::maxUnsigned (void) Util::maxUnsigned()
{ {
return std::numeric_limits<unsigned>::max(); return std::numeric_limits<unsigned>::max();
} }
@ -277,106 +290,106 @@ void pow (Params&, unsigned);
void pow (Params&, double); void pow (Params&, double);
}; } // namespace LogAware
template <typename T> template <typename T> void
void operator+=(std::vector<T>& v, double val) operator+=(std::vector<T>& v, double val)
{ {
std::transform (v.begin(), v.end(), v.begin(), std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (plus<double>(), val)); std::bind2nd (std::plus<double>(), val));
} }
template <typename T> template <typename T> void
void operator-=(std::vector<T>& v, double val) operator-=(std::vector<T>& v, double val)
{ {
std::transform (v.begin(), v.end(), v.begin(), std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (minus<double>(), val)); std::bind2nd (std::minus<double>(), val));
} }
template <typename T> template <typename T> void
void operator*=(std::vector<T>& v, double val) operator*=(std::vector<T>& v, double val)
{ {
std::transform (v.begin(), v.end(), v.begin(), std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (multiplies<double>(), val)); std::bind2nd (std::multiplies<double>(), val));
} }
template <typename T> template <typename T> void
void operator/=(std::vector<T>& v, double val) operator/=(std::vector<T>& v, double val)
{ {
std::transform (v.begin(), v.end(), v.begin(), std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (divides<double>(), val)); std::bind2nd (std::divides<double>(), val));
} }
template <typename T> template <typename T> void
void operator+=(std::vector<T>& a, const std::vector<T>& b) operator+=(std::vector<T>& a, const std::vector<T>& b)
{ {
assert (a.size() == b.size()); assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(), std::transform (a.begin(), a.end(), b.begin(), a.begin(),
plus<double>()); std::plus<double>());
} }
template <typename T> template <typename T> void
void operator-=(std::vector<T>& a, const std::vector<T>& b) operator-=(std::vector<T>& a, const std::vector<T>& b)
{ {
assert (a.size() == b.size()); assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(), std::transform (a.begin(), a.end(), b.begin(), a.begin(),
minus<double>()); std::minus<double>());
} }
template <typename T> template <typename T> void
void operator*=(std::vector<T>& a, const std::vector<T>& b) operator*=(std::vector<T>& a, const std::vector<T>& b)
{ {
assert (a.size() == b.size()); assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(), std::transform (a.begin(), a.end(), b.begin(), a.begin(),
multiplies<double>()); std::multiplies<double>());
} }
template <typename T> template <typename T> void
void operator/=(std::vector<T>& a, const std::vector<T>& b) operator/=(std::vector<T>& a, const std::vector<T>& b)
{ {
assert (a.size() == b.size()); assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(), std::transform (a.begin(), a.end(), b.begin(), a.begin(),
divides<double>()); std::divides<double>());
} }
template <typename T> template <typename T> void
void operator^=(std::vector<T>& v, double exp) operator^=(std::vector<T>& v, double exp)
{ {
std::transform (v.begin(), v.end(), v.begin(), std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp)); std::bind2nd (std::ptr_fun<double, double, double> (std::pow), exp));
} }
template <typename T> template <typename T> void
void operator^=(std::vector<T>& v, int iexp) operator^=(std::vector<T>& v, int iexp)
{ {
std::transform (v.begin(), v.end(), v.begin(), std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp)); std::bind2nd (std::ptr_fun<double, int, double> (std::pow), iexp));
} }
template <typename T> template <typename T> std::ostream&
std::ostream& operator<< (std::ostream& os, const vector<T>& v) operator<< (std::ostream& os, const std::vector<T>& v)
{ {
os << "[" ; os << "[" ;
os << Util::elementsToString (v, ", "); os << Util::elementsToString (v, ", ");
@ -385,40 +398,33 @@ std::ostream& operator<< (std::ostream& os, const vector<T>& v)
} }
namespace FuncObject { namespace FuncObj {
template<typename T> template<typename T>
struct max : public std::binary_function<T, T, T> struct max : public std::binary_function<T, T, T> {
{ T operator() (const T& x, const T& y) const {
T operator() (const T& x, const T& y) const
{
return x < y ? y : x; return x < y ? y : x;
} }};
};
template <typename T> template <typename T>
struct abs_diff : public std::binary_function<T, T, T> struct abs_diff : public std::binary_function<T, T, T> {
{ T operator() (const T& x, const T& y) const {
T operator() (const T& x, const T& y) const
{
return std::abs (x - y); return std::abs (x - y);
} }};
};
template <typename T> template <typename T>
struct abs_diff_exp : public std::binary_function<T, T, T> struct abs_diff_exp : public std::binary_function<T, T, T> {
{ T operator() (const T& x, const T& y) const {
T operator() (const T& x, const T& y) const
{
return std::abs (std::exp (x) - std::exp (y)); return std::abs (std::exp (x) - std::exp (y));
} }};
};
} } // namespace FuncObj
#endif // HORUS_UTIL_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_UTIL_H_

View File

@ -3,7 +3,9 @@
#include "Var.h" #include "Var.h"
unordered_map<VarId, VarInfo> Var::varsInfo_; namespace Horus {
std::unordered_map<VarId, Var::VarInfo> Var::varsInfo_;
Var::Var (const Var* v) Var::Var (const Var* v)
@ -45,13 +47,14 @@ Var::setEvidence (int evidence)
string std::string
Var::label (void) const Var::label() const
{ {
if (Var::varsHaveInfo()) { if (Var::varsHaveInfo()) {
return Var::getVarInfo (varId_).label; assert (Util::contains (varsInfo_, varId_));
return varsInfo_.find (varId_)->second.first;
} }
stringstream ss; std::stringstream ss;
ss << "x" << varId_; ss << "x" << varId_;
return ss.str(); return ss.str();
} }
@ -59,17 +62,46 @@ Var::label (void) const
States States
Var::states (void) const Var::states() const
{ {
if (Var::varsHaveInfo()) { if (Var::varsHaveInfo()) {
return Var::getVarInfo (varId_).states; assert (Util::contains (varsInfo_, varId_));
return varsInfo_.find (varId_)->second.second;
} }
States states; States states;
for (unsigned i = 0; i < range_; i++) { for (unsigned i = 0; i < range_; i++) {
stringstream ss; std::stringstream ss;
ss << i ; ss << i ;
states.push_back (ss.str()); states.push_back (ss.str());
} }
return states; return states;
} }
void
Var::addVarInfo (
VarId vid, std::string label, const States& states)
{
assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (std::make_pair (vid, VarInfo (label, states)));
}
bool
Var::varsHaveInfo()
{
return varsInfo_.empty() == false;
}
void
Var::clearVarsInfo()
{
varsInfo_.clear();
}
} // namespace Horus

View File

@ -1,102 +1,105 @@
#ifndef HORUS_VAR_H #ifndef YAP_PACKAGES_CLPBN_HORUS_VAR_H_
#define HORUS_VAR_H #define YAP_PACKAGES_CLPBN_HORUS_VAR_H_
#include <cassert> #include <cassert>
#include <unordered_map>
#include <string>
#include "Util.h" #include "Util.h"
#include "Horus.h" #include "Horus.h"
using namespace std; namespace Horus {
class Var {
struct VarInfo
{
VarInfo (string l, const States& sts)
: label(l), states(sts) { }
string label;
States states;
};
class Var
{
public: public:
Var (const Var*); Var (const Var*);
Var (VarId, unsigned, int = Constants::NO_EVIDENCE); Var (VarId, unsigned range, int evidence = Constants::unobserved);
virtual ~Var (void) { }; virtual ~Var() { };
VarId varId (void) const { return varId_; } VarId varId() const { return varId_; }
unsigned range (void) const { return range_; } unsigned range() const { return range_; }
int getEvidence (void) const { return evidence_; } int getEvidence() const { return evidence_; }
size_t getIndex (void) const { return index_; } size_t getIndex() const { return index_; }
void setIndex (size_t idx) { index_ = idx; } void setIndex (size_t idx) { index_ = idx; }
bool hasEvidence (void) const bool hasEvidence() const;
{
return evidence_ != Constants::NO_EVIDENCE;
}
operator size_t (void) const { return index_; } operator size_t() const;
bool operator== (const Var& var) const bool operator== (const Var& var) const;
{
assert (!(varId_ == var.varId() && range_ != var.range()));
return varId_ == var.varId();
}
bool operator!= (const Var& var) const bool operator!= (const Var& var) const;
{
return !(*this == var);
}
bool isValidState (int); bool isValidState (int);
void setEvidence (int); void setEvidence (int);
string label (void) const; std::string label() const;
States states (void) const; States states() const;
static void addVarInfo ( static void addVarInfo (
VarId vid, string label, const States& states) VarId vid, std::string label, const States& states);
{
assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
}
static VarInfo getVarInfo (VarId vid) static bool varsHaveInfo();
{
assert (Util::contains (varsInfo_, vid));
return varsInfo_.find (vid)->second;
}
static bool varsHaveInfo (void) static void clearVarsInfo();
{
return varsInfo_.empty() == false;
}
static void clearVarsInfo (void)
{
varsInfo_.clear();
}
private: private:
typedef std::pair<std::string, States> VarInfo;
VarId varId_; VarId varId_;
unsigned range_; unsigned range_;
int evidence_; int evidence_;
size_t index_; size_t index_;
static unordered_map<VarId, VarInfo> varsInfo_; static std::unordered_map<VarId, VarInfo> varsInfo_;
DISALLOW_COPY_AND_ASSIGN(Var);
}; };
#endif // HORUS_VAR_H
inline bool
Var::hasEvidence() const
{
return evidence_ != Constants::unobserved;
}
inline
Var::operator size_t() const
{
return index_;
}
inline bool
Var::operator== (const Var& var) const
{
assert (!(varId_ == var.varId() && range_ != var.range()));
return varId_ == var.varId();
}
inline bool
Var::operator!= (const Var& var) const
{
return !(*this == var);
}
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_VAR_H_

View File

@ -1,4 +1,6 @@
#include <algorithm> #include <algorithm>
#include <iostream>
#include <sstream>
#include "VarElim.h" #include "VarElim.h"
#include "ElimGraph.h" #include "ElimGraph.h"
@ -6,16 +8,18 @@
#include "Util.h" #include "Util.h"
namespace Horus {
Params Params
VarElim::solveQuery (VarIds queryVids) VarElim::solveQuery (VarIds queryVids)
{ {
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << "Solving query on " ; std::cout << "Solving query on " ;
for (size_t i = 0; i < queryVids.size(); i++) { for (size_t i = 0; i < queryVids.size(); i++) {
if (i != 0) cout << ", " ; if (i != 0) std::cout << ", " ;
cout << fg.getVarNode (queryVids[i])->label(); std::cout << fg.getVarNode (queryVids[i])->label();
} }
cout << endl; std::cout << std::endl;
} }
totalFactorSize_ = 0; totalFactorSize_ = 0;
largestFactorSize_ = 0; largestFactorSize_ = 0;
@ -33,27 +37,28 @@ VarElim::solveQuery (VarIds queryVids)
void void
VarElim::printSolverFlags (void) const VarElim::printSolverFlags() const
{ {
stringstream ss; std::stringstream ss;
ss << "variable elimination [" ; ss << "variable elimination [" ;
ss << "elim_heuristic=" ; ss << "elim_heuristic=" ;
typedef ElimGraph::ElimHeuristic ElimHeuristic;
switch (ElimGraph::elimHeuristic()) { switch (ElimGraph::elimHeuristic()) {
case ElimHeuristic::SEQUENTIAL: ss << "sequential"; break; case ElimHeuristic::sequentialEh: ss << "sequential"; break;
case ElimHeuristic::MIN_NEIGHBORS: ss << "min_neighbors"; break; case ElimHeuristic::minNeighborsEh: ss << "min_neighbors"; break;
case ElimHeuristic::MIN_WEIGHT: ss << "min_weight"; break; case ElimHeuristic::minWeightEh: ss << "min_weight"; break;
case ElimHeuristic::MIN_FILL: ss << "min_fill"; break; case ElimHeuristic::minFillEh: ss << "min_fill"; break;
case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break; case ElimHeuristic::weightedMinFillEh: ss << "weighted_min_fill"; break;
} }
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; std::cout << ss.str() << std::endl;
} }
void void
VarElim::createFactorList (void) VarElim::createFactorList()
{ {
const FacNodes& facNodes = fg.facNodes(); const FacNodes& facNodes = fg.facNodes();
factorList_.reserve (facNodes.size() * 2); factorList_.reserve (facNodes.size() * 2);
@ -61,7 +66,7 @@ VarElim::createFactorList (void)
factorList_.push_back (new Factor (facNodes[i]->factor())); factorList_.push_back (new Factor (facNodes[i]->factor()));
const VarIds& args = facNodes[i]->factor().arguments(); const VarIds& args = facNodes[i]->factor().arguments();
for (size_t j = 0; j < args.size(); j++) { for (size_t j = 0; j < args.size(); j++) {
unordered_map<VarId, vector<size_t>>::iterator it; std::unordered_map<VarId, std::vector<size_t>>::iterator it;
it = varMap_.find (args[j]); it = varMap_.find (args[j]);
if (it != varMap_.end()) { if (it != varMap_.end()) {
it->second.push_back (i); it->second.push_back (i);
@ -75,22 +80,22 @@ VarElim::createFactorList (void)
void void
VarElim::absorveEvidence (void) VarElim::absorveEvidence()
{ {
if (Globals::verbosity > 2) { if (Globals::verbosity > 2) {
Util::printDashedLine(); Util::printDashedLine();
cout << "(initial factor list)" << endl; std::cout << "(initial factor list)" << std::endl;
printActiveFactors(); printActiveFactors();
} }
const VarNodes& varNodes = fg.varNodes(); const VarNodes& varNodes = fg.varNodes();
for (size_t i = 0; i < varNodes.size(); i++) { for (size_t i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) { if (varNodes[i]->hasEvidence()) {
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << "-> aborving evidence on "; std::cout << "-> aborving evidence on ";
cout << varNodes[i]->label() << " = " ; std::cout << varNodes[i]->label() << " = " ;
cout << varNodes[i]->getEvidence() << endl; std::cout << varNodes[i]->getEvidence() << std::endl;
} }
const vector<size_t>& indices = varMap_[varNodes[i]->varId()]; const std::vector<size_t>& indices = varMap_[varNodes[i]->varId()];
for (size_t j = 0; j < indices.size(); j++) { for (size_t j = 0; j < indices.size(); j++) {
size_t idx = indices[j]; size_t idx = indices[j];
if (factorList_[idx]->nrArguments() > 1) { if (factorList_[idx]->nrArguments() > 1) {
@ -118,8 +123,8 @@ VarElim::processFactorList (const VarIds& queryVids)
Util::printDashedLine(); Util::printDashedLine();
printActiveFactors(); printActiveFactors();
} }
cout << "-> summing out " ; std::cout << "-> summing out " ;
cout << fg.getVarNode (elimOrder[i])->label() << endl; std::cout << fg.getVarNode (elimOrder[i])->label() << std::endl;
} }
eliminate (elimOrder[i]); eliminate (elimOrder[i]);
} }
@ -143,9 +148,9 @@ VarElim::processFactorList (const VarIds& queryVids)
result.reorderArguments (unobservedVids); result.reorderArguments (unobservedVids);
result.normalize(); result.normalize();
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
cout << "total factor size: " << totalFactorSize_ << endl; std::cout << "total factor size: " << totalFactorSize_ << std::endl;
cout << "largest factor size: " << largestFactorSize_ << endl; std::cout << "largest factor size: " << largestFactorSize_ << std::endl;
cout << endl; std::cout << std::endl;
} }
return result.params(); return result.params();
} }
@ -156,7 +161,7 @@ void
VarElim::eliminate (VarId vid) VarElim::eliminate (VarId vid)
{ {
Factor* result = new Factor(); Factor* result = new Factor();
const vector<size_t>& indices = varMap_[vid]; const std::vector<size_t>& indices = varMap_[vid];
for (size_t i = 0; i < indices.size(); i++) { for (size_t i = 0; i < indices.size(); i++) {
size_t idx = indices[i]; size_t idx = indices[i];
if (factorList_[idx]) { if (factorList_[idx]) {
@ -173,7 +178,7 @@ VarElim::eliminate (VarId vid)
result->sumOut (vid); result->sumOut (vid);
const VarIds& args = result->arguments(); const VarIds& args = result->arguments();
for (size_t i = 0; i < args.size(); i++) { for (size_t i = 0; i < args.size(); i++) {
vector<size_t>& indices2 = varMap_[args[i]]; std::vector<size_t>& indices2 = varMap_[args[i]];
indices2.push_back (factorList_.size()); indices2.push_back (factorList_.size());
} }
factorList_.push_back (result); factorList_.push_back (result);
@ -185,14 +190,16 @@ VarElim::eliminate (VarId vid)
void void
VarElim::printActiveFactors (void) VarElim::printActiveFactors()
{ {
for (size_t i = 0; i < factorList_.size(); i++) { for (size_t i = 0; i < factorList_.size(); i++) {
if (factorList_[i]) { if (factorList_[i]) {
cout << factorList_[i]->getLabel() << " " ; std::cout << factorList_[i]->getLabel() << " " ;
cout << factorList_[i]->params(); std::cout << factorList_[i]->params();
cout << endl; std::cout << std::endl;
} }
} }
} }
} // namespace Horus

View File

@ -1,45 +1,46 @@
#ifndef HORUS_VARELIM_H #ifndef YAP_PACKAGES_CLPBN_HORUS_VARELIM_H_
#define HORUS_VARELIM_H #define YAP_PACKAGES_CLPBN_HORUS_VARELIM_H_
#include "unordered_map" #include <vector>
#include <unordered_map>
#include "GroundSolver.h" #include "GroundSolver.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Horus.h" #include "Horus.h"
using namespace std; namespace Horus {
class VarElim : public GroundSolver {
class VarElim : public GroundSolver
{
public: public:
VarElim (const FactorGraph& fg) : GroundSolver (fg) { } VarElim (const FactorGraph& fg) : GroundSolver (fg) { }
~VarElim (void) { } ~VarElim() { }
Params solveQuery (VarIds); Params solveQuery (VarIds);
void printSolverFlags (void) const; void printSolverFlags() const;
private: private:
void createFactorList (void); void createFactorList();
void absorveEvidence (void); void absorveEvidence();
Params processFactorList (const VarIds&); Params processFactorList (const VarIds&);
void eliminate (VarId); void eliminate (VarId);
void printActiveFactors (void); void printActiveFactors();
Factors factorList_; Factors factorList_;
unsigned largestFactorSize_; unsigned largestFactorSize_;
unsigned totalFactorSize_; unsigned totalFactorSize_;
unordered_map<VarId, vector<size_t>> varMap_; std::unordered_map<VarId, std::vector<size_t>> varMap_;
DISALLOW_COPY_AND_ASSIGN (VarElim); DISALLOW_COPY_AND_ASSIGN (VarElim);
}; };
#endif // HORUS_VARELIM_H } // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_VARELIM_H_

View File

@ -1,7 +1,24 @@
#include <cassert>
#include <iostream>
#include <iomanip>
#include "WeightedBp.h" #include "WeightedBp.h"
WeightedBp::~WeightedBp (void) namespace Horus {
WeightedBp::WeightedBp (
const FactorGraph& fg,
const std::vector<std::vector<unsigned>>& weights)
: BeliefProp (fg), weights_(weights)
{
}
WeightedBp::~WeightedBp()
{ {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
delete links_[i]; delete links_[i];
@ -25,7 +42,7 @@ WeightedBp::getPosterioriOf (VarId vid)
probs[var->getEvidence()] = LogAware::withEvidence(); probs[var->getEvidence()] = LogAware::withEvidence();
} else { } else {
probs.resize (var->range(), LogAware::multIdenty()); probs.resize (var->range(), LogAware::multIdenty());
const BpLinks& links = ninf(var)->getLinks(); const BpLinks& links = getLinks (var);
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links[i]); WeightedLink* l = static_cast<WeightedLink*> (links[i]);
@ -46,9 +63,24 @@ WeightedBp::getPosterioriOf (VarId vid)
void WeightedBp::WeightedLink::WeightedLink (
WeightedBp::createLinks (void) FacNode* fn,
VarNode* vn,
size_t idx,
unsigned weight)
: BpLink (fn, vn), index_(idx), weight_(weight),
pwdMsg_(vn->range(), LogAware::one())
{ {
}
void
WeightedBp::createLinks()
{
using std::cout;
using std::endl;
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
cout << "compressed factor graph contains " ; cout << "compressed factor graph contains " ;
cout << fg.nrVarNodes() << " variables and " ; cout << fg.nrVarNodes() << " variables and " ;
@ -78,7 +110,7 @@ WeightedBp::createLinks (void)
void void
WeightedBp::maxResidualSchedule (void) WeightedBp::maxResidualSchedule()
{ {
if (nIters_ == 1) { if (nIters_ == 1) {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
@ -86,7 +118,7 @@ WeightedBp::maxResidualSchedule (void)
SortedOrder::iterator it = sortedOrder_.insert (links_[i]); SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it)); linkMap_.insert (make_pair (links_[i], it));
if (Globals::verbosity >= 1) { if (Globals::verbosity >= 1) {
cout << "calculating " << links_[i]->toString() << endl; std::cout << "calculating " << links_[i]->toString() << std::endl;
} }
} }
return; return;
@ -94,18 +126,20 @@ WeightedBp::maxResidualSchedule (void)
for (size_t c = 0; c < links_.size(); c++) { for (size_t c = 0; c < links_.size(); c++) {
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << endl << "current residuals:" << endl; std::cout << std::endl << "current residuals:" << std::endl;
for (SortedOrder::iterator it = sortedOrder_.begin(); for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); ++it) { it != sortedOrder_.end(); ++it) {
cout << " " << setw (30) << left << (*it)->toString(); std::cout << " " << std::setw (30) << std::left;
cout << "residual = " << (*it)->residual() << endl; std::cout << (*it)->toString();
std::cout << "residual = " << (*it)->residual() << std::endl;
} }
} }
SortedOrder::iterator it = sortedOrder_.begin(); SortedOrder::iterator it = sortedOrder_.begin();
BpLink* link = *it; BpLink* link = *it;
if (Globals::verbosity >= 1) { if (Globals::verbosity >= 1) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl; std::cout << "updating " << (*sortedOrder_.begin())->toString();
std::cout << std::endl;
} }
if (link->residual() < accuracy_) { if (link->residual() < accuracy_) {
return; return;
@ -118,11 +152,12 @@ WeightedBp::maxResidualSchedule (void)
// update the messages that depend on message source --> destin // update the messages that depend on message source --> destin
const FacNodes& factorNeighbors = link->varNode()->neighbors(); const FacNodes& factorNeighbors = link->varNode()->neighbors();
for (size_t i = 0; i < factorNeighbors.size(); i++) { for (size_t i = 0; i < factorNeighbors.size(); i++) {
const BpLinks& links = ninf(factorNeighbors[i])->getLinks(); const BpLinks& links = getLinks (factorNeighbors[i]);
for (size_t j = 0; j < links.size(); j++) { for (size_t j = 0; j < links.size(); j++) {
if (links[j]->varNode() != link->varNode()) { if (links[j]->varNode() != link->varNode()) {
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << " calculating " << links[j]->toString() << endl; std::cout << " calculating " << links[j]->toString();
std::cout << std::endl;
} }
calculateMessage (links[j]); calculateMessage (links[j]);
BpLinkMap::iterator iter = linkMap_.find (links[j]); BpLinkMap::iterator iter = linkMap_.find (links[j]);
@ -133,11 +168,12 @@ WeightedBp::maxResidualSchedule (void)
} }
// in counting bp, the message that a variable X sends to // in counting bp, the message that a variable X sends to
// to a factor F depends on the message that F sent to the X // to a factor F depends on the message that F sent to the X
const BpLinks& links = ninf(link->facNode())->getLinks(); const BpLinks& links = getLinks (link->facNode());
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
if (links[i]->varNode() != link->varNode()) { if (links[i]->varNode() != link->varNode()) {
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << " calculating " << links[i]->toString() << endl; std::cout << " calculating " << links[i]->toString();
std::cout << std::endl;
} }
calculateMessage (links[i]); calculateMessage (links[i]);
BpLinkMap::iterator iter = linkMap_.find (links[i]); BpLinkMap::iterator iter = linkMap_.find (links[i]);
@ -156,7 +192,7 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
WeightedLink* link = static_cast<WeightedLink*> (_link); WeightedLink* link = static_cast<WeightedLink*> (_link);
FacNode* src = link->facNode(); FacNode* src = link->facNode();
const VarNode* dst = link->varNode(); const VarNode* dst = link->varNode();
const BpLinks& links = ninf(src)->getLinks(); const BpLinks& links = getLinks (src);
// calculate the product of messages that were sent // calculate the product of messages that were sent
// to factor `src', except from var `dst' // to factor `src', except from var `dst'
unsigned reps = 1; unsigned reps = 1;
@ -166,14 +202,14 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]); const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
if ( ! (l->varNode() == dst && l->index() == link->index())) { if ( ! (l->varNode() == dst && l->index() == link->index())) {
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " message from " << links[i]->varNode()->label(); std::cout << " message from " << links[i]->varNode()->label();
cout << ": " ; std::cout << ": " ;
} }
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]), Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::plus<double>()); reps, std::plus<double>());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << endl; std::cout << std::endl;
} }
} }
reps *= links[i]->varNode()->range(); reps *= links[i]->varNode()->range();
@ -182,14 +218,14 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]); const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
if ( ! (l->varNode() == dst && l->index() == link->index())) { if ( ! (l->varNode() == dst && l->index() == link->index())) {
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " message from " << links[i]->varNode()->label(); std::cout << " message from " << links[i]->varNode()->label();
cout << ": " ; std::cout << ": " ;
} }
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]), Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::multiplies<double>()); reps, std::multiplies<double>());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << endl; std::cout << std::endl;
} }
} }
reps *= links[i]->varNode()->range(); reps *= links[i]->varNode()->range();
@ -203,27 +239,33 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
} else { } else {
result.params() *= src->factor().params(); result.params() *= src->factor().params();
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " message product: " << msgProduct << endl; std::cout << " message product: " ;
cout << " original factor: " << src->factor().params() << endl; std::cout << msgProduct << std::endl;
cout << " factor product: " << result.params() << endl; std::cout << " original factor: " ;
std::cout << src->factor().params() << std::endl;
std::cout << " factor product: " ;
std::cout << result.params() << std::endl;
} }
result.sumOutAllExceptIndex (link->index()); result.sumOutAllExceptIndex (link->index());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " marginalized: " << result.params() << endl; std::cout << " marginalized: " ;
std::cout << result.params() << std::endl;
} }
link->nextMessage() = result.params(); link->nextMessage() = result.params();
LogAware::normalize (link->nextMessage()); LogAware::normalize (link->nextMessage());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " curr msg: " << link->message() << endl; std::cout << " curr msg: " ;
cout << " next msg: " << link->nextMessage() << endl; std::cout << link->message() << std::endl;
std::cout << " next msg: " ;
std::cout << link->nextMessage() << std::endl;
} }
} }
Params Params
WeightedBp::getVarToFactorMsg (const BpLink* _link) const WeightedBp::getVarToFactorMsg (const BpLink* _link)
{ {
const WeightedLink* link = static_cast<const WeightedLink*> (_link); const WeightedLink* link = static_cast<const WeightedLink*> (_link);
const VarNode* src = link->varNode(); const VarNode* src = link->varNode();
@ -232,19 +274,19 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
if (src->hasEvidence()) { if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence()); msg.resize (src->range(), LogAware::noEvidence());
double value = link->message()[src->getEvidence()]; double value = link->message()[src->getEvidence()];
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
msg[src->getEvidence()] = value; msg[src->getEvidence()] = value;
cout << msg << "^" << link->weight() << "-1" ; std::cout << msg << "^" << link->weight() << "-1" ;
} }
msg[src->getEvidence()] = LogAware::pow (value, link->weight() - 1); msg[src->getEvidence()] = LogAware::pow (value, link->weight() - 1);
} else { } else {
msg = link->message(); msg = link->message();
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << msg << "^" << link->weight() << "-1" ; std::cout << msg << "^" << link->weight() << "-1" ;
} }
LogAware::pow (msg, link->weight() - 1); LogAware::pow (msg, link->weight() - 1);
} }
const BpLinks& links = ninf(src)->getLinks(); const BpLinks& links = getLinks (src);
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links[i]); WeightedLink* l = static_cast<WeightedLink*> (links[i]);
@ -257,14 +299,14 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
WeightedLink* l = static_cast<WeightedLink*> (links[i]); WeightedLink* l = static_cast<WeightedLink*> (links[i]);
if ( ! (l->facNode() == dst && l->index() == link->index())) { if ( ! (l->facNode() == dst && l->index() == link->index())) {
msg *= l->powMessage(); msg *= l->powMessage();
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " x " << l->nextMessage() << "^" << link->weight(); std::cout << " x " << l->nextMessage() << "^" << link->weight();
} }
} }
} }
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
cout << " = " << msg; std::cout << " = " << msg;
} }
return msg; return msg;
} }
@ -272,8 +314,10 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
void void
WeightedBp::printLinkInformation (void) const WeightedBp::printLinkInformation() const
{ {
using std::cout;
using std::endl;
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links_[i]); WeightedLink* l = static_cast<WeightedLink*> (links_[i]);
cout << l->toString() << ":" << endl; cout << l->toString() << ":" << endl;
@ -286,3 +330,5 @@ WeightedBp::printLinkInformation (void) const
} }
} }
} // namespace Horus

View File

@ -1,64 +1,69 @@
#ifndef HORUS_WEIGHTEDBP_H #ifndef YAP_PACKAGES_CLPBN_HORUS_WEIGHTEDBP_H_
#define HORUS_WEIGHTEDBP_H #define YAP_PACKAGES_CLPBN_HORUS_WEIGHTEDBP_H_
#include "BeliefProp.h" #include "BeliefProp.h"
class WeightedLink : public BpLink
{ namespace Horus {
class WeightedBp : public BeliefProp {
public: public:
WeightedLink (FacNode* fn, VarNode* vn, size_t idx, unsigned weight) WeightedBp (const FactorGraph& fg,
: BpLink (fn, vn), index_(idx), weight_(weight), const std::vector<std::vector<unsigned>>& weights);
pwdMsg_(vn->range(), LogAware::one()) { }
size_t index (void) const { return index_; } ~WeightedBp();
unsigned weight (void) const { return weight_; } Params getPosterioriOf (VarId);
const Params& powMessage (void) const { return pwdMsg_; } private:
class WeightedLink : public BeliefProp::BpLink {
public:
WeightedLink (FacNode* fn, VarNode* vn, size_t idx,
unsigned weight);
void updateMessage (void) size_t index() const { return index_; }
unsigned weight() const { return weight_; }
const Params& powMessage() const { return pwdMsg_; }
void updateMessage();
private:
size_t index_;
unsigned weight_;
Params pwdMsg_;
DISALLOW_COPY_AND_ASSIGN (WeightedLink);
};
void createLinks();
void maxResidualSchedule();
void calcFactorToVarMsg (BpLink*);
Params getVarToFactorMsg (const BpLink*);
void printLinkInformation() const;
std::vector<std::vector<unsigned>> weights_;
DISALLOW_COPY_AND_ASSIGN (WeightedBp);
};
inline void
WeightedBp::WeightedLink::updateMessage()
{ {
pwdMsg_ = *nextMsg_; pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_); swap (currMsg_, nextMsg_);
LogAware::pow (pwdMsg_, weight_); LogAware::pow (pwdMsg_, weight_);
} }
private: } // namespace Horus
DISALLOW_COPY_AND_ASSIGN (WeightedLink);
size_t index_; #endif // YAP_PACKAGES_CLPBN_HORUS_WEIGHTEDBP_H_
unsigned weight_;
Params pwdMsg_;
};
class WeightedBp : public BeliefProp
{
public:
WeightedBp (const FactorGraph& fg,
const vector<vector<unsigned>>& weights)
: BeliefProp (fg), weights_(weights) { }
~WeightedBp (void);
Params getPosterioriOf (VarId);
private:
void createLinks (void);
void maxResidualSchedule (void);
void calcFactorToVarMsg (BpLink*);
Params getVarToFactorMsg (const BpLink*) const;
void printLinkInformation (void) const;
vector<vector<unsigned>> weights_;
DISALLOW_COPY_AND_ASSIGN (WeightedBp);
};
#endif // HORUS_WEIGHTEDBP_H

View File

@ -0,0 +1,51 @@
#include "../BeliefProp.h"
#include "../FactorGraph.h"
#include "Common.h"
namespace Horus {
namespace UnitTests {
class BeliefPropTest : public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE (BeliefPropTest);
CPPUNIT_TEST (testMarginals);
CPPUNIT_TEST (testJoint);
CPPUNIT_TEST_SUITE_END();
public:
void testMarginals();
void testJoint();
};
void
BeliefPropTest::testMarginals()
{
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
BeliefProp solver (fg);
for (unsigned i = 0; i < marginalProbs.size(); i++) {
Params params = solver.solveQuery ({i});
CPPUNIT_ASSERT (similiar (params, marginalProbs[i]));
}
}
void
BeliefPropTest::testJoint()
{
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
BeliefProp solver (fg);
Params params = solver.solveQuery ({0, 4, 6});
CPPUNIT_ASSERT (similiar (params, jointProbs));
}
CPPUNIT_TEST_SUITE_REGISTRATION (BeliefPropTest);
} // namespace UnitTests
} // namespace Horus

View File

@ -0,0 +1,95 @@
#include <cstdlib>
#include <cmath>
#include <cassert>
#include <numeric>
#include <functional>
#include <iostream>
#include "Common.h"
#include "Util.h"
namespace Horus {
namespace UnitTests {
const std::string modelFile = "../examples/complex.fg" ;
const std::vector<Params> marginalProbs = {
/* marginals x0 = */ {0.5825521, 0.4174479},
/* marginals x1 = */ {0.648528, 0.351472},
{0.03100852, 0.9689915},
{0.04565728, 0.503854, 0.4504888},
{0.7713128, 0.03128429, 0.1974029},
{0.8771822, 0.1228178},
{0.05617282, 0.01509834, 0.9287288},
{0.08224711, 0.5698616, 0.047964, 0.2999273},
{0.1368483, 0.8631517},
/* marginals x9 = */ {0.7529569, 0.2470431}
};
const Params jointProbs = {
/* P(x0=0, x4=0, x6=0) = */ 0.025463399,
/* P(x0=0, x4=0, x6=1) = */ 0.0067233122,
/* P(x0=0, x4=0, x6=2) = */ 0.42069289,
/* P(x0=0, x4=1, x6=0) = */ 0.0010111473,
/* P(x0=0, x4=1, x6=1) = */ 0.00027096982,
/* P(x0=0, x4=1, x6=2) = */ 0.016715682,
/* P(x0=0, x4=2, x6=0) = */ 0.0062433667,
/* P(x0=0, x4=2, x6=1) = */ 0.001828545,
/* P(x0=0, x4=2, x6=2) = */ 0.10360283,
/* P(x0=1, x4=0, x6=0) = */ 0.017910021,
/* P(x0=1, x4=0, x6=1) = */ 0.0046988842,
/* P(x0=1, x4=0, x6=2) = */ 0.29582433,
/* P(x0=1, x4=1, x6=0) = */ 0.00074648444,
/* P(x0=1, x4=1, x6=1) = */ 0.00019991076,
/* P(x0=1, x4=1, x6=2) = */ 0.012340097,
/* P(x0=1, x4=2, x6=0) = */ 0.0047984062,
/* P(x0=1, x4=2, x6=1) = */ 0.0013767189,
/* P(x0=1, x4=2, x6=2) = */ 0.079553004
};
Params
generateRandomParams (Ranges ranges)
{
Params params;
unsigned size = std::accumulate (ranges.begin(), ranges.end(),
1, std::multiplies<unsigned>());
for (unsigned i = 0; i < size; i++) {
params.push_back (rand() / double (RAND_MAX));
}
Horus::LogAware::normalize (params);
return params;
}
bool
similiar (double v1, double v2)
{
const double epsilon = 0.0000001;
return std::fabs (v1 - v2) < epsilon;
}
bool
similiar (const Params& p1, const Params& p2)
{
assert (p1.size() == p2.size());
for (size_t i = 0; i < p1.size(); i++) {
if (! similiar(p1[i], p2[i])) {
return false;
}
}
return true;
}
} // namespace UnitTests
} // namespace Horus;

View File

@ -0,0 +1,28 @@
#include <vector>
#include <string>
#include "../Horus.h"
#include <cppunit/extensions/HelperMacros.h>
namespace Horus {
namespace UnitTests {
extern const std::string modelFile;
extern const std::vector<Params> marginalProbs;
extern const Params jointProbs;
Params generateRandomParams (Ranges ranges);
bool similiar (double v1, double v2);
bool similiar (const Params& p1, const Params& p2);
} // namespace UnitTests
} // namespace Horus;

View File

@ -0,0 +1,51 @@
#include "../CountingBp.h"
#include "../FactorGraph.h"
#include "Common.h"
namespace Horus {
namespace UnitTests {
class CountingBpTest : public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE (CountingBpTest);
CPPUNIT_TEST (testMarginals);
CPPUNIT_TEST (testJoint);
CPPUNIT_TEST_SUITE_END();
public:
void testMarginals();
void testJoint();
};
void
CountingBpTest::testMarginals()
{
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
CountingBp solver (fg);
for (unsigned i = 0; i < marginalProbs.size(); i++) {
Params params = solver.solveQuery ({i});
CPPUNIT_ASSERT (similiar (params, marginalProbs[i]));
}
}
void
CountingBpTest::testJoint()
{
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
CountingBp solver (fg);
Params params = solver.solveQuery ({0, 4, 6});
CPPUNIT_ASSERT (similiar (params, jointProbs));
}
CPPUNIT_TEST_SUITE_REGISTRATION (CountingBpTest);
} // namespace UnitTests
} // namespace Horus

View File

@ -0,0 +1,107 @@
#include <iostream>
#include "../Factor.h"
#include "Common.h"
namespace Horus {
namespace UnitTests {
class FactorTest : public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE (FactorTest);
CPPUNIT_TEST (testSummingOut);
CPPUNIT_TEST (testProduct);
CPPUNIT_TEST_SUITE_END();
public:
void testSummingOut();
void testProduct();
};
void
FactorTest::testSummingOut()
{
VarIds vids = {0, 1, 2, 3};
Ranges ranges = {3, 2, 4, 3};
Params params = {
0.022757933283133, 0.0106825417145475, 0.0212115929862968,
0.0216271252738214, 0.0246935408909929, 0.00535101952882101,
0.00908008645423061, 0.0208088234425334, 0.00752400708452212,
0.0150052316136527, 0.0129311224551535, 0.0170340535302049,
0.00988081654256193, 0.0139063490493519, 0.025792784294836,
0.0248167234610076, 0.017219348482278, 0.0194292243637016,
0.00383554941557795, 0.0164407987747966, 0.00044152909395022,
0.00657900705816833, 0.00371715392294919, 0.0217825142487465,
0.00424392333677727, 0.0108602703755316, 0.00351559808401304,
0.00294727405145356, 0.0270575932871257, 0.005911864680038,
0.0138936584911577, 0.0227288019859002, 0.0165944064071987,
0.0080185268930961, 0.0172692026753632, 0.0142012227138332,
0.0133695464219171, 0.0263492891422071, 0.00792332157200822,
0.0208935535064392, 0.0142677961715013, 0.0208544440271617,
0.0108408824522857, 0.0241486127140633, 0.00767406849215521,
0.00954694217537661, 0.0218786116033257, 0.0248934169744332,
0.00188944195471982, 0.0257141610189036, 0.0142474911774847,
0.00233097104867004, 0.00520644350532678, 0.0179646451004339,
0.0241134853100298, 0.00945036684210405, 0.00173819089160705,
0.000542358809684406, 0.0123976408935576, 0.00170905959437435,
0.00645422348972241, 0.0262912993847153, 0.0244378615928878,
0.0230486298969212, 0.00722310170606624, 0.0146203396838926,
0.0101631280263959, 0.0205926481279833, 0.0138829042417413,
0.0180864495984042, 0.0143994770626774, 0.00106397584149748
};
Factor f (vids, ranges, params);
double sum = std::accumulate (f.params().begin(), f.params().end(), 0.0);
CPPUNIT_ASSERT (similiar (sum, 1.0));
f.sumOut (0);
f.sumOut (3);
f.sumOut (2);
sum = std::accumulate (f.params().begin(), f.params().end(), 0.0);
CPPUNIT_ASSERT (similiar (sum, 1.0));
}
void
FactorTest::testProduct()
{
VarIds vids1 = {0, 1, 2};
Ranges ranges1 = {3, 2, 2};
Params params1 = {
0.01, 0.02, 0.03, 0.04, 0.05, 0.06,
0.07, 0.08, 0.09, 0.10, 0.11, 0.12
};
VarIds vids2 = {1, 3, 0};
Ranges ranges2 = {2, 3, 3};
Params params2 = {
0.15, 0.30, 0.45, 0.60, 0.75, 0.90, 1.20, 1.50, 1.80,
0.99, 0.88, 0.77, 0.66, 0.55, 0.44, 0.33, 0.22, 0.11
};
Factor f1 (vids1, ranges1, params1);
Factor f2 (vids2, ranges2, params2);
f1.multiply (f2);
Params result = {
0.0015, 0.006, 0.012, 0.003, 0.012, 0.024, 0.0297, 0.0198, 0.0099,
0.0396, 0.0264, 0.0132, 0.015, 0.0375, 0.075, 0.018, 0.045, 0.09,
0.0616, 0.0385, 0.0154, 0.0704, 0.044, 0.0176, 0.0405, 0.081, 0.162,
0.045, 0.09, 0.18, 0.0847, 0.0484, 0.0121, 0.0924, 0.0528, 0.0132
};
CPPUNIT_ASSERT (similiar (f1.params(), result));
}
CPPUNIT_TEST_SUITE_REGISTRATION (FactorTest);
} // namespace UnitTests
} // namespace Horus

View File

@ -0,0 +1,15 @@
#include "../Factor.h"
#include <cppunit/ui/text/TestRunner.h>
#include <cppunit/extensions/TestFactoryRegistry.h>
int main()
{
CppUnit::TextUi::TestRunner runner;
CppUnit::TestFactoryRegistry& registry =
CppUnit::TestFactoryRegistry::getRegistry();
runner.addTest (registry.makeTest());
return runner.run() ? 1 : 0;
}

View File

@ -0,0 +1,51 @@
#include "../VarElim.h"
#include "../FactorGraph.h"
#include "Common.h"
namespace Horus {
namespace UnitTests {
class VarElimTest : public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE (VarElimTest);
CPPUNIT_TEST (testMarginals);
CPPUNIT_TEST (testJoint);
CPPUNIT_TEST_SUITE_END();
public:
void testMarginals();
void testJoint();
};
void
VarElimTest::testMarginals()
{
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
VarElim solver (fg);
for (unsigned i = 0; i < marginalProbs.size(); i++) {
Params params = solver.solveQuery ({i});
CPPUNIT_ASSERT (similiar (params, marginalProbs[i]));
}
}
void
VarElimTest::testJoint()
{
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
VarElim solver (fg);
Params params = solver.solveQuery ({0, 4, 6});
CPPUNIT_ASSERT (similiar (params, jointProbs));
}
CPPUNIT_TEST_SUITE_REGISTRATION (VarElimTest);
} // namespace UnitTests
} // namespace Horus

View File

@ -0,0 +1,471 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset=utf-8>
<meta name="description" content="Prolog Factor Language Tutorial">
<meta name="keywords" content="Graphical Models, Lifted Inference, First Order, Bayesian Networks, Markov Networks, Variable Elimination">
<meta name="author" content="Tiago Gomes">
<link rel="stylesheet" type="text/css" href="pfl.css">
<title>The Prolog Factor Language</title>
</head>
<body id="top">
<div class="container">
<div class="header">
<div id="leftcolumn"><h1>Prolog Factor Language</h1></div>
<div id="rightcolumn">
<div>
<div class="name">Vítor Costa</div>
<div class="email">vsc at gmail.com </div>
</div>
<div style="padding-top:10px">
<div class="name">Tiago Gomes</div>
<div class="email">tiago.avv at gmail.com</div>
</div>
</div>
<div style="clear: both"></div>
<nav id="menu">
<ul>
<li><a href="#intro">Introduction</a></li>
<li><a href="#installation">Installation</a></li>
<li><a href="#language">Language</a></li>
<li><a href="#querying">Querying</a></li>
<li><a href="#options">Options</a></li>
<li><a href="#learning">Learning</a></li>
<li><a href="#external_interface">External Interface</a></li>
<li><a href="#papers">Papers</a></li>
</ul>
</nav>
</div> <!-- end of header -->
<div class="mainbody">
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<h2 id="introduction">Introduction</h2>
The Prolog Factor Language (PFL) is a language that extends Prolog for providing a syntax to describe first-order probabilistic graphical models. These models can be either directed (bayesian networks) or undirected (markov networks). This language replaces the old one known as CLP(BN).
<p>The package also includes implementations for a set of well-known inference algorithms for solving probabilistic queries on these models. Both ground and lifted inference methods are support.</p>
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<h2 id="installation">Installation</h2>
PFL is included with the <a href="http://www.dcc.fc.up.pt/~vsc/Yap/">YAP</a> Prolog system. However, there isn't yet a stable release of YAP that includes PFL and you will need to install a development version. To do so, you must have installed the <a href="http://git-scm.com/">Git</a> version control system. The commands to perform a default installation of YAP in your home directory in a Unix-based environment are shown next.
<p>
<div class=console>
<p>$ cd $HOME</p>
<p>$ git clone git://yap.git.sourceforge.net/gitroot/yap/yap-6.3</p>
<p>$ cd yap-6.3/</p>
<p>$ ./configure --enable-clpbn-bp --prefix=$HOME</p>
<p>$ make depend &amp; make install</p>
</div>
<p>In case you want to install YAP somewhere else or with different settings, please consult the YAP documentation. From now on, we will assume that the directory <span class=texttt>$HOME &#x25B7; bin</span> (where the binary is) is in your <span class=texttt>$PATH</span> environment variable.</p>
<p>Once in a while, we will refer to the PFL examples directory. In a default installation, this directory will be located at <span class=texttt> $HOME &#x25B7; share &#x25B7; doc &#x25B7; Yap &#x25B7; packages &#x25B7; examples &#x25B7; CLPBN</span>.</p>
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<h2 id="language">Language</h2>
A first-order probabilistic graphical model is described using parametric factors, commonly known as parfactors. The PFL syntax for a parfactor is
<div style="text-align:center">
<br>
<em>Type</em> &nbsp; <em>F</em> &nbsp; ; &nbsp; <em>Phi</em> &nbsp; ; &nbsp; <em>C</em>.
<br>
<br>
</div>
Where,
<ul>
<li>
<em>Type</em> refers the type of network over which the parfactor is defined. It can be <span class=texttt>bayes</span> for directed networks, or <span class=texttt>markov</span> for undirected ones.
<p>
</li>
<li>
<em>F</em> is a comma-separated sequence of Prolog terms that will define sets of random variables under the constraint <em>C</em>. If <em>Type</em> is <span class=texttt>bayes</span>, the first term defines the node while the remaining terms define its parents.
<p>
</li>
<li>
<em>Phi</em> is either a Prolog list of potential values or a Prolog goal that unifies with one. Notice that if <em>Type</em> is <span class=texttt>bayes</span>, this will correspond to the conditional probability table. Domain combinations are implicitly assumed in ascending order, with the first term being the 'most significant' (e.g.
<span class="texttt">x<sub>0</sub>y<sub>0</sub></span>,
<span class="texttt">x<sub>0</sub>y<sub>1</sub></span>,
<span class="texttt">x<sub>0</sub>y<sub>2</sub></span>,
<span class="texttt">x<sub>1</sub>y<sub>0</sub></span>,
<span class="texttt">x<sub>1</sub>y<sub>1</sub></span>,
<span class="texttt">x<sub>1</sub>y<sub>2</sub></span>).
<p>
</li>
<li>
<em>C</em> is a (possibly empty) list of Prolog goals that will instantiate the logical variables that appear in <em>F</em>, that is, the successful substitutions for the goals in <em>C</em> will be the valid values for the logical variables. This allows the constraint to be defined as any relation (set of tuples) over the logical variables.
</li>
</ul>
<IMG style="display:block; margin:auto" src="sprinkler.png" alt="Sprinkler Network">
<p>Towards a better understanding of the language, next we show the PFL representation for the sprinkler network found in the above figure.</p>
<div class="pflcode">
<pre >
:- use_module(library(pfl)).
bayes cloudy ; cloudy_table ; [].
bayes sprinkler, cloudy ; sprinkler_table ; [].
bayes rain, cloudy ; rain_table ; [].
bayes wet_grass, sprinkler, rain ; wet_grass_table ; [].
cloudy_table(
[ 0.5,
0.5 ]).
sprinkler_table(
[ 0.1, 0.5,
0.9, 0.5 ]).
rain_table(
[ 0.8, 0.2,
0.2, 0.8 ]).
wet_grass_table(
[ 0.99, 0.9, 0.9, 0.0,
0.01, 0.1, 0.1, 1.0 ]).
</pre>
</div>
<p>In the example, we started by loading the PFL library, then we have defined one factor for each node, and finally we have specified the probabilities for each conditional probability table.</p>
<p>Notice that this network is fully grounded, as all constraints are empty. Next we present the PFL representation for a well-known markov logic network - the social network model. For convenience, the two main weighted formulas of this model are shown below.</p>
<div class="pflcode">
<pre>
1.5 : Smokes(x) => Cancer(x)
1.1 : Smokes(x) ^ Friends(x,y) => Smokes(y)
</pre>
</div>
<p>Next, we show the PFL representation for this model.</p>
<div class="pflcode">
<pre>
:- use_module(library(pfl)).
person(anna).
person(bob).
markov smokes(X), cancer(X) ;
[4.482, 4.482, 1.0, 4.482] ;
[person(X)].
markov friends(X,Y), smokes(X), smokes(Y) ;
[3.004, 3.004, 3.004, 3.004, 3.004, 1.0, 1.0, 3.004] ;
[person(X), person(Y)].
</pre>
</div>
<p>Notice that we have defined the world to be consisted of only two persons, <span class=texttt>anna</span> and <span class=texttt>bob</span>. We can easily add as many persons as we want by inserting in the program a fact like <span class=texttt>person @ 10.</span>&nbsp;. This would automatically create ten persons named <span class=texttt>p1</span>, <span class=texttt>p2</span>, ..., <span class=texttt>p10</span>.</p>
<p>Unlike other fist-order probabilistic languages, in PFL the logical variables that appear in the terms are not directly typed, and they will be only constrained by the goals that appears in the constraint of the parfactor. This allows the logical variables to be constrained to any relation (set of tuples), and not only pairwise (in)equalities. For instance, the next example defines a network with three ground factors, each defined respectively over the random variables <span class=texttt>p(a,b)</span>, <span class=texttt>p(b,d)</span> and <span class=texttt>p(d,e)</span>.</p>
<div class="pflcode">
<pre>
constraint(a,b).
constraint(b,d).
constraint(d,e).
markov p(A,B); some_table; [constraint(A,B)].
</pre>
</div>
<p>We can easily add static evidence to PFL programs by inserting a fact with the same functor and arguments as the random variable, plus one extra argument with the observed state or value. For instance, suppose that we know that <span class=texttt>anna</span> and <span class=texttt>bob</span> are friends. We can add this knowledge to the program with the following fact: <span class=texttt>friends(anna,bob,t).</span>&nbsp;.</p>
<p>One last note for the domain of the random variables. By default, all terms instantiate boolean (<span class=texttt>t</span>/<span class=texttt>f</span>) random variables. It is possible to choose a different domain for a term by appending a list of its possible values or states. Next we present a self-explanatory example of how this can be done.</p>
<div class="pflcode">
<pre>
bayes professor_ability::[high, medium, low] ; [0.5, 0.4, 0.1].
</pre>
</div>
<p>More probabilistic models defined using PFL can be found in the examples directory.</p>
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<h2 id=querying>Querying</h2>
In this section we demonstrate how to use PFL to solve probabilistic queries. We will use the sprinkler network as example.
<p>Assuming that the current directory is the one where the examples are located, first we load the model with the following command.</p>
<p class=console>$ yap -l sprinkler.pfl</p>
<p>Let's suppose that we want to estimate the marginal probability for the <em>WetGrass</em> random variable. To do so, we call the following goal.</p>
<p class=console>?- wet_grass(X).</p>
<p>The output of this goal will show the marginal probability for each <em>WetGrass</em> possible state or value, that is, <span class=texttt>t</span> and <span class=texttt>f</span>. Notice that in PFL a random variable is identified by a term with the same functor and arguments plus one extra argument.</p>
<p>Now let's suppose that we want to estimate the probability for the same random variable, but this time we have evidence that it had rained in the day before. We can estimate this probability without resorting to static evidence with:</p>
<p class=console>?- wet_grass(X), rain(t).</p>
<p>PFL also supports calculating joint probability distributions. For instance, we can obtain the joint probability for <em>Sprinkler</em> and <em>Rain</em> with:</p>
<p class=console>?- sprinkler(X), rain(Y).</p>
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<h2 id="options">Options</h2>
PFL supports both ground and lifted inference methods. The inference algorithm can be chosen by calling <span class=texttt>set_solver/1</span>. The following are supported:
<ul>
<li><span class=texttt>ve</span>, variable elimination (written in Prolog)</li>
<li><span class=texttt>hve</span>, variable elimination (written in C++)</li>
<li><span class=texttt>jt</span>, junction tree</li>
<li><span class=texttt>bdd</span>, binary decision diagrams</li>
<li><span class=texttt>bp</span>, belief propagation</li>
<li><span class=texttt>cbp</span>, counting belief propagation</li>
<li><span class=texttt>gibbs</span>, gibbs sampling</li>
<li><span class=texttt>lve</span>, generalized counting first-order variable elimination (GC-FOVE)</li>
<li><span class=texttt>lkc</span>, lifted first-order knowledge compilation</li>
<li><span class=texttt>lbp</span>, lifted first-order belief propagation</li>
</ul>
<p>For instance, if we want to use belief propagation to solve some probabilistic query, we need to call first:</p>
<p class=console>?- set_solver(bp).</p>
<p>It is possible to tweak some parameters of PFL through <span class=texttt>set_pfl_flag/2</span> predicate. The first argument is a option name that identifies the parameter that we want to tweak. The second argument is some possible value for this option. Next we explain the available options in detail.</p>
<h3><span class=texttt>verbosity</span></h3>
This option controls the level of debugging information that will be shown.
<ul>
<li>Values: a positive integer (default is 0 - no debugging). The higher the number, the more information that will be shown.</li>
<li>Affects: <span class=texttt>hve</span>, <span class=texttt>bp</span>, <span class=texttt>cbp</span>, <span class=texttt>lve</span>, <span class=texttt>lkc</span> and <span class=texttt>lbp</span>.</li>
</ul>
<p>
For instance, we can view some basic debugging information by calling the following goal.
<p class="console">?- set_pfl_flag(verbosity, 1).</p>
<h3><span class=texttt>use_logarithms</span></h3>
This option controls whether the calculations performed during inference should be done in a logarithm domain or not.
<ul>
<li>Values: <span class=texttt>true</span> (default) or <span class=texttt>false</span>.</li>
<li>Affects: <span class=texttt>hve</span>, <span class=texttt>bp</span>, <span class=texttt>cbp</span>, <span class=texttt>lve</span>, <span class=texttt>lkc</span> and <span class=texttt>lbp</span>.</li>
</ul>
<h3><span class=texttt>hve_elim_heuristic</span></h3>
This option allows to choose which elimination heuristic will be used by the <span class=texttt>hve</span>.
<ul>
<li>Values: <span class=texttt>sequential</span>, <span class=texttt>min_neighbors</span>, <span class=texttt>min_weight</span>, <span class=texttt>min_fill</span> and
<br><span class=texttt>weighted_min_fill</span> (default).</li>
<li>Affects: <span class=texttt>hve</span>.</li>
</ul>
<p>An explanation for each of these heuristics can be found in Daphne Koller's book <em>Probabilistic Graphical Models</em>.</p>
<h3><span class=texttt>bp_max_iter</span></h3>
This option establishes a maximum number of iterations. One iteration consists in sending all possible messages.
<ul>
<li>Values: a positive integer (default is <span class=texttt>1000</span>).</li>
<li>Affects: <span class=texttt>bp</span>, <span class=texttt>cbp</span> and <span class=texttt>lbp</span>.</li>
</ul>
<h3><span class=texttt>bp_accuracy</span></h3>
This option allows to control when the message passing should cease. Be the residual of one message the difference (according some metric) between the one sent in the current iteration and the one sent in the previous. If the highest residual is lesser than the given value, the message passing is stopped and the probabilities are calculated using the last messages that were sent.
<ul>
<li>Values: a float-point number (default is <span class=texttt>0.0001</span>).</li>
<li>Affects: <span class=texttt>bp</span>, <span class=texttt>cbp</span> and <span class=texttt>lbp</span>.</li>
</ul>
<h3><span class=texttt>bp_msg_schedule</span></h3>
This option allows to control the message sending order.
<ul>
<li>Values:
<ul>
<li><span class=texttt>seq_fixed</span> (default), at each iteration, all messages are sent with the same order.<p></li>
<li><span class=texttt>seq_random</span>, at each iteration, all messages are sent with a random order.<p></li>
<li><span class=texttt>parallel</span>, at each iteration, all messages are calculated using only the values of the previous iteration.<p></li>
<li><span class=texttt>max_residual</span>, the next message to be sent is the one with maximum residual (as explained in the paper <em>Residual Belief Propagation: Informed Scheduling for Asynchronous Message Passing</em>).</li>
</ul>
</li>
<li>Affects: <span class=texttt>bp</span>, <span class=texttt>cbp</span> and <span class=texttt>lbp</span>.
</li>
</ul>
<h3><span class=texttt>export_libdai</span></h3>
This option allows exporting the current model to the <a href="http://cs.ru.nl/~jorism/libDAI/doc/fileformats.html">libDAI</a> file format.
<ul>
<li>Values: <span class=texttt>true</span> or <span class=texttt>false</span> (default).</li>
<li>Affects: <span class=texttt>hve</span>, <span class=texttt>bp</span>, and <span class=texttt>cbp</span>.</li>
</ul>
<h3><span class=texttt>export_uai</span></h3>
This option allows exporting the current model to the <a href="http://graphmod.ics.uci.edu/uai08/FileFormat">UAI</a> file format.
<ul>
<li>Values: <span class=texttt>true</span> or <span class=texttt>false</span> (default).</li>
<li>Affects: <span class=texttt>hve</span>, <span class=texttt>bp</span>, and <span class=texttt>cbp</span>.</li>
</ul>
<h3><span class=texttt>export_graphviz</span></h3>
This option allows exporting the factor graph's structure into a format that can be parsed by <a href="http://www.graphviz.org/">Graphviz</a>.
<ul>
<li>Values: <span class=texttt>true</span> or <span class=texttt>false</span> (default).</li>
<li>Affects: <span class=texttt>hve</span>, <span class=texttt>bp</span>, and <span class=texttt>cbp</span>.</li>
</ul>
<h3><span class=texttt>print_fg</span></h3>
This option allows to print a textual representation of the factor graph.
<ul>
<li>Values: <span class=texttt>true</span> or <span class=texttt>false</span> (default).</li>
<li>Affects: <span class=texttt>hve</span>, <span class=texttt>bp</span>, and <span class=texttt>cbp</span>.</li>
</ul>
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<h2 id="learning">Learning</h2>
PFL is capable to learn the parameters for bayesian networks, through an implementation of the expectation-maximization algorithm.
<p>Next we show an example of parameter learning for the sprinkler network.</p>
<div class="pflcode">
<pre>
:- [sprinkler.pfl].
:- use_module(library(clpbn/learning/em)).
data(t, t, t, t).
data(_, t, _, t).
data(t, t, f, f).
data(t, t, f, t).
data(t, _, _, t).
data(t, f, t, t).
data(t, t, f, t).
data(t, _, f, f).
data(t, t, f, f).
data(f, f, t, t).
main :-
findall(X, scan_data(X), L),
em(L, 0.01, 10, CPTs, LogLik),
writeln(LogLik:CPTs).
scan_data([cloudy(C), sprinkler(S), rain(R), wet_grass(W)]) :-
data(C, S, R, W).
</pre>
</div>
<p>Parameter learning is done by calling the <span class=texttt>em/5</span> predicate. Its arguments are the following.</p>
<div style="text-align:center">
<br>
<span class=texttt>em(+Data, +MaxError, +MaxIters, -CPTs, -LogLik)</span>
<br>
<br>
</div>
Where,
<ul>
<li><span class=texttt>Data</span> is a list of samples for the distribution that we want to estimate. Each sample is a list of either observed random variables or unobserved random variables (denoted when its state or value is not instantiated).</li>
<li><span class=texttt>MaxError</span> is the maximum error allowed before stopping the EM loop.</li>
<li><span class=texttt>MaxIters</span> is the maximum number of iterations for the EM loop.</li>
<li><span class=texttt>CPTs</span> is a list with the estimated conditional probability tables.</li>
<li><span class=texttt>LogLik</span> is the log-likelihood.</li>
</ul>
<p>It is possible to choose the solver that will be used for the inference part during parameter learning with the <span class=texttt>set_em_solver/1</span> predicate (defaults to <span class=texttt>hve</span>). At the moment, only the following solvers support parameter learning: <span class=texttt>ve</span>, <span class=texttt>hve</span>, <span class=texttt>bdd</span>, <span class=texttt>bp</span> and <span class=texttt>cbp</span>.</p>
<p>Inside the <span class=texttt>learning</span> directory from the examples directory, one can find more examples of parameter learning.</p>
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<h2 id="external_interface">External Interface</h2>
This package also includes an external command for perform inference over probabilistic graphical models described in formats other than PFL. Currently two are support, the http://cs.ru.nl/&nbsp;jorism/libDAI/doc/fileformats.htmllibDAI file format, and the http://graphmod.ics.uci.edu/uai08/FileFormatUAI08 file format.
<p>This command's name is <span class=texttt>hcli</span> and its usage is as follows.</p>
<p class=console>$ ./hcli [solver=hve|bp|cbp] [&lt;OPTION&gt;=&lt;VALUE&gt;]... &lt;FILE&gt;[&lt;VAR&gt;|&lt;VAR&gt;=&lt;EVIDENCE&gt;]... </p>
<p>Let's assume that the current directory is the one where the examples are located. We can perform inference in any supported model by passing the file name where the model is defined as argument. Next, we show how to load a model with <span class=texttt>hcli</span>.</p>
<p class=console>$ ./hcli burglary-alarm.uai</p>
<p>With the above command, the program will load the model and print the marginal probabilities for all defined random variables. We can view only the marginal probability for some variable with a identifier <em>X</em>, if we pass <em>X</em> as an extra argument following the file name. For instance, the following command will output only the marginal probability for the variable with identifier <em>0</em>.</p>
<p class=console>$ ./hcli burglary-alarm.uai 0</p>
<p>If we give more than one variable identifier as argument, the program will output the joint probability for all the passed variables.</p>
<p>Evidence can be given as a pair containing a variable identifier and its observed state (index), separated by a '=`. For instance, we can introduce knowledge that some variable with identifier <em>0</em> has evidence on its second state as follows.</p>
<p class=console>$ ./hcli burglary-alarm.uai 0=1</p>
<p>By default, all probability tasks are resolved using the <span class=texttt>hve</span> solver. It is possible to choose another solver using the <span class=texttt>solver</span> option as follows.</p>
<p class=console>$ ./hcli solver=bp burglary-alarm.uai</p>
<p>Notice that only the <span class=texttt>hve</span>, <span class=texttt>bp</span> and <span class=texttt>cbp</span> solvers can be used with <span class=texttt>hcli</span>.</p>
<p>The options that are available with the <span class=texttt>set_pfl_flag/2</span> predicate can be used in <span class=texttt>hcli</span> too. The syntax is a pair <span class=texttt>&lt;Option&gt;=&lt;Value&gt;</span> before the model's file name.</p>
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<!--+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-->
<h2 id="papers">Papers</h2>
<ul>
<li><em>Evaluating Inference Algorithms for the Prolog Factor Language.</em></li>
</ul>
<p><a href="#top" style="font-size:15px">Back to the top</a></p>
</div> <!-- end of mainbody -->
<div class="footer"></div>
</div> <!-- end of container -->
</body>
</html>

118
packages/CLPBN/html/pfl.css Normal file
View File

@ -0,0 +1,118 @@
html {
background: gray
}
body {
margin: 10px;
}
.container {
width: 900px;
margin: auto;
padding: 10px 60px;
background-color: #ffffff;
box-shadow: 0px 2px 7px #292929;
border-radius: 10px;
}
.header,
.mainbody,
.footer {
padding: 0px;
}
.header {
height: 120px;
border-bottom: 1px solid #EEE;
background-color: #ffffff;
border-top-left-radius: 5px;
border-top-right-radius: 5px;
}
#leftcolumn {
width: 550px;
float: left
}
#rightcolumn {
float: right;
text-align:right;
font-family: sans-serif;
font-size:15px;
line-height:15px;
padding-top:15px;
padding-right:0px;
margin-right:0px;
}
div.name {
color:#2798CA;
}
div.email {
color:gray;
}
.mainbody {
padding-top:10px;
}
.footer {
height: 5px;
border-bottom-left-radius: 5px;
border-bottom-right-radius: 5px;
}
#menu ul {
padding:0px;
margin: 20px 0px;
background-color:#EDEDED;
list-style:none;
text-align: center;
}
#menu ul li {
display: inline;
}
#menu ul li a {
padding: 2px 10px;
display: inline-block;
color: #333;
text-decoration: none;
border-bottom:3px solid #EDEDED;
}
#menu ul li a:hover {
text-decoration: underline;
}
h2 {
padding-top: 20px;
padding-bottom: 0px;
}
.pflcode {
border-radius: 15px;
max-width:550px;
padding:5px 20px;
margin:auto;
margin-top:35px;
margin-bottom:35px;
background-color: #FAFAD3;
font: normal 14px verdana;
border: solid 1px #ddd;
}
.console {
font-family: monospace;
color:white;
background-color:black;
padding: 5px;
border-radius: 8px;
}
.texttt {
font-family: monospace
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

View File

@ -38,7 +38,7 @@ CRACS \& INESC TEC, Faculty of Sciences, University of Porto
\thispagestyle{empty} \thispagestyle{empty}
\vspace{5cm} \vspace{5cm}
\begin{center} \begin{center}
\large Last revision: January 18, 2013 \large Last revision: April 12, 2013
\end{center} \end{center}
\newpage \newpage
@ -87,7 +87,7 @@ A first-order probabilistic graphical model is described using parametric factor
$$Type~~F~~;~~Phi~~;~~C.$$ $$Type~~F~~;~~Phi~~;~~C.$$
, where Where,
\begin{itemize} \begin{itemize}
\item $Type$ refers the type of network over which the parfactor is defined. It can be \texttt{bayes} for directed networks, or \texttt{markov} for undirected ones. \item $Type$ refers the type of network over which the parfactor is defined. It can be \texttt{bayes} for directed networks, or \texttt{markov} for undirected ones.
@ -274,7 +274,7 @@ PFL also supports calculating joint probability distributions. For instance, we
%------------------------------------------------------------------------------ %------------------------------------------------------------------------------
%------------------------------------------------------------------------------ %------------------------------------------------------------------------------
%------------------------------------------------------------------------------ %------------------------------------------------------------------------------
\section{Inference Options} \section{Options}
PFL supports both ground and lifted inference methods. The inference algorithm can be chosen by calling \texttt{set\_solver/1}. The following are supported: PFL supports both ground and lifted inference methods. The inference algorithm can be chosen by calling \texttt{set\_solver/1}. The following are supported:
\begin{itemize} \begin{itemize}
\item \texttt{ve}, variable elimination (written in Prolog) \item \texttt{ve}, variable elimination (written in Prolog)
@ -397,7 +397,7 @@ This option allows to print a textual representation of the factor graph.
%------------------------------------------------------------------------------ %------------------------------------------------------------------------------
%------------------------------------------------------------------------------ %------------------------------------------------------------------------------
%------------------------------------------------------------------------------ %------------------------------------------------------------------------------
\section{Parameter Learning} \section{Learning}
PFL is capable to learn the parameters for bayesian networks, through an implementation of the expectation-maximization algorithm. PFL is capable to learn the parameters for bayesian networks, through an implementation of the expectation-maximization algorithm.
Next we show an example of parameter learning for the sprinkler network. Next we show an example of parameter learning for the sprinkler network.
@ -429,7 +429,9 @@ scan_data([cloudy(C), sprinkler(S), rain(R), wet_grass(W)]) :-
Parameter learning is done by calling the \texttt{em/5} predicate. Its arguments are the following. Parameter learning is done by calling the \texttt{em/5} predicate. Its arguments are the following.
\begin{center}
\texttt{em(+Data, +MaxError, +MaxIters, -CPTs, -LogLik)} \texttt{em(+Data, +MaxError, +MaxIters, -CPTs, -LogLik)}
\end{center}
Where, Where,
\begin{itemize} \begin{itemize}
@ -489,9 +491,9 @@ The options that are available with the \texttt{set\_pfl\_flag/2} predicate can
%------------------------------------------------------------------------------ %------------------------------------------------------------------------------
%------------------------------------------------------------------------------ %------------------------------------------------------------------------------
%------------------------------------------------------------------------------ %------------------------------------------------------------------------------
\section{Further Information} \section{Papers}
Please check the paper \textit{Evaluating Inference Algorithms for the Prolog Factor Language} for further information. \begin{itemize}
\item \textit{Evaluating Inference Algorithms for the Prolog Factor Language}.
Any question? Don't hesitate to contact us! \end{itemize}
\end{document} \end{document}