This commit is contained in:
Vitor Santos Costa 2012-06-22 14:51:58 +01:00
commit 5fe052a3ef
50 changed files with 1278 additions and 617 deletions

110
packages/CLPBN/README.txt Normal file
View File

@ -0,0 +1,110 @@
Prolog Factor Language (PFL)
Prolog Factor Language (PFL) is a extension of the Prolog language that
allows a natural representation of this first-order probabilistic models
(either directed or undirected). PFL is also capable of solving probabilistic
queries on this models through the implementation of several inference
techniques: variable elimination, belief propagation, lifted variable
elimination and lifted belief propagation.
Language
-------------------------------------------------------------------------------
A graphical model in PFL is represented using parfactors. A PFL parfactor
has the following four components:
Type ; Formulas ; Phi ; Constraint .
- Type refers the type of the network over which the parfactor is defined.
It can be bayes for directed networks, or markov for undirected ones.
- Formulas is a sequence of Prolog terms that define sets of random variables
under the constraint.
- Phi is either a list of parameters or a call to a Prolog goal that will
unify its last argument with a list of parameters.
- Constraint is a list (possible empty) of Prolog goals that will impose
bindings on the logical variables that appear in the formulas.
The "examples" directory contains some popular graphical models described
using PFL.
Querying
-------------------------------------------------------------------------------
Now we show how to use PFL to solve probabilistic queries. We will
use the burlgary alarm network as an example. First, we load the model:
$ yap -l examples/burglary-alarm.yap
Now let's suppose that we want to estimate the probability of a earthquake
ocurred given that mary called. We can do it with the following query:
?- earthquake(X), mary_calls(t).
Suppose now that we want the joint distribution for john_calls and
mary_calls. We can obtain this with the following query:
?- john_calls(X), mary_calls(Y).
Inference Options
-------------------------------------------------------------------------------
PFL supports both ground and lifted inference. The inference algorithm
can be chosen using the set_solver/1 predicate. The following algorithms
are supported:
- fove: lifted variable elimination with arbitrary constraints (GC-FOVE)
- hve: (ground) variable elimination
- lbp: lifted first-order belief propagation
- cbp: counting belief propagation
- bp: (ground) belief propagation
For example, if we want to use ground variable elimination to solve some
query, we need to call first the following goal:
?- set_solver(hve).
It is possible to tweak several parameters of PFL through the
set_horus_flag/2 predicate. The first argument is a key that
identifies the parameter that we desire to tweak, while the second
is some possible value for this key.
The verbosity key controls the level of log information that will be
printed by the corresponding solver. Its possible values are positive
integers. The bigger the number, more log information will be printed.
For example, to view some basic log information we need to call the
following goal:
?- set_horus_flag(verbosity, 1).
The use_logarithms key controls whether the calculations performed
during inference should be done in the log domain or not. Its values
can be true or false. By default is false.
There are also keys specific to the inference algorithm. For example,
elim_heuristic key controls the elimination heuristic that will be
used by ground variable elimination. The following heuristics are
supported:
- sequential
- min_neighbors
- min_weight
- min_fill
- weighted_min_fill
An explanation of this heuristics can be found in Probabilistic Graphical
Models by Daphne Koller.
The schedule, accuracy and max_iter keys are specific for inference
algorithms based on message passing, namely lbp, cbp and bp.
The key schedule can be used to specify the order in which the messages
are sent in belief propagation. The possible values are:
- seq_fixed: at each iteration, all messages are sent in the same order
- seq_random: at each iteration, the messages are sent with a random order
- parallel: at each iteration, the messages are all calculated using the
values of the previous iteration.
- max_residual: the next message to be sent is the one with maximum residual,
(Residual Belief Propagation:Informed Scheduling for Asynchronous Message
Passing)
The max_iter key sets the maximum number of iterations. One iteration
consists in sending all possible messages. The accuracy key indicate
when we should stop sending messages. If the largest difference between
a message sent in the current iteration and one message sent in the previous
iteration is less that accuracy value given, we terminate belief propagation.

View File

@ -26,6 +26,8 @@ function run_solver
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
elif [ $SOLVER = cbp ]; then
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
elif [ $SOLVER = lbp ]; then
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
else
echo "unknow flag $2"
fi

View File

@ -23,7 +23,7 @@ function run_all_graphs
run_solver city60000 $2
run_solver city65000 $2
run_solver city70000 $2
return
run_solver city75000 $2
run_solver city80000 $2
run_solver city85000 $2

View File

@ -0,0 +1,36 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="lbp"
function run_all_graphs
{
write_header $1
run_solver city1000 $2
run_solver city5000 $2
run_solver city10000 $2
run_solver city15000 $2
run_solver city20000 $2
run_solver city25000 $2
run_solver city30000 $2
run_solver city35000 $2
run_solver city40000 $2
run_solver city45000 $2
run_solver city50000 $2
run_solver city55000 $2
run_solver city60000 $2
run_solver city65000 $2
run_solver city70000 $2
run_solver city75000 $2
run_solver city80000 $2
run_solver city85000 $2
run_solver city90000 $2
run_solver city95000 $2
run_solver city100000 $2
}
prepare_new_run
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,30 @@
#!/bin/bash
source cw.sh
source ../benchs.sh
SOLVER="lbp"
function run_all_graphs
{
write_header $1
run_solver p1000w$N_WORKSHOPS $2
run_solver p5000w$N_WORKSHOPS $2
run_solver p10000w$N_WORKSHOPS $2
run_solver p15000w$N_WORKSHOPS $2
run_solver p20000w$N_WORKSHOPS $2
run_solver p25000w$N_WORKSHOPS $2
run_solver p30000w$N_WORKSHOPS $2
run_solver p35000w$N_WORKSHOPS $2
run_solver p40000w$N_WORKSHOPS $2
run_solver p45000w$N_WORKSHOPS $2
run_solver p50000w$N_WORKSHOPS $2
run_solver p55000w$N_WORKSHOPS $2
run_solver p60000w$N_WORKSHOPS $2
run_solver p65000w$N_WORKSHOPS $2
run_solver p70000w$N_WORKSHOPS $2
}
prepare_new_run
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,35 @@
#!/bin/bash
cd workshop_attrs
source hve_tests.sh
source bp_tests.sh
source fove_tests.sh
source lbp_tests.sh
source cbp_tests.sh
cd ..
cd comp_workshops
source hve_tests.sh
source bp_tests.sh
source fove_tests.sh
source lbp_tests.sh
source cbp_tests.sh
cd ..
cd city
source hve_tests.sh
source bp_tests.sh
source fove_tests.sh
source lbp_tests.sh
source cbp_tests.sh
cd ..
cd smokers
source hve_tests.sh
source bp_tests.sh
source fove_tests.sh
source lbp_tests.sh
source cbp_tests.sh
cd ..

View File

@ -0,0 +1,30 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="lbp"
function run_all_graphs
{
write_header $1
run_solver pop100 $2
run_solver pop200 $2
run_solver pop300 $2
run_solver pop400 $2
run_solver pop500 $2
run_solver pop600 $2
run_solver pop700 $2
run_solver pop800 $2
run_solver pop900 $2
run_solver pop1000 $2
run_solver pop1100 $2
run_solver pop1200 $2
run_solver pop1300 $2
run_solver pop1400 $2
run_solver pop1500 $2
}
prepare_new_run
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed

View File

@ -2,5 +2,6 @@
NETWORK="'../../examples/social_domain2'"
SHORTNAME="sm"
QUERY="smokes(p1,t), smokes(p2,t), friends(p1,p2,X)"
#QUERY="smokes(p1,t), smokes(p2,t), friends(p1,p2,X)"
QUERY="friends(p1,p2,X)"

View File

@ -0,0 +1,34 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="bp"
function run_all_graphs
{
write_header $1
run_solver ev0p$POP $2
run_solver ev5p$POP $2
run_solver ev10p$POP $2
run_solver ev15p$POP $2
run_solver ev20p$POP $2
run_solver ev25p$POP $2
run_solver ev30p$POP $2
run_solver ev35p$POP $2
run_solver ev40p$POP $2
run_solver ev45p$POP $2
run_solver ev50p$POP $2
run_solver ev55p$POP $2
run_solver ev60p$POP $2
run_solver ev65p$POP $2
run_solver ev70p$POP $2
run_solver ev75p$POP $2
run_solver ev80p$POP $2
run_solver ev85p$POP $2
run_solver ev90p$POP $2
}
prepare_new_run
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,34 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="cbp"
function run_all_graphs
{
write_header $1
run_solver ev0p$POP $2
run_solver ev5p$POP $2
run_solver ev10p$POP $2
run_solver ev15p$POP $2
run_solver ev20p$POP $2
run_solver ev25p$POP $2
run_solver ev30p$POP $2
run_solver ev35p$POP $2
run_solver ev40p$POP $2
run_solver ev45p$POP $2
run_solver ev50p$POP $2
run_solver ev55p$POP $2
run_solver ev60p$POP $2
run_solver ev65p$POP $2
run_solver ev70p$POP $2
run_solver ev75p$POP $2
run_solver ev80p$POP $2
run_solver ev85p$POP $2
run_solver ev90p$POP $2
}
prepare_new_run
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,35 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="fove"
function run_all_graphs
{
write_header $1
run_solver ev0p$POP $2
run_solver ev5p$POP $2
run_solver ev10p$POP $2
run_solver ev15p$POP $2
run_solver ev20p$POP $2
run_solver ev25p$POP $2
run_solver ev30p$POP $2
run_solver ev35p$POP $2
run_solver ev40p$POP $2
run_solver ev45p$POP $2
run_solver ev50p$POP $2
run_solver ev55p$POP $2
run_solver ev60p$POP $2
run_solver ev65p$POP $2
run_solver ev70p$POP $2
run_solver ev75p$POP $2
run_solver ev80p$POP $2
run_solver ev85p$POP $2
run_solver ev90p$POP $2
}
prepare_new_run
run_all_graphs "fove "

View File

