This commit is contained in:
Vitor Santos Costa 2012-06-22 14:51:58 +01:00
commit 5fe052a3ef
50 changed files with 1278 additions and 617 deletions

110
packages/CLPBN/README.txt Normal file
View File

@ -0,0 +1,110 @@
Prolog Factor Language (PFL)
Prolog Factor Language (PFL) is a extension of the Prolog language that
allows a natural representation of this first-order probabilistic models
(either directed or undirected). PFL is also capable of solving probabilistic
queries on this models through the implementation of several inference
techniques: variable elimination, belief propagation, lifted variable
elimination and lifted belief propagation.
Language
-------------------------------------------------------------------------------
A graphical model in PFL is represented using parfactors. A PFL parfactor
has the following four components:
Type ; Formulas ; Phi ; Constraint .
- Type refers the type of the network over which the parfactor is defined.
It can be bayes for directed networks, or markov for undirected ones.
- Formulas is a sequence of Prolog terms that define sets of random variables
under the constraint.
- Phi is either a list of parameters or a call to a Prolog goal that will
unify its last argument with a list of parameters.
- Constraint is a list (possible empty) of Prolog goals that will impose
bindings on the logical variables that appear in the formulas.
The "examples" directory contains some popular graphical models described
using PFL.
Querying
-------------------------------------------------------------------------------
Now we show how to use PFL to solve probabilistic queries. We will
use the burlgary alarm network as an example. First, we load the model:
$ yap -l examples/burglary-alarm.yap
Now let's suppose that we want to estimate the probability of a earthquake
ocurred given that mary called. We can do it with the following query:
?- earthquake(X), mary_calls(t).
Suppose now that we want the joint distribution for john_calls and
mary_calls. We can obtain this with the following query:
?- john_calls(X), mary_calls(Y).
Inference Options
-------------------------------------------------------------------------------
PFL supports both ground and lifted inference. The inference algorithm
can be chosen using the set_solver/1 predicate. The following algorithms
are supported:
- fove: lifted variable elimination with arbitrary constraints (GC-FOVE)
- hve: (ground) variable elimination
- lbp: lifted first-order belief propagation
- cbp: counting belief propagation
- bp: (ground) belief propagation
For example, if we want to use ground variable elimination to solve some
query, we need to call first the following goal:
?- set_solver(hve).
It is possible to tweak several parameters of PFL through the
set_horus_flag/2 predicate. The first argument is a key that
identifies the parameter that we desire to tweak, while the second
is some possible value for this key.
The verbosity key controls the level of log information that will be
printed by the corresponding solver. Its possible values are positive
integers. The bigger the number, more log information will be printed.
For example, to view some basic log information we need to call the
following goal:
?- set_horus_flag(verbosity, 1).
The use_logarithms key controls whether the calculations performed
during inference should be done in the log domain or not. Its values
can be true or false. By default is false.
There are also keys specific to the inference algorithm. For example,
elim_heuristic key controls the elimination heuristic that will be
used by ground variable elimination. The following heuristics are
supported:
- sequential
- min_neighbors
- min_weight
- min_fill
- weighted_min_fill
An explanation of this heuristics can be found in Probabilistic Graphical
Models by Daphne Koller.
The schedule, accuracy and max_iter keys are specific for inference
algorithms based on message passing, namely lbp, cbp and bp.
The key schedule can be used to specify the order in which the messages
are sent in belief propagation. The possible values are:
- seq_fixed: at each iteration, all messages are sent in the same order
- seq_random: at each iteration, the messages are sent with a random order
- parallel: at each iteration, the messages are all calculated using the
values of the previous iteration.
- max_residual: the next message to be sent is the one with maximum residual,
(Residual Belief Propagation:Informed Scheduling for Asynchronous Message
Passing)
The max_iter key sets the maximum number of iterations. One iteration
consists in sending all possible messages. The accuracy key indicate
when we should stop sending messages. If the largest difference between
a message sent in the current iteration and one message sent in the previous
iteration is less that accuracy value given, we terminate belief propagation.

View File

@ -26,6 +26,8 @@ function run_solver
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\) solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
elif [ $SOLVER = cbp ]; then elif [ $SOLVER = cbp ]; then
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\) solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
elif [ $SOLVER = lbp ]; then
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
else else
echo "unknow flag $2" echo "unknow flag $2"
fi fi

View File

@ -23,7 +23,7 @@ function run_all_graphs
run_solver city60000 $2 run_solver city60000 $2
run_solver city65000 $2 run_solver city65000 $2
run_solver city70000 $2 run_solver city70000 $2
return
run_solver city75000 $2 run_solver city75000 $2
run_solver city80000 $2 run_solver city80000 $2
run_solver city85000 $2 run_solver city85000 $2

View File

@ -0,0 +1,36 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="lbp"
function run_all_graphs
{
write_header $1
run_solver city1000 $2
run_solver city5000 $2
run_solver city10000 $2
run_solver city15000 $2
run_solver city20000 $2
run_solver city25000 $2
run_solver city30000 $2
run_solver city35000 $2
run_solver city40000 $2
run_solver city45000 $2
run_solver city50000 $2
run_solver city55000 $2
run_solver city60000 $2
run_solver city65000 $2
run_solver city70000 $2
run_solver city75000 $2
run_solver city80000 $2
run_solver city85000 $2
run_solver city90000 $2
run_solver city95000 $2
run_solver city100000 $2
}
prepare_new_run
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,30 @@
#!/bin/bash
source cw.sh
source ../benchs.sh
SOLVER="lbp"
function run_all_graphs
{
write_header $1
run_solver p1000w$N_WORKSHOPS $2
run_solver p5000w$N_WORKSHOPS $2
run_solver p10000w$N_WORKSHOPS $2
run_solver p15000w$N_WORKSHOPS $2
run_solver p20000w$N_WORKSHOPS $2
run_solver p25000w$N_WORKSHOPS $2
run_solver p30000w$N_WORKSHOPS $2
run_solver p35000w$N_WORKSHOPS $2
run_solver p40000w$N_WORKSHOPS $2
run_solver p45000w$N_WORKSHOPS $2
run_solver p50000w$N_WORKSHOPS $2
run_solver p55000w$N_WORKSHOPS $2
run_solver p60000w$N_WORKSHOPS $2
run_solver p65000w$N_WORKSHOPS $2
run_solver p70000w$N_WORKSHOPS $2
}
prepare_new_run
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,35 @@
#!/bin/bash
cd workshop_attrs
source hve_tests.sh
source bp_tests.sh
source fove_tests.sh
source lbp_tests.sh
source cbp_tests.sh
cd ..
cd comp_workshops
source hve_tests.sh
source bp_tests.sh
source fove_tests.sh
source lbp_tests.sh
source cbp_tests.sh
cd ..
cd city
source hve_tests.sh
source bp_tests.sh
source fove_tests.sh
source lbp_tests.sh
source cbp_tests.sh
cd ..
cd smokers
source hve_tests.sh
source bp_tests.sh
source fove_tests.sh
source lbp_tests.sh
source cbp_tests.sh
cd ..

View File

@ -0,0 +1,30 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="lbp"
function run_all_graphs
{
write_header $1
run_solver pop100 $2
run_solver pop200 $2
run_solver pop300 $2
run_solver pop400 $2
run_solver pop500 $2
run_solver pop600 $2
run_solver pop700 $2
run_solver pop800 $2
run_solver pop900 $2
run_solver pop1000 $2
run_solver pop1100 $2
run_solver pop1200 $2
run_solver pop1300 $2
run_solver pop1400 $2
run_solver pop1500 $2
}
prepare_new_run
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed

View File

@ -2,5 +2,6 @@
NETWORK="'../../examples/social_domain2'" NETWORK="'../../examples/social_domain2'"
SHORTNAME="sm" SHORTNAME="sm"
QUERY="smokes(p1,t), smokes(p2,t), friends(p1,p2,X)" #QUERY="smokes(p1,t), smokes(p2,t), friends(p1,p2,X)"
QUERY="friends(p1,p2,X)"

View File

@ -0,0 +1,34 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="bp"
function run_all_graphs
{
write_header $1
run_solver ev0p$POP $2
run_solver ev5p$POP $2
run_solver ev10p$POP $2
run_solver ev15p$POP $2
run_solver ev20p$POP $2
run_solver ev25p$POP $2
run_solver ev30p$POP $2
run_solver ev35p$POP $2
run_solver ev40p$POP $2
run_solver ev45p$POP $2
run_solver ev50p$POP $2
run_solver ev55p$POP $2
run_solver ev60p$POP $2
run_solver ev65p$POP $2
run_solver ev70p$POP $2
run_solver ev75p$POP $2
run_solver ev80p$POP $2
run_solver ev85p$POP $2
run_solver ev90p$POP $2
}
prepare_new_run
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,34 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="cbp"
function run_all_graphs
{
write_header $1
run_solver ev0p$POP $2
run_solver ev5p$POP $2
run_solver ev10p$POP $2
run_solver ev15p$POP $2
run_solver ev20p$POP $2
run_solver ev25p$POP $2
run_solver ev30p$POP $2
run_solver ev35p$POP $2
run_solver ev40p$POP $2
run_solver ev45p$POP $2
run_solver ev50p$POP $2
run_solver ev55p$POP $2
run_solver ev60p$POP $2
run_solver ev65p$POP $2
run_solver ev70p$POP $2
run_solver ev75p$POP $2
run_solver ev80p$POP $2
run_solver ev85p$POP $2
run_solver ev90p$POP $2
}
prepare_new_run
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,35 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="fove"
function run_all_graphs
{
write_header $1
run_solver ev0p$POP $2
run_solver ev5p$POP $2
run_solver ev10p$POP $2
run_solver ev15p$POP $2
run_solver ev20p$POP $2
run_solver ev25p$POP $2
run_solver ev30p$POP $2
run_solver ev35p$POP $2
run_solver ev40p$POP $2
run_solver ev45p$POP $2
run_solver ev50p$POP $2
run_solver ev55p$POP $2
run_solver ev60p$POP $2
run_solver ev65p$POP $2
run_solver ev70p$POP $2
run_solver ev75p$POP $2
run_solver ev80p$POP $2
run_solver ev85p$POP $2
run_solver ev90p$POP $2
}
prepare_new_run
run_all_graphs "fove "

View File

@ -0,0 +1,49 @@
#!/home/tgomes/bin/yap -L --
:- use_module(library(lists)).
:- use_module(library(random)).
:- initialization(main).
main :-
unix(argv(Args)),
nth(1, Args, EV), % percentage of evidence
nth(2, Args, NP), % number of individuals
atomic_concat(['ev', EV, 'p', NP, '.yap'], FileName),
open(FileName, 'write', S),
atom_number(EV, EV2),
atom_number(NP, NP2),
EV3 is EV2 / 100.0,
generate_people(S, NP2, 4),
write(S, '\n'),
write(S, 'query(X) :- '),
generate_evidence(S, NP2, EV3, 4),
write(S, 'friends(p1,p2,X).\n'),
close(S).
generate_people(S, N, Counting) :-
Counting > N, !.
generate_people(S, N, Counting) :-
format(S, 'people(p~w).~n', [Counting]),
Counting1 is Counting + 1,
generate_people(S, N, Counting1).
generate_evidence(S, N, Ev, Counting) :-
Counting > N, !.
generate_evidence(S, N, Ev, Counting) :-
random(X),
(
X < Ev
->
random(Y),
(Y > 0.5 -> Val = t ; Val = f),
format(S, 'smokes(p~w,~w),', [Counting,Val])
;
true
),
Counting1 is Counting + 1,
generate_evidence(S, N, Ev, Counting1).