@ -0,0 +1,49 @@
#!/home/tgomes/bin/yap -L --
:- use_module(library(lists)).
:- use_module(library(random)).
:- initialization(main).
main :-
unix(argv(Args)),
nth(1, Args, EV), % percentage of evidence
nth(2, Args, NP), % number of individuals
atomic_concat(['ev', EV, 'p', NP, '.yap'], FileName),
open(FileName, 'write', S),
atom_number(EV, EV2),
atom_number(NP, NP2),
EV3 is EV2 / 100.0,
generate_people(S, NP2, 4),
write(S, '\n'),
write(S, 'query(X) :- '),
generate_evidence(S, NP2, EV3, 4),
write(S, 'friends(p1,p2,X).\n'),
close(S).
generate_people(S, N, Counting) :-
Counting > N, !.
generate_people(S, N, Counting) :-
format(S, 'people(p~w).~n', [Counting]),
Counting1 is Counting + 1,
generate_people(S, N, Counting1).
generate_evidence(S, N, Ev, Counting) :-
Counting > N, !.
generate_evidence(S, N, Ev, Counting) :-
random(X),
(
X < Ev
->
random(Y),
(Y > 0.5 -> Val = t ; Val = f),
format(S, 'smokes(p~w,~w),', [Counting,Val])
;
true
),
Counting1 is Counting + 1,
generate_evidence(S, N, Ev, Counting1).

View File

@ -0,0 +1,37 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="hve"
function run_all_graphs
{
write_header $1
run_solver ev0p$POP $2
run_solver ev5p$POP $2
run_solver ev10p$POP $2
run_solver ev15p$POP $2
run_solver ev20p$POP $2
run_solver ev25p$POP $2
run_solver ev30p$POP $2
run_solver ev35p$POP $2
run_solver ev40p$POP $2
run_solver ev45p$POP $2
run_solver ev50p$POP $2
run_solver ev55p$POP $2
run_solver ev60p$POP $2
run_solver ev65p$POP $2
run_solver ev70p$POP $2
run_solver ev75p$POP $2
run_solver ev80p$POP $2
run_solver ev85p$POP $2
run_solver ev90p$POP $2
}
prepare_new_run
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
#run_all_graphs "hve(elim_heuristic=min_weight) " min_weight
#run_all_graphs "hve(elim_heuristic=min_fill) " min_fill
#run_all_graphs "hve(elim_heuristic=weighted_min_fill) " weighted_min_fill

View File

@ -0,0 +1,34 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="lbp"
function run_all_graphs
{
write_header $1
run_solver ev0p$POP $2
run_solver ev5p$POP $2
run_solver ev10p$POP $2
run_solver ev15p$POP $2
run_solver ev20p$POP $2
run_solver ev25p$POP $2
run_solver ev30p$POP $2
run_solver ev35p$POP $2
run_solver ev40p$POP $2
run_solver ev45p$POP $2
run_solver ev50p$POP $2
run_solver ev55p$POP $2
run_solver ev60p$POP $2
run_solver ev65p$POP $2
run_solver ev70p$POP $2
run_solver ev75p$POP $2
run_solver ev80p$POP $2
run_solver ev85p$POP $2
run_solver ev90p$POP $2
}
prepare_new_run
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,8 @@
#!/bin/bash
NETWORK="'../../examples/social_domain2'"
SHORTNAME="sm"
QUERY="query(X)"
POP=500

View File

@ -15,8 +15,8 @@ function run_all_graphs
run_solver p20000attrs$N_ATTRS $2
run_solver p25000attrs$N_ATTRS $2
run_solver p30000attrs$N_ATTRS $2
run_solver p35000attrs$N_ATTRS $2
return
run_solver p35000attrs$N_ATTRS $2
run_solver p40000attrs$N_ATTRS $2
run_solver p45000attrs$N_ATTRS $2
run_solver p50000attrs$N_ATTRS $2

View File

@ -0,0 +1,36 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="lbp"
function run_all_graphs
{
write_header $1
run_solver p1000attrs$N_ATTRS $2
run_solver p5000attrs$N_ATTRS $2
run_solver p10000attrs$N_ATTRS $2
run_solver p15000attrs$N_ATTRS $2
run_solver p20000attrs$N_ATTRS $2
run_solver p25000attrs$N_ATTRS $2
run_solver p30000attrs$N_ATTRS $2
run_solver p35000attrs$N_ATTRS $2
run_solver p40000attrs$N_ATTRS $2
run_solver p45000attrs$N_ATTRS $2
run_solver p50000attrs$N_ATTRS $2
run_solver p55000attrs$N_ATTRS $2
run_solver p60000attrs$N_ATTRS $2
run_solver p65000attrs$N_ATTRS $2
run_solver p70000attrs$N_ATTRS $2
run_solver p75000attrs$N_ATTRS $2
run_solver p80000attrs$N_ATTRS $2
run_solver p85000attrs$N_ATTRS $2
run_solver p90000attrs$N_ATTRS $2
run_solver p95000attrs$N_ATTRS $2
run_solver p100000attrs$N_ATTRS $2
}
prepare_new_run
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed

View File

@ -15,7 +15,7 @@
cpp_run_ground_solver/3,
cpp_set_vars_information/2,
cpp_set_horus_flag/2,
cpp_free_parfactors/1,
cpp_free_lifted_network/1,
cpp_free_ground_network/1
]).

View File

@ -79,7 +79,6 @@ run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
%writeln(queryKeys:QueryKeys), writeln(''),
%writeln(queryIds:QueryIds), writeln(''),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
cpp_run_ground_solver(Network, [QueryIds], Solutions).

View File

@ -17,7 +17,7 @@
[cpp_create_lifted_network/3,
cpp_set_parfactors_params/2,
cpp_run_lifted_solver/3,
cpp_free_parfactors/1
cpp_free_lifted_network/1
]).
:- use_module(library('clpbn/display'),
@ -144,5 +144,5 @@ run_horus_lifted_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds))
finalize_horus_lifted_solver(fove(ParfactorList, _)) :-
cpp_free_parfactors(ParfactorList).
cpp_free_lifted_network(ParfactorList).

View File

@ -7,17 +7,17 @@
:- yap_flag(write_strings, off).
bayes burglary::[b1,b2] ; [0.001, 0.999] ; [].
bayes burglary::[t,f] ; [0.001, 0.999] ; [].
bayes earthquake::[e1,e2] ; [0.002, 0.998]; [].
bayes earthquake::[t,f] ; [0.002, 0.998]; [].
bayes alarm::[a1,a2], burglary, earthquake ;
bayes alarm::[t,f], burglary, earthquake ;
[0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ;
[].
bayes john_calls::[j1,j2], alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
bayes john_calls::[t,f], alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
bayes mary_calls::[m1,m2], alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
bayes mary_calls::[t,f], alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
% ?- john_calls(J), mary_calls(m1).
% ?- john_calls(J), mary_calls(t).

View File

@ -16,13 +16,13 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
Scheduling scheduling;
for (size_t i = 0; i < queryIds.size(); i++) {
assert (dag_.getNode (queryIds[i]));
DAGraphNode* n = dag_.getNode (queryIds[i]);
BBNode* n = dag_.getNode (queryIds[i]);
scheduling.push (ScheduleInfo (n, false, true));
}
while (!scheduling.empty()) {
ScheduleInfo& sch = scheduling.front();
DAGraphNode* n = sch.node;
BBNode* n = sch.node;
n->setAsVisited();
if (n->hasEvidence() == false && sch.visitedFromChild) {
if (n->isMarkedOnTop() == false) {
@ -59,7 +59,7 @@ BayesBall::constructGraph (FactorGraph* fg) const
{
const FacNodes& facNodes = fg_.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) {
const DAGraphNode* n = dag_.getNode (
const BBNode* n = dag_.getNode (
facNodes[i]->factor().argument (0));
if (n->isMarkedOnTop()) {
fg->addFactor (facNodes[i]->factor());

View File

@ -7,7 +7,7 @@
#include <map>
#include "FactorGraph.h"
#include "BayesNet.h"
#include "BayesBallGraph.h"
#include "Horus.h"
using namespace std;
@ -15,12 +15,12 @@ using namespace std;
struct ScheduleInfo
{
ScheduleInfo (DAGraphNode* n, bool vfp, bool vfc) :
ScheduleInfo (BBNode* n, bool vfp, bool vfc) :
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
DAGraphNode* node;
bool visitedFromParent;
bool visitedFromChild;
BBNode* node;
bool visitedFromParent;
bool visitedFromChild;
};
@ -30,40 +30,40 @@ typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
class BayesBall
{
public:
BayesBall (FactorGraph& fg)
: fg_(fg) , dag_(fg.getStructure())
{
dag_.clear();
}
BayesBall (FactorGraph& fg)
: fg_(fg) , dag_(fg.getStructure())
{
dag_.clear();
}
FactorGraph* getMinimalFactorGraph (const VarIds&);
FactorGraph* getMinimalFactorGraph (const VarIds&);
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
{
BayesBall bb (fg);
return bb.getMinimalFactorGraph (vids);
}
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
{
BayesBall bb (fg);
return bb.getMinimalFactorGraph (vids);
}
private:
void constructGraph (FactorGraph* fg) const;
void scheduleParents (const DAGraphNode* n, Scheduling& sch) const;
void scheduleParents (const BBNode* n, Scheduling& sch) const;
void scheduleChilds (const DAGraphNode* n, Scheduling& sch) const;
void scheduleChilds (const BBNode* n, Scheduling& sch) const;
FactorGraph& fg_;
DAGraph& dag_;
BayesBallGraph& dag_;
};
inline void
BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const
{
const vector<DAGraphNode*>& ps = n->parents();
for (vector<DAGraphNode*>::const_iterator it = ps.begin();
const vector<BBNode*>& ps = n->parents();
for (vector<BBNode*>::const_iterator it = ps.begin();
it != ps.end(); ++it) {
sch.push (ScheduleInfo (*it, false, true));
}
@ -72,10 +72,10 @@ BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
inline void
BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const
BayesBall::scheduleChilds (const BBNode* n, Scheduling& sch) const
{
const vector<DAGraphNode*>& cs = n->childs();
for (vector<DAGraphNode*>::const_iterator it = cs.begin();
const vector<BBNode*>& cs = n->childs();
for (vector<BBNode*>::const_iterator it = cs.begin();
it != cs.end(); ++it) {
sch.push (ScheduleInfo (*it, true, false));
}

View File

@ -5,12 +5,12 @@
#include <fstream>
#include <sstream>
#include "BayesNet.h"
#include "BayesBallGraph.h"
#include "Util.h"
void
DAGraph::addNode (DAGraphNode* n)
BayesBallGraph::addNode (BBNode* n)
{
assert (Util::contains (varMap_, n->varId()) == false);
nodes_.push_back (n);
@ -20,10 +20,10 @@ DAGraph::addNode (DAGraphNode* n)
void
DAGraph::addEdge (VarId vid1, VarId vid2)
BayesBallGraph::addEdge (VarId vid1, VarId vid2)
{
unordered_map<VarId, DAGraphNode*>::iterator it1;
unordered_map<VarId, DAGraphNode*>::iterator it2;
unordered_map<VarId, BBNode*>::iterator it1;
unordered_map<VarId, BBNode*>::iterator it2;
it1 = varMap_.find (vid1);
it2 = varMap_.find (vid2);
assert (it1 != varMap_.end());
@ -34,20 +34,20 @@ DAGraph::addEdge (VarId vid1, VarId vid2)
const DAGraphNode*
DAGraph::getNode (VarId vid) const
const BBNode*
BayesBallGraph::getNode (VarId vid) const
{
unordered_map<VarId, DAGraphNode*>::const_iterator it;
unordered_map<VarId, BBNode*>::const_iterator it;
it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
DAGraphNode*
DAGraph::getNode (VarId vid)
BBNode*
BayesBallGraph::getNode (VarId vid)
{
unordered_map<VarId, DAGraphNode*>::const_iterator it;
unordered_map<VarId, BBNode*>::const_iterator it;
it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
@ -55,7 +55,7 @@ DAGraph::getNode (VarId vid)
void
DAGraph::setIndexes (void)
BayesBallGraph::setIndexes (void)
{
for (size_t i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i);
@ -65,7 +65,7 @@ DAGraph::setIndexes (void)
void
DAGraph::clear (void)
BayesBallGraph::clear (void)
{
for (size_t i = 0; i < nodes_.size(); i++) {
nodes_[i]->clear();
@ -75,12 +75,12 @@ DAGraph::clear (void)
void
DAGraph::exportToGraphViz (const char* fileName)
BayesBallGraph::exportToGraphViz (const char* fileName)
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "DAGraph::exportToDotFile()" << endl;
cerr << "BayesBallGraph::exportToDotFile()" << endl;
abort();
}
out << "digraph {" << endl;
@ -95,7 +95,7 @@ DAGraph::exportToGraphViz (const char* fileName)
out << "]" << endl;
}
for (size_t i = 0; i < nodes_.size(); i++) {
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
const vector<BBNode*>& childs = nodes_[i]->childs();
for (size_t j = 0; j < childs.size(); j++) {
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
out << " [style=bold]" << endl ;

View File

@ -1,5 +1,5 @@
#ifndef HORUS_BAYESNET_H
#define HORUS_BAYESNET_H
#ifndef HORUS_BAYESBALLGRAPH_H
#define HORUS_BAYESBALLGRAPH_H
#include <vector>
#include <queue>
@ -9,29 +9,25 @@
#include "Var.h"
#include "Horus.h"
using namespace std;
class Var;
class DAGraphNode : public Var
class BBNode : public Var
{
public:
DAGraphNode (Var* v) : Var (v) , visited_(false),
BBNode (Var* v) : Var (v) , visited_(false),
markedOnTop_(false), markedOnBottom_(false) { }
const vector<DAGraphNode*>& childs (void) const { return childs_; }
const vector<BBNode*>& childs (void) const { return childs_; }
vector<DAGraphNode*>& childs (void) { return childs_; }
vector<BBNode*>& childs (void) { return childs_; }
const vector<DAGraphNode*>& parents (void) const { return parents_; }
const vector<BBNode*>& parents (void) const { return parents_; }
vector<DAGraphNode*>& parents (void) { return parents_; }
vector<BBNode*>& parents (void) { return parents_; }
void addParent (DAGraphNode* p) { parents_.push_back (p); }
void addParent (BBNode* p) { parents_.push_back (p); }
void addChild (DAGraphNode* c) { childs_.push_back (c); }
void addChild (BBNode* c) { childs_.push_back (c); }
bool isVisited (void) const { return visited_; }
@ -52,23 +48,23 @@ class DAGraphNode : public Var
bool markedOnTop_;
bool markedOnBottom_;
vector<DAGraphNode*> childs_;
vector<DAGraphNode*> parents_;
vector<BBNode*> childs_;
vector<BBNode*> parents_;
};
class DAGraph
class BayesBallGraph
{
public:
DAGraph (void) { }
BayesBallGraph (void) { }
void addNode (DAGraphNode* n);
void addNode (BBNode* n);
void addEdge (VarId vid1, VarId vid2);
const DAGraphNode* getNode (VarId vid) const;
const BBNode* getNode (VarId vid) const;
DAGraphNode* getNode (VarId vid);
BBNode* getNode (VarId vid);
bool empty (void) const { return nodes_.empty(); }
@ -79,10 +75,10 @@ class DAGraph
void exportToGraphViz (const char*);
private:
vector<DAGraphNode*> nodes_;
vector<BBNode*> nodes_;
unordered_map<VarId, DAGraphNode*> varMap_;
unordered_map<VarId, BBNode*> varMap_;
};
#endif // HORUS_BAYESNET_H
#endif // HORUS_BAYESBALLGRAPH_H

View File

@ -5,21 +5,21 @@
#include <iostream>
#include "BpSolver.h"
#include "BeliefProp.h"
#include "FactorGraph.h"
#include "Factor.h"
#include "Indexer.h"
#include "Horus.h"
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
BeliefProp::BeliefProp (const FactorGraph& fg) : Solver (fg)
{
runned_ = false;
}
BpSolver::~BpSolver (void)
BeliefProp::~BeliefProp (void)
{
for (size_t i = 0; i < varsI_.size(); i++) {
delete varsI_[i];
@ -35,7 +35,7 @@ BpSolver::~BpSolver (void)
Params
BpSolver::solveQuery (VarIds queryVids)
BeliefProp::solveQuery (VarIds queryVids)
{
assert (queryVids.empty() == false);
return queryVids.size() == 1
@ -46,7 +46,7 @@ BpSolver::solveQuery (VarIds queryVids)
void
BpSolver::printSolverFlags (void) const
BeliefProp::printSolverFlags (void) const
{
stringstream ss;
ss << "belief propagation [" ;
@ -68,7 +68,7 @@ BpSolver::printSolverFlags (void) const
Params
BpSolver::getPosterioriOf (VarId vid)
BeliefProp::getPosterioriOf (VarId vid)
{
if (runned_ == false) {
runSolver();
@ -101,7 +101,7 @@ BpSolver::getPosterioriOf (VarId vid)
Params
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
{
if (runned_ == false) {
runSolver();
@ -117,30 +117,43 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
}
if (idx == facNodes.size()) {
return getJointByConditioning (jointVarIds);
} else {
Factor res (facNodes[idx]->factor());
const BpLinks& links = ninf(facNodes[idx])->getLinks();
for (size_t i = 0; i < links.size(); i++) {
Factor msg ({links[i]->varNode()->varId()},
{links[i]->varNode()->range()},
getVarToFactorMsg (links[i]));
res.multiply (msg);
}
res.sumOutAllExcept (jointVarIds);
res.reorderArguments (jointVarIds);
res.normalize();
Params jointDist = res.params();
if (Globals::logDomain) {
Util::exp (jointDist);
}
return jointDist;
}
return getFactorJoint (idx, jointVarIds);
}
Params
BeliefProp::getFactorJoint (
size_t fnIdx,
const VarIds& jointVarIds)
{
if (runned_ == false) {
runSolver();
}
FacNode* fn = fg.facNodes()[fnIdx];
Factor res (fn->factor());
const BpLinks& links = ninf(fn)->getLinks();
for (size_t i = 0; i < links.size(); i++) {
Factor msg ({links[i]->varNode()->varId()},
{links[i]->varNode()->range()},
getVarToFactorMsg (links[i]));
res.multiply (msg);
}
res.sumOutAllExcept (jointVarIds);
res.reorderArguments (jointVarIds);
res.normalize();
Params jointDist = res.params();
if (Globals::logDomain) {
Util::exp (jointDist);
}
return jointDist;
}
void
BpSolver::runSolver (void)
BeliefProp::runSolver (void)
{
initializeSolver();
nIters_ = 0;
@ -173,7 +186,7 @@ BpSolver::runSolver (void)
}
if (Globals::verbosity > 0) {
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << "Belief propagation converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
@ -187,7 +200,7 @@ BpSolver::runSolver (void)
void
BpSolver::createLinks (void)
BeliefProp::createLinks (void)
{
const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) {
@ -201,7 +214,7 @@ BpSolver::createLinks (void)
void
BpSolver::maxResidualSchedule (void)
BeliefProp::maxResidualSchedule (void)
{
if (nIters_ == 1) {
for (size_t i = 0; i < links_.size(); i++) {
@ -256,7 +269,7 @@ BpSolver::maxResidualSchedule (void)
void
BpSolver::calcFactorToVarMsg (BpLink* link)
BeliefProp::calcFactorToVarMsg (BpLink* link)
{
FacNode* src = link->facNode();
const VarNode* dst = link->varNode();
@ -320,7 +333,7 @@ BpSolver::calcFactorToVarMsg (BpLink* link)
Params
BpSolver::getVarToFactorMsg (const BpLink* link) const
BeliefProp::getVarToFactorMsg (const BpLink* link) const
{
const VarNode* src = link->varNode();
Params msg;
@ -361,61 +374,15 @@ BpSolver::getVarToFactorMsg (const BpLink* link) const
Params
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
{
VarNodes jointVars;
for (size_t i = 0; i < jointVarIds.size(); i++) {
assert (fg.getVarNode (jointVarIds[i]));
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
}
FactorGraph* tempFg = new FactorGraph (fg);
BpSolver solver (*tempFg);
solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
VarIds observedVids = {jointVars[0]->varId()};
for (size_t i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs;
Vars observedVars;
Ranges observedRanges;
for (size_t j = 0; j < observedVids.size(); j++) {
observedVars.push_back (tempFg->getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range());
}
Indexer indexer (observedRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]);
}
BpSolver solver (*tempFg);
solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->range() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
}
return prevBeliefs;
return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds);
}
void
BpSolver::initializeSolver (void)
BeliefProp::initializeSolver (void)
{
const VarNodes& varNodes = fg.varNodes();
varsI_.reserve (varNodes.size());
@ -439,7 +406,7 @@ BpSolver::initializeSolver (void)
bool
BpSolver::converged (void)
BeliefProp::converged (void)
{
if (links_.size() == 0) {
return true;
@ -487,7 +454,7 @@ BpSolver::converged (void)
void
BpSolver::printLinkInformation (void) const
BeliefProp::printLinkInformation (void) const
{
for (size_t i = 0; i < links_.size(); i++) {
BpLink* l = links_[i];

View File

@ -1,5 +1,5 @@
#ifndef HORUS_BPSOLVER_H
#define HORUS_BPSOLVER_H
#ifndef HORUS_BELIEFPROP_H
#define HORUS_BELIEFPROP_H
#include <set>
#include <vector>
@ -83,12 +83,12 @@ class SPNodeInfo
};
class BpSolver : public Solver
class BeliefProp : public Solver
{
public:
BpSolver (const FactorGraph&);
BeliefProp (const FactorGraph&);
virtual ~BpSolver (void);
virtual ~BeliefProp (void);
Params solveQuery (VarIds);
@ -111,6 +111,10 @@ class BpSolver : public Solver
virtual Params getJointByConditioning (const VarIds&) const;
public:
Params getFactorJoint (size_t fnIdx, const VarIds&);
protected:
SPNodeInfo* ninf (const VarNode* var) const
{
return varsI_[var->getIndex()];
@ -180,5 +184,5 @@ class BpSolver : public Solver
virtual void printLinkInformation (void) const;
};
#endif // HORUS_BPSOLVER_H
#endif // HORUS_BELIEFPROP_H

View File

@ -1,23 +1,23 @@
#include "CbpSolver.h"
#include "WeightedBpSolver.h"
#include "CountingBp.h"
#include "WeightedBp.h"
bool CbpSolver::checkForIdenticalFactors = true;
bool CountingBp::checkForIdenticalFactors = true;
CbpSolver::CbpSolver (const FactorGraph& fg)
CountingBp::CountingBp (const FactorGraph& fg)
: Solver (fg), freeColor_(0)
{
findIdenticalFactors();
setInitialColors();
createGroups();
compressedFg_ = getCompressedFactorGraph();
solver_ = new WeightedBpSolver (*compressedFg_, getWeights());
solver_ = new WeightedBp (*compressedFg_, getWeights());
}
CbpSolver::~CbpSolver (void)
CountingBp::~CountingBp (void)
{
delete solver_;
delete compressedFg_;
@ -32,7 +32,7 @@ CbpSolver::~CbpSolver (void)
void
CbpSolver::printSolverFlags (void) const
CountingBp::printSolverFlags (void) const
{
stringstream ss;
ss << "counting bp [" ;
@ -48,7 +48,7 @@ CbpSolver::printSolverFlags (void) const
ss << ",accuracy=" << BpOptions::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",chkif=" <<
Util::toString (CbpSolver::checkForIdenticalFactors);
Util::toString (CountingBp::checkForIdenticalFactors);
ss << "]" ;
cout << ss.str() << endl;
}
@ -56,7 +56,7 @@ CbpSolver::printSolverFlags (void) const
Params
CbpSolver::solveQuery (VarIds queryVids)
CountingBp::solveQuery (VarIds queryVids)
{
assert (queryVids.empty() == false);
Params res;
@ -74,16 +74,15 @@ CbpSolver::solveQuery (VarIds queryVids)
cout << endl;
}
if (idx == facNodes.size()) {
cerr << "error: only joint distributions on variables of some " ;
cerr << "clique are supported with the current solver" ;
cerr << endl;
exit (1);
res = Solver::getJointByConditioning (
GroundSolver::CBP, fg, queryVids);
} else {
VarIds reprArgs;
for (size_t i = 0; i < queryVids.size(); i++) {
reprArgs.push_back (getRepresentative (queryVids[i]));
}
res = solver_->getFactorJoint (idx, reprArgs);
}
VarIds representatives;
for (size_t i = 0; i < queryVids.size(); i++) {
representatives.push_back (getRepresentative (queryVids[i]));
}
res = solver_->getJointDistributionOf (representatives);
}
return res;
}
@ -91,7 +90,7 @@ CbpSolver::solveQuery (VarIds queryVids)
void
CbpSolver::findIdenticalFactors()
CountingBp::findIdenticalFactors()
{
const FacNodes& facNodes = fg.facNodes();
if (checkForIdenticalFactors == false ||
@ -126,7 +125,7 @@ CbpSolver::findIdenticalFactors()
void
CbpSolver::setInitialColors (void)
CountingBp::setInitialColors (void)
{
varColors_.resize (fg.nrVarNodes());
facColors_.resize (fg.nrFacNodes());
@ -165,7 +164,7 @@ CbpSolver::setInitialColors (void)
void
CbpSolver::createGroups (void)
CountingBp::createGroups (void)
{
VarSignMap varGroups;
FacSignMap facGroups;
@ -227,7 +226,7 @@ CbpSolver::createGroups (void)
void
CbpSolver::createClusters (
CountingBp::createClusters (
const VarSignMap& varGroups,
const FacSignMap& facGroups)
{
@ -260,7 +259,7 @@ CbpSolver::createClusters (
VarSignature
CbpSolver::getSignature (const VarNode* varNode)
CountingBp::getSignature (const VarNode* varNode)
{
const FacNodes& neighs = varNode->neighbors();
VarSignature sign;
@ -278,7 +277,7 @@ CbpSolver::getSignature (const VarNode* varNode)
FacSignature
CbpSolver::getSignature (const FacNode* facNode)
CountingBp::getSignature (const FacNode* facNode)
{
const VarNodes& neighs = facNode->neighbors();
FacSignature sign;
@ -292,8 +291,31 @@ CbpSolver::getSignature (const FacNode* facNode)
VarId
CountingBp::getRepresentative (VarId vid)
{
assert (Util::contains (vid2VarCluster_, vid));
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->representative()->varId();
}
FacNode*
CountingBp::getRepresentative (FacNode* fn)
{
for (size_t i = 0; i < facClusters_.size(); i++) {
if (Util::contains (facClusters_[i]->members(), fn)) {
return facClusters_[i]->representative();
}
}
return 0;
}
FactorGraph*
CbpSolver::getCompressedFactorGraph (void)
CountingBp::getCompressedFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
for (size_t i = 0; i < varClusters_.size(); i++) {
@ -322,7 +344,7 @@ CbpSolver::getCompressedFactorGraph (void)
vector<vector<unsigned>>
CbpSolver::getWeights (void) const
CountingBp::getWeights (void) const
{
vector<vector<unsigned>> weights;
weights.reserve (facClusters_.size());
@ -341,7 +363,7 @@ CbpSolver::getWeights (void) const
unsigned
CbpSolver::getWeight (
CountingBp::getWeight (
const FacCluster* fc,
const VarCluster* vc,
size_t index) const
@ -364,7 +386,7 @@ CbpSolver::getWeight (
void
CbpSolver::printGroups (
CountingBp::printGroups (
const VarSignMap& varGroups,
const FacSignMap& facGroups) const
{

View File

@ -1,5 +1,5 @@
#ifndef HORUS_CBPSOLVER_H
#define HORUS_CBPSOLVER_H
#ifndef HORUS_COUNTINGBP_H
#define HORUS_COUNTINGBP_H
#include <unordered_map>
@ -12,7 +12,7 @@ class VarCluster;
class FacCluster;
class VarSignHash;
class FacSignHash;
class WeightedBpSolver;
class WeightedBp;
typedef long Color;
typedef vector<Color> Colors;
@ -100,12 +100,12 @@ class FacCluster
};
class CbpSolver : public Solver
class CountingBp : public Solver
{
public:
CbpSolver (const FactorGraph& fg);
CountingBp (const FactorGraph& fg);
~CbpSolver (void);
~CountingBp (void);
void printSolverFlags (void) const;
@ -154,12 +154,9 @@ class CbpSolver : public Solver
void printGroups (const VarSignMap&, const FacSignMap&) const;
VarId getRepresentative (VarId vid)
{
assert (Util::contains (vid2VarCluster_, vid));
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->representative()->varId();
}
VarId getRepresentative (VarId vid);
FacNode* getRepresentative (FacNode*);
FactorGraph* getCompressedFactorGraph (void);
@ -176,8 +173,8 @@ class CbpSolver : public Solver
FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_;
const FactorGraph* compressedFg_;
WeightedBpSolver* solver_;
WeightedBp* solver_;
};
#endif // HORUS_CBPSOLVER_H
#endif // HORUS_COUNTINGBP_H

View File

@ -8,7 +8,6 @@
#include "FactorGraph.h"
#include "Factor.h"
#include "BayesNet.h"
#include "BayesBall.h"
#include "Util.h"
@ -236,13 +235,13 @@ FactorGraph::isTree (void) const
DAGraph&
BayesBallGraph&
FactorGraph::getStructure (void)
{
assert (bayesFactors_);
if (structure_.empty()) {
for (size_t i = 0; i < varNodes_.size(); i++) {
structure_.addNode (new DAGraphNode (varNodes_[i]));
structure_.addNode (new BBNode (varNodes_[i]));
}
for (size_t i = 0; i < facNodes_.size(); i++) {
const VarIds& vids = facNodes_[i]->factor().arguments();

View File

@ -4,7 +4,7 @@
#include <vector>
#include "Factor.h"
#include "BayesNet.h"
#include "BayesBallGraph.h"
#include "Horus.h"
using namespace std;
@ -103,7 +103,7 @@ class FactorGraph
bool isTree (void) const;
DAGraph& getStructure (void);
BayesBallGraph& getStructure (void);
void print (void) const;
@ -129,7 +129,7 @@ class FactorGraph
VarNodes varNodes_;
FacNodes facNodes_;
DAGraph structure_;
BayesBallGraph structure_;
bool bayesFactors_;
typedef unordered_map<unsigned, VarNode*> VarMap;

View File

@ -28,14 +28,14 @@ typedef vector<unsigned> Ranges;
typedef unsigned long long ullong;
enum LiftedSolvers
enum LiftedSolver
{
FOVE, // first order variable elimination
LBP, // lifted belief propagation
};
enum GroundSolvers
enum GroundSolver
{
VE, // variable elimination
BP, // belief propagation
@ -50,8 +50,8 @@ extern bool logDomain;
// level of debug information
extern unsigned verbosity;
extern LiftedSolvers liftedSolver;
extern GroundSolvers groundSolver;
extern LiftedSolver liftedSolver;
extern GroundSolver groundSolver;
};

View File

@ -4,9 +4,9 @@
#include <sstream>
#include "FactorGraph.h"
#include "VarElimSolver.h"
#include "BpSolver.h"
#include "CbpSolver.h"
#include "VarElim.h"
#include "BeliefProp.h"
#include "CountingBp.h"
using namespace std;
@ -162,14 +162,14 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
{
Solver* solver = 0;
switch (Globals::groundSolver) {
case GroundSolvers::VE:
solver = new VarElimSolver (fg);
case GroundSolver::VE:
solver = new VarElim (fg);
break;
case GroundSolvers::BP:
solver = new BpSolver (fg);
case GroundSolver::BP:
solver = new BeliefProp (fg);
break;
case GroundSolvers::CBP:
solver = new CbpSolver (fg);
case GroundSolver::CBP:
solver = new CountingBp (fg);
break;
default:
assert (false);

View File

@ -9,21 +9,19 @@
#include "ParfactorList.h"
#include "FactorGraph.h"
#include "FoveSolver.h"
#include "VarElimSolver.h"
#include "LiftedBpSolver.h"
#include "BpSolver.h"
#include "CbpSolver.h"
#include "LiftedVe.h"
#include "VarElim.h"
#include "LiftedBp.h"
#include "CountingBp.h"
#include "BeliefProp.h"
#include "ElimGraph.h"
#include "BayesBall.h"
using namespace std;
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
Params readParameters (YAP_Term);
vector<unsigned> readUnsignedList (YAP_Term);
@ -32,14 +30,6 @@ void readLiftedEvidence (YAP_Term, ObservedFormulas&);
Parfactor* readParfactor (YAP_Term);
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
vector<unsigned>
readUnsignedList (YAP_Term list)
@ -54,7 +44,8 @@ readUnsignedList (YAP_Term list)
int createLiftedNetwork (void)
int
createLiftedNetwork (void)
{
Parfactors parfactors;
YAP_Term parfactorList = YAP_ARG1;
@ -91,7 +82,8 @@ int createLiftedNetwork (void)
Parfactor* readParfactor (YAP_Term pfTerm)
Parfactor*
readParfactor (YAP_Term pfTerm)
{
// read dist id
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
@ -171,7 +163,8 @@ Parfactor* readParfactor (YAP_Term pfTerm)
void readLiftedEvidence (
void
readLiftedEvidence (
YAP_Term observedList,
ObservedFormulas& obsFormulas)
{
@ -237,7 +230,6 @@ createGroundNetwork (void)
fg->addFactor (Factor (varIds, ranges, params, distId));
factorList = YAP_TailOfTerm (factorList);
}
unsigned nrObservedVars = 0;
YAP_Term evidenceList = YAP_ARG3;
while (evidenceList != YAP_TermNil()) {
@ -285,7 +277,7 @@ runLiftedSolver (void)
YAP_Term taskList = YAP_ARG2;
vector<Params> results;
ParfactorList pfListCopy (*network->first);
FoveSolver::absorveEvidence (pfListCopy, *network->second);
LiftedVe::absorveEvidence (pfListCopy, *network->second);
while (taskList != YAP_TermNil()) {
Grounds queryVars;
YAP_Term jointList = YAP_HeadOfTerm (taskList);
@ -311,15 +303,15 @@ runLiftedSolver (void)
}
jointList = YAP_TailOfTerm (jointList);
}
if (Globals::liftedSolver == LiftedSolvers::FOVE) {
FoveSolver solver (pfListCopy);
if (Globals::liftedSolver == LiftedSolver::FOVE) {
LiftedVe solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (queryVars));
} else if (Globals::liftedSolver == LiftedSolvers::LBP) {
LiftedBpSolver solver (pfListCopy);
} else if (Globals::liftedSolver == LiftedSolver::LBP) {
LiftedBp solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
@ -361,11 +353,42 @@ runGroundSolver (void)
taskList = YAP_TailOfTerm (taskList);
}
vector<Params> results;
if (Globals::groundSolver == GroundSolvers::VE) {
runVeSolver (fg, tasks, results);
std::set<VarId> vids;
for (size_t i = 0; i < tasks.size(); i++) {
Util::addToSet (vids, tasks[i]);
}
Solver* solver = 0;
FactorGraph* mfg = fg;
if (fg->bayesianFactors()) {
mfg = BayesBall::getMinimalFactorGraph (
*fg, VarIds (vids.begin(), vids.end()));
}
if (Globals::groundSolver == GroundSolver::VE) {
solver = new VarElim (*mfg);
} else if (Globals::groundSolver == GroundSolver::BP) {
solver = new BeliefProp (*mfg);
} else if (Globals::groundSolver == GroundSolver::CBP) {
CountingBp::checkForIdenticalFactors = false;
solver = new CountingBp (*mfg);
} else {
runBpSolver (fg, tasks, results);
assert (false);
}
if (Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
}
vector<Params> results;
results.reserve (tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i]));
}
delete solver;
if (fg->bayesianFactors()) {
delete mfg;
}
YAP_Term list = YAP_TermNil();
@ -386,72 +409,6 @@ runGroundSolver (void)
void runVeSolver (
FactorGraph* fg,
const vector<VarIds>& tasks,
vector<Params>& results)
{
results.reserve (tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
FactorGraph* mfg = fg;
if (fg->bayesianFactors()) {
// mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
}
// VarElimSolver solver (*mfg);
VarElimSolver solver (*fg); //FIXME
if (Globals::verbosity > 0 && i == 0) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (tasks[i]));
if (fg->bayesianFactors()) {
// delete mfg;
}
}
}
void runBpSolver (
FactorGraph* fg,
const vector<VarIds>& tasks,
vector<Params>& results)
{
std::set<VarId> vids;
for (size_t i = 0; i < tasks.size(); i++) {
Util::addToSet (vids, tasks[i]);
}
Solver* solver = 0;
FactorGraph* mfg = fg;
if (fg->bayesianFactors()) {
//mfg = BayesBall::getMinimalFactorGraph (
// *fg, VarIds (vids.begin(),vids.end()));
}
if (Globals::groundSolver == GroundSolvers::BP) {
solver = new BpSolver (*fg); // FIXME
} else if (Globals::groundSolver == GroundSolvers::CBP) {
CbpSolver::checkForIdenticalFactors = false;
solver = new CbpSolver (*fg); // FIXME
} else {
cerr << "error: unknow solver" << endl;
abort();
}
if (Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
}
results.reserve (tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i]));
}
if (fg->bayesianFactors()) {
//delete mfg;
}
delete solver;
}
int
setParfactorsParams (void)
{
@ -567,7 +524,7 @@ freeGroundNetwork (void)
int
freeParfactors (void)
freeLiftedNetwork (void)
{
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
delete network->first;
@ -589,7 +546,7 @@ init_predicates (void)
YAP_UserCPredicate ("cpp_cpp_set_factors_params", setFactorsParams, 2);
YAP_UserCPredicate ("cpp_set_vars_information", setVarsInformation, 2);
YAP_UserCPredicate ("cpp_set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("cpp_free_parfactors", freeParfactors, 1);
YAP_UserCPredicate ("cpp_free_lifted_network", freeLiftedNetwork, 1);
YAP_UserCPredicate ("cpp_free_ground_network", freeGroundNetwork, 1);
}

View File

@ -0,0 +1,232 @@
#include "LiftedBp.h"
#include "WeightedBp.h"
#include "FactorGraph.h"
#include "LiftedVe.h"
LiftedBp::LiftedBp (const ParfactorList& pfList)
: pfList_(pfList)
{
refineParfactors();
solver_ = new WeightedBp (*getFactorGraph(), getWeights());
}
LiftedBp::~LiftedBp (void)
{
delete solver_;
}
Params
LiftedBp::solveQuery (const Grounds& query)
{
assert (query.empty() == false);
Params res;
vector<PrvGroup> groups = getQueryGroups (query);
if (query.size() == 1) {
res = solver_->getPosterioriOf (groups[0]);
} else {
ParfactorList::iterator it = pfList_.begin();
size_t idx = pfList_.size();
size_t count = 0;
while (it != pfList_.end()) {
if ((*it)->containsGrounds (query)) {
idx = count;
break;
}
++ it;
++ count;
}
if (idx == pfList_.size()) {
res = getJointByConditioning (pfList_, query);
} else {
VarIds queryVids;
for (unsigned i = 0; i < groups.size(); i++) {
queryVids.push_back (groups[i]);
}
res = solver_->getFactorJoint (idx, queryVids);
}
}
return res;
}
void
LiftedBp::printSolverFlags (void) const
{
stringstream ss;
ss << "lifted bp [" ;
ss << "schedule=" ;
typedef BpOptions::Schedule Sch;
switch (BpOptions::schedule) {
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
case Sch::PARALLEL: ss << "parallel"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
}
ss << ",max_iter=" << BpOptions::maxIter;
ss << ",accuracy=" << BpOptions::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
}
void
LiftedBp::refineParfactors (void)
{
while (iterate() == false);
if (Globals::verbosity > 2) {
Util::printHeader ("AFTER REFINEMENT");
pfList_.print();
}
}
bool
LiftedBp::iterate (void)
{
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
const ProbFormulas& args = (*it)->arguments();
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
if ((*it)->constr()->isCountNormalized (lvs) == false) {
Parfactors pfs = LiftedVe::countNormalize (*it, lvs);
it = pfList_.removeAndDelete (it);
pfList_.add (pfs);
return false;
}
}
++ it;
}
return true;
}
vector<PrvGroup>
LiftedBp::getQueryGroups (const Grounds& query)
{
vector<PrvGroup> queryGroups;
for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
if ((*it)->containsGround (query[i])) {
queryGroups.push_back ((*it)->findGroup (query[i]));
break;
}
}
}
assert (queryGroups.size() == query.size());
return queryGroups;
}
FactorGraph*
LiftedBp::getFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
vector<PrvGroup> groups = (*it)->getAllGroups();
VarIds varIds;
for (size_t i = 0; i < groups.size(); i++) {
varIds.push_back (groups[i]);
}
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
}
return fg;
}
vector<vector<unsigned>>
LiftedBp::getWeights (void) const
{
vector<vector<unsigned>> weights;
weights.reserve (pfList_.size());
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
const ProbFormulas& args = (*it)->arguments();
weights.push_back ({ });
weights.back().reserve (args.size());
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
weights.back().push_back ((*it)->constr()->getConditionalCount (lvs));
}
}
return weights;
}
unsigned
LiftedBp::rangeOfGround (const Ground& gr)
{
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if ((*it)->containsGround (gr)) {
PrvGroup prvGroup = (*it)->findGroup (gr);
return (*it)->range ((*it)->indexOfGroup (prvGroup));
}
++ it;
}
return std::numeric_limits<unsigned>::max();
}
Params
LiftedBp::getJointByConditioning (
const ParfactorList& pfList,
const Grounds& grounds)
{
LiftedBp solver (pfList);
Params prevBeliefs = solver.solveQuery ({grounds[0]});
Grounds obsGrounds = {grounds[0]};
for (size_t i = 1; i < grounds.size(); i++) {
Params newBeliefs;
vector<ObservedFormula> obsFs;
Ranges obsRanges;
for (size_t j = 0; j < obsGrounds.size(); j++) {
obsFs.push_back (ObservedFormula (
obsGrounds[j].functor(), 0, obsGrounds[j].args()));
obsRanges.push_back (rangeOfGround (obsGrounds[j]));
}
Indexer indexer (obsRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < obsFs.size(); j++) {
obsFs[j].setEvidence (indexer[j]);
}
ParfactorList tempPfList (pfList);
LiftedVe::absorveEvidence (tempPfList, obsFs);
LiftedBp solver (tempPfList);
Params beliefs = solver.solveQuery ({grounds[i]});
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
unsigned range = rangeOfGround (grounds[i]);
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % range == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
obsGrounds.push_back (grounds[i]);
}
return prevBeliefs;
}

View File

@ -1,15 +1,17 @@
#ifndef HORUS_LIFTEDBPSOLVER_H
#define HORUS_LIFTEDBPSOLVER_H
#ifndef HORUS_LIFTEDBP_H
#define HORUS_LIFTEDBP_H
#include "ParfactorList.h"
class FactorGraph;
class WeightedBpSolver;
class WeightedBp;
class LiftedBpSolver
class LiftedBp
{
public:
LiftedBpSolver (const ParfactorList& pfList);
LiftedBp (const ParfactorList& pfList);
~LiftedBp (void);
Params solveQuery (const Grounds&);
@ -25,10 +27,14 @@ class LiftedBpSolver
FactorGraph* getFactorGraph (void);
vector<vector<unsigned>> getWeights (void) const;
unsigned rangeOfGround (const Ground&);
ParfactorList pfList_;
WeightedBpSolver* solver_;
Params getJointByConditioning (const ParfactorList&, const Grounds&);
ParfactorList pfList_;
WeightedBp* solver_;
};
#endif // HORUS_LIFTEDBPSOLVER_H
#endif // HORUS_LIFTEDBP_H

View File

@ -1,148 +0,0 @@
#include "LiftedBpSolver.h"
#include "WeightedBpSolver.h"
#include "FactorGraph.h"
#include "FoveSolver.h"
LiftedBpSolver::LiftedBpSolver (const ParfactorList& pfList)
: pfList_(pfList)
{
refineParfactors();
solver_ = new WeightedBpSolver (*getFactorGraph(), getWeights());
}
Params
LiftedBpSolver::solveQuery (const Grounds& query)
{
assert (query.empty() == false);
Params res;
vector<PrvGroup> groups = getQueryGroups (query);
if (query.size() == 1) {
res = solver_->getPosterioriOf (groups[0]);
} else {
VarIds queryVids;
for (unsigned i = 0; i < groups.size(); i++) {
queryVids.push_back (groups[i]);
}
res = solver_->getJointDistributionOf (queryVids);
}
return res;
}
void
LiftedBpSolver::printSolverFlags (void) const
{
stringstream ss;
ss << "lifted bp [" ;
ss << "schedule=" ;
typedef BpOptions::Schedule Sch;
switch (BpOptions::schedule) {
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
case Sch::PARALLEL: ss << "parallel"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
}
ss << ",max_iter=" << BpOptions::maxIter;
ss << ",accuracy=" << BpOptions::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
}
void
LiftedBpSolver::refineParfactors (void)
{
while (iterate() == false);
if (Globals::verbosity > 2) {
Util::printHeader ("AFTER REFINEMENT");
pfList_.print();
}
}
bool
LiftedBpSolver::iterate (void)
{
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
const ProbFormulas& args = (*it)->arguments();
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
if ((*it)->constr()->isCountNormalized (lvs) == false) {
Parfactors pfs = FoveSolver::countNormalize (*it, lvs);
it = pfList_.removeAndDelete (it);
pfList_.add (pfs);
return false;
}
}
++ it;
}
return true;
}
vector<PrvGroup>
LiftedBpSolver::getQueryGroups (const Grounds& query)
{
vector<PrvGroup> queryGroups;
for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
if ((*it)->containsGround (query[i])) {
queryGroups.push_back ((*it)->findGroup (query[i]));
break;
}
}
}
assert (queryGroups.size() == query.size());
return queryGroups;
}
FactorGraph*
LiftedBpSolver::getFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
vector<PrvGroup> groups = (*it)->getAllGroups();
VarIds varIds;
for (size_t i = 0; i < groups.size(); i++) {
varIds.push_back (groups[i]);
}
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
}
return fg;
}
vector<vector<unsigned>>
LiftedBpSolver::getWeights (void) const
{
vector<vector<unsigned>> weights;
weights.reserve (pfList_.size());
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
const ProbFormulas& args = (*it)->arguments();
weights.push_back ({ });
weights.back().reserve (args.size());
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
weights.back().push_back ((*it)->constr()->getConditionalCount (lvs));
}
}
return weights;
}

View File

@ -1,8 +1,7 @@
#include <algorithm>
#include <set>
#include "FoveSolver.h"
#include "LiftedVe.h"
#include "Histogram.h"
#include "Util.h"
@ -12,7 +11,7 @@ LiftedOperator::getValidOps (
ParfactorList& pfList,
const Grounds& query)
{
vector<LiftedOperator*> validOps;
vector<LiftedOperator*> validOps;
vector<ProductOperator*> multOps;
multOps = ProductOperator::getValidOps (pfList);
@ -222,7 +221,7 @@ SumOutOperator::apply (void)
product->sumOutIndex (fIdx);
pfList_.addShattered (product);
} else {
Parfactors pfs = FoveSolver::countNormalize (product, excl);
Parfactors pfs = LiftedVe::countNormalize (product, excl);
for (size_t i = 0; i < pfs.size(); i++) {
pfs[i]->sumOutIndex (fIdx);
pfList_.add (pfs[i]);
@ -376,7 +375,7 @@ CountingOperator::apply (void)
} else {
Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_);
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
Parfactors pfs = LiftedVe::countNormalize (pf, X_);
for (size_t i = 0; i < pfs.size(); i++) {
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
bool cartProduct = pfs[i]->constr()->isCartesianProduct (
@ -420,7 +419,7 @@ CountingOperator::toString (void)
ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel();
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
Parfactors pfs = LiftedVe::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (size_t i = 0; i < pfs.size(); i++) {
ss << " º " << pfs[i]->getLabel() << endl;
@ -501,7 +500,7 @@ GroundOperator::getLogCost (void)
++ pflIt;
}
// cout << endl;
return totalCost;
return totalCost + 3;
}
@ -610,7 +609,7 @@ GroundOperator::getAffectedFormulas (void)
LogVar X = f.logVars()[front.second];
const ProbFormulas& fs = (*pflIt)->arguments();
for (size_t i = 0; i < fs.size(); i++) {
if ((int)i != idx && fs[i].contains (X)) {
if (i != idx && fs[i].contains (X)) {
pair<PrvGroup, unsigned> pair = make_pair (
fs[i].group(), fs[i].indexOf (X));
if (Util::contains (affectedFormulas, pair) == false) {
@ -630,7 +629,7 @@ GroundOperator::getAffectedFormulas (void)
Params
FoveSolver::solveQuery (const Grounds& query)
LiftedVe::solveQuery (const Grounds& query)
{
assert (query.empty() == false);
runSolver (query);
@ -645,7 +644,7 @@ FoveSolver::solveQuery (const Grounds& query)
void
FoveSolver::printSolverFlags (void) const
LiftedVe::printSolverFlags (void) const
{
stringstream ss;
ss << "fove [" ;
@ -657,7 +656,7 @@ FoveSolver::printSolverFlags (void) const
void
FoveSolver::absorveEvidence (
LiftedVe::absorveEvidence (
ParfactorList& pfList,
ObservedFormulas& obsFormulas)
{
@ -696,7 +695,7 @@ FoveSolver::absorveEvidence (
Parfactors
FoveSolver::countNormalize (
LiftedVe::countNormalize (
Parfactor* g,
const LogVarSet& set)
{
@ -715,7 +714,7 @@ FoveSolver::countNormalize (
Parfactor
FoveSolver::calcGroundMultiplication (Parfactor pf)
LiftedVe::calcGroundMultiplication (Parfactor pf)
{
LogVarSet lvs = pf.constr()->logVarSet();
lvs -= pf.constr()->singletons();
@ -748,7 +747,7 @@ FoveSolver::calcGroundMultiplication (Parfactor pf)
void
FoveSolver::runSolver (const Grounds& query)
LiftedVe::runSolver (const Grounds& query)
{
largestCost_ = std::log (0);
shatterAgainstQuery (query);
@ -794,7 +793,7 @@ FoveSolver::runSolver (const Grounds& query)
LiftedOperator*
FoveSolver::getBestOperation (const Grounds& query)
LiftedVe::getBestOperation (const Grounds& query)
{
double bestCost = 0.0;
LiftedOperator* bestOp = 0;
@ -821,7 +820,7 @@ FoveSolver::getBestOperation (const Grounds& query)
void
FoveSolver::runWeakBayesBall (const Grounds& query)
LiftedVe::runWeakBayesBall (const Grounds& query)
{
queue<PrvGroup> todo; // groups to process
set<PrvGroup> done; // processed or in queue
@ -880,7 +879,7 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
void
FoveSolver::shatterAgainstQuery (const Grounds& query)
LiftedVe::shatterAgainstQuery (const Grounds& query)
{
for (size_t i = 0; i < query.size(); i++) {
if (query[i].isAtom()) {
@ -931,7 +930,7 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
Parfactors
FoveSolver::absorve (
LiftedVe::absorve (
ObservedFormula& obsFormula,
Parfactor* g)
{

View File

@ -1,5 +1,5 @@
#ifndef HORUS_FOVESOLVER_H
#define HORUS_FOVESOLVER_H
#ifndef HORUS_LIFTEDVE_H
#define HORUS_LIFTEDVE_H
#include "ParfactorList.h"
@ -130,10 +130,10 @@ class GroundOperator : public LiftedOperator
class FoveSolver
class LiftedVe
{
public:
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
LiftedVe (const ParfactorList& pfList) : pfList_(pfList) { }
Params solveQuery (const Grounds&);
@ -162,5 +162,5 @@ class FoveSolver
double largestCost_;
};
#endif // HORUS_FOVESOLVER_H
#endif // HORUS_LIFTEDVE_H

View File

@ -23,10 +23,10 @@ CC=@CC@
CXX=@CXX@
# normal
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG
# debug
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra
#
@ -45,98 +45,91 @@ CWD=$(PWD)
HEADERS = \
$(srcdir)/BayesNet.h \
$(srcdir)/BayesBall.h \
$(srcdir)/ElimGraph.h \
$(srcdir)/FactorGraph.h \
$(srcdir)/Factor.h \
$(srcdir)/BayesBallGraph.h \
$(srcdir)/BeliefProp.h \
$(srcdir)/ConstraintTree.h \
$(srcdir)/Solver.h \
$(srcdir)/VarElimSolver.h \
$(srcdir)/BpSolver.h \
$(srcdir)/CbpSolver.h \
$(srcdir)/FoveSolver.h \
$(srcdir)/Var.h \
$(srcdir)/Indexer.h \
$(srcdir)/Parfactor.h \
$(srcdir)/ProbFormula.h \
$(srcdir)/CountingBp.h \
$(srcdir)/ElimGraph.h \
$(srcdir)/Factor.h \
$(srcdir)/FactorGraph.h \
$(srcdir)/Histogram.h \
$(srcdir)/ParfactorList.h \
$(srcdir)/Horus.h \
$(srcdir)/Indexer.h \
$(srcdir)/LiftedBp.h \
$(srcdir)/LiftedUtils.h \
$(srcdir)/LiftedVe.h \
$(srcdir)/Parfactor.h \
$(srcdir)/ParfactorList.h \
$(srcdir)/ProbFormula.h \
$(srcdir)/Solver.h \
$(srcdir)/TinySet.h \
$(srcdir)/LiftedBpSolver.h \
$(srcdir)/WeightedBpSolver.h \
$(srcdir)/Util.h \
$(srcdir)/Horus.h
$(srcdir)/Var.h \
$(srcdir)/VarElim.h \
$(srcdir)/WeightedBp.h
CPP_SOURCES = \
$(srcdir)/BayesNet.cpp \
$(srcdir)/BayesBall.cpp \
$(srcdir)/ElimGraph.cpp \
$(srcdir)/FactorGraph.cpp \
$(srcdir)/Factor.cpp \
$(srcdir)/BayesBallGraph.cpp \
$(srcdir)/BeliefProp.cpp \
$(srcdir)/ConstraintTree.cpp \
$(srcdir)/Var.cpp \
$(srcdir)/Solver.cpp \
$(srcdir)/VarElimSolver.cpp \
$(srcdir)/BpSolver.cpp \
$(srcdir)/CbpSolver.cpp \
$(srcdir)/FoveSolver.cpp \
$(srcdir)/Parfactor.cpp \
$(srcdir)/ProbFormula.cpp \
$(srcdir)/CountingBp.cpp \
$(srcdir)/ElimGraph.cpp \
$(srcdir)/Factor.cpp \
$(srcdir)/FactorGraph.cpp \
$(srcdir)/Histogram.cpp \
$(srcdir)/ParfactorList.cpp \
$(srcdir)/LiftedUtils.cpp \
$(srcdir)/Util.cpp \
$(srcdir)/LiftedBpSolver.cpp \
$(srcdir)/WeightedBpSolver.cpp \
$(srcdir)/HorusCli.cpp \
$(srcdir)/HorusYap.cpp \
$(srcdir)/HorusCli.cpp
$(srcdir)/LiftedBp.cpp \
$(srcdir)/LiftedUtils.cpp \
$(srcdir)/LiftedVe.cpp \
$(srcdir)/Parfactor.cpp \
$(srcdir)/ParfactorList.cpp \
$(srcdir)/ProbFormula.cpp \
$(srcdir)/Solver.cpp \
$(srcdir)/Util.cpp \
$(srcdir)/Var.cpp \
$(srcdir)/VarElim.cpp \
$(srcdir)/WeightedBp.cpp \
OBJS = \
BayesNet.o \
BayesBall.o \
ElimGraph.o \
FactorGraph.o \
Factor.o \
BayesBallGraph.o \
BeliefProp.o \
ConstraintTree.o \
Var.o \
Solver.o \
VarElimSolver.o \
BpSolver.o \
CbpSolver.o \
FoveSolver.o \
Parfactor.o \
ProbFormula.o \
CountingBp.o \
ElimGraph.o \
Factor.o \
FactorGraph.o \
Histogram.o \
ParfactorList.o \
HorusYap.o \
LiftedBp.o \
LiftedUtils.o \
LiftedVe.o \
ProbFormula.o \
Parfactor.o \
ParfactorList.o \
Solver.o \
Util.o \
LiftedBpSolver.o \
WeightedBpSolver.o \
HorusYap.o
Var.o \
VarElim.o \
WeightedBp.o
HCLI_OBJS = \
BayesNet.o \
BayesBall.o \
BayesBallGraph.o \
BeliefProp.o \
CountingBp.o \
ElimGraph.o \
FactorGraph.o \
Factor.o \
ConstraintTree.o \
Var.o \
FactorGraph.o \
HorusCli.o \
Solver.o \
VarElimSolver.o \
BpSolver.o \
CbpSolver.o \
FoveSolver.o \
Parfactor.o \
ProbFormula.o \
Histogram.o \
ParfactorList.o \
WeightedBpSolver.o \
LiftedUtils.o \
Util.o \
HorusCli.o
Var.o \
VarElim.o \
WeightedBp.o
SOBJS=horus.@SO@

View File

@ -402,21 +402,32 @@ Parfactor::applySubstitution (const Substitution& theta)
PrvGroup
Parfactor::findGroup (const Ground& ground) const
size_t
Parfactor::indexOfGround (const Ground& ground) const
{
PrvGroup group = numeric_limits<PrvGroup>::max();
size_t idx = args_.size();
for (size_t i = 0; i < args_.size(); i++) {
if (args_[i].functor() == ground.functor() &&
args_[i].arity() == ground.arity()) {
constr_->moveToTop (args_[i].logVars());
if (constr_->containsTuple (ground.args())) {
group = args_[i].group();
idx = i;
break;
}
}
}
return group;
return idx;
}
PrvGroup
Parfactor::findGroup (const Ground& ground) const
{
size_t idx = indexOfGround (ground);
return idx == args_.size()
? numeric_limits<PrvGroup>::max()
: args_[idx].group();
}
@ -429,6 +440,30 @@ Parfactor::containsGround (const Ground& ground) const
bool
Parfactor::containsGrounds (const Grounds& grounds) const
{
Tuple tuple;
LogVars tupleLvs;
for (size_t i = 0; i < grounds.size(); i++) {
size_t idx = indexOfGround (grounds[i]);
if (idx == args_.size()) {
return false;
}
LogVars lvs = args_[idx].logVars();
for (size_t j = 0; j < lvs.size(); j++) {
if (Util::contains (tupleLvs, lvs[j]) == false) {
tuple.push_back (grounds[i].args()[j]);
tupleLvs.push_back (lvs[j]);
}
}
}
constr_->moveToTop (tupleLvs);
return constr_->containsTuple (tuple);
}
bool
Parfactor::containsGroup (PrvGroup group) const
{
@ -442,6 +477,19 @@ Parfactor::containsGroup (PrvGroup group) const
bool
Parfactor::containsGroups (vector<PrvGroup> groups) const
{
for (size_t i = 0; i < groups.size(); i++) {
if (containsGroup (groups[i]) == false) {
return false;
}
}
return true;
}
unsigned
Parfactor::nrFormulas (LogVar X) const
{

View File

@ -64,11 +64,17 @@ class Parfactor : public TFactor<ProbFormula>
void applySubstitution (const Substitution&);
size_t indexOfGround (const Ground&) const;
PrvGroup findGroup (const Ground&) const;
bool containsGround (const Ground&) const;
bool containsGrounds (const Grounds&) const;
bool containsGroup (PrvGroup) const;
bool containsGroups (vector<PrvGroup>) const;
unsigned nrFormulas (LogVar) const;

View File

@ -91,6 +91,8 @@ class ObservedFormula
unsigned evidence (void) const { return evidence_; }
void setEvidence (unsigned ev) { evidence_ = ev; }
ConstraintTree& constr (void) { return constr_; }
bool isAtom (void) const { return arity_ == 0; }

View File

@ -1,5 +1,8 @@
#include "Solver.h"
#include "Util.h"
#include "BeliefProp.h"
#include "CountingBp.h"
#include "VarElim.h"
void
@ -38,3 +41,67 @@ Solver::printAllPosterioris (void)
}
}
Params
Solver::getJointByConditioning (
GroundSolver solverType,
FactorGraph fg,
const VarIds& jointVarIds) const
{
VarNodes jointVars;
for (size_t i = 0; i < jointVarIds.size(); i++) {
assert (fg.getVarNode (jointVarIds[i]));
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
}
Solver* solver = 0;
switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break;
}
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
VarIds observedVids = {jointVars[0]->varId()};
for (size_t i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs;
Vars observedVars;
Ranges observedRanges;
for (size_t j = 0; j < observedVids.size(); j++) {
observedVars.push_back (fg.getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range());
}
Indexer indexer (observedRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]);
}
delete solver;
switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break;
}
Params beliefs = solver->solveQuery ({jointVarIds[i]});
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->range() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
}
delete solver;
return prevBeliefs;
}

View File

@ -3,8 +3,9 @@
#include <iomanip>
#include "Var.h"
#include "FactorGraph.h"
#include "Var.h"
#include "Horus.h"
using namespace std;
@ -23,6 +24,9 @@ class Solver
void printAnswer (const VarIds& vids);
void printAllPosterioris (void);
Params getJointByConditioning (GroundSolver,
FactorGraph, const VarIds& jointVarIds) const;
protected:
const FactorGraph& fg;

View File

@ -13,9 +13,9 @@ bool logDomain = false;
unsigned verbosity = 0;
LiftedSolvers liftedSolver = LiftedSolvers::FOVE;
LiftedSolver liftedSolver = LiftedSolver::FOVE;
GroundSolvers groundSolver = GroundSolvers::VE;
GroundSolver groundSolver = GroundSolver::VE;
};
@ -211,9 +211,9 @@ setHorusFlag (string key, string value)
ss >> Globals::verbosity;
} else if (key == "lifted_solver") {
if ( value == "fove") {
Globals::liftedSolver = LiftedSolvers::FOVE;
Globals::liftedSolver = LiftedSolver::FOVE;
} else if (value == "lbp") {
Globals::liftedSolver = LiftedSolvers::LBP;
Globals::liftedSolver = LiftedSolver::LBP;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
@ -221,11 +221,11 @@ setHorusFlag (string key, string value)
}
} else if (key == "ground_solver") {
if ( value == "ve") {
Globals::groundSolver = GroundSolvers::VE;
Globals::groundSolver = GroundSolver::VE;
} else if (value == "bp") {
Globals::groundSolver = GroundSolvers::BP;
Globals::groundSolver = GroundSolver::BP;
} else if (value == "cbp") {
Globals::groundSolver = GroundSolvers::CBP;
Globals::groundSolver = GroundSolver::CBP;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;

View File

@ -1,12 +1,12 @@
#include <algorithm>
#include "VarElimSolver.h"
#include "VarElim.h"
#include "ElimGraph.h"
#include "Factor.h"
#include "Util.h"
VarElimSolver::~VarElimSolver (void)
VarElim::~VarElim (void)
{
delete factorList_.back();
}
@ -14,7 +14,7 @@ VarElimSolver::~VarElimSolver (void)
Params
VarElimSolver::solveQuery (VarIds queryVids)
VarElim::solveQuery (VarIds queryVids)
{
if (Globals::verbosity > 1) {
cout << "Solving query on " ;
@ -41,7 +41,7 @@ VarElimSolver::solveQuery (VarIds queryVids)
void
VarElimSolver::printSolverFlags (void) const
VarElim::printSolverFlags (void) const
{
stringstream ss;
ss << "variable elimination [" ;
@ -62,7 +62,7 @@ VarElimSolver::printSolverFlags (void) const
void
VarElimSolver::createFactorList (void)
VarElim::createFactorList (void)
{
const FacNodes& facNodes = fg.facNodes();
factorList_.reserve (facNodes.size() * 2);
@ -84,7 +84,7 @@ VarElimSolver::createFactorList (void)
void
VarElimSolver::absorveEvidence (void)
VarElim::absorveEvidence (void)
{
if (Globals::verbosity > 2) {
Util::printDashedLine();
@ -117,7 +117,7 @@ VarElimSolver::absorveEvidence (void)
void
VarElimSolver::findEliminationOrder (const VarIds& vids)
VarElim::findEliminationOrder (const VarIds& vids)
{
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
}
@ -125,7 +125,7 @@ VarElimSolver::findEliminationOrder (const VarIds& vids)
void
VarElimSolver::processFactorList (const VarIds& vids)
VarElim::processFactorList (const VarIds& vids)
{
totalFactorSize_ = 0;
largestFactorSize_ = 0;
@ -170,7 +170,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
void
VarElimSolver::eliminate (VarId elimVar)
VarElim::eliminate (VarId elimVar)
{
Factor* result = 0;
vector<size_t>& idxs = varFactors_.find (elimVar)->second;
@ -205,7 +205,7 @@ VarElimSolver::eliminate (VarId elimVar)
void
VarElimSolver::printActiveFactors (void)
VarElim::printActiveFactors (void)
{
for (size_t i = 0; i < factorList_.size(); i++) {
if (factorList_[i] != 0) {

View File

@ -1,5 +1,5 @@
#ifndef HORUS_VARELIMSOLVER_H
#define HORUS_VARELIMSOLVER_H
#ifndef HORUS_VARELIM_H
#define HORUS_VARELIM_H
#include "unordered_map"
@ -11,12 +11,12 @@
using namespace std;
class VarElimSolver : public Solver
class VarElim : public Solver
{
public:
VarElimSolver (const FactorGraph& fg) : Solver (fg) { }
VarElim (const FactorGraph& fg) : Solver (fg) { }
~VarElimSolver (void);
~VarElim (void);
Params solveQuery (VarIds);
@ -42,5 +42,5 @@ class VarElimSolver : public Solver
unordered_map<VarId, vector<size_t>> varFactors_;
};
#endif // HORUS_VARELIMSOLVER_H
#endif // HORUS_VARELIM_H

View File

@ -1,7 +1,7 @@
#include "WeightedBpSolver.h"
#include "WeightedBp.h"
WeightedBpSolver::~WeightedBpSolver (void)
WeightedBp::~WeightedBp (void)
{
for (size_t i = 0; i < links_.size(); i++) {
delete links_[i];
@ -12,7 +12,7 @@ WeightedBpSolver::~WeightedBpSolver (void)
Params
WeightedBpSolver::getPosterioriOf (VarId vid)
WeightedBp::getPosterioriOf (VarId vid)
{
if (runned_ == false) {
runSolver();
@ -47,7 +47,7 @@ WeightedBpSolver::getPosterioriOf (VarId vid)
void
WeightedBpSolver::createLinks (void)
WeightedBp::createLinks (void)
{
if (Globals::verbosity > 0) {
cout << "compressed factor graph contains " ;
@ -78,7 +78,7 @@ WeightedBpSolver::createLinks (void)
void
WeightedBpSolver::maxResidualSchedule (void)
WeightedBp::maxResidualSchedule (void)
{
if (nIters_ == 1) {
for (size_t i = 0; i < links_.size(); i++) {
@ -151,7 +151,7 @@ WeightedBpSolver::maxResidualSchedule (void)
void
WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
WeightedBp::calcFactorToVarMsg (BpLink* _link)
{
WeightedLink* link = static_cast<WeightedLink*> (_link);
FacNode* src = link->facNode();
@ -223,7 +223,7 @@ WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
Params
WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const
WeightedBp::getVarToFactorMsg (const BpLink* _link) const
{
const WeightedLink* link = static_cast<const WeightedLink*> (_link);
const VarNode* src = link->varNode();
@ -272,7 +272,7 @@ WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const
void
WeightedBpSolver::printLinkInformation (void) const
WeightedBp::printLinkInformation (void) const
{
for (size_t i = 0; i < links_.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links_[i]);

View File

@ -1,7 +1,7 @@
#ifndef HORUS_WEIGHTEDBPSOLVER_H
#define HORUS_WEIGHTEDBPSOLVER_H
#ifndef HORUS_WEIGHTEDBP_H
#define HORUS_WEIGHTEDBP_H
#include "BpSolver.h"
#include "BeliefProp.h"
class WeightedLink : public BpLink
{
@ -31,14 +31,14 @@ class WeightedLink : public BpLink
class WeightedBpSolver : public BpSolver
class WeightedBp : public BeliefProp
{
public:
WeightedBpSolver (const FactorGraph& fg,
WeightedBp (const FactorGraph& fg,
const vector<vector<unsigned>>& weights)
: BpSolver (fg), weights_(weights) { }
: BeliefProp (fg), weights_(weights) { }
~WeightedBpSolver (void);
~WeightedBp (void);
Params getPosterioriOf (VarId);
@ -57,5 +57,5 @@ class WeightedBpSolver : public BpSolver
vector<vector<unsigned>> weights_;
};
#endif // HORUS_WEIGHTEDBPSOLVER_H
#endif // HORUS_WEIGHTEDBP_H