View File

@ -0,0 +1,37 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="hve"
function run_all_graphs
{
write_header $1
run_solver ev0p$POP $2
run_solver ev5p$POP $2
run_solver ev10p$POP $2
run_solver ev15p$POP $2
run_solver ev20p$POP $2
run_solver ev25p$POP $2
run_solver ev30p$POP $2
run_solver ev35p$POP $2
run_solver ev40p$POP $2
run_solver ev45p$POP $2
run_solver ev50p$POP $2
run_solver ev55p$POP $2
run_solver ev60p$POP $2
run_solver ev65p$POP $2
run_solver ev70p$POP $2
run_solver ev75p$POP $2
run_solver ev80p$POP $2
run_solver ev85p$POP $2
run_solver ev90p$POP $2
}
prepare_new_run
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
#run_all_graphs "hve(elim_heuristic=min_weight) " min_weight
#run_all_graphs "hve(elim_heuristic=min_fill) " min_fill
#run_all_graphs "hve(elim_heuristic=weighted_min_fill) " weighted_min_fill

View File

@ -0,0 +1,34 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="lbp"
function run_all_graphs
{
write_header $1
run_solver ev0p$POP $2
run_solver ev5p$POP $2
run_solver ev10p$POP $2
run_solver ev15p$POP $2
run_solver ev20p$POP $2
run_solver ev25p$POP $2
run_solver ev30p$POP $2
run_solver ev35p$POP $2
run_solver ev40p$POP $2
run_solver ev45p$POP $2
run_solver ev50p$POP $2
run_solver ev55p$POP $2
run_solver ev60p$POP $2
run_solver ev65p$POP $2
run_solver ev70p$POP $2
run_solver ev75p$POP $2
run_solver ev80p$POP $2
run_solver ev85p$POP $2
run_solver ev90p$POP $2
}
prepare_new_run
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,8 @@
#!/bin/bash
NETWORK="'../../examples/social_domain2'"
SHORTNAME="sm"
QUERY="query(X)"
POP=500

View File

@ -15,8 +15,8 @@ function run_all_graphs
run_solver p20000attrs$N_ATTRS $2 run_solver p20000attrs$N_ATTRS $2
run_solver p25000attrs$N_ATTRS $2 run_solver p25000attrs$N_ATTRS $2
run_solver p30000attrs$N_ATTRS $2 run_solver p30000attrs$N_ATTRS $2
run_solver p35000attrs$N_ATTRS $2
return return
run_solver p35000attrs$N_ATTRS $2
run_solver p40000attrs$N_ATTRS $2 run_solver p40000attrs$N_ATTRS $2
run_solver p45000attrs$N_ATTRS $2 run_solver p45000attrs$N_ATTRS $2
run_solver p50000attrs$N_ATTRS $2 run_solver p50000attrs$N_ATTRS $2

View File

@ -0,0 +1,36 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="lbp"
function run_all_graphs
{
write_header $1
run_solver p1000attrs$N_ATTRS $2
run_solver p5000attrs$N_ATTRS $2
run_solver p10000attrs$N_ATTRS $2
run_solver p15000attrs$N_ATTRS $2
run_solver p20000attrs$N_ATTRS $2
run_solver p25000attrs$N_ATTRS $2
run_solver p30000attrs$N_ATTRS $2
run_solver p35000attrs$N_ATTRS $2
run_solver p40000attrs$N_ATTRS $2
run_solver p45000attrs$N_ATTRS $2
run_solver p50000attrs$N_ATTRS $2
run_solver p55000attrs$N_ATTRS $2
run_solver p60000attrs$N_ATTRS $2
run_solver p65000attrs$N_ATTRS $2
run_solver p70000attrs$N_ATTRS $2
run_solver p75000attrs$N_ATTRS $2
run_solver p80000attrs$N_ATTRS $2
run_solver p85000attrs$N_ATTRS $2
run_solver p90000attrs$N_ATTRS $2
run_solver p95000attrs$N_ATTRS $2
run_solver p100000attrs$N_ATTRS $2
}
prepare_new_run
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed

View File

@ -15,7 +15,7 @@
cpp_run_ground_solver/3, cpp_run_ground_solver/3,
cpp_set_vars_information/2, cpp_set_vars_information/2,
cpp_set_horus_flag/2, cpp_set_horus_flag/2,
cpp_free_parfactors/1, cpp_free_lifted_network/1,
cpp_free_ground_network/1 cpp_free_ground_network/1
]). ]).

View File

@ -79,7 +79,6 @@ run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
list_of_keys_to_ids(QueryKeys, Hash, QueryIds), list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
%writeln(queryKeys:QueryKeys), writeln(''), %writeln(queryKeys:QueryKeys), writeln(''),
%writeln(queryIds:QueryIds), writeln(''), %writeln(queryIds:QueryIds), writeln(''),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
cpp_run_ground_solver(Network, [QueryIds], Solutions). cpp_run_ground_solver(Network, [QueryIds], Solutions).

View File

@ -17,7 +17,7 @@
[cpp_create_lifted_network/3, [cpp_create_lifted_network/3,
cpp_set_parfactors_params/2, cpp_set_parfactors_params/2,
cpp_run_lifted_solver/3, cpp_run_lifted_solver/3,
cpp_free_parfactors/1 cpp_free_lifted_network/1
]). ]).
:- use_module(library('clpbn/display'), :- use_module(library('clpbn/display'),
@ -144,5 +144,5 @@ run_horus_lifted_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds))
finalize_horus_lifted_solver(fove(ParfactorList, _)) :- finalize_horus_lifted_solver(fove(ParfactorList, _)) :-
cpp_free_parfactors(ParfactorList). cpp_free_lifted_network(ParfactorList).

View File

@ -7,17 +7,17 @@
:- yap_flag(write_strings, off). :- yap_flag(write_strings, off).
bayes burglary::[b1,b2] ; [0.001, 0.999] ; []. bayes burglary::[t,f] ; [0.001, 0.999] ; [].
bayes earthquake::[e1,e2] ; [0.002, 0.998]; []. bayes earthquake::[t,f] ; [0.002, 0.998]; [].
bayes alarm::[a1,a2], burglary, earthquake ; bayes alarm::[t,f], burglary, earthquake ;
[0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ; [0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ;
[]. [].
bayes john_calls::[j1,j2], alarm ; [0.9, 0.05, 0.1, 0.95] ; []. bayes john_calls::[t,f], alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
bayes mary_calls::[m1,m2], alarm ; [0.7, 0.01, 0.3, 0.99] ; []. bayes mary_calls::[t,f], alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
% ?- john_calls(J), mary_calls(m1). % ?- john_calls(J), mary_calls(t).

View File

@ -16,13 +16,13 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
Scheduling scheduling; Scheduling scheduling;
for (size_t i = 0; i < queryIds.size(); i++) { for (size_t i = 0; i < queryIds.size(); i++) {
assert (dag_.getNode (queryIds[i])); assert (dag_.getNode (queryIds[i]));
DAGraphNode* n = dag_.getNode (queryIds[i]); BBNode* n = dag_.getNode (queryIds[i]);
scheduling.push (ScheduleInfo (n, false, true)); scheduling.push (ScheduleInfo (n, false, true));
} }
while (!scheduling.empty()) { while (!scheduling.empty()) {
ScheduleInfo& sch = scheduling.front(); ScheduleInfo& sch = scheduling.front();
DAGraphNode* n = sch.node; BBNode* n = sch.node;
n->setAsVisited(); n->setAsVisited();
if (n->hasEvidence() == false && sch.visitedFromChild) { if (n->hasEvidence() == false && sch.visitedFromChild) {
if (n->isMarkedOnTop() == false) { if (n->isMarkedOnTop() == false) {
@ -59,7 +59,7 @@ BayesBall::constructGraph (FactorGraph* fg) const
{ {
const FacNodes& facNodes = fg_.facNodes(); const FacNodes& facNodes = fg_.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
const DAGraphNode* n = dag_.getNode ( const BBNode* n = dag_.getNode (
facNodes[i]->factor().argument (0)); facNodes[i]->factor().argument (0));
if (n->isMarkedOnTop()) { if (n->isMarkedOnTop()) {
fg->addFactor (facNodes[i]->factor()); fg->addFactor (facNodes[i]->factor());

View File

@ -7,7 +7,7 @@
#include <map> #include <map>
#include "FactorGraph.h" #include "FactorGraph.h"
#include "BayesNet.h" #include "BayesBallGraph.h"
#include "Horus.h" #include "Horus.h"
using namespace std; using namespace std;
@ -15,10 +15,10 @@ using namespace std;
struct ScheduleInfo struct ScheduleInfo
{ {
ScheduleInfo (DAGraphNode* n, bool vfp, bool vfc) : ScheduleInfo (BBNode* n, bool vfp, bool vfc) :
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { } node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
DAGraphNode* node; BBNode* node;
bool visitedFromParent; bool visitedFromParent;
bool visitedFromChild; bool visitedFromChild;
}; };
@ -48,22 +48,22 @@ class BayesBall
void constructGraph (FactorGraph* fg) const; void constructGraph (FactorGraph* fg) const;
void scheduleParents (const DAGraphNode* n, Scheduling& sch) const; void scheduleParents (const BBNode* n, Scheduling& sch) const;
void scheduleChilds (const DAGraphNode* n, Scheduling& sch) const; void scheduleChilds (const BBNode* n, Scheduling& sch) const;
FactorGraph& fg_; FactorGraph& fg_;
DAGraph& dag_; BayesBallGraph& dag_;
}; };
inline void inline void
BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const
{ {
const vector<DAGraphNode*>& ps = n->parents(); const vector<BBNode*>& ps = n->parents();
for (vector<DAGraphNode*>::const_iterator it = ps.begin(); for (vector<BBNode*>::const_iterator it = ps.begin();
it != ps.end(); ++it) { it != ps.end(); ++it) {
sch.push (ScheduleInfo (*it, false, true)); sch.push (ScheduleInfo (*it, false, true));
} }
@ -72,10 +72,10 @@ BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
inline void inline void
BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const BayesBall::scheduleChilds (const BBNode* n, Scheduling& sch) const
{ {
const vector<DAGraphNode*>& cs = n->childs(); const vector<BBNode*>& cs = n->childs();
for (vector<DAGraphNode*>::const_iterator it = cs.begin(); for (vector<BBNode*>::const_iterator it = cs.begin();
it != cs.end(); ++it) { it != cs.end(); ++it) {
sch.push (ScheduleInfo (*it, true, false)); sch.push (ScheduleInfo (*it, true, false));
} }

View File

@ -5,12 +5,12 @@
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include "BayesNet.h" #include "BayesBallGraph.h"
#include "Util.h" #include "Util.h"
void void
DAGraph::addNode (DAGraphNode* n) BayesBallGraph::addNode (BBNode* n)
{ {
assert (Util::contains (varMap_, n->varId()) == false); assert (Util::contains (varMap_, n->varId()) == false);
nodes_.push_back (n); nodes_.push_back (n);
@ -20,10 +20,10 @@ DAGraph::addNode (DAGraphNode* n)
void void
DAGraph::addEdge (VarId vid1, VarId vid2) BayesBallGraph::addEdge (VarId vid1, VarId vid2)
{ {
unordered_map<VarId, DAGraphNode*>::iterator it1; unordered_map<VarId, BBNode*>::iterator it1;
unordered_map<VarId, DAGraphNode*>::iterator it2; unordered_map<VarId, BBNode*>::iterator it2;
it1 = varMap_.find (vid1); it1 = varMap_.find (vid1);
it2 = varMap_.find (vid2); it2 = varMap_.find (vid2);
assert (it1 != varMap_.end()); assert (it1 != varMap_.end());
@ -34,20 +34,20 @@ DAGraph::addEdge (VarId vid1, VarId vid2)
const DAGraphNode* const BBNode*
DAGraph::getNode (VarId vid) const BayesBallGraph::getNode (VarId vid) const
{ {
unordered_map<VarId, DAGraphNode*>::const_iterator it; unordered_map<VarId, BBNode*>::const_iterator it;
it = varMap_.find (vid); it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0; return it != varMap_.end() ? it->second : 0;
} }
DAGraphNode* BBNode*
DAGraph::getNode (VarId vid) BayesBallGraph::getNode (VarId vid)
{ {
unordered_map<VarId, DAGraphNode*>::const_iterator it; unordered_map<VarId, BBNode*>::const_iterator it;
it = varMap_.find (vid); it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0; return it != varMap_.end() ? it->second : 0;
} }
@ -55,7 +55,7 @@ DAGraph::getNode (VarId vid)
void void
DAGraph::setIndexes (void) BayesBallGraph::setIndexes (void)
{ {
for (size_t i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i); nodes_[i]->setIndex (i);
@ -65,7 +65,7 @@ DAGraph::setIndexes (void)
void void
DAGraph::clear (void) BayesBallGraph::clear (void)
{ {
for (size_t i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
nodes_[i]->clear(); nodes_[i]->clear();
@ -75,12 +75,12 @@ DAGraph::clear (void)
void void
DAGraph::exportToGraphViz (const char* fileName) BayesBallGraph::exportToGraphViz (const char* fileName)
{ {
ofstream out (fileName); ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "error: cannot open file to write at " ; cerr << "error: cannot open file to write at " ;
cerr << "DAGraph::exportToDotFile()" << endl; cerr << "BayesBallGraph::exportToDotFile()" << endl;
abort(); abort();
} }
out << "digraph {" << endl; out << "digraph {" << endl;
@ -95,7 +95,7 @@ DAGraph::exportToGraphViz (const char* fileName)
out << "]" << endl; out << "]" << endl;
} }
for (size_t i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
const vector<DAGraphNode*>& childs = nodes_[i]->childs(); const vector<BBNode*>& childs = nodes_[i]->childs();
for (size_t j = 0; j < childs.size(); j++) { for (size_t j = 0; j < childs.size(); j++) {
out << nodes_[i]->varId() << " -> " << childs[j]->varId(); out << nodes_[i]->varId() << " -> " << childs[j]->varId();
out << " [style=bold]" << endl ; out << " [style=bold]" << endl ;

View File

@ -1,5 +1,5 @@
#ifndef HORUS_BAYESNET_H #ifndef HORUS_BAYESBALLGRAPH_H
#define HORUS_BAYESNET_H #define HORUS_BAYESBALLGRAPH_H
#include <vector> #include <vector>
#include <queue> #include <queue>
@ -9,29 +9,25 @@
#include "Var.h" #include "Var.h"
#include "Horus.h" #include "Horus.h"
using namespace std; using namespace std;
class BBNode : public Var
class Var;
class DAGraphNode : public Var
{ {
public: public:
DAGraphNode (Var* v) : Var (v) , visited_(false), BBNode (Var* v) : Var (v) , visited_(false),
markedOnTop_(false), markedOnBottom_(false) { } markedOnTop_(false), markedOnBottom_(false) { }
const vector<DAGraphNode*>& childs (void) const { return childs_; } const vector<BBNode*>& childs (void) const { return childs_; }
vector<DAGraphNode*>& childs (void) { return childs_; } vector<BBNode*>& childs (void) { return childs_; }
const vector<DAGraphNode*>& parents (void) const { return parents_; } const vector<BBNode*>& parents (void) const { return parents_; }
vector<DAGraphNode*>& parents (void) { return parents_; } vector<BBNode*>& parents (void) { return parents_; }
void addParent (DAGraphNode* p) { parents_.push_back (p); } void addParent (BBNode* p) { parents_.push_back (p); }
void addChild (DAGraphNode* c) { childs_.push_back (c); } void addChild (BBNode* c) { childs_.push_back (c); }
bool isVisited (void) const { return visited_; } bool isVisited (void) const { return visited_; }
@ -52,23 +48,23 @@ class DAGraphNode : public Var
bool markedOnTop_; bool markedOnTop_;
bool markedOnBottom_; bool markedOnBottom_;
vector<DAGraphNode*> childs_; vector<BBNode*> childs_;
vector<DAGraphNode*> parents_; vector<BBNode*> parents_;
}; };
class DAGraph class BayesBallGraph
{ {
public: public:
DAGraph (void) { } BayesBallGraph (void) { }
void addNode (DAGraphNode* n); void addNode (BBNode* n);
void addEdge (VarId vid1, VarId vid2); void addEdge (VarId vid1, VarId vid2);
const DAGraphNode* getNode (VarId vid) const; const BBNode* getNode (VarId vid) const;
DAGraphNode* getNode (VarId vid); BBNode* getNode (VarId vid);
bool empty (void) const { return nodes_.empty(); } bool empty (void) const { return nodes_.empty(); }
@ -79,10 +75,10 @@ class DAGraph
void exportToGraphViz (const char*); void exportToGraphViz (const char*);
private: private:
vector<DAGraphNode*> nodes_; vector<BBNode*> nodes_;
unordered_map<VarId, DAGraphNode*> varMap_; unordered_map<VarId, BBNode*> varMap_;
}; };
#endif // HORUS_BAYESNET_H #endif // HORUS_BAYESBALLGRAPH_H

View File

@ -5,21 +5,21 @@
#include <iostream> #include <iostream>
#include "BpSolver.h" #include "BeliefProp.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Factor.h" #include "Factor.h"
#include "Indexer.h" #include "Indexer.h"
#include "Horus.h" #include "Horus.h"
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg) BeliefProp::BeliefProp (const FactorGraph& fg) : Solver (fg)
{ {
runned_ = false; runned_ = false;
} }
BpSolver::~BpSolver (void) BeliefProp::~BeliefProp (void)
{ {
for (size_t i = 0; i < varsI_.size(); i++) { for (size_t i = 0; i < varsI_.size(); i++) {
delete varsI_[i]; delete varsI_[i];
@ -35,7 +35,7 @@ BpSolver::~BpSolver (void)
Params Params
BpSolver::solveQuery (VarIds queryVids) BeliefProp::solveQuery (VarIds queryVids)
{ {
assert (queryVids.empty() == false); assert (queryVids.empty() == false);
return queryVids.size() == 1 return queryVids.size() == 1
@ -46,7 +46,7 @@ BpSolver::solveQuery (VarIds queryVids)
void void
BpSolver::printSolverFlags (void) const BeliefProp::printSolverFlags (void) const
{ {
stringstream ss; stringstream ss;
ss << "belief propagation [" ; ss << "belief propagation [" ;
@ -68,7 +68,7 @@ BpSolver::printSolverFlags (void) const
Params Params
BpSolver::getPosterioriOf (VarId vid) BeliefProp::getPosterioriOf (VarId vid)
{ {
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
@ -101,7 +101,7 @@ BpSolver::getPosterioriOf (VarId vid)
Params Params
BpSolver::getJointDistributionOf (const VarIds& jointVarIds) BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
{ {
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
@ -117,9 +117,23 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
} }
if (idx == facNodes.size()) { if (idx == facNodes.size()) {
return getJointByConditioning (jointVarIds); return getJointByConditioning (jointVarIds);
} else { }
Factor res (facNodes[idx]->factor()); return getFactorJoint (idx, jointVarIds);
const BpLinks& links = ninf(facNodes[idx])->getLinks(); }
Params
BeliefProp::getFactorJoint (
size_t fnIdx,
const VarIds& jointVarIds)
{
if (runned_ == false) {
runSolver();
}
FacNode* fn = fg.facNodes()[fnIdx];
Factor res (fn->factor());
const BpLinks& links = ninf(fn)->getLinks();
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
Factor msg ({links[i]->varNode()->varId()}, Factor msg ({links[i]->varNode()->varId()},
{links[i]->varNode()->range()}, {links[i]->varNode()->range()},
@ -135,12 +149,11 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
} }
return jointDist; return jointDist;
} }
}
void void
BpSolver::runSolver (void) BeliefProp::runSolver (void)
{ {
initializeSolver(); initializeSolver();
nIters_ = 0; nIters_ = 0;
@ -173,7 +186,7 @@ BpSolver::runSolver (void)
} }
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
if (nIters_ < BpOptions::maxIter) { if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ; cout << "Belief propagation converged in " ;
cout << nIters_ << " iterations" << endl; cout << nIters_ << " iterations" << endl;
} else { } else {
cout << "The maximum number of iterations was hit, terminating..." ; cout << "The maximum number of iterations was hit, terminating..." ;
@ -187,7 +200,7 @@ BpSolver::runSolver (void)
void void
BpSolver::createLinks (void) BeliefProp::createLinks (void)
{ {
const FacNodes& facNodes = fg.facNodes(); const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
@ -201,7 +214,7 @@ BpSolver::createLinks (void)
void void
BpSolver::maxResidualSchedule (void) BeliefProp::maxResidualSchedule (void)
{ {
if (nIters_ == 1) { if (nIters_ == 1) {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
@ -256,7 +269,7 @@ BpSolver::maxResidualSchedule (void)
void void
BpSolver::calcFactorToVarMsg (BpLink* link) BeliefProp::calcFactorToVarMsg (BpLink* link)
{ {
FacNode* src = link->facNode(); FacNode* src = link->facNode();
const VarNode* dst = link->varNode(); const VarNode* dst = link->varNode();
@ -320,7 +333,7 @@ BpSolver::calcFactorToVarMsg (BpLink* link)
Params Params
BpSolver::getVarToFactorMsg (const BpLink* link) const BeliefProp::getVarToFactorMsg (const BpLink* link) const
{ {
const VarNode* src = link->varNode(); const VarNode* src = link->varNode();
Params msg; Params msg;
@ -361,61 +374,15 @@ BpSolver::getVarToFactorMsg (const BpLink* link) const
Params Params
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
{ {
VarNodes jointVars; return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds);
for (size_t i = 0; i < jointVarIds.size(); i++) {
assert (fg.getVarNode (jointVarIds[i]));
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
}
FactorGraph* tempFg = new FactorGraph (fg);
BpSolver solver (*tempFg);
solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
VarIds observedVids = {jointVars[0]->varId()};
for (size_t i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs;
Vars observedVars;
Ranges observedRanges;
for (size_t j = 0; j < observedVids.size(); j++) {
observedVars.push_back (tempFg->getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range());
}
Indexer indexer (observedRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]);
}
BpSolver solver (*tempFg);
solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->range() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
}
return prevBeliefs;
} }
void void
BpSolver::initializeSolver (void) BeliefProp::initializeSolver (void)
{ {
const VarNodes& varNodes = fg.varNodes(); const VarNodes& varNodes = fg.varNodes();
varsI_.reserve (varNodes.size()); varsI_.reserve (varNodes.size());
@ -439,7 +406,7 @@ BpSolver::initializeSolver (void)
bool bool
BpSolver::converged (void) BeliefProp::converged (void)
{ {
if (links_.size() == 0) { if (links_.size() == 0) {
return true; return true;
@ -487,7 +454,7 @@ BpSolver::converged (void)
void void
BpSolver::printLinkInformation (void) const BeliefProp::printLinkInformation (void) const
{ {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
BpLink* l = links_[i]; BpLink* l = links_[i];

View File

@ -1,5 +1,5 @@
#ifndef HORUS_BPSOLVER_H #ifndef HORUS_BELIEFPROP_H
#define HORUS_BPSOLVER_H #define HORUS_BELIEFPROP_H
#include <set> #include <set>
#include <vector> #include <vector>
@ -83,12 +83,12 @@ class SPNodeInfo
}; };
class BpSolver : public Solver class BeliefProp : public Solver
{ {
public: public:
BpSolver (const FactorGraph&); BeliefProp (const FactorGraph&);
virtual ~BpSolver (void); virtual ~BeliefProp (void);
Params solveQuery (VarIds); Params solveQuery (VarIds);
@ -111,6 +111,10 @@ class BpSolver : public Solver
virtual Params getJointByConditioning (const VarIds&) const; virtual Params getJointByConditioning (const VarIds&) const;
public:
Params getFactorJoint (size_t fnIdx, const VarIds&);
protected:
SPNodeInfo* ninf (const VarNode* var) const SPNodeInfo* ninf (const VarNode* var) const
{ {
return varsI_[var->getIndex()]; return varsI_[var->getIndex()];
@ -180,5 +184,5 @@ class BpSolver : public Solver
virtual void printLinkInformation (void) const; virtual void printLinkInformation (void) const;
}; };
#endif // HORUS_BPSOLVER_H #endif // HORUS_BELIEFPROP_H

View File

@ -1,23 +1,23 @@
#include "CbpSolver.h" #include "CountingBp.h"
#include "WeightedBpSolver.h" #include "WeightedBp.h"
bool CbpSolver::checkForIdenticalFactors = true; bool CountingBp::checkForIdenticalFactors = true;
CbpSolver::CbpSolver (const FactorGraph& fg) CountingBp::CountingBp (const FactorGraph& fg)
: Solver (fg), freeColor_(0) : Solver (fg), freeColor_(0)
{ {
findIdenticalFactors(); findIdenticalFactors();
setInitialColors(); setInitialColors();
createGroups(); createGroups();
compressedFg_ = getCompressedFactorGraph(); compressedFg_ = getCompressedFactorGraph();
solver_ = new WeightedBpSolver (*compressedFg_, getWeights()); solver_ = new WeightedBp (*compressedFg_, getWeights());
} }
CbpSolver::~CbpSolver (void) CountingBp::~CountingBp (void)
{ {
delete solver_; delete solver_;
delete compressedFg_; delete compressedFg_;
@ -32,7 +32,7 @@ CbpSolver::~CbpSolver (void)
void void
CbpSolver::printSolverFlags (void) const CountingBp::printSolverFlags (void) const
{ {
stringstream ss; stringstream ss;
ss << "counting bp [" ; ss << "counting bp [" ;
@ -48,7 +48,7 @@ CbpSolver::printSolverFlags (void) const
ss << ",accuracy=" << BpOptions::accuracy; ss << ",accuracy=" << BpOptions::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",chkif=" << ss << ",chkif=" <<
Util::toString (CbpSolver::checkForIdenticalFactors); Util::toString (CountingBp::checkForIdenticalFactors);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; cout << ss.str() << endl;
} }
@ -56,7 +56,7 @@ CbpSolver::printSolverFlags (void) const
Params Params
CbpSolver::solveQuery (VarIds queryVids) CountingBp::solveQuery (VarIds queryVids)
{ {
assert (queryVids.empty() == false); assert (queryVids.empty() == false);
Params res; Params res;
@ -74,16 +74,15 @@ CbpSolver::solveQuery (VarIds queryVids)
cout << endl; cout << endl;
} }
if (idx == facNodes.size()) { if (idx == facNodes.size()) {
cerr << "error: only joint distributions on variables of some " ; res = Solver::getJointByConditioning (
cerr << "clique are supported with the current solver" ; GroundSolver::CBP, fg, queryVids);
cerr << endl; } else {
exit (1); VarIds reprArgs;
}
VarIds representatives;
for (size_t i = 0; i < queryVids.size(); i++) { for (size_t i = 0; i < queryVids.size(); i++) {
representatives.push_back (getRepresentative (queryVids[i])); reprArgs.push_back (getRepresentative (queryVids[i]));
}
res = solver_->getFactorJoint (idx, reprArgs);
} }
res = solver_->getJointDistributionOf (representatives);
} }
return res; return res;
} }
@ -91,7 +90,7 @@ CbpSolver::solveQuery (VarIds queryVids)
void void
CbpSolver::findIdenticalFactors() CountingBp::findIdenticalFactors()
{ {
const FacNodes& facNodes = fg.facNodes(); const FacNodes& facNodes = fg.facNodes();
if (checkForIdenticalFactors == false || if (checkForIdenticalFactors == false ||
@ -126,7 +125,7 @@ CbpSolver::findIdenticalFactors()
void void
CbpSolver::setInitialColors (void) CountingBp::setInitialColors (void)
{ {
varColors_.resize (fg.nrVarNodes()); varColors_.resize (fg.nrVarNodes());
facColors_.resize (fg.nrFacNodes()); facColors_.resize (fg.nrFacNodes());
@ -165,7 +164,7 @@ CbpSolver::setInitialColors (void)
void void
CbpSolver::createGroups (void) CountingBp::createGroups (void)
{ {
VarSignMap varGroups; VarSignMap varGroups;
FacSignMap facGroups; FacSignMap facGroups;
@ -227,7 +226,7 @@ CbpSolver::createGroups (void)
void void
CbpSolver::createClusters ( CountingBp::createClusters (
const VarSignMap& varGroups, const VarSignMap& varGroups,
const FacSignMap& facGroups) const FacSignMap& facGroups)
{ {
@ -260,7 +259,7 @@ CbpSolver::createClusters (
VarSignature VarSignature
CbpSolver::getSignature (const VarNode* varNode) CountingBp::getSignature (const VarNode* varNode)
{ {
const FacNodes& neighs = varNode->neighbors(); const FacNodes& neighs = varNode->neighbors();
VarSignature sign; VarSignature sign;
@ -278,7 +277,7 @@ CbpSolver::getSignature (const VarNode* varNode)
FacSignature FacSignature
CbpSolver::getSignature (const FacNode* facNode) CountingBp::getSignature (const FacNode* facNode)
{ {
const VarNodes& neighs = facNode->neighbors(); const VarNodes& neighs = facNode->neighbors();
FacSignature sign; FacSignature sign;
@ -292,8 +291,31 @@ CbpSolver::getSignature (const FacNode* facNode)
VarId
CountingBp::getRepresentative (VarId vid)
{
assert (Util::contains (vid2VarCluster_, vid));
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->representative()->varId();
}
FacNode*
CountingBp::getRepresentative (FacNode* fn)
{
for (size_t i = 0; i < facClusters_.size(); i++) {
if (Util::contains (facClusters_[i]->members(), fn)) {
return facClusters_[i]->representative();
}
}
return 0;
}
FactorGraph* FactorGraph*
CbpSolver::getCompressedFactorGraph (void) CountingBp::getCompressedFactorGraph (void)
{ {
FactorGraph* fg = new FactorGraph(); FactorGraph* fg = new FactorGraph();
for (size_t i = 0; i < varClusters_.size(); i++) { for (size_t i = 0; i < varClusters_.size(); i++) {
@ -322,7 +344,7 @@ CbpSolver::getCompressedFactorGraph (void)
vector<vector<unsigned>> vector<vector<unsigned>>
CbpSolver::getWeights (void) const CountingBp::getWeights (void) const
{ {
vector<vector<unsigned>> weights; vector<vector<unsigned>> weights;
weights.reserve (facClusters_.size()); weights.reserve (facClusters_.size());
@ -341,7 +363,7 @@ CbpSolver::getWeights (void) const
unsigned unsigned
CbpSolver::getWeight ( CountingBp::getWeight (
const FacCluster* fc, const FacCluster* fc,
const VarCluster* vc, const VarCluster* vc,
size_t index) const size_t index) const
@ -364,7 +386,7 @@ CbpSolver::getWeight (
void void
CbpSolver::printGroups ( CountingBp::printGroups (
const VarSignMap& varGroups, const VarSignMap& varGroups,
const FacSignMap& facGroups) const const FacSignMap& facGroups) const
{ {

View File

@ -1,5 +1,5 @@
#ifndef HORUS_CBPSOLVER_H #ifndef HORUS_COUNTINGBP_H
#define HORUS_CBPSOLVER_H #define HORUS_COUNTINGBP_H
#include <unordered_map> #include <unordered_map>
@ -12,7 +12,7 @@ class VarCluster;
class FacCluster; class FacCluster;
class VarSignHash; class VarSignHash;
class FacSignHash; class FacSignHash;
class WeightedBpSolver; class WeightedBp;
typedef long Color; typedef long Color;
typedef vector<Color> Colors; typedef vector<Color> Colors;
@ -100,12 +100,12 @@ class FacCluster
}; };
class CbpSolver : public Solver class CountingBp : public Solver
{ {
public: public:
CbpSolver (const FactorGraph& fg); CountingBp (const FactorGraph& fg);
~CbpSolver (void); ~CountingBp (void);
void printSolverFlags (void) const; void printSolverFlags (void) const;
@ -154,12 +154,9 @@ class CbpSolver : public Solver
void printGroups (const VarSignMap&, const FacSignMap&) const; void printGroups (const VarSignMap&, const FacSignMap&) const;
VarId getRepresentative (VarId vid) VarId getRepresentative (VarId vid);
{
assert (Util::contains (vid2VarCluster_, vid)); FacNode* getRepresentative (FacNode*);
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->representative()->varId();
}
FactorGraph* getCompressedFactorGraph (void); FactorGraph* getCompressedFactorGraph (void);
@ -176,8 +173,8 @@ class CbpSolver : public Solver
FacClusters facClusters_; FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_; VarId2VarCluster vid2VarCluster_;
const FactorGraph* compressedFg_; const FactorGraph* compressedFg_;
WeightedBpSolver* solver_; WeightedBp* solver_;
}; };
#endif // HORUS_CBPSOLVER_H #endif // HORUS_COUNTINGBP_H

View File

@ -8,7 +8,6 @@
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Factor.h" #include "Factor.h"
#include "BayesNet.h"
#include "BayesBall.h" #include "BayesBall.h"
#include "Util.h" #include "Util.h"
@ -236,13 +235,13 @@ FactorGraph::isTree (void) const
DAGraph& BayesBallGraph&
FactorGraph::getStructure (void) FactorGraph::getStructure (void)
{ {
assert (bayesFactors_); assert (bayesFactors_);
if (structure_.empty()) { if (structure_.empty()) {
for (size_t i = 0; i < varNodes_.size(); i++) { for (size_t i = 0; i < varNodes_.size(); i++) {
structure_.addNode (new DAGraphNode (varNodes_[i])); structure_.addNode (new BBNode (varNodes_[i]));
} }
for (size_t i = 0; i < facNodes_.size(); i++) { for (size_t i = 0; i < facNodes_.size(); i++) {
const VarIds& vids = facNodes_[i]->factor().arguments(); const VarIds& vids = facNodes_[i]->factor().arguments();

View File

@ -4,7 +4,7 @@
#include <vector> #include <vector>
#include "Factor.h" #include "Factor.h"
#include "BayesNet.h" #include "BayesBallGraph.h"
#include "Horus.h" #include "Horus.h"
using namespace std; using namespace std;
@ -103,7 +103,7 @@ class FactorGraph
bool isTree (void) const; bool isTree (void) const;
DAGraph& getStructure (void); BayesBallGraph& getStructure (void);
void print (void) const; void print (void) const;
@ -129,7 +129,7 @@ class FactorGraph
VarNodes varNodes_; VarNodes varNodes_;
FacNodes facNodes_; FacNodes facNodes_;
DAGraph structure_; BayesBallGraph structure_;
bool bayesFactors_; bool bayesFactors_;
typedef unordered_map<unsigned, VarNode*> VarMap; typedef unordered_map<unsigned, VarNode*> VarMap;

View File

@ -28,14 +28,14 @@ typedef vector<unsigned> Ranges;
typedef unsigned long long ullong; typedef unsigned long long ullong;
enum LiftedSolvers enum LiftedSolver
{ {
FOVE, // first order variable elimination FOVE, // first order variable elimination
LBP, // lifted belief propagation LBP, // lifted belief propagation
}; };
enum GroundSolvers enum GroundSolver
{ {
VE, // variable elimination VE, // variable elimination
BP, // belief propagation BP, // belief propagation
@ -50,8 +50,8 @@ extern bool logDomain;
// level of debug information // level of debug information
extern unsigned verbosity; extern unsigned verbosity;
extern LiftedSolvers liftedSolver; extern LiftedSolver liftedSolver;
extern GroundSolvers groundSolver; extern GroundSolver groundSolver;
}; };

View File

@ -4,9 +4,9 @@
#include <sstream> #include <sstream>
#include "FactorGraph.h" #include "FactorGraph.h"
#include "VarElimSolver.h" #include "VarElim.h"
#include "BpSolver.h" #include "BeliefProp.h"
#include "CbpSolver.h" #include "CountingBp.h"
using namespace std; using namespace std;
@ -162,14 +162,14 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
{ {
Solver* solver = 0; Solver* solver = 0;
switch (Globals::groundSolver) { switch (Globals::groundSolver) {
case GroundSolvers::VE: case GroundSolver::VE:
solver = new VarElimSolver (fg); solver = new VarElim (fg);
break; break;
case GroundSolvers::BP: case GroundSolver::BP:
solver = new BpSolver (fg); solver = new BeliefProp (fg);
break; break;
case GroundSolvers::CBP: case GroundSolver::CBP:
solver = new CbpSolver (fg); solver = new CountingBp (fg);
break; break;
default: default:
assert (false); assert (false);

View File

@ -9,21 +9,19 @@
#include "ParfactorList.h" #include "ParfactorList.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "FoveSolver.h" #include "LiftedVe.h"
#include "VarElimSolver.h" #include "VarElim.h"
#include "LiftedBpSolver.h" #include "LiftedBp.h"
#include "BpSolver.h" #include "CountingBp.h"
#include "CbpSolver.h" #include "BeliefProp.h"
#include "ElimGraph.h" #include "ElimGraph.h"
#include "BayesBall.h" #include "BayesBall.h"
using namespace std; using namespace std;
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork; typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
Params readParameters (YAP_Term); Params readParameters (YAP_Term);
vector<unsigned> readUnsignedList (YAP_Term); vector<unsigned> readUnsignedList (YAP_Term);
@ -32,14 +30,6 @@ void readLiftedEvidence (YAP_Term, ObservedFormulas&);
Parfactor* readParfactor (YAP_Term); Parfactor* readParfactor (YAP_Term);
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
vector<unsigned> vector<unsigned>
readUnsignedList (YAP_Term list) readUnsignedList (YAP_Term list)
@ -54,7 +44,8 @@ readUnsignedList (YAP_Term list)
int createLiftedNetwork (void) int
createLiftedNetwork (void)
{ {
Parfactors parfactors; Parfactors parfactors;
YAP_Term parfactorList = YAP_ARG1; YAP_Term parfactorList = YAP_ARG1;
@ -91,7 +82,8 @@ int createLiftedNetwork (void)
Parfactor* readParfactor (YAP_Term pfTerm) Parfactor*
readParfactor (YAP_Term pfTerm)
{ {
// read dist id // read dist id
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm)); unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
@ -171,7 +163,8 @@ Parfactor* readParfactor (YAP_Term pfTerm)
void readLiftedEvidence ( void
readLiftedEvidence (
YAP_Term observedList, YAP_Term observedList,
ObservedFormulas& obsFormulas) ObservedFormulas& obsFormulas)
{ {
@ -237,7 +230,6 @@ createGroundNetwork (void)
fg->addFactor (Factor (varIds, ranges, params, distId)); fg->addFactor (Factor (varIds, ranges, params, distId));
factorList = YAP_TailOfTerm (factorList); factorList = YAP_TailOfTerm (factorList);
} }
unsigned nrObservedVars = 0; unsigned nrObservedVars = 0;
YAP_Term evidenceList = YAP_ARG3; YAP_Term evidenceList = YAP_ARG3;
while (evidenceList != YAP_TermNil()) { while (evidenceList != YAP_TermNil()) {
@ -285,7 +277,7 @@ runLiftedSolver (void)
YAP_Term taskList = YAP_ARG2; YAP_Term taskList = YAP_ARG2;
vector<Params> results; vector<Params> results;
ParfactorList pfListCopy (*network->first); ParfactorList pfListCopy (*network->first);
FoveSolver::absorveEvidence (pfListCopy, *network->second); LiftedVe::absorveEvidence (pfListCopy, *network->second);
while (taskList != YAP_TermNil()) { while (taskList != YAP_TermNil()) {
Grounds queryVars; Grounds queryVars;
YAP_Term jointList = YAP_HeadOfTerm (taskList); YAP_Term jointList = YAP_HeadOfTerm (taskList);
@ -311,15 +303,15 @@ runLiftedSolver (void)
} }
jointList = YAP_TailOfTerm (jointList); jointList = YAP_TailOfTerm (jointList);
} }
if (Globals::liftedSolver == LiftedSolvers::FOVE) { if (Globals::liftedSolver == LiftedSolver::FOVE) {
FoveSolver solver (pfListCopy); LiftedVe solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) { if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags(); solver.printSolverFlags();
cout << endl; cout << endl;
} }
results.push_back (solver.solveQuery (queryVars)); results.push_back (solver.solveQuery (queryVars));
} else if (Globals::liftedSolver == LiftedSolvers::LBP) { } else if (Globals::liftedSolver == LiftedSolver::LBP) {
LiftedBpSolver solver (pfListCopy); LiftedBp solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) { if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags(); solver.printSolverFlags();
cout << endl; cout << endl;
@ -361,11 +353,42 @@ runGroundSolver (void)
taskList = YAP_TailOfTerm (taskList); taskList = YAP_TailOfTerm (taskList);
} }
vector<Params> results; std::set<VarId> vids;
if (Globals::groundSolver == GroundSolvers::VE) { for (size_t i = 0; i < tasks.size(); i++) {
runVeSolver (fg, tasks, results); Util::addToSet (vids, tasks[i]);
}
Solver* solver = 0;
FactorGraph* mfg = fg;
if (fg->bayesianFactors()) {
mfg = BayesBall::getMinimalFactorGraph (
*fg, VarIds (vids.begin(), vids.end()));
}
if (Globals::groundSolver == GroundSolver::VE) {
solver = new VarElim (*mfg);
} else if (Globals::groundSolver == GroundSolver::BP) {
solver = new BeliefProp (*mfg);
} else if (Globals::groundSolver == GroundSolver::CBP) {
CountingBp::checkForIdenticalFactors = false;
solver = new CountingBp (*mfg);
} else { } else {
runBpSolver (fg, tasks, results); assert (false);
}
if (Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
}
vector<Params> results;
results.reserve (tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i]));
}
delete solver;
if (fg->bayesianFactors()) {
delete mfg;
} }
YAP_Term list = YAP_TermNil(); YAP_Term list = YAP_TermNil();
@ -386,72 +409,6 @@ runGroundSolver (void)
void runVeSolver (
FactorGraph* fg,
const vector<VarIds>& tasks,
vector<Params>& results)
{
results.reserve (tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
FactorGraph* mfg = fg;
if (fg->bayesianFactors()) {
// mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
}
// VarElimSolver solver (*mfg);
VarElimSolver solver (*fg); //FIXME
if (Globals::verbosity > 0 && i == 0) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (tasks[i]));
if (fg->bayesianFactors()) {
// delete mfg;
}
}
}
void runBpSolver (
FactorGraph* fg,
const vector<VarIds>& tasks,
vector<Params>& results)
{
std::set<VarId> vids;
for (size_t i = 0; i < tasks.size(); i++) {
Util::addToSet (vids, tasks[i]);
}
Solver* solver = 0;
FactorGraph* mfg = fg;
if (fg->bayesianFactors()) {
//mfg = BayesBall::getMinimalFactorGraph (
// *fg, VarIds (vids.begin(),vids.end()));
}
if (Globals::groundSolver == GroundSolvers::BP) {
solver = new BpSolver (*fg); // FIXME
} else if (Globals::groundSolver == GroundSolvers::CBP) {
CbpSolver::checkForIdenticalFactors = false;
solver = new CbpSolver (*fg); // FIXME
} else {
cerr << "error: unknow solver" << endl;
abort();
}
if (Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
}
results.reserve (tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i]));
}
if (fg->bayesianFactors()) {
//delete mfg;
}
delete solver;
}
int int
setParfactorsParams (void) setParfactorsParams (void)
{ {
@ -567,7 +524,7 @@ freeGroundNetwork (void)
int int
freeParfactors (void) freeLiftedNetwork (void)
{ {
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1); LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
delete network->first; delete network->first;
@ -589,7 +546,7 @@ init_predicates (void)
YAP_UserCPredicate ("cpp_cpp_set_factors_params", setFactorsParams, 2); YAP_UserCPredicate ("cpp_cpp_set_factors_params", setFactorsParams, 2);
YAP_UserCPredicate ("cpp_set_vars_information", setVarsInformation, 2); YAP_UserCPredicate ("cpp_set_vars_information", setVarsInformation, 2);
YAP_UserCPredicate ("cpp_set_horus_flag", setHorusFlag, 2); YAP_UserCPredicate ("cpp_set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("cpp_free_parfactors", freeParfactors, 1); YAP_UserCPredicate ("cpp_free_lifted_network", freeLiftedNetwork, 1);
YAP_UserCPredicate ("cpp_free_ground_network", freeGroundNetwork, 1); YAP_UserCPredicate ("cpp_free_ground_network", freeGroundNetwork, 1);
} }

View File

@ -0,0 +1,232 @@
#include "LiftedBp.h"
#include "WeightedBp.h"
#include "FactorGraph.h"
#include "LiftedVe.h"
LiftedBp::LiftedBp (const ParfactorList& pfList)
: pfList_(pfList)
{
refineParfactors();
solver_ = new WeightedBp (*getFactorGraph(), getWeights());
}
LiftedBp::~LiftedBp (void)
{
delete solver_;
}
Params
LiftedBp::solveQuery (const Grounds& query)
{
assert (query.empty() == false);
Params res;
vector<PrvGroup> groups = getQueryGroups (query);
if (query.size() == 1) {
res = solver_->getPosterioriOf (groups[0]);
} else {
ParfactorList::iterator it = pfList_.begin();
size_t idx = pfList_.size();
size_t count = 0;
while (it != pfList_.end()) {
if ((*it)->containsGrounds (query)) {
idx = count;
break;
}
++ it;
++ count;
}
if (idx == pfList_.size()) {
res = getJointByConditioning (pfList_, query);
} else {
VarIds queryVids;
for (unsigned i = 0; i < groups.size(); i++) {
queryVids.push_back (groups[i]);
}
res = solver_->getFactorJoint (idx, queryVids);
}
}
return res;
}
void
LiftedBp::printSolverFlags (void) const
{
stringstream ss;
ss << "lifted bp [" ;
ss << "schedule=" ;
typedef BpOptions::Schedule Sch;
switch (BpOptions::schedule) {
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
case Sch::PARALLEL: ss << "parallel"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
}
ss << ",max_iter=" << BpOptions::maxIter;
ss << ",accuracy=" << BpOptions::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
}
void
LiftedBp::refineParfactors (void)
{
while (iterate() == false);
if (Globals::verbosity > 2) {
Util::printHeader ("AFTER REFINEMENT");
pfList_.print();
}
}
bool
LiftedBp::iterate (void)
{
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
const ProbFormulas& args = (*it)->arguments();
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
if ((*it)->constr()->isCountNormalized (lvs) == false) {
Parfactors pfs = LiftedVe::countNormalize (*it, lvs);
it = pfList_.removeAndDelete (it);
pfList_.add (pfs);
return false;
}
}
++ it;
}
return true;
}
vector<PrvGroup>
LiftedBp::getQueryGroups (const Grounds& query)
{
vector<PrvGroup> queryGroups;
for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
if ((*it)->containsGround (query[i])) {
queryGroups.push_back ((*it)->findGroup (query[i]));
break;
}
}
}
assert (queryGroups.size() == query.size());
return queryGroups;
}
FactorGraph*
LiftedBp::getFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
vector<PrvGroup> groups = (*it)->getAllGroups();
VarIds varIds;
for (size_t i = 0; i < groups.size(); i++) {
varIds.push_back (groups[i]);
}
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
}
return fg;
}
vector<vector<unsigned>>
LiftedBp::getWeights (void) const
{
vector<vector<unsigned>> weights;
weights.reserve (pfList_.size());
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
const ProbFormulas& args = (*it)->arguments();
weights.push_back ({ });
weights.back().reserve (args.size());
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
weights.back().push_back ((*it)->constr()->getConditionalCount (lvs));
}
}
return weights;
}
unsigned
LiftedBp::rangeOfGround (const Ground& gr)
{
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if ((*it)->containsGround (gr)) {
PrvGroup prvGroup = (*it)->findGroup (gr);
return (*it)->range ((*it)->indexOfGroup (prvGroup));
}
++ it;
}
return std::numeric_limits<unsigned>::max();
}
Params
LiftedBp::getJointByConditioning (
const ParfactorList& pfList,
const Grounds& grounds)
{
LiftedBp solver (pfList);
Params prevBeliefs = solver.solveQuery ({grounds[0]});
Grounds obsGrounds = {grounds[0]};
for (size_t i = 1; i < grounds.size(); i++) {
Params newBeliefs;
vector<ObservedFormula> obsFs;
Ranges obsRanges;
for (size_t j = 0; j < obsGrounds.size(); j++) {
obsFs.push_back (ObservedFormula (
obsGrounds[j].functor(), 0, obsGrounds[j].args()));
obsRanges.push_back (rangeOfGround (obsGrounds[j]));
}
Indexer indexer (obsRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < obsFs.size(); j++) {
obsFs[j].setEvidence (indexer[j]);
}
ParfactorList tempPfList (pfList);
LiftedVe::absorveEvidence (tempPfList, obsFs);
LiftedBp solver (tempPfList);
Params beliefs = solver.solveQuery ({grounds[i]});
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
unsigned range = rangeOfGround (grounds[i]);
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % range == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
obsGrounds.push_back (grounds[i]);
}
return prevBeliefs;
}

View File

@ -1,15 +1,17 @@
#ifndef HORUS_LIFTEDBPSOLVER_H #ifndef HORUS_LIFTEDBP_H
#define HORUS_LIFTEDBPSOLVER_H #define HORUS_LIFTEDBP_H
#include "ParfactorList.h" #include "ParfactorList.h"
class FactorGraph; class FactorGraph;
class WeightedBpSolver; class WeightedBp;
class LiftedBpSolver class LiftedBp
{ {
public: public:
LiftedBpSolver (const ParfactorList& pfList); LiftedBp (const ParfactorList& pfList);
~LiftedBp (void);
Params solveQuery (const Grounds&); Params solveQuery (const Grounds&);
@ -26,9 +28,13 @@ class LiftedBpSolver
vector<vector<unsigned>> getWeights (void) const; vector<vector<unsigned>> getWeights (void) const;
unsigned rangeOfGround (const Ground&);
Params getJointByConditioning (const ParfactorList&, const Grounds&);
ParfactorList pfList_; ParfactorList pfList_;
WeightedBpSolver* solver_; WeightedBp* solver_;
}; };
#endif // HORUS_LIFTEDBPSOLVER_H #endif // HORUS_LIFTEDBP_H

View File

@ -1,148 +0,0 @@
#include "LiftedBpSolver.h"
#include "WeightedBpSolver.h"
#include "FactorGraph.h"
#include "FoveSolver.h"
LiftedBpSolver::LiftedBpSolver (const ParfactorList& pfList)
: pfList_(pfList)
{
refineParfactors();
solver_ = new WeightedBpSolver (*getFactorGraph(), getWeights());
}
Params
LiftedBpSolver::solveQuery (const Grounds& query)
{
assert (query.empty() == false);
Params res;
vector<PrvGroup> groups = getQueryGroups (query);
if (query.size() == 1) {
res = solver_->getPosterioriOf (groups[0]);
} else {
VarIds queryVids;
for (unsigned i = 0; i < groups.size(); i++) {
queryVids.push_back (groups[i]);
}
res = solver_->getJointDistributionOf (queryVids);
}
return res;
}
void
LiftedBpSolver::printSolverFlags (void) const
{
stringstream ss;
ss << "lifted bp [" ;
ss << "schedule=" ;
typedef BpOptions::Schedule Sch;
switch (BpOptions::schedule) {
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
case Sch::PARALLEL: ss << "parallel"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
}
ss << ",max_iter=" << BpOptions::maxIter;
ss << ",accuracy=" << BpOptions::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
}
void
LiftedBpSolver::refineParfactors (void)
{
while (iterate() == false);
if (Globals::verbosity > 2) {
Util::printHeader ("AFTER REFINEMENT");
pfList_.print();
}
}
bool
LiftedBpSolver::iterate (void)
{
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
const ProbFormulas& args = (*it)->arguments();
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
if ((*it)->constr()->isCountNormalized (lvs) == false) {
Parfactors pfs = FoveSolver::countNormalize (*it, lvs);
it = pfList_.removeAndDelete (it);
pfList_.add (pfs);
return false;
}
}
++ it;
}
return true;
}
vector<PrvGroup>
LiftedBpSolver::getQueryGroups (const Grounds& query)
{
vector<PrvGroup> queryGroups;
for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
if ((*it)->containsGround (query[i])) {
queryGroups.push_back ((*it)->findGroup (query[i]));
break;
}
}
}
assert (queryGroups.size() == query.size());
return queryGroups;
}
FactorGraph*
LiftedBpSolver::getFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
vector<PrvGroup> groups = (*it)->getAllGroups();
VarIds varIds;
for (size_t i = 0; i < groups.size(); i++) {
varIds.push_back (groups[i]);
}
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
}
return fg;
}
vector<vector<unsigned>>
LiftedBpSolver::getWeights (void) const
{
vector<vector<unsigned>> weights;
weights.reserve (pfList_.size());
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
const ProbFormulas& args = (*it)->arguments();
weights.push_back ({ });
weights.back().reserve (args.size());
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
weights.back().push_back ((*it)->constr()->getConditionalCount (lvs));
}
}
return weights;
}

View File

@ -1,8 +1,7 @@
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include "FoveSolver.h" #include "LiftedVe.h"
#include "Histogram.h" #include "Histogram.h"
#include "Util.h" #include "Util.h"
@ -222,7 +221,7 @@ SumOutOperator::apply (void)
product->sumOutIndex (fIdx); product->sumOutIndex (fIdx);
pfList_.addShattered (product); pfList_.addShattered (product);
} else { } else {
Parfactors pfs = FoveSolver::countNormalize (product, excl); Parfactors pfs = LiftedVe::countNormalize (product, excl);
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
pfs[i]->sumOutIndex (fIdx); pfs[i]->sumOutIndex (fIdx);
pfList_.add (pfs[i]); pfList_.add (pfs[i]);
@ -376,7 +375,7 @@ CountingOperator::apply (void)
} else { } else {
Parfactor* pf = *pfIter_; Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_); pfList_.remove (pfIter_);
Parfactors pfs = FoveSolver::countNormalize (pf, X_); Parfactors pfs = LiftedVe::countNormalize (pf, X_);
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_); unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
bool cartProduct = pfs[i]->constr()->isCartesianProduct ( bool cartProduct = pfs[i]->constr()->isCartesianProduct (
@ -420,7 +419,7 @@ CountingOperator::toString (void)
ss << "count convert " << X_ << " in " ; ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel(); ss << (*pfIter_)->getLabel();
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl; ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_); Parfactors pfs = LiftedVe::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) { if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
ss << " º " << pfs[i]->getLabel() << endl; ss << " º " << pfs[i]->getLabel() << endl;
@ -501,7 +500,7 @@ GroundOperator::getLogCost (void)
++ pflIt; ++ pflIt;
} }
// cout << endl; // cout << endl;
return totalCost; return totalCost + 3;
} }
@ -610,7 +609,7 @@ GroundOperator::getAffectedFormulas (void)
LogVar X = f.logVars()[front.second]; LogVar X = f.logVars()[front.second];
const ProbFormulas& fs = (*pflIt)->arguments(); const ProbFormulas& fs = (*pflIt)->arguments();
for (size_t i = 0; i < fs.size(); i++) { for (size_t i = 0; i < fs.size(); i++) {
if ((int)i != idx && fs[i].contains (X)) { if (i != idx && fs[i].contains (X)) {
pair<PrvGroup, unsigned> pair = make_pair ( pair<PrvGroup, unsigned> pair = make_pair (
fs[i].group(), fs[i].indexOf (X)); fs[i].group(), fs[i].indexOf (X));
if (Util::contains (affectedFormulas, pair) == false) { if (Util::contains (affectedFormulas, pair) == false) {
@ -630,7 +629,7 @@ GroundOperator::getAffectedFormulas (void)
Params Params
FoveSolver::solveQuery (const Grounds& query) LiftedVe::solveQuery (const Grounds& query)
{ {
assert (query.empty() == false); assert (query.empty() == false);
runSolver (query); runSolver (query);
@ -645,7 +644,7 @@ FoveSolver::solveQuery (const Grounds& query)
void void
FoveSolver::printSolverFlags (void) const LiftedVe::printSolverFlags (void) const
{ {
stringstream ss; stringstream ss;
ss << "fove [" ; ss << "fove [" ;
@ -657,7 +656,7 @@ FoveSolver::printSolverFlags (void) const
void void
FoveSolver::absorveEvidence ( LiftedVe::absorveEvidence (
ParfactorList& pfList, ParfactorList& pfList,
ObservedFormulas& obsFormulas) ObservedFormulas& obsFormulas)
{ {
@ -696,7 +695,7 @@ FoveSolver::absorveEvidence (
Parfactors Parfactors
FoveSolver::countNormalize ( LiftedVe::countNormalize (
Parfactor* g, Parfactor* g,
const LogVarSet& set) const LogVarSet& set)
{ {
@ -715,7 +714,7 @@ FoveSolver::countNormalize (
Parfactor Parfactor
FoveSolver::calcGroundMultiplication (Parfactor pf) LiftedVe::calcGroundMultiplication (Parfactor pf)
{ {
LogVarSet lvs = pf.constr()->logVarSet(); LogVarSet lvs = pf.constr()->logVarSet();
lvs -= pf.constr()->singletons(); lvs -= pf.constr()->singletons();
@ -748,7 +747,7 @@ FoveSolver::calcGroundMultiplication (Parfactor pf)
void void
FoveSolver::runSolver (const Grounds& query) LiftedVe::runSolver (const Grounds& query)
{ {
largestCost_ = std::log (0); largestCost_ = std::log (0);
shatterAgainstQuery (query); shatterAgainstQuery (query);
@ -794,7 +793,7 @@ FoveSolver::runSolver (const Grounds& query)
LiftedOperator* LiftedOperator*
FoveSolver::getBestOperation (const Grounds& query) LiftedVe::getBestOperation (const Grounds& query)
{ {
double bestCost = 0.0; double bestCost = 0.0;
LiftedOperator* bestOp = 0; LiftedOperator* bestOp = 0;
@ -821,7 +820,7 @@ FoveSolver::getBestOperation (const Grounds& query)
void void
FoveSolver::runWeakBayesBall (const Grounds& query) LiftedVe::runWeakBayesBall (const Grounds& query)
{ {
queue<PrvGroup> todo; // groups to process queue<PrvGroup> todo; // groups to process
set<PrvGroup> done; // processed or in queue set<PrvGroup> done; // processed or in queue
@ -880,7 +879,7 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
void void
FoveSolver::shatterAgainstQuery (const Grounds& query) LiftedVe::shatterAgainstQuery (const Grounds& query)
{ {
for (size_t i = 0; i < query.size(); i++) { for (size_t i = 0; i < query.size(); i++) {
if (query[i].isAtom()) { if (query[i].isAtom()) {
@ -931,7 +930,7 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
Parfactors Parfactors
FoveSolver::absorve ( LiftedVe::absorve (
ObservedFormula& obsFormula, ObservedFormula& obsFormula,
Parfactor* g) Parfactor* g)
{ {

View File

@ -1,5 +1,5 @@
#ifndef HORUS_FOVESOLVER_H #ifndef HORUS_LIFTEDVE_H
#define HORUS_FOVESOLVER_H #define HORUS_LIFTEDVE_H
#include "ParfactorList.h" #include "ParfactorList.h"
@ -130,10 +130,10 @@ class GroundOperator : public LiftedOperator
class FoveSolver class LiftedVe
{ {
public: public:
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { } LiftedVe (const ParfactorList& pfList) : pfList_(pfList) { }
Params solveQuery (const Grounds&); Params solveQuery (const Grounds&);
@ -162,5 +162,5 @@ class FoveSolver
double largestCost_; double largestCost_;
}; };
#endif // HORUS_FOVESOLVER_H #endif // HORUS_LIFTEDVE_H

View File

@ -23,10 +23,10 @@ CC=@CC@
CXX=@CXX@ CXX=@CXX@
# normal # normal
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG
# debug # debug
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra #CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra
# #
@ -45,98 +45,91 @@ CWD=$(PWD)
HEADERS = \ HEADERS = \
$(srcdir)/BayesNet.h \
$(srcdir)/BayesBall.h \ $(srcdir)/BayesBall.h \
$(srcdir)/ElimGraph.h \ $(srcdir)/BayesBallGraph.h \
$(srcdir)/FactorGraph.h \ $(srcdir)/BeliefProp.h \
$(srcdir)/Factor.h \
$(srcdir)/ConstraintTree.h \ $(srcdir)/ConstraintTree.h \
$(srcdir)/Solver.h \ $(srcdir)/CountingBp.h \
$(srcdir)/VarElimSolver.h \ $(srcdir)/ElimGraph.h \
$(srcdir)/BpSolver.h \ $(srcdir)/Factor.h \
$(srcdir)/CbpSolver.h \ $(srcdir)/FactorGraph.h \
$(srcdir)/FoveSolver.h \
$(srcdir)/Var.h \
$(srcdir)/Indexer.h \
$(srcdir)/Parfactor.h \
$(srcdir)/ProbFormula.h \
$(srcdir)/Histogram.h \ $(srcdir)/Histogram.h \
$(srcdir)/ParfactorList.h \ $(srcdir)/Horus.h \
$(srcdir)/Indexer.h \
$(srcdir)/LiftedBp.h \
$(srcdir)/LiftedUtils.h \ $(srcdir)/LiftedUtils.h \
$(srcdir)/LiftedVe.h \
$(srcdir)/Parfactor.h \
$(srcdir)/ParfactorList.h \
$(srcdir)/ProbFormula.h \
$(srcdir)/Solver.h \
$(srcdir)/TinySet.h \ $(srcdir)/TinySet.h \
$(srcdir)/LiftedBpSolver.h \
$(srcdir)/WeightedBpSolver.h \
$(srcdir)/Util.h \ $(srcdir)/Util.h \
$(srcdir)/Horus.h $(srcdir)/Var.h \
$(srcdir)/VarElim.h \
$(srcdir)/WeightedBp.h
CPP_SOURCES = \ CPP_SOURCES = \
$(srcdir)/BayesNet.cpp \
$(srcdir)/BayesBall.cpp \ $(srcdir)/BayesBall.cpp \
$(srcdir)/ElimGraph.cpp \ $(srcdir)/BayesBallGraph.cpp \
$(srcdir)/FactorGraph.cpp \ $(srcdir)/BeliefProp.cpp \
$(srcdir)/Factor.cpp \
$(srcdir)/ConstraintTree.cpp \ $(srcdir)/ConstraintTree.cpp \
$(srcdir)/Var.cpp \ $(srcdir)/CountingBp.cpp \
$(srcdir)/Solver.cpp \ $(srcdir)/ElimGraph.cpp \
$(srcdir)/VarElimSolver.cpp \ $(srcdir)/Factor.cpp \
$(srcdir)/BpSolver.cpp \ $(srcdir)/FactorGraph.cpp \
$(srcdir)/CbpSolver.cpp \
$(srcdir)/FoveSolver.cpp \
$(srcdir)/Parfactor.cpp \
$(srcdir)/ProbFormula.cpp \
$(srcdir)/Histogram.cpp \ $(srcdir)/Histogram.cpp \
$(srcdir)/ParfactorList.cpp \ $(srcdir)/HorusCli.cpp \
$(srcdir)/LiftedUtils.cpp \
$(srcdir)/Util.cpp \
$(srcdir)/LiftedBpSolver.cpp \
$(srcdir)/WeightedBpSolver.cpp \
$(srcdir)/HorusYap.cpp \ $(srcdir)/HorusYap.cpp \
$(srcdir)/HorusCli.cpp $(srcdir)/LiftedBp.cpp \
$(srcdir)/LiftedUtils.cpp \
$(srcdir)/LiftedVe.cpp \
$(srcdir)/Parfactor.cpp \
$(srcdir)/ParfactorList.cpp \
$(srcdir)/ProbFormula.cpp \
$(srcdir)/Solver.cpp \
$(srcdir)/Util.cpp \
$(srcdir)/Var.cpp \
$(srcdir)/VarElim.cpp \
$(srcdir)/WeightedBp.cpp \
OBJS = \ OBJS = \
BayesNet.o \
BayesBall.o \ BayesBall.o \
ElimGraph.o \ BayesBallGraph.o \
FactorGraph.o \ BeliefProp.o \
Factor.o \
ConstraintTree.o \ ConstraintTree.o \
Var.o \ CountingBp.o \
Solver.o \ ElimGraph.o \
VarElimSolver.o \ Factor.o \
BpSolver.o \ FactorGraph.o \
CbpSolver.o \
FoveSolver.o \
Parfactor.o \
ProbFormula.o \
Histogram.o \ Histogram.o \
ParfactorList.o \ HorusYap.o \
LiftedBp.o \
LiftedUtils.o \ LiftedUtils.o \
LiftedVe.o \
ProbFormula.o \
Parfactor.o \
ParfactorList.o \
Solver.o \
Util.o \ Util.o \
LiftedBpSolver.o \ Var.o \
WeightedBpSolver.o \ VarElim.o \
HorusYap.o WeightedBp.o
HCLI_OBJS = \ HCLI_OBJS = \
BayesNet.o \
BayesBall.o \ BayesBall.o \
BayesBallGraph.o \
BeliefProp.o \
CountingBp.o \
ElimGraph.o \ ElimGraph.o \
FactorGraph.o \
Factor.o \ Factor.o \
ConstraintTree.o \ FactorGraph.o \
Var.o \ HorusCli.o \
Solver.o \ Solver.o \
VarElimSolver.o \
BpSolver.o \
CbpSolver.o \
FoveSolver.o \
Parfactor.o \
ProbFormula.o \
Histogram.o \
ParfactorList.o \
WeightedBpSolver.o \
LiftedUtils.o \
Util.o \ Util.o \
HorusCli.o Var.o \
VarElim.o \
WeightedBp.o
SOBJS=horus.@SO@ SOBJS=horus.@SO@

View File

@ -402,21 +402,32 @@ Parfactor::applySubstitution (const Substitution& theta)
PrvGroup size_t
Parfactor::findGroup (const Ground& ground) const Parfactor::indexOfGround (const Ground& ground) const
{ {
PrvGroup group = numeric_limits<PrvGroup>::max(); size_t idx = args_.size();
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
if (args_[i].functor() == ground.functor() && if (args_[i].functor() == ground.functor() &&
args_[i].arity() == ground.arity()) { args_[i].arity() == ground.arity()) {
constr_->moveToTop (args_[i].logVars()); constr_->moveToTop (args_[i].logVars());
if (constr_->containsTuple (ground.args())) { if (constr_->containsTuple (ground.args())) {
group = args_[i].group(); idx = i;
break; break;
} }
} }
} }
return group; return idx;
}
PrvGroup
Parfactor::findGroup (const Ground& ground) const
{
size_t idx = indexOfGround (ground);
return idx == args_.size()
? numeric_limits<PrvGroup>::max()
: args_[idx].group();
} }
@ -429,6 +440,30 @@ Parfactor::containsGround (const Ground& ground) const
bool
Parfactor::containsGrounds (const Grounds& grounds) const
{
Tuple tuple;
LogVars tupleLvs;
for (size_t i = 0; i < grounds.size(); i++) {
size_t idx = indexOfGround (grounds[i]);
if (idx == args_.size()) {
return false;
}
LogVars lvs = args_[idx].logVars();
for (size_t j = 0; j < lvs.size(); j++) {
if (Util::contains (tupleLvs, lvs[j]) == false) {
tuple.push_back (grounds[i].args()[j]);
tupleLvs.push_back (lvs[j]);
}
}
}
constr_->moveToTop (tupleLvs);
return constr_->containsTuple (tuple);
}
bool bool
Parfactor::containsGroup (PrvGroup group) const Parfactor::containsGroup (PrvGroup group) const
{ {
@ -442,6 +477,19 @@ Parfactor::containsGroup (PrvGroup group) const
bool
Parfactor::containsGroups (vector<PrvGroup> groups) const
{
for (size_t i = 0; i < groups.size(); i++) {
if (containsGroup (groups[i]) == false) {
return false;
}
}
return true;
}
unsigned unsigned
Parfactor::nrFormulas (LogVar X) const Parfactor::nrFormulas (LogVar X) const
{ {

View File

@ -64,12 +64,18 @@ class Parfactor : public TFactor<ProbFormula>
void applySubstitution (const Substitution&); void applySubstitution (const Substitution&);
size_t indexOfGround (const Ground&) const;
PrvGroup findGroup (const Ground&) const; PrvGroup findGroup (const Ground&) const;
bool containsGround (const Ground&) const; bool containsGround (const Ground&) const;
bool containsGrounds (const Grounds&) const;
bool containsGroup (PrvGroup) const; bool containsGroup (PrvGroup) const;
bool containsGroups (vector<PrvGroup>) const;
unsigned nrFormulas (LogVar) const; unsigned nrFormulas (LogVar) const;
int indexOfLogVar (LogVar) const; int indexOfLogVar (LogVar) const;

View File

@ -91,6 +91,8 @@ class ObservedFormula
unsigned evidence (void) const { return evidence_; } unsigned evidence (void) const { return evidence_; }
void setEvidence (unsigned ev) { evidence_ = ev; }
ConstraintTree& constr (void) { return constr_; } ConstraintTree& constr (void) { return constr_; }
bool isAtom (void) const { return arity_ == 0; } bool isAtom (void) const { return arity_ == 0; }

View File

@ -1,5 +1,8 @@
#include "Solver.h" #include "Solver.h"
#include "Util.h" #include "Util.h"
#include "BeliefProp.h"
#include "CountingBp.h"
#include "VarElim.h"
void void
@ -38,3 +41,67 @@ Solver::printAllPosterioris (void)
} }
} }
Params
Solver::getJointByConditioning (
GroundSolver solverType,
FactorGraph fg,
const VarIds& jointVarIds) const
{
VarNodes jointVars;
for (size_t i = 0; i < jointVarIds.size(); i++) {
assert (fg.getVarNode (jointVarIds[i]));
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
}
Solver* solver = 0;
switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break;
}
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
VarIds observedVids = {jointVars[0]->varId()};
for (size_t i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs;
Vars observedVars;
Ranges observedRanges;
for (size_t j = 0; j < observedVids.size(); j++) {
observedVars.push_back (fg.getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range());
}
Indexer indexer (observedRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]);
}
delete solver;
switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break;
}
Params beliefs = solver->solveQuery ({jointVarIds[i]});
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->range() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
}
delete solver;
return prevBeliefs;
}

View File

@ -3,8 +3,9 @@
#include <iomanip> #include <iomanip>
#include "Var.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Var.h"
#include "Horus.h"
using namespace std; using namespace std;
@ -24,6 +25,9 @@ class Solver
void printAllPosterioris (void); void printAllPosterioris (void);
Params getJointByConditioning (GroundSolver,
FactorGraph, const VarIds& jointVarIds) const;
protected: protected:
const FactorGraph& fg; const FactorGraph& fg;
}; };

View File

@ -13,9 +13,9 @@ bool logDomain = false;
unsigned verbosity = 0; unsigned verbosity = 0;
LiftedSolvers liftedSolver = LiftedSolvers::FOVE; LiftedSolver liftedSolver = LiftedSolver::FOVE;
GroundSolvers groundSolver = GroundSolvers::VE; GroundSolver groundSolver = GroundSolver::VE;
}; };
@ -211,9 +211,9 @@ setHorusFlag (string key, string value)
ss >> Globals::verbosity; ss >> Globals::verbosity;
} else if (key == "lifted_solver") { } else if (key == "lifted_solver") {
if ( value == "fove") { if ( value == "fove") {
Globals::liftedSolver = LiftedSolvers::FOVE; Globals::liftedSolver = LiftedSolver::FOVE;
} else if (value == "lbp") { } else if (value == "lbp") {
Globals::liftedSolver = LiftedSolvers::LBP; Globals::liftedSolver = LiftedSolver::LBP;
} else { } else {
cerr << "warning: invalid value `" << value << "' " ; cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;
@ -221,11 +221,11 @@ setHorusFlag (string key, string value)
} }
} else if (key == "ground_solver") { } else if (key == "ground_solver") {
if ( value == "ve") { if ( value == "ve") {
Globals::groundSolver = GroundSolvers::VE; Globals::groundSolver = GroundSolver::VE;
} else if (value == "bp") { } else if (value == "bp") {
Globals::groundSolver = GroundSolvers::BP; Globals::groundSolver = GroundSolver::BP;
} else if (value == "cbp") { } else if (value == "cbp") {
Globals::groundSolver = GroundSolvers::CBP; Globals::groundSolver = GroundSolver::CBP;
} else { } else {
cerr << "warning: invalid value `" << value << "' " ; cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;

View File

@ -1,12 +1,12 @@
#include <algorithm> #include <algorithm>
#include "VarElimSolver.h" #include "VarElim.h"
#include "ElimGraph.h" #include "ElimGraph.h"
#include "Factor.h" #include "Factor.h"
#include "Util.h" #include "Util.h"
VarElimSolver::~VarElimSolver (void) VarElim::~VarElim (void)
{ {
delete factorList_.back(); delete factorList_.back();
} }
@ -14,7 +14,7 @@ VarElimSolver::~VarElimSolver (void)
Params Params
VarElimSolver::solveQuery (VarIds queryVids) VarElim::solveQuery (VarIds queryVids)
{ {
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << "Solving query on " ; cout << "Solving query on " ;
@ -41,7 +41,7 @@ VarElimSolver::solveQuery (VarIds queryVids)
void void
VarElimSolver::printSolverFlags (void) const VarElim::printSolverFlags (void) const
{ {
stringstream ss; stringstream ss;
ss << "variable elimination [" ; ss << "variable elimination [" ;
@ -62,7 +62,7 @@ VarElimSolver::printSolverFlags (void) const
void void
VarElimSolver::createFactorList (void) VarElim::createFactorList (void)
{ {
const FacNodes& facNodes = fg.facNodes(); const FacNodes& facNodes = fg.facNodes();
factorList_.reserve (facNodes.size() * 2); factorList_.reserve (facNodes.size() * 2);
@ -84,7 +84,7 @@ VarElimSolver::createFactorList (void)
void void
VarElimSolver::absorveEvidence (void) VarElim::absorveEvidence (void)
{ {
if (Globals::verbosity > 2) { if (Globals::verbosity > 2) {
Util::printDashedLine(); Util::printDashedLine();
@ -117,7 +117,7 @@ VarElimSolver::absorveEvidence (void)
void void
VarElimSolver::findEliminationOrder (const VarIds& vids) VarElim::findEliminationOrder (const VarIds& vids)
{ {
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids); elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
} }
@ -125,7 +125,7 @@ VarElimSolver::findEliminationOrder (const VarIds& vids)
void void
VarElimSolver::processFactorList (const VarIds& vids) VarElim::processFactorList (const VarIds& vids)
{ {
totalFactorSize_ = 0; totalFactorSize_ = 0;
largestFactorSize_ = 0; largestFactorSize_ = 0;
@ -170,7 +170,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
void void
VarElimSolver::eliminate (VarId elimVar) VarElim::eliminate (VarId elimVar)
{ {
Factor* result = 0; Factor* result = 0;
vector<size_t>& idxs = varFactors_.find (elimVar)->second; vector<size_t>& idxs = varFactors_.find (elimVar)->second;
@ -205,7 +205,7 @@ VarElimSolver::eliminate (VarId elimVar)
void void
VarElimSolver::printActiveFactors (void) VarElim::printActiveFactors (void)
{ {
for (size_t i = 0; i < factorList_.size(); i++) { for (size_t i = 0; i < factorList_.size(); i++) {
if (factorList_[i] != 0) { if (factorList_[i] != 0) {

View File

@ -1,5 +1,5 @@
#ifndef HORUS_VARELIMSOLVER_H #ifndef HORUS_VARELIM_H
#define HORUS_VARELIMSOLVER_H #define HORUS_VARELIM_H
#include "unordered_map" #include "unordered_map"
@ -11,12 +11,12 @@
using namespace std; using namespace std;
class VarElimSolver : public Solver class VarElim : public Solver
{ {
public: public:
VarElimSolver (const FactorGraph& fg) : Solver (fg) { } VarElim (const FactorGraph& fg) : Solver (fg) { }
~VarElimSolver (void); ~VarElim (void);
Params solveQuery (VarIds); Params solveQuery (VarIds);
@ -42,5 +42,5 @@ class VarElimSolver : public Solver
unordered_map<VarId, vector<size_t>> varFactors_; unordered_map<VarId, vector<size_t>> varFactors_;
}; };
#endif // HORUS_VARELIMSOLVER_H #endif // HORUS_VARELIM_H

View File

@ -1,7 +1,7 @@
#include "WeightedBpSolver.h" #include "WeightedBp.h"
WeightedBpSolver::~WeightedBpSolver (void) WeightedBp::~WeightedBp (void)
{ {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
delete links_[i]; delete links_[i];
@ -12,7 +12,7 @@ WeightedBpSolver::~WeightedBpSolver (void)
Params Params
WeightedBpSolver::getPosterioriOf (VarId vid) WeightedBp::getPosterioriOf (VarId vid)
{ {
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
@ -47,7 +47,7 @@ WeightedBpSolver::getPosterioriOf (VarId vid)
void void
WeightedBpSolver::createLinks (void) WeightedBp::createLinks (void)
{ {
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
cout << "compressed factor graph contains " ; cout << "compressed factor graph contains " ;
@ -78,7 +78,7 @@ WeightedBpSolver::createLinks (void)
void void
WeightedBpSolver::maxResidualSchedule (void) WeightedBp::maxResidualSchedule (void)
{ {
if (nIters_ == 1) { if (nIters_ == 1) {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
@ -151,7 +151,7 @@ WeightedBpSolver::maxResidualSchedule (void)
void void
WeightedBpSolver::calcFactorToVarMsg (BpLink* _link) WeightedBp::calcFactorToVarMsg (BpLink* _link)
{ {
WeightedLink* link = static_cast<WeightedLink*> (_link); WeightedLink* link = static_cast<WeightedLink*> (_link);
FacNode* src = link->facNode(); FacNode* src = link->facNode();
@ -223,7 +223,7 @@ WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
Params Params
WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const WeightedBp::getVarToFactorMsg (const BpLink* _link) const
{ {
const WeightedLink* link = static_cast<const WeightedLink*> (_link); const WeightedLink* link = static_cast<const WeightedLink*> (_link);
const VarNode* src = link->varNode(); const VarNode* src = link->varNode();
@ -272,7 +272,7 @@ WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const
void void
WeightedBpSolver::printLinkInformation (void) const WeightedBp::printLinkInformation (void) const
{ {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links_[i]); WeightedLink* l = static_cast<WeightedLink*> (links_[i]);

View File

@ -1,7 +1,7 @@
#ifndef HORUS_WEIGHTEDBPSOLVER_H #ifndef HORUS_WEIGHTEDBP_H
#define HORUS_WEIGHTEDBPSOLVER_H #define HORUS_WEIGHTEDBP_H
#include "BpSolver.h" #include "BeliefProp.h"
class WeightedLink : public BpLink class WeightedLink : public BpLink
{ {
@ -31,14 +31,14 @@ class WeightedLink : public BpLink
class WeightedBpSolver : public BpSolver class WeightedBp : public BeliefProp
{ {
public: public:
WeightedBpSolver (const FactorGraph& fg, WeightedBp (const FactorGraph& fg,
const vector<vector<unsigned>>& weights) const vector<vector<unsigned>>& weights)
: BpSolver (fg), weights_(weights) { } : BeliefProp (fg), weights_(weights) { }
~WeightedBpSolver (void); ~WeightedBp (void);
Params getPosterioriOf (VarId); Params getPosterioriOf (VarId);
@ -57,5 +57,5 @@ class WeightedBpSolver : public BpSolver
vector<vector<unsigned>> weights_; vector<vector<unsigned>> weights_;
}; };
#endif // HORUS_WEIGHTEDBPSOLVER_H #endif // HORUS_WEIGHTEDBP_H