Merge branch 'master' of https://github.com/tacgomes/yap6.3
This commit is contained in:
commit
5fe052a3ef
110
packages/CLPBN/README.txt
Normal file
110
packages/CLPBN/README.txt
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
Prolog Factor Language (PFL)
|
||||||
|
|
||||||
|
Prolog Factor Language (PFL) is a extension of the Prolog language that
|
||||||
|
allows a natural representation of this first-order probabilistic models
|
||||||
|
(either directed or undirected). PFL is also capable of solving probabilistic
|
||||||
|
queries on this models through the implementation of several inference
|
||||||
|
techniques: variable elimination, belief propagation, lifted variable
|
||||||
|
elimination and lifted belief propagation.
|
||||||
|
|
||||||
|
Language
|
||||||
|
-------------------------------------------------------------------------------
|
||||||
|
A graphical model in PFL is represented using parfactors. A PFL parfactor
|
||||||
|
has the following four components:
|
||||||
|
|
||||||
|
Type ; Formulas ; Phi ; Constraint .
|
||||||
|
|
||||||
|
- Type refers the type of the network over which the parfactor is defined.
|
||||||
|
It can be bayes for directed networks, or markov for undirected ones.
|
||||||
|
- Formulas is a sequence of Prolog terms that define sets of random variables
|
||||||
|
under the constraint.
|
||||||
|
- Phi is either a list of parameters or a call to a Prolog goal that will
|
||||||
|
unify its last argument with a list of parameters.
|
||||||
|
- Constraint is a list (possible empty) of Prolog goals that will impose
|
||||||
|
bindings on the logical variables that appear in the formulas.
|
||||||
|
|
||||||
|
The "examples" directory contains some popular graphical models described
|
||||||
|
using PFL.
|
||||||
|
|
||||||
|
Querying
|
||||||
|
-------------------------------------------------------------------------------
|
||||||
|
Now we show how to use PFL to solve probabilistic queries. We will
|
||||||
|
use the burlgary alarm network as an example. First, we load the model:
|
||||||
|
|
||||||
|
$ yap -l examples/burglary-alarm.yap
|
||||||
|
|
||||||
|
Now let's suppose that we want to estimate the probability of a earthquake
|
||||||
|
ocurred given that mary called. We can do it with the following query:
|
||||||
|
|
||||||
|
?- earthquake(X), mary_calls(t).
|
||||||
|
|
||||||
|
Suppose now that we want the joint distribution for john_calls and
|
||||||
|
mary_calls. We can obtain this with the following query:
|
||||||
|
|
||||||
|
?- john_calls(X), mary_calls(Y).
|
||||||
|
|
||||||
|
|
||||||
|
Inference Options
|
||||||
|
-------------------------------------------------------------------------------
|
||||||
|
PFL supports both ground and lifted inference. The inference algorithm
|
||||||
|
can be chosen using the set_solver/1 predicate. The following algorithms
|
||||||
|
are supported:
|
||||||
|
- fove: lifted variable elimination with arbitrary constraints (GC-FOVE)
|
||||||
|
- hve: (ground) variable elimination
|
||||||
|
- lbp: lifted first-order belief propagation
|
||||||
|
- cbp: counting belief propagation
|
||||||
|
- bp: (ground) belief propagation
|
||||||
|
|
||||||
|
For example, if we want to use ground variable elimination to solve some
|
||||||
|
query, we need to call first the following goal:
|
||||||
|
|
||||||
|
?- set_solver(hve).
|
||||||
|
|
||||||
|
It is possible to tweak several parameters of PFL through the
|
||||||
|
set_horus_flag/2 predicate. The first argument is a key that
|
||||||
|
identifies the parameter that we desire to tweak, while the second
|
||||||
|
is some possible value for this key.
|
||||||
|
|
||||||
|
The verbosity key controls the level of log information that will be
|
||||||
|
printed by the corresponding solver. Its possible values are positive
|
||||||
|
integers. The bigger the number, more log information will be printed.
|
||||||
|
For example, to view some basic log information we need to call the
|
||||||
|
following goal:
|
||||||
|
|
||||||
|
?- set_horus_flag(verbosity, 1).
|
||||||
|
|
||||||
|
The use_logarithms key controls whether the calculations performed
|
||||||
|
during inference should be done in the log domain or not. Its values
|
||||||
|
can be true or false. By default is false.
|
||||||
|
|
||||||
|
There are also keys specific to the inference algorithm. For example,
|
||||||
|
elim_heuristic key controls the elimination heuristic that will be
|
||||||
|
used by ground variable elimination. The following heuristics are
|
||||||
|
supported:
|
||||||
|
- sequential
|
||||||
|
- min_neighbors
|
||||||
|
- min_weight
|
||||||
|
- min_fill
|
||||||
|
- weighted_min_fill
|
||||||
|
|
||||||
|
An explanation of this heuristics can be found in Probabilistic Graphical
|
||||||
|
Models by Daphne Koller.
|
||||||
|
|
||||||
|
The schedule, accuracy and max_iter keys are specific for inference
|
||||||
|
algorithms based on message passing, namely lbp, cbp and bp.
|
||||||
|
The key schedule can be used to specify the order in which the messages
|
||||||
|
are sent in belief propagation. The possible values are:
|
||||||
|
- seq_fixed: at each iteration, all messages are sent in the same order
|
||||||
|
- seq_random: at each iteration, the messages are sent with a random order
|
||||||
|
- parallel: at each iteration, the messages are all calculated using the
|
||||||
|
values of the previous iteration.
|
||||||
|
- max_residual: the next message to be sent is the one with maximum residual,
|
||||||
|
(Residual Belief Propagation:Informed Scheduling for Asynchronous Message
|
||||||
|
Passing)
|
||||||
|
|
||||||
|
The max_iter key sets the maximum number of iterations. One iteration
|
||||||
|
consists in sending all possible messages. The accuracy key indicate
|
||||||
|
when we should stop sending messages. If the largest difference between
|
||||||
|
a message sent in the current iteration and one message sent in the previous
|
||||||
|
iteration is less that accuracy value given, we terminate belief propagation.
|
||||||
|
|
@ -26,6 +26,8 @@ function run_solver
|
|||||||
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||||
elif [ $SOLVER = cbp ]; then
|
elif [ $SOLVER = cbp ]; then
|
||||||
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||||
|
elif [ $SOLVER = lbp ]; then
|
||||||
|
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||||
else
|
else
|
||||||
echo "unknow flag $2"
|
echo "unknow flag $2"
|
||||||
fi
|
fi
|
||||||
|
@ -23,7 +23,7 @@ function run_all_graphs
|
|||||||
run_solver city60000 $2
|
run_solver city60000 $2
|
||||||
run_solver city65000 $2
|
run_solver city65000 $2
|
||||||
run_solver city70000 $2
|
run_solver city70000 $2
|
||||||
|
return
|
||||||
run_solver city75000 $2
|
run_solver city75000 $2
|
||||||
run_solver city80000 $2
|
run_solver city80000 $2
|
||||||
run_solver city85000 $2
|
run_solver city85000 $2
|
||||||
|
36
packages/CLPBN/benchmarks/city/lbp_tests.sh
Executable file
36
packages/CLPBN/benchmarks/city/lbp_tests.sh
Executable file
@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source city.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="lbp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver city1000 $2
|
||||||
|
run_solver city5000 $2
|
||||||
|
run_solver city10000 $2
|
||||||
|
run_solver city15000 $2
|
||||||
|
run_solver city20000 $2
|
||||||
|
run_solver city25000 $2
|
||||||
|
run_solver city30000 $2
|
||||||
|
run_solver city35000 $2
|
||||||
|
run_solver city40000 $2
|
||||||
|
run_solver city45000 $2
|
||||||
|
run_solver city50000 $2
|
||||||
|
run_solver city55000 $2
|
||||||
|
run_solver city60000 $2
|
||||||
|
run_solver city65000 $2
|
||||||
|
run_solver city70000 $2
|
||||||
|
run_solver city75000 $2
|
||||||
|
run_solver city80000 $2
|
||||||
|
run_solver city85000 $2
|
||||||
|
run_solver city90000 $2
|
||||||
|
run_solver city95000 $2
|
||||||
|
run_solver city100000 $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
30
packages/CLPBN/benchmarks/comp_workshops/lbp_tests.sh
Executable file
30
packages/CLPBN/benchmarks/comp_workshops/lbp_tests.sh
Executable file
@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source cw.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="lbp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver p1000w$N_WORKSHOPS $2
|
||||||
|
run_solver p5000w$N_WORKSHOPS $2
|
||||||
|
run_solver p10000w$N_WORKSHOPS $2
|
||||||
|
run_solver p15000w$N_WORKSHOPS $2
|
||||||
|
run_solver p20000w$N_WORKSHOPS $2
|
||||||
|
run_solver p25000w$N_WORKSHOPS $2
|
||||||
|
run_solver p30000w$N_WORKSHOPS $2
|
||||||
|
run_solver p35000w$N_WORKSHOPS $2
|
||||||
|
run_solver p40000w$N_WORKSHOPS $2
|
||||||
|
run_solver p45000w$N_WORKSHOPS $2
|
||||||
|
run_solver p50000w$N_WORKSHOPS $2
|
||||||
|
run_solver p55000w$N_WORKSHOPS $2
|
||||||
|
run_solver p60000w$N_WORKSHOPS $2
|
||||||
|
run_solver p65000w$N_WORKSHOPS $2
|
||||||
|
run_solver p70000w$N_WORKSHOPS $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
35
packages/CLPBN/benchmarks/run_all.sh
Executable file
35
packages/CLPBN/benchmarks/run_all.sh
Executable file
@ -0,0 +1,35 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
cd workshop_attrs
|
||||||
|
source hve_tests.sh
|
||||||
|
source bp_tests.sh
|
||||||
|
source fove_tests.sh
|
||||||
|
source lbp_tests.sh
|
||||||
|
source cbp_tests.sh
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
cd comp_workshops
|
||||||
|
source hve_tests.sh
|
||||||
|
source bp_tests.sh
|
||||||
|
source fove_tests.sh
|
||||||
|
source lbp_tests.sh
|
||||||
|
source cbp_tests.sh
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
cd city
|
||||||
|
source hve_tests.sh
|
||||||
|
source bp_tests.sh
|
||||||
|
source fove_tests.sh
|
||||||
|
source lbp_tests.sh
|
||||||
|
source cbp_tests.sh
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
cd smokers
|
||||||
|
source hve_tests.sh
|
||||||
|
source bp_tests.sh
|
||||||
|
source fove_tests.sh
|
||||||
|
source lbp_tests.sh
|
||||||
|
source cbp_tests.sh
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
|
30
packages/CLPBN/benchmarks/smokers/lbp_tests.sh
Executable file
30
packages/CLPBN/benchmarks/smokers/lbp_tests.sh
Executable file
@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source sm.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="lbp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver pop100 $2
|
||||||
|
run_solver pop200 $2
|
||||||
|
run_solver pop300 $2
|
||||||
|
run_solver pop400 $2
|
||||||
|
run_solver pop500 $2
|
||||||
|
run_solver pop600 $2
|
||||||
|
run_solver pop700 $2
|
||||||
|
run_solver pop800 $2
|
||||||
|
run_solver pop900 $2
|
||||||
|
run_solver pop1000 $2
|
||||||
|
run_solver pop1100 $2
|
||||||
|
run_solver pop1200 $2
|
||||||
|
run_solver pop1300 $2
|
||||||
|
run_solver pop1400 $2
|
||||||
|
run_solver pop1500 $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
@ -2,5 +2,6 @@
|
|||||||
|
|
||||||
NETWORK="'../../examples/social_domain2'"
|
NETWORK="'../../examples/social_domain2'"
|
||||||
SHORTNAME="sm"
|
SHORTNAME="sm"
|
||||||
QUERY="smokes(p1,t), smokes(p2,t), friends(p1,p2,X)"
|
#QUERY="smokes(p1,t), smokes(p2,t), friends(p1,p2,X)"
|
||||||
|
QUERY="friends(p1,p2,X)"
|
||||||
|
|
||||||
|
34
packages/CLPBN/benchmarks/smokers_evidence/bp_tests.sh
Executable file
34
packages/CLPBN/benchmarks/smokers_evidence/bp_tests.sh
Executable file
@ -0,0 +1,34 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source sm.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="bp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver ev0p$POP $2
|
||||||
|
run_solver ev5p$POP $2
|
||||||
|
run_solver ev10p$POP $2
|
||||||
|
run_solver ev15p$POP $2
|
||||||
|
run_solver ev20p$POP $2
|
||||||
|
run_solver ev25p$POP $2
|
||||||
|
run_solver ev30p$POP $2
|
||||||
|
run_solver ev35p$POP $2
|
||||||
|
run_solver ev40p$POP $2
|
||||||
|
run_solver ev45p$POP $2
|
||||||
|
run_solver ev50p$POP $2
|
||||||
|
run_solver ev55p$POP $2
|
||||||
|
run_solver ev60p$POP $2
|
||||||
|
run_solver ev65p$POP $2
|
||||||
|
run_solver ev70p$POP $2
|
||||||
|
run_solver ev75p$POP $2
|
||||||
|
run_solver ev80p$POP $2
|
||||||
|
run_solver ev85p$POP $2
|
||||||
|
run_solver ev90p$POP $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
34
packages/CLPBN/benchmarks/smokers_evidence/cbp_tests.sh
Executable file
34
packages/CLPBN/benchmarks/smokers_evidence/cbp_tests.sh
Executable file
@ -0,0 +1,34 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source sm.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="cbp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver ev0p$POP $2
|
||||||
|
run_solver ev5p$POP $2
|
||||||
|
run_solver ev10p$POP $2
|
||||||
|
run_solver ev15p$POP $2
|
||||||
|
run_solver ev20p$POP $2
|
||||||
|
run_solver ev25p$POP $2
|
||||||
|
run_solver ev30p$POP $2
|
||||||
|
run_solver ev35p$POP $2
|
||||||
|
run_solver ev40p$POP $2
|
||||||
|
run_solver ev45p$POP $2
|
||||||
|
run_solver ev50p$POP $2
|
||||||
|
run_solver ev55p$POP $2
|
||||||
|
run_solver ev60p$POP $2
|
||||||
|
run_solver ev65p$POP $2
|
||||||
|
run_solver ev70p$POP $2
|
||||||
|
run_solver ev75p$POP $2
|
||||||
|
run_solver ev80p$POP $2
|
||||||
|
run_solver ev85p$POP $2
|
||||||
|
run_solver ev90p$POP $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
35
packages/CLPBN/benchmarks/smokers_evidence/fove_tests.sh
Executable file
35
packages/CLPBN/benchmarks/smokers_evidence/fove_tests.sh
Executable file
@ -0,0 +1,35 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source sm.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="fove"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver ev0p$POP $2
|
||||||
|
run_solver ev5p$POP $2
|
||||||
|
run_solver ev10p$POP $2
|
||||||
|
run_solver ev15p$POP $2
|
||||||
|
run_solver ev20p$POP $2
|
||||||
|
run_solver ev25p$POP $2
|
||||||
|
run_solver ev30p$POP $2
|
||||||
|
run_solver ev35p$POP $2
|
||||||
|
run_solver ev40p$POP $2
|
||||||
|
run_solver ev45p$POP $2
|
||||||
|
run_solver ev50p$POP $2
|
||||||
|
run_solver ev55p$POP $2
|
||||||
|
run_solver ev60p$POP $2
|
||||||
|
run_solver ev65p$POP $2
|
||||||
|
run_solver ev70p$POP $2
|
||||||
|
run_solver ev75p$POP $2
|
||||||
|
run_solver ev80p$POP $2
|
||||||
|
run_solver ev85p$POP $2
|
||||||
|
run_solver ev90p$POP $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "fove "
|
||||||
|
|
||||||
|
|
49
packages/CLPBN/benchmarks/smokers_evidence/gen_people.sh
Executable file
49
packages/CLPBN/benchmarks/smokers_evidence/gen_people.sh
Executable file
@ -0,0 +1,49 @@
|
|||||||
|
#!/home/tgomes/bin/yap -L --
|
||||||
|
|
||||||
|
:- use_module(library(lists)).
|
||||||
|
:- use_module(library(random)).
|
||||||
|
|
||||||
|
|
||||||
|
:- initialization(main).
|
||||||
|
|
||||||
|
main :-
|
||||||
|
unix(argv(Args)),
|
||||||
|
nth(1, Args, EV), % percentage of evidence
|
||||||
|
nth(2, Args, NP), % number of individuals
|
||||||
|
atomic_concat(['ev', EV, 'p', NP, '.yap'], FileName),
|
||||||
|
open(FileName, 'write', S),
|
||||||
|
atom_number(EV, EV2),
|
||||||
|
atom_number(NP, NP2),
|
||||||
|
EV3 is EV2 / 100.0,
|
||||||
|
generate_people(S, NP2, 4),
|
||||||
|
write(S, '\n'),
|
||||||
|
write(S, 'query(X) :- '),
|
||||||
|
generate_evidence(S, NP2, EV3, 4),
|
||||||
|
write(S, 'friends(p1,p2,X).\n'),
|
||||||
|
close(S).
|
||||||
|
|
||||||
|
|
||||||
|
generate_people(S, N, Counting) :-
|
||||||
|
Counting > N, !.
|
||||||
|
generate_people(S, N, Counting) :-
|
||||||
|
format(S, 'people(p~w).~n', [Counting]),
|
||||||
|
Counting1 is Counting + 1,
|
||||||
|
generate_people(S, N, Counting1).
|
||||||
|
|
||||||
|
|
||||||
|
generate_evidence(S, N, Ev, Counting) :-
|
||||||
|
Counting > N, !.
|
||||||
|
generate_evidence(S, N, Ev, Counting) :-
|
||||||
|
random(X),
|
||||||
|
(
|
||||||
|
X < Ev
|
||||||
|
->
|
||||||
|
random(Y),
|
||||||
|
(Y > 0.5 -> Val = t ; Val = f),
|
||||||
|
format(S, 'smokes(p~w,~w),', [Counting,Val])
|
||||||
|
;
|
||||||
|
true
|
||||||
|
),
|
||||||
|
Counting1 is Counting + 1,
|
||||||
|
generate_evidence(S, N, Ev, Counting1).
|
||||||
|
|
37
packages/CLPBN/benchmarks/smokers_evidence/hve_tests.sh
Executable file
37
packages/CLPBN/benchmarks/smokers_evidence/hve_tests.sh
Executable file
@ -0,0 +1,37 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source sm.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="hve"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver ev0p$POP $2
|
||||||
|
run_solver ev5p$POP $2
|
||||||
|
run_solver ev10p$POP $2
|
||||||
|
run_solver ev15p$POP $2
|
||||||
|
run_solver ev20p$POP $2
|
||||||
|
run_solver ev25p$POP $2
|
||||||
|
run_solver ev30p$POP $2
|
||||||
|
run_solver ev35p$POP $2
|
||||||
|
run_solver ev40p$POP $2
|
||||||
|
run_solver ev45p$POP $2
|
||||||
|
run_solver ev50p$POP $2
|
||||||
|
run_solver ev55p$POP $2
|
||||||
|
run_solver ev60p$POP $2
|
||||||
|
run_solver ev65p$POP $2
|
||||||
|
run_solver ev70p$POP $2
|
||||||
|
run_solver ev75p$POP $2
|
||||||
|
run_solver ev80p$POP $2
|
||||||
|
run_solver ev85p$POP $2
|
||||||
|
run_solver ev90p$POP $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
||||||
|
#run_all_graphs "hve(elim_heuristic=min_weight) " min_weight
|
||||||
|
#run_all_graphs "hve(elim_heuristic=min_fill) " min_fill
|
||||||
|
#run_all_graphs "hve(elim_heuristic=weighted_min_fill) " weighted_min_fill
|
||||||
|
|
34
packages/CLPBN/benchmarks/smokers_evidence/lbp_tests.sh
Executable file
34
packages/CLPBN/benchmarks/smokers_evidence/lbp_tests.sh
Executable file
@ -0,0 +1,34 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source sm.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="lbp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver ev0p$POP $2
|
||||||
|
run_solver ev5p$POP $2
|
||||||
|
run_solver ev10p$POP $2
|
||||||
|
run_solver ev15p$POP $2
|
||||||
|
run_solver ev20p$POP $2
|
||||||
|
run_solver ev25p$POP $2
|
||||||
|
run_solver ev30p$POP $2
|
||||||
|
run_solver ev35p$POP $2
|
||||||
|
run_solver ev40p$POP $2
|
||||||
|
run_solver ev45p$POP $2
|
||||||
|
run_solver ev50p$POP $2
|
||||||
|
run_solver ev55p$POP $2
|
||||||
|
run_solver ev60p$POP $2
|
||||||
|
run_solver ev65p$POP $2
|
||||||
|
run_solver ev70p$POP $2
|
||||||
|
run_solver ev75p$POP $2
|
||||||
|
run_solver ev80p$POP $2
|
||||||
|
run_solver ev85p$POP $2
|
||||||
|
run_solver ev90p$POP $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
8
packages/CLPBN/benchmarks/smokers_evidence/sm.sh
Executable file
8
packages/CLPBN/benchmarks/smokers_evidence/sm.sh
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
NETWORK="'../../examples/social_domain2'"
|
||||||
|
SHORTNAME="sm"
|
||||||
|
QUERY="query(X)"
|
||||||
|
|
||||||
|
POP=500
|
||||||
|
|
@ -15,8 +15,8 @@ function run_all_graphs
|
|||||||
run_solver p20000attrs$N_ATTRS $2
|
run_solver p20000attrs$N_ATTRS $2
|
||||||
run_solver p25000attrs$N_ATTRS $2
|
run_solver p25000attrs$N_ATTRS $2
|
||||||
run_solver p30000attrs$N_ATTRS $2
|
run_solver p30000attrs$N_ATTRS $2
|
||||||
run_solver p35000attrs$N_ATTRS $2
|
|
||||||
return
|
return
|
||||||
|
run_solver p35000attrs$N_ATTRS $2
|
||||||
run_solver p40000attrs$N_ATTRS $2
|
run_solver p40000attrs$N_ATTRS $2
|
||||||
run_solver p45000attrs$N_ATTRS $2
|
run_solver p45000attrs$N_ATTRS $2
|
||||||
run_solver p50000attrs$N_ATTRS $2
|
run_solver p50000attrs$N_ATTRS $2
|
||||||
|
36
packages/CLPBN/benchmarks/workshop_attrs/lbp_tests.sh
Executable file
36
packages/CLPBN/benchmarks/workshop_attrs/lbp_tests.sh
Executable file
@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source wa.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="lbp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver p1000attrs$N_ATTRS $2
|
||||||
|
run_solver p5000attrs$N_ATTRS $2
|
||||||
|
run_solver p10000attrs$N_ATTRS $2
|
||||||
|
run_solver p15000attrs$N_ATTRS $2
|
||||||
|
run_solver p20000attrs$N_ATTRS $2
|
||||||
|
run_solver p25000attrs$N_ATTRS $2
|
||||||
|
run_solver p30000attrs$N_ATTRS $2
|
||||||
|
run_solver p35000attrs$N_ATTRS $2
|
||||||
|
run_solver p40000attrs$N_ATTRS $2
|
||||||
|
run_solver p45000attrs$N_ATTRS $2
|
||||||
|
run_solver p50000attrs$N_ATTRS $2
|
||||||
|
run_solver p55000attrs$N_ATTRS $2
|
||||||
|
run_solver p60000attrs$N_ATTRS $2
|
||||||
|
run_solver p65000attrs$N_ATTRS $2
|
||||||
|
run_solver p70000attrs$N_ATTRS $2
|
||||||
|
run_solver p75000attrs$N_ATTRS $2
|
||||||
|
run_solver p80000attrs$N_ATTRS $2
|
||||||
|
run_solver p85000attrs$N_ATTRS $2
|
||||||
|
run_solver p90000attrs$N_ATTRS $2
|
||||||
|
run_solver p95000attrs$N_ATTRS $2
|
||||||
|
run_solver p100000attrs$N_ATTRS $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
@ -15,7 +15,7 @@
|
|||||||
cpp_run_ground_solver/3,
|
cpp_run_ground_solver/3,
|
||||||
cpp_set_vars_information/2,
|
cpp_set_vars_information/2,
|
||||||
cpp_set_horus_flag/2,
|
cpp_set_horus_flag/2,
|
||||||
cpp_free_parfactors/1,
|
cpp_free_lifted_network/1,
|
||||||
cpp_free_ground_network/1
|
cpp_free_ground_network/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
@ -79,7 +79,6 @@ run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
|||||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||||
%writeln(queryKeys:QueryKeys), writeln(''),
|
%writeln(queryKeys:QueryKeys), writeln(''),
|
||||||
%writeln(queryIds:QueryIds), writeln(''),
|
%writeln(queryIds:QueryIds), writeln(''),
|
||||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
|
||||||
cpp_run_ground_solver(Network, [QueryIds], Solutions).
|
cpp_run_ground_solver(Network, [QueryIds], Solutions).
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
[cpp_create_lifted_network/3,
|
[cpp_create_lifted_network/3,
|
||||||
cpp_set_parfactors_params/2,
|
cpp_set_parfactors_params/2,
|
||||||
cpp_run_lifted_solver/3,
|
cpp_run_lifted_solver/3,
|
||||||
cpp_free_parfactors/1
|
cpp_free_lifted_network/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
:- use_module(library('clpbn/display'),
|
:- use_module(library('clpbn/display'),
|
||||||
@ -144,5 +144,5 @@ run_horus_lifted_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds))
|
|||||||
|
|
||||||
|
|
||||||
finalize_horus_lifted_solver(fove(ParfactorList, _)) :-
|
finalize_horus_lifted_solver(fove(ParfactorList, _)) :-
|
||||||
cpp_free_parfactors(ParfactorList).
|
cpp_free_lifted_network(ParfactorList).
|
||||||
|
|
||||||
|
@ -7,17 +7,17 @@
|
|||||||
|
|
||||||
:- yap_flag(write_strings, off).
|
:- yap_flag(write_strings, off).
|
||||||
|
|
||||||
bayes burglary::[b1,b2] ; [0.001, 0.999] ; [].
|
bayes burglary::[t,f] ; [0.001, 0.999] ; [].
|
||||||
|
|
||||||
bayes earthquake::[e1,e2] ; [0.002, 0.998]; [].
|
bayes earthquake::[t,f] ; [0.002, 0.998]; [].
|
||||||
|
|
||||||
bayes alarm::[a1,a2], burglary, earthquake ;
|
bayes alarm::[t,f], burglary, earthquake ;
|
||||||
[0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ;
|
[0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ;
|
||||||
[].
|
[].
|
||||||
|
|
||||||
bayes john_calls::[j1,j2], alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
|
bayes john_calls::[t,f], alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
|
||||||
|
|
||||||
bayes mary_calls::[m1,m2], alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
|
bayes mary_calls::[t,f], alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
|
||||||
|
|
||||||
% ?- john_calls(J), mary_calls(m1).
|
% ?- john_calls(J), mary_calls(t).
|
||||||
|
|
||||||
|
@ -16,13 +16,13 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
|||||||
Scheduling scheduling;
|
Scheduling scheduling;
|
||||||
for (size_t i = 0; i < queryIds.size(); i++) {
|
for (size_t i = 0; i < queryIds.size(); i++) {
|
||||||
assert (dag_.getNode (queryIds[i]));
|
assert (dag_.getNode (queryIds[i]));
|
||||||
DAGraphNode* n = dag_.getNode (queryIds[i]);
|
BBNode* n = dag_.getNode (queryIds[i]);
|
||||||
scheduling.push (ScheduleInfo (n, false, true));
|
scheduling.push (ScheduleInfo (n, false, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
while (!scheduling.empty()) {
|
while (!scheduling.empty()) {
|
||||||
ScheduleInfo& sch = scheduling.front();
|
ScheduleInfo& sch = scheduling.front();
|
||||||
DAGraphNode* n = sch.node;
|
BBNode* n = sch.node;
|
||||||
n->setAsVisited();
|
n->setAsVisited();
|
||||||
if (n->hasEvidence() == false && sch.visitedFromChild) {
|
if (n->hasEvidence() == false && sch.visitedFromChild) {
|
||||||
if (n->isMarkedOnTop() == false) {
|
if (n->isMarkedOnTop() == false) {
|
||||||
@ -59,7 +59,7 @@ BayesBall::constructGraph (FactorGraph* fg) const
|
|||||||
{
|
{
|
||||||
const FacNodes& facNodes = fg_.facNodes();
|
const FacNodes& facNodes = fg_.facNodes();
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
const DAGraphNode* n = dag_.getNode (
|
const BBNode* n = dag_.getNode (
|
||||||
facNodes[i]->factor().argument (0));
|
facNodes[i]->factor().argument (0));
|
||||||
if (n->isMarkedOnTop()) {
|
if (n->isMarkedOnTop()) {
|
||||||
fg->addFactor (facNodes[i]->factor());
|
fg->addFactor (facNodes[i]->factor());
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "BayesNet.h"
|
#include "BayesBallGraph.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -15,10 +15,10 @@ using namespace std;
|
|||||||
|
|
||||||
struct ScheduleInfo
|
struct ScheduleInfo
|
||||||
{
|
{
|
||||||
ScheduleInfo (DAGraphNode* n, bool vfp, bool vfc) :
|
ScheduleInfo (BBNode* n, bool vfp, bool vfc) :
|
||||||
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||||
|
|
||||||
DAGraphNode* node;
|
BBNode* node;
|
||||||
bool visitedFromParent;
|
bool visitedFromParent;
|
||||||
bool visitedFromChild;
|
bool visitedFromChild;
|
||||||
};
|
};
|
||||||
@ -48,22 +48,22 @@ class BayesBall
|
|||||||
|
|
||||||
void constructGraph (FactorGraph* fg) const;
|
void constructGraph (FactorGraph* fg) const;
|
||||||
|
|
||||||
void scheduleParents (const DAGraphNode* n, Scheduling& sch) const;
|
void scheduleParents (const BBNode* n, Scheduling& sch) const;
|
||||||
|
|
||||||
void scheduleChilds (const DAGraphNode* n, Scheduling& sch) const;
|
void scheduleChilds (const BBNode* n, Scheduling& sch) const;
|
||||||
|
|
||||||
FactorGraph& fg_;
|
FactorGraph& fg_;
|
||||||
|
|
||||||
DAGraph& dag_;
|
BayesBallGraph& dag_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
inline void
|
||||||
BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
|
BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const
|
||||||
{
|
{
|
||||||
const vector<DAGraphNode*>& ps = n->parents();
|
const vector<BBNode*>& ps = n->parents();
|
||||||
for (vector<DAGraphNode*>::const_iterator it = ps.begin();
|
for (vector<BBNode*>::const_iterator it = ps.begin();
|
||||||
it != ps.end(); ++it) {
|
it != ps.end(); ++it) {
|
||||||
sch.push (ScheduleInfo (*it, false, true));
|
sch.push (ScheduleInfo (*it, false, true));
|
||||||
}
|
}
|
||||||
@ -72,10 +72,10 @@ BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
|
|||||||
|
|
||||||
|
|
||||||
inline void
|
inline void
|
||||||
BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const
|
BayesBall::scheduleChilds (const BBNode* n, Scheduling& sch) const
|
||||||
{
|
{
|
||||||
const vector<DAGraphNode*>& cs = n->childs();
|
const vector<BBNode*>& cs = n->childs();
|
||||||
for (vector<DAGraphNode*>::const_iterator it = cs.begin();
|
for (vector<BBNode*>::const_iterator it = cs.begin();
|
||||||
it != cs.end(); ++it) {
|
it != cs.end(); ++it) {
|
||||||
sch.push (ScheduleInfo (*it, true, false));
|
sch.push (ScheduleInfo (*it, true, false));
|
||||||
}
|
}
|
||||||
|
@ -5,12 +5,12 @@
|
|||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "BayesNet.h"
|
#include "BayesBallGraph.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
DAGraph::addNode (DAGraphNode* n)
|
BayesBallGraph::addNode (BBNode* n)
|
||||||
{
|
{
|
||||||
assert (Util::contains (varMap_, n->varId()) == false);
|
assert (Util::contains (varMap_, n->varId()) == false);
|
||||||
nodes_.push_back (n);
|
nodes_.push_back (n);
|
||||||
@ -20,10 +20,10 @@ DAGraph::addNode (DAGraphNode* n)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
DAGraph::addEdge (VarId vid1, VarId vid2)
|
BayesBallGraph::addEdge (VarId vid1, VarId vid2)
|
||||||
{
|
{
|
||||||
unordered_map<VarId, DAGraphNode*>::iterator it1;
|
unordered_map<VarId, BBNode*>::iterator it1;
|
||||||
unordered_map<VarId, DAGraphNode*>::iterator it2;
|
unordered_map<VarId, BBNode*>::iterator it2;
|
||||||
it1 = varMap_.find (vid1);
|
it1 = varMap_.find (vid1);
|
||||||
it2 = varMap_.find (vid2);
|
it2 = varMap_.find (vid2);
|
||||||
assert (it1 != varMap_.end());
|
assert (it1 != varMap_.end());
|
||||||
@ -34,20 +34,20 @@ DAGraph::addEdge (VarId vid1, VarId vid2)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
const DAGraphNode*
|
const BBNode*
|
||||||
DAGraph::getNode (VarId vid) const
|
BayesBallGraph::getNode (VarId vid) const
|
||||||
{
|
{
|
||||||
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
unordered_map<VarId, BBNode*>::const_iterator it;
|
||||||
it = varMap_.find (vid);
|
it = varMap_.find (vid);
|
||||||
return it != varMap_.end() ? it->second : 0;
|
return it != varMap_.end() ? it->second : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DAGraphNode*
|
BBNode*
|
||||||
DAGraph::getNode (VarId vid)
|
BayesBallGraph::getNode (VarId vid)
|
||||||
{
|
{
|
||||||
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
unordered_map<VarId, BBNode*>::const_iterator it;
|
||||||
it = varMap_.find (vid);
|
it = varMap_.find (vid);
|
||||||
return it != varMap_.end() ? it->second : 0;
|
return it != varMap_.end() ? it->second : 0;
|
||||||
}
|
}
|
||||||
@ -55,7 +55,7 @@ DAGraph::getNode (VarId vid)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
DAGraph::setIndexes (void)
|
BayesBallGraph::setIndexes (void)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
nodes_[i]->setIndex (i);
|
nodes_[i]->setIndex (i);
|
||||||
@ -65,7 +65,7 @@ DAGraph::setIndexes (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
DAGraph::clear (void)
|
BayesBallGraph::clear (void)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
nodes_[i]->clear();
|
nodes_[i]->clear();
|
||||||
@ -75,12 +75,12 @@ DAGraph::clear (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
DAGraph::exportToGraphViz (const char* fileName)
|
BayesBallGraph::exportToGraphViz (const char* fileName)
|
||||||
{
|
{
|
||||||
ofstream out (fileName);
|
ofstream out (fileName);
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file to write at " ;
|
||||||
cerr << "DAGraph::exportToDotFile()" << endl;
|
cerr << "BayesBallGraph::exportToDotFile()" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
out << "digraph {" << endl;
|
out << "digraph {" << endl;
|
||||||
@ -95,7 +95,7 @@ DAGraph::exportToGraphViz (const char* fileName)
|
|||||||
out << "]" << endl;
|
out << "]" << endl;
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
|
const vector<BBNode*>& childs = nodes_[i]->childs();
|
||||||
for (size_t j = 0; j < childs.size(); j++) {
|
for (size_t j = 0; j < childs.size(); j++) {
|
||||||
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
||||||
out << " [style=bold]" << endl ;
|
out << " [style=bold]" << endl ;
|
@ -1,5 +1,5 @@
|
|||||||
#ifndef HORUS_BAYESNET_H
|
#ifndef HORUS_BAYESBALLGRAPH_H
|
||||||
#define HORUS_BAYESNET_H
|
#define HORUS_BAYESBALLGRAPH_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
@ -9,29 +9,25 @@
|
|||||||
#include "Var.h"
|
#include "Var.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
class BBNode : public Var
|
||||||
class Var;
|
|
||||||
|
|
||||||
class DAGraphNode : public Var
|
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
DAGraphNode (Var* v) : Var (v) , visited_(false),
|
BBNode (Var* v) : Var (v) , visited_(false),
|
||||||
markedOnTop_(false), markedOnBottom_(false) { }
|
markedOnTop_(false), markedOnBottom_(false) { }
|
||||||
|
|
||||||
const vector<DAGraphNode*>& childs (void) const { return childs_; }
|
const vector<BBNode*>& childs (void) const { return childs_; }
|
||||||
|
|
||||||
vector<DAGraphNode*>& childs (void) { return childs_; }
|
vector<BBNode*>& childs (void) { return childs_; }
|
||||||
|
|
||||||
const vector<DAGraphNode*>& parents (void) const { return parents_; }
|
const vector<BBNode*>& parents (void) const { return parents_; }
|
||||||
|
|
||||||
vector<DAGraphNode*>& parents (void) { return parents_; }
|
vector<BBNode*>& parents (void) { return parents_; }
|
||||||
|
|
||||||
void addParent (DAGraphNode* p) { parents_.push_back (p); }
|
void addParent (BBNode* p) { parents_.push_back (p); }
|
||||||
|
|
||||||
void addChild (DAGraphNode* c) { childs_.push_back (c); }
|
void addChild (BBNode* c) { childs_.push_back (c); }
|
||||||
|
|
||||||
bool isVisited (void) const { return visited_; }
|
bool isVisited (void) const { return visited_; }
|
||||||
|
|
||||||
@ -52,23 +48,23 @@ class DAGraphNode : public Var
|
|||||||
bool markedOnTop_;
|
bool markedOnTop_;
|
||||||
bool markedOnBottom_;
|
bool markedOnBottom_;
|
||||||
|
|
||||||
vector<DAGraphNode*> childs_;
|
vector<BBNode*> childs_;
|
||||||
vector<DAGraphNode*> parents_;
|
vector<BBNode*> parents_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class DAGraph
|
class BayesBallGraph
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
DAGraph (void) { }
|
BayesBallGraph (void) { }
|
||||||
|
|
||||||
void addNode (DAGraphNode* n);
|
void addNode (BBNode* n);
|
||||||
|
|
||||||
void addEdge (VarId vid1, VarId vid2);
|
void addEdge (VarId vid1, VarId vid2);
|
||||||
|
|
||||||
const DAGraphNode* getNode (VarId vid) const;
|
const BBNode* getNode (VarId vid) const;
|
||||||
|
|
||||||
DAGraphNode* getNode (VarId vid);
|
BBNode* getNode (VarId vid);
|
||||||
|
|
||||||
bool empty (void) const { return nodes_.empty(); }
|
bool empty (void) const { return nodes_.empty(); }
|
||||||
|
|
||||||
@ -79,10 +75,10 @@ class DAGraph
|
|||||||
void exportToGraphViz (const char*);
|
void exportToGraphViz (const char*);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<DAGraphNode*> nodes_;
|
vector<BBNode*> nodes_;
|
||||||
|
|
||||||
unordered_map<VarId, DAGraphNode*> varMap_;
|
unordered_map<VarId, BBNode*> varMap_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_BAYESNET_H
|
#endif // HORUS_BAYESBALLGRAPH_H
|
||||||
|
|
@ -5,21 +5,21 @@
|
|||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "BpSolver.h"
|
#include "BeliefProp.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
BeliefProp::BeliefProp (const FactorGraph& fg) : Solver (fg)
|
||||||
{
|
{
|
||||||
runned_ = false;
|
runned_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BpSolver::~BpSolver (void)
|
BeliefProp::~BeliefProp (void)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < varsI_.size(); i++) {
|
for (size_t i = 0; i < varsI_.size(); i++) {
|
||||||
delete varsI_[i];
|
delete varsI_[i];
|
||||||
@ -35,7 +35,7 @@ BpSolver::~BpSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
BpSolver::solveQuery (VarIds queryVids)
|
BeliefProp::solveQuery (VarIds queryVids)
|
||||||
{
|
{
|
||||||
assert (queryVids.empty() == false);
|
assert (queryVids.empty() == false);
|
||||||
return queryVids.size() == 1
|
return queryVids.size() == 1
|
||||||
@ -46,7 +46,7 @@ BpSolver::solveQuery (VarIds queryVids)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BpSolver::printSolverFlags (void) const
|
BeliefProp::printSolverFlags (void) const
|
||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "belief propagation [" ;
|
ss << "belief propagation [" ;
|
||||||
@ -68,7 +68,7 @@ BpSolver::printSolverFlags (void) const
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
BpSolver::getPosterioriOf (VarId vid)
|
BeliefProp::getPosterioriOf (VarId vid)
|
||||||
{
|
{
|
||||||
if (runned_ == false) {
|
if (runned_ == false) {
|
||||||
runSolver();
|
runSolver();
|
||||||
@ -101,7 +101,7 @@ BpSolver::getPosterioriOf (VarId vid)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
|
||||||
{
|
{
|
||||||
if (runned_ == false) {
|
if (runned_ == false) {
|
||||||
runSolver();
|
runSolver();
|
||||||
@ -117,9 +117,23 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
}
|
}
|
||||||
if (idx == facNodes.size()) {
|
if (idx == facNodes.size()) {
|
||||||
return getJointByConditioning (jointVarIds);
|
return getJointByConditioning (jointVarIds);
|
||||||
} else {
|
}
|
||||||
Factor res (facNodes[idx]->factor());
|
return getFactorJoint (idx, jointVarIds);
|
||||||
const BpLinks& links = ninf(facNodes[idx])->getLinks();
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
BeliefProp::getFactorJoint (
|
||||||
|
size_t fnIdx,
|
||||||
|
const VarIds& jointVarIds)
|
||||||
|
{
|
||||||
|
if (runned_ == false) {
|
||||||
|
runSolver();
|
||||||
|
}
|
||||||
|
FacNode* fn = fg.facNodes()[fnIdx];
|
||||||
|
Factor res (fn->factor());
|
||||||
|
const BpLinks& links = ninf(fn)->getLinks();
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
Factor msg ({links[i]->varNode()->varId()},
|
Factor msg ({links[i]->varNode()->varId()},
|
||||||
{links[i]->varNode()->range()},
|
{links[i]->varNode()->range()},
|
||||||
@ -135,12 +149,11 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
}
|
}
|
||||||
return jointDist;
|
return jointDist;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BpSolver::runSolver (void)
|
BeliefProp::runSolver (void)
|
||||||
{
|
{
|
||||||
initializeSolver();
|
initializeSolver();
|
||||||
nIters_ = 0;
|
nIters_ = 0;
|
||||||
@ -173,7 +186,7 @@ BpSolver::runSolver (void)
|
|||||||
}
|
}
|
||||||
if (Globals::verbosity > 0) {
|
if (Globals::verbosity > 0) {
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
if (nIters_ < BpOptions::maxIter) {
|
||||||
cout << "Sum-Product converged in " ;
|
cout << "Belief propagation converged in " ;
|
||||||
cout << nIters_ << " iterations" << endl;
|
cout << nIters_ << " iterations" << endl;
|
||||||
} else {
|
} else {
|
||||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||||
@ -187,7 +200,7 @@ BpSolver::runSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BpSolver::createLinks (void)
|
BeliefProp::createLinks (void)
|
||||||
{
|
{
|
||||||
const FacNodes& facNodes = fg.facNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
@ -201,7 +214,7 @@ BpSolver::createLinks (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BpSolver::maxResidualSchedule (void)
|
BeliefProp::maxResidualSchedule (void)
|
||||||
{
|
{
|
||||||
if (nIters_ == 1) {
|
if (nIters_ == 1) {
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
@ -256,7 +269,7 @@ BpSolver::maxResidualSchedule (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BpSolver::calcFactorToVarMsg (BpLink* link)
|
BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||||
{
|
{
|
||||||
FacNode* src = link->facNode();
|
FacNode* src = link->facNode();
|
||||||
const VarNode* dst = link->varNode();
|
const VarNode* dst = link->varNode();
|
||||||
@ -320,7 +333,7 @@ BpSolver::calcFactorToVarMsg (BpLink* link)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
BpSolver::getVarToFactorMsg (const BpLink* link) const
|
BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||||
{
|
{
|
||||||
const VarNode* src = link->varNode();
|
const VarNode* src = link->varNode();
|
||||||
Params msg;
|
Params msg;
|
||||||
@ -361,61 +374,15 @@ BpSolver::getVarToFactorMsg (const BpLink* link) const
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
|
||||||
{
|
{
|
||||||
VarNodes jointVars;
|
return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds);
|
||||||
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
|
||||||
assert (fg.getVarNode (jointVarIds[i]));
|
|
||||||
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
FactorGraph* tempFg = new FactorGraph (fg);
|
|
||||||
BpSolver solver (*tempFg);
|
|
||||||
solver.runSolver();
|
|
||||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
|
||||||
|
|
||||||
VarIds observedVids = {jointVars[0]->varId()};
|
|
||||||
|
|
||||||
for (size_t i = 1; i < jointVarIds.size(); i++) {
|
|
||||||
assert (jointVars[i]->hasEvidence() == false);
|
|
||||||
Params newBeliefs;
|
|
||||||
Vars observedVars;
|
|
||||||
Ranges observedRanges;
|
|
||||||
for (size_t j = 0; j < observedVids.size(); j++) {
|
|
||||||
observedVars.push_back (tempFg->getVarNode (observedVids[j]));
|
|
||||||
observedRanges.push_back (observedVars.back()->range());
|
|
||||||
}
|
|
||||||
Indexer indexer (observedRanges, false);
|
|
||||||
while (indexer.valid()) {
|
|
||||||
for (size_t j = 0; j < observedVars.size(); j++) {
|
|
||||||
observedVars[j]->setEvidence (indexer[j]);
|
|
||||||
}
|
|
||||||
BpSolver solver (*tempFg);
|
|
||||||
solver.runSolver();
|
|
||||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
|
||||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
|
||||||
newBeliefs.push_back (beliefs[k]);
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
|
|
||||||
int count = -1;
|
|
||||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
|
||||||
if (j % jointVars[i]->range() == 0) {
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
newBeliefs[j] *= prevBeliefs[count];
|
|
||||||
}
|
|
||||||
prevBeliefs = newBeliefs;
|
|
||||||
observedVids.push_back (jointVars[i]->varId());
|
|
||||||
}
|
|
||||||
return prevBeliefs;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BpSolver::initializeSolver (void)
|
BeliefProp::initializeSolver (void)
|
||||||
{
|
{
|
||||||
const VarNodes& varNodes = fg.varNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
varsI_.reserve (varNodes.size());
|
varsI_.reserve (varNodes.size());
|
||||||
@ -439,7 +406,7 @@ BpSolver::initializeSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
BpSolver::converged (void)
|
BeliefProp::converged (void)
|
||||||
{
|
{
|
||||||
if (links_.size() == 0) {
|
if (links_.size() == 0) {
|
||||||
return true;
|
return true;
|
||||||
@ -487,7 +454,7 @@ BpSolver::converged (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BpSolver::printLinkInformation (void) const
|
BeliefProp::printLinkInformation (void) const
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
BpLink* l = links_[i];
|
BpLink* l = links_[i];
|
@ -1,5 +1,5 @@
|
|||||||
#ifndef HORUS_BPSOLVER_H
|
#ifndef HORUS_BELIEFPROP_H
|
||||||
#define HORUS_BPSOLVER_H
|
#define HORUS_BELIEFPROP_H
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -83,12 +83,12 @@ class SPNodeInfo
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class BpSolver : public Solver
|
class BeliefProp : public Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
BpSolver (const FactorGraph&);
|
BeliefProp (const FactorGraph&);
|
||||||
|
|
||||||
virtual ~BpSolver (void);
|
virtual ~BeliefProp (void);
|
||||||
|
|
||||||
Params solveQuery (VarIds);
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
@ -111,6 +111,10 @@ class BpSolver : public Solver
|
|||||||
|
|
||||||
virtual Params getJointByConditioning (const VarIds&) const;
|
virtual Params getJointByConditioning (const VarIds&) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Params getFactorJoint (size_t fnIdx, const VarIds&);
|
||||||
|
|
||||||
|
protected:
|
||||||
SPNodeInfo* ninf (const VarNode* var) const
|
SPNodeInfo* ninf (const VarNode* var) const
|
||||||
{
|
{
|
||||||
return varsI_[var->getIndex()];
|
return varsI_[var->getIndex()];
|
||||||
@ -180,5 +184,5 @@ class BpSolver : public Solver
|
|||||||
virtual void printLinkInformation (void) const;
|
virtual void printLinkInformation (void) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_BPSOLVER_H
|
#endif // HORUS_BELIEFPROP_H
|
||||||
|
|
@ -1,23 +1,23 @@
|
|||||||
#include "CbpSolver.h"
|
#include "CountingBp.h"
|
||||||
#include "WeightedBpSolver.h"
|
#include "WeightedBp.h"
|
||||||
|
|
||||||
|
|
||||||
bool CbpSolver::checkForIdenticalFactors = true;
|
bool CountingBp::checkForIdenticalFactors = true;
|
||||||
|
|
||||||
|
|
||||||
CbpSolver::CbpSolver (const FactorGraph& fg)
|
CountingBp::CountingBp (const FactorGraph& fg)
|
||||||
: Solver (fg), freeColor_(0)
|
: Solver (fg), freeColor_(0)
|
||||||
{
|
{
|
||||||
findIdenticalFactors();
|
findIdenticalFactors();
|
||||||
setInitialColors();
|
setInitialColors();
|
||||||
createGroups();
|
createGroups();
|
||||||
compressedFg_ = getCompressedFactorGraph();
|
compressedFg_ = getCompressedFactorGraph();
|
||||||
solver_ = new WeightedBpSolver (*compressedFg_, getWeights());
|
solver_ = new WeightedBp (*compressedFg_, getWeights());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CbpSolver::~CbpSolver (void)
|
CountingBp::~CountingBp (void)
|
||||||
{
|
{
|
||||||
delete solver_;
|
delete solver_;
|
||||||
delete compressedFg_;
|
delete compressedFg_;
|
||||||
@ -32,7 +32,7 @@ CbpSolver::~CbpSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CbpSolver::printSolverFlags (void) const
|
CountingBp::printSolverFlags (void) const
|
||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "counting bp [" ;
|
ss << "counting bp [" ;
|
||||||
@ -48,7 +48,7 @@ CbpSolver::printSolverFlags (void) const
|
|||||||
ss << ",accuracy=" << BpOptions::accuracy;
|
ss << ",accuracy=" << BpOptions::accuracy;
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
ss << ",chkif=" <<
|
ss << ",chkif=" <<
|
||||||
Util::toString (CbpSolver::checkForIdenticalFactors);
|
Util::toString (CountingBp::checkForIdenticalFactors);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
cout << ss.str() << endl;
|
cout << ss.str() << endl;
|
||||||
}
|
}
|
||||||
@ -56,7 +56,7 @@ CbpSolver::printSolverFlags (void) const
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
CbpSolver::solveQuery (VarIds queryVids)
|
CountingBp::solveQuery (VarIds queryVids)
|
||||||
{
|
{
|
||||||
assert (queryVids.empty() == false);
|
assert (queryVids.empty() == false);
|
||||||
Params res;
|
Params res;
|
||||||
@ -74,16 +74,15 @@ CbpSolver::solveQuery (VarIds queryVids)
|
|||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
if (idx == facNodes.size()) {
|
if (idx == facNodes.size()) {
|
||||||
cerr << "error: only joint distributions on variables of some " ;
|
res = Solver::getJointByConditioning (
|
||||||
cerr << "clique are supported with the current solver" ;
|
GroundSolver::CBP, fg, queryVids);
|
||||||
cerr << endl;
|
} else {
|
||||||
exit (1);
|
VarIds reprArgs;
|
||||||
}
|
|
||||||
VarIds representatives;
|
|
||||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||||
representatives.push_back (getRepresentative (queryVids[i]));
|
reprArgs.push_back (getRepresentative (queryVids[i]));
|
||||||
|
}
|
||||||
|
res = solver_->getFactorJoint (idx, reprArgs);
|
||||||
}
|
}
|
||||||
res = solver_->getJointDistributionOf (representatives);
|
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@ -91,7 +90,7 @@ CbpSolver::solveQuery (VarIds queryVids)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CbpSolver::findIdenticalFactors()
|
CountingBp::findIdenticalFactors()
|
||||||
{
|
{
|
||||||
const FacNodes& facNodes = fg.facNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
if (checkForIdenticalFactors == false ||
|
if (checkForIdenticalFactors == false ||
|
||||||
@ -126,7 +125,7 @@ CbpSolver::findIdenticalFactors()
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CbpSolver::setInitialColors (void)
|
CountingBp::setInitialColors (void)
|
||||||
{
|
{
|
||||||
varColors_.resize (fg.nrVarNodes());
|
varColors_.resize (fg.nrVarNodes());
|
||||||
facColors_.resize (fg.nrFacNodes());
|
facColors_.resize (fg.nrFacNodes());
|
||||||
@ -165,7 +164,7 @@ CbpSolver::setInitialColors (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CbpSolver::createGroups (void)
|
CountingBp::createGroups (void)
|
||||||
{
|
{
|
||||||
VarSignMap varGroups;
|
VarSignMap varGroups;
|
||||||
FacSignMap facGroups;
|
FacSignMap facGroups;
|
||||||
@ -227,7 +226,7 @@ CbpSolver::createGroups (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CbpSolver::createClusters (
|
CountingBp::createClusters (
|
||||||
const VarSignMap& varGroups,
|
const VarSignMap& varGroups,
|
||||||
const FacSignMap& facGroups)
|
const FacSignMap& facGroups)
|
||||||
{
|
{
|
||||||
@ -260,7 +259,7 @@ CbpSolver::createClusters (
|
|||||||
|
|
||||||
|
|
||||||
VarSignature
|
VarSignature
|
||||||
CbpSolver::getSignature (const VarNode* varNode)
|
CountingBp::getSignature (const VarNode* varNode)
|
||||||
{
|
{
|
||||||
const FacNodes& neighs = varNode->neighbors();
|
const FacNodes& neighs = varNode->neighbors();
|
||||||
VarSignature sign;
|
VarSignature sign;
|
||||||
@ -278,7 +277,7 @@ CbpSolver::getSignature (const VarNode* varNode)
|
|||||||
|
|
||||||
|
|
||||||
FacSignature
|
FacSignature
|
||||||
CbpSolver::getSignature (const FacNode* facNode)
|
CountingBp::getSignature (const FacNode* facNode)
|
||||||
{
|
{
|
||||||
const VarNodes& neighs = facNode->neighbors();
|
const VarNodes& neighs = facNode->neighbors();
|
||||||
FacSignature sign;
|
FacSignature sign;
|
||||||
@ -292,8 +291,31 @@ CbpSolver::getSignature (const FacNode* facNode)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
VarId
|
||||||
|
CountingBp::getRepresentative (VarId vid)
|
||||||
|
{
|
||||||
|
assert (Util::contains (vid2VarCluster_, vid));
|
||||||
|
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||||
|
return vc->representative()->varId();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FacNode*
|
||||||
|
CountingBp::getRepresentative (FacNode* fn)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||||
|
if (Util::contains (facClusters_[i]->members(), fn)) {
|
||||||
|
return facClusters_[i]->representative();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph*
|
FactorGraph*
|
||||||
CbpSolver::getCompressedFactorGraph (void)
|
CountingBp::getCompressedFactorGraph (void)
|
||||||
{
|
{
|
||||||
FactorGraph* fg = new FactorGraph();
|
FactorGraph* fg = new FactorGraph();
|
||||||
for (size_t i = 0; i < varClusters_.size(); i++) {
|
for (size_t i = 0; i < varClusters_.size(); i++) {
|
||||||
@ -322,7 +344,7 @@ CbpSolver::getCompressedFactorGraph (void)
|
|||||||
|
|
||||||
|
|
||||||
vector<vector<unsigned>>
|
vector<vector<unsigned>>
|
||||||
CbpSolver::getWeights (void) const
|
CountingBp::getWeights (void) const
|
||||||
{
|
{
|
||||||
vector<vector<unsigned>> weights;
|
vector<vector<unsigned>> weights;
|
||||||
weights.reserve (facClusters_.size());
|
weights.reserve (facClusters_.size());
|
||||||
@ -341,7 +363,7 @@ CbpSolver::getWeights (void) const
|
|||||||
|
|
||||||
|
|
||||||
unsigned
|
unsigned
|
||||||
CbpSolver::getWeight (
|
CountingBp::getWeight (
|
||||||
const FacCluster* fc,
|
const FacCluster* fc,
|
||||||
const VarCluster* vc,
|
const VarCluster* vc,
|
||||||
size_t index) const
|
size_t index) const
|
||||||
@ -364,7 +386,7 @@ CbpSolver::getWeight (
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CbpSolver::printGroups (
|
CountingBp::printGroups (
|
||||||
const VarSignMap& varGroups,
|
const VarSignMap& varGroups,
|
||||||
const FacSignMap& facGroups) const
|
const FacSignMap& facGroups) const
|
||||||
{
|
{
|
@ -1,5 +1,5 @@
|
|||||||
#ifndef HORUS_CBPSOLVER_H
|
#ifndef HORUS_COUNTINGBP_H
|
||||||
#define HORUS_CBPSOLVER_H
|
#define HORUS_COUNTINGBP_H
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ class VarCluster;
|
|||||||
class FacCluster;
|
class FacCluster;
|
||||||
class VarSignHash;
|
class VarSignHash;
|
||||||
class FacSignHash;
|
class FacSignHash;
|
||||||
class WeightedBpSolver;
|
class WeightedBp;
|
||||||
|
|
||||||
typedef long Color;
|
typedef long Color;
|
||||||
typedef vector<Color> Colors;
|
typedef vector<Color> Colors;
|
||||||
@ -100,12 +100,12 @@ class FacCluster
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class CbpSolver : public Solver
|
class CountingBp : public Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CbpSolver (const FactorGraph& fg);
|
CountingBp (const FactorGraph& fg);
|
||||||
|
|
||||||
~CbpSolver (void);
|
~CountingBp (void);
|
||||||
|
|
||||||
void printSolverFlags (void) const;
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
@ -154,12 +154,9 @@ class CbpSolver : public Solver
|
|||||||
|
|
||||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||||
|
|
||||||
VarId getRepresentative (VarId vid)
|
VarId getRepresentative (VarId vid);
|
||||||
{
|
|
||||||
assert (Util::contains (vid2VarCluster_, vid));
|
FacNode* getRepresentative (FacNode*);
|
||||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
|
||||||
return vc->representative()->varId();
|
|
||||||
}
|
|
||||||
|
|
||||||
FactorGraph* getCompressedFactorGraph (void);
|
FactorGraph* getCompressedFactorGraph (void);
|
||||||
|
|
||||||
@ -176,8 +173,8 @@ class CbpSolver : public Solver
|
|||||||
FacClusters facClusters_;
|
FacClusters facClusters_;
|
||||||
VarId2VarCluster vid2VarCluster_;
|
VarId2VarCluster vid2VarCluster_;
|
||||||
const FactorGraph* compressedFg_;
|
const FactorGraph* compressedFg_;
|
||||||
WeightedBpSolver* solver_;
|
WeightedBp* solver_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_CBPSOLVER_H
|
#endif // HORUS_COUNTINGBP_H
|
||||||
|
|
@ -8,7 +8,6 @@
|
|||||||
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "BayesNet.h"
|
|
||||||
#include "BayesBall.h"
|
#include "BayesBall.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
@ -236,13 +235,13 @@ FactorGraph::isTree (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
DAGraph&
|
BayesBallGraph&
|
||||||
FactorGraph::getStructure (void)
|
FactorGraph::getStructure (void)
|
||||||
{
|
{
|
||||||
assert (bayesFactors_);
|
assert (bayesFactors_);
|
||||||
if (structure_.empty()) {
|
if (structure_.empty()) {
|
||||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||||
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
structure_.addNode (new BBNode (varNodes_[i]));
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||||
const VarIds& vids = facNodes_[i]->factor().arguments();
|
const VarIds& vids = facNodes_[i]->factor().arguments();
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "BayesNet.h"
|
#include "BayesBallGraph.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -103,7 +103,7 @@ class FactorGraph
|
|||||||
|
|
||||||
bool isTree (void) const;
|
bool isTree (void) const;
|
||||||
|
|
||||||
DAGraph& getStructure (void);
|
BayesBallGraph& getStructure (void);
|
||||||
|
|
||||||
void print (void) const;
|
void print (void) const;
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ class FactorGraph
|
|||||||
VarNodes varNodes_;
|
VarNodes varNodes_;
|
||||||
FacNodes facNodes_;
|
FacNodes facNodes_;
|
||||||
|
|
||||||
DAGraph structure_;
|
BayesBallGraph structure_;
|
||||||
bool bayesFactors_;
|
bool bayesFactors_;
|
||||||
|
|
||||||
typedef unordered_map<unsigned, VarNode*> VarMap;
|
typedef unordered_map<unsigned, VarNode*> VarMap;
|
||||||
|
@ -28,14 +28,14 @@ typedef vector<unsigned> Ranges;
|
|||||||
typedef unsigned long long ullong;
|
typedef unsigned long long ullong;
|
||||||
|
|
||||||
|
|
||||||
enum LiftedSolvers
|
enum LiftedSolver
|
||||||
{
|
{
|
||||||
FOVE, // first order variable elimination
|
FOVE, // first order variable elimination
|
||||||
LBP, // lifted belief propagation
|
LBP, // lifted belief propagation
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
enum GroundSolvers
|
enum GroundSolver
|
||||||
{
|
{
|
||||||
VE, // variable elimination
|
VE, // variable elimination
|
||||||
BP, // belief propagation
|
BP, // belief propagation
|
||||||
@ -50,8 +50,8 @@ extern bool logDomain;
|
|||||||
// level of debug information
|
// level of debug information
|
||||||
extern unsigned verbosity;
|
extern unsigned verbosity;
|
||||||
|
|
||||||
extern LiftedSolvers liftedSolver;
|
extern LiftedSolver liftedSolver;
|
||||||
extern GroundSolvers groundSolver;
|
extern GroundSolver groundSolver;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -4,9 +4,9 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "VarElimSolver.h"
|
#include "VarElim.h"
|
||||||
#include "BpSolver.h"
|
#include "BeliefProp.h"
|
||||||
#include "CbpSolver.h"
|
#include "CountingBp.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
@ -162,14 +162,14 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
|||||||
{
|
{
|
||||||
Solver* solver = 0;
|
Solver* solver = 0;
|
||||||
switch (Globals::groundSolver) {
|
switch (Globals::groundSolver) {
|
||||||
case GroundSolvers::VE:
|
case GroundSolver::VE:
|
||||||
solver = new VarElimSolver (fg);
|
solver = new VarElim (fg);
|
||||||
break;
|
break;
|
||||||
case GroundSolvers::BP:
|
case GroundSolver::BP:
|
||||||
solver = new BpSolver (fg);
|
solver = new BeliefProp (fg);
|
||||||
break;
|
break;
|
||||||
case GroundSolvers::CBP:
|
case GroundSolver::CBP:
|
||||||
solver = new CbpSolver (fg);
|
solver = new CountingBp (fg);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
assert (false);
|
assert (false);
|
||||||
|
@ -9,21 +9,19 @@
|
|||||||
|
|
||||||
#include "ParfactorList.h"
|
#include "ParfactorList.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "FoveSolver.h"
|
#include "LiftedVe.h"
|
||||||
#include "VarElimSolver.h"
|
#include "VarElim.h"
|
||||||
#include "LiftedBpSolver.h"
|
#include "LiftedBp.h"
|
||||||
#include "BpSolver.h"
|
#include "CountingBp.h"
|
||||||
#include "CbpSolver.h"
|
#include "BeliefProp.h"
|
||||||
#include "ElimGraph.h"
|
#include "ElimGraph.h"
|
||||||
#include "BayesBall.h"
|
#include "BayesBall.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||||
|
|
||||||
|
|
||||||
Params readParameters (YAP_Term);
|
Params readParameters (YAP_Term);
|
||||||
|
|
||||||
vector<unsigned> readUnsignedList (YAP_Term);
|
vector<unsigned> readUnsignedList (YAP_Term);
|
||||||
@ -32,14 +30,6 @@ void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
|||||||
|
|
||||||
Parfactor* readParfactor (YAP_Term);
|
Parfactor* readParfactor (YAP_Term);
|
||||||
|
|
||||||
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
|
||||||
vector<Params>& results);
|
|
||||||
|
|
||||||
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
|
||||||
vector<Params>& results);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<unsigned>
|
vector<unsigned>
|
||||||
readUnsignedList (YAP_Term list)
|
readUnsignedList (YAP_Term list)
|
||||||
@ -54,7 +44,8 @@ readUnsignedList (YAP_Term list)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
int createLiftedNetwork (void)
|
int
|
||||||
|
createLiftedNetwork (void)
|
||||||
{
|
{
|
||||||
Parfactors parfactors;
|
Parfactors parfactors;
|
||||||
YAP_Term parfactorList = YAP_ARG1;
|
YAP_Term parfactorList = YAP_ARG1;
|
||||||
@ -91,7 +82,8 @@ int createLiftedNetwork (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor* readParfactor (YAP_Term pfTerm)
|
Parfactor*
|
||||||
|
readParfactor (YAP_Term pfTerm)
|
||||||
{
|
{
|
||||||
// read dist id
|
// read dist id
|
||||||
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
||||||
@ -171,7 +163,8 @@ Parfactor* readParfactor (YAP_Term pfTerm)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void readLiftedEvidence (
|
void
|
||||||
|
readLiftedEvidence (
|
||||||
YAP_Term observedList,
|
YAP_Term observedList,
|
||||||
ObservedFormulas& obsFormulas)
|
ObservedFormulas& obsFormulas)
|
||||||
{
|
{
|
||||||
@ -237,7 +230,6 @@ createGroundNetwork (void)
|
|||||||
fg->addFactor (Factor (varIds, ranges, params, distId));
|
fg->addFactor (Factor (varIds, ranges, params, distId));
|
||||||
factorList = YAP_TailOfTerm (factorList);
|
factorList = YAP_TailOfTerm (factorList);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned nrObservedVars = 0;
|
unsigned nrObservedVars = 0;
|
||||||
YAP_Term evidenceList = YAP_ARG3;
|
YAP_Term evidenceList = YAP_ARG3;
|
||||||
while (evidenceList != YAP_TermNil()) {
|
while (evidenceList != YAP_TermNil()) {
|
||||||
@ -285,7 +277,7 @@ runLiftedSolver (void)
|
|||||||
YAP_Term taskList = YAP_ARG2;
|
YAP_Term taskList = YAP_ARG2;
|
||||||
vector<Params> results;
|
vector<Params> results;
|
||||||
ParfactorList pfListCopy (*network->first);
|
ParfactorList pfListCopy (*network->first);
|
||||||
FoveSolver::absorveEvidence (pfListCopy, *network->second);
|
LiftedVe::absorveEvidence (pfListCopy, *network->second);
|
||||||
while (taskList != YAP_TermNil()) {
|
while (taskList != YAP_TermNil()) {
|
||||||
Grounds queryVars;
|
Grounds queryVars;
|
||||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||||
@ -311,15 +303,15 @@ runLiftedSolver (void)
|
|||||||
}
|
}
|
||||||
jointList = YAP_TailOfTerm (jointList);
|
jointList = YAP_TailOfTerm (jointList);
|
||||||
}
|
}
|
||||||
if (Globals::liftedSolver == LiftedSolvers::FOVE) {
|
if (Globals::liftedSolver == LiftedSolver::FOVE) {
|
||||||
FoveSolver solver (pfListCopy);
|
LiftedVe solver (pfListCopy);
|
||||||
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||||
solver.printSolverFlags();
|
solver.printSolverFlags();
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
results.push_back (solver.solveQuery (queryVars));
|
results.push_back (solver.solveQuery (queryVars));
|
||||||
} else if (Globals::liftedSolver == LiftedSolvers::LBP) {
|
} else if (Globals::liftedSolver == LiftedSolver::LBP) {
|
||||||
LiftedBpSolver solver (pfListCopy);
|
LiftedBp solver (pfListCopy);
|
||||||
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||||
solver.printSolverFlags();
|
solver.printSolverFlags();
|
||||||
cout << endl;
|
cout << endl;
|
||||||
@ -361,11 +353,42 @@ runGroundSolver (void)
|
|||||||
taskList = YAP_TailOfTerm (taskList);
|
taskList = YAP_TailOfTerm (taskList);
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<Params> results;
|
std::set<VarId> vids;
|
||||||
if (Globals::groundSolver == GroundSolvers::VE) {
|
for (size_t i = 0; i < tasks.size(); i++) {
|
||||||
runVeSolver (fg, tasks, results);
|
Util::addToSet (vids, tasks[i]);
|
||||||
|
}
|
||||||
|
Solver* solver = 0;
|
||||||
|
FactorGraph* mfg = fg;
|
||||||
|
if (fg->bayesianFactors()) {
|
||||||
|
mfg = BayesBall::getMinimalFactorGraph (
|
||||||
|
*fg, VarIds (vids.begin(), vids.end()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Globals::groundSolver == GroundSolver::VE) {
|
||||||
|
solver = new VarElim (*mfg);
|
||||||
|
} else if (Globals::groundSolver == GroundSolver::BP) {
|
||||||
|
solver = new BeliefProp (*mfg);
|
||||||
|
} else if (Globals::groundSolver == GroundSolver::CBP) {
|
||||||
|
CountingBp::checkForIdenticalFactors = false;
|
||||||
|
solver = new CountingBp (*mfg);
|
||||||
} else {
|
} else {
|
||||||
runBpSolver (fg, tasks, results);
|
assert (false);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
solver->printSolverFlags();
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<Params> results;
|
||||||
|
results.reserve (tasks.size());
|
||||||
|
for (size_t i = 0; i < tasks.size(); i++) {
|
||||||
|
results.push_back (solver->solveQuery (tasks[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
delete solver;
|
||||||
|
if (fg->bayesianFactors()) {
|
||||||
|
delete mfg;
|
||||||
}
|
}
|
||||||
|
|
||||||
YAP_Term list = YAP_TermNil();
|
YAP_Term list = YAP_TermNil();
|
||||||
@ -386,72 +409,6 @@ runGroundSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void runVeSolver (
|
|
||||||
FactorGraph* fg,
|
|
||||||
const vector<VarIds>& tasks,
|
|
||||||
vector<Params>& results)
|
|
||||||
{
|
|
||||||
results.reserve (tasks.size());
|
|
||||||
for (size_t i = 0; i < tasks.size(); i++) {
|
|
||||||
FactorGraph* mfg = fg;
|
|
||||||
if (fg->bayesianFactors()) {
|
|
||||||
// mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
|
||||||
}
|
|
||||||
// VarElimSolver solver (*mfg);
|
|
||||||
VarElimSolver solver (*fg); //FIXME
|
|
||||||
if (Globals::verbosity > 0 && i == 0) {
|
|
||||||
solver.printSolverFlags();
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
results.push_back (solver.solveQuery (tasks[i]));
|
|
||||||
if (fg->bayesianFactors()) {
|
|
||||||
// delete mfg;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void runBpSolver (
|
|
||||||
FactorGraph* fg,
|
|
||||||
const vector<VarIds>& tasks,
|
|
||||||
vector<Params>& results)
|
|
||||||
{
|
|
||||||
std::set<VarId> vids;
|
|
||||||
for (size_t i = 0; i < tasks.size(); i++) {
|
|
||||||
Util::addToSet (vids, tasks[i]);
|
|
||||||
}
|
|
||||||
Solver* solver = 0;
|
|
||||||
FactorGraph* mfg = fg;
|
|
||||||
if (fg->bayesianFactors()) {
|
|
||||||
//mfg = BayesBall::getMinimalFactorGraph (
|
|
||||||
// *fg, VarIds (vids.begin(),vids.end()));
|
|
||||||
}
|
|
||||||
if (Globals::groundSolver == GroundSolvers::BP) {
|
|
||||||
solver = new BpSolver (*fg); // FIXME
|
|
||||||
} else if (Globals::groundSolver == GroundSolvers::CBP) {
|
|
||||||
CbpSolver::checkForIdenticalFactors = false;
|
|
||||||
solver = new CbpSolver (*fg); // FIXME
|
|
||||||
} else {
|
|
||||||
cerr << "error: unknow solver" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
if (Globals::verbosity > 0) {
|
|
||||||
solver->printSolverFlags();
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
results.reserve (tasks.size());
|
|
||||||
for (size_t i = 0; i < tasks.size(); i++) {
|
|
||||||
results.push_back (solver->solveQuery (tasks[i]));
|
|
||||||
}
|
|
||||||
if (fg->bayesianFactors()) {
|
|
||||||
//delete mfg;
|
|
||||||
}
|
|
||||||
delete solver;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
setParfactorsParams (void)
|
setParfactorsParams (void)
|
||||||
{
|
{
|
||||||
@ -567,7 +524,7 @@ freeGroundNetwork (void)
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
freeParfactors (void)
|
freeLiftedNetwork (void)
|
||||||
{
|
{
|
||||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
delete network->first;
|
delete network->first;
|
||||||
@ -589,7 +546,7 @@ init_predicates (void)
|
|||||||
YAP_UserCPredicate ("cpp_cpp_set_factors_params", setFactorsParams, 2);
|
YAP_UserCPredicate ("cpp_cpp_set_factors_params", setFactorsParams, 2);
|
||||||
YAP_UserCPredicate ("cpp_set_vars_information", setVarsInformation, 2);
|
YAP_UserCPredicate ("cpp_set_vars_information", setVarsInformation, 2);
|
||||||
YAP_UserCPredicate ("cpp_set_horus_flag", setHorusFlag, 2);
|
YAP_UserCPredicate ("cpp_set_horus_flag", setHorusFlag, 2);
|
||||||
YAP_UserCPredicate ("cpp_free_parfactors", freeParfactors, 1);
|
YAP_UserCPredicate ("cpp_free_lifted_network", freeLiftedNetwork, 1);
|
||||||
YAP_UserCPredicate ("cpp_free_ground_network", freeGroundNetwork, 1);
|
YAP_UserCPredicate ("cpp_free_ground_network", freeGroundNetwork, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
232
packages/CLPBN/horus/LiftedBp.cpp
Normal file
232
packages/CLPBN/horus/LiftedBp.cpp
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
#include "LiftedBp.h"
|
||||||
|
#include "WeightedBp.h"
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "LiftedVe.h"
|
||||||
|
|
||||||
|
|
||||||
|
LiftedBp::LiftedBp (const ParfactorList& pfList)
|
||||||
|
: pfList_(pfList)
|
||||||
|
{
|
||||||
|
refineParfactors();
|
||||||
|
solver_ = new WeightedBp (*getFactorGraph(), getWeights());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LiftedBp::~LiftedBp (void)
|
||||||
|
{
|
||||||
|
delete solver_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
LiftedBp::solveQuery (const Grounds& query)
|
||||||
|
{
|
||||||
|
assert (query.empty() == false);
|
||||||
|
Params res;
|
||||||
|
vector<PrvGroup> groups = getQueryGroups (query);
|
||||||
|
if (query.size() == 1) {
|
||||||
|
res = solver_->getPosterioriOf (groups[0]);
|
||||||
|
} else {
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
size_t idx = pfList_.size();
|
||||||
|
size_t count = 0;
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
if ((*it)->containsGrounds (query)) {
|
||||||
|
idx = count;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
++ count;
|
||||||
|
}
|
||||||
|
if (idx == pfList_.size()) {
|
||||||
|
res = getJointByConditioning (pfList_, query);
|
||||||
|
} else {
|
||||||
|
VarIds queryVids;
|
||||||
|
for (unsigned i = 0; i < groups.size(); i++) {
|
||||||
|
queryVids.push_back (groups[i]);
|
||||||
|
}
|
||||||
|
res = solver_->getFactorJoint (idx, queryVids);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedBp::printSolverFlags (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "lifted bp [" ;
|
||||||
|
ss << "schedule=" ;
|
||||||
|
typedef BpOptions::Schedule Sch;
|
||||||
|
switch (BpOptions::schedule) {
|
||||||
|
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||||
|
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
||||||
|
case Sch::PARALLEL: ss << "parallel"; break;
|
||||||
|
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
|
}
|
||||||
|
ss << ",max_iter=" << BpOptions::maxIter;
|
||||||
|
ss << ",accuracy=" << BpOptions::accuracy;
|
||||||
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
|
ss << "]" ;
|
||||||
|
cout << ss.str() << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedBp::refineParfactors (void)
|
||||||
|
{
|
||||||
|
while (iterate() == false);
|
||||||
|
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
Util::printHeader ("AFTER REFINEMENT");
|
||||||
|
pfList_.print();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
LiftedBp::iterate (void)
|
||||||
|
{
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
const ProbFormulas& args = (*it)->arguments();
|
||||||
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
|
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
||||||
|
if ((*it)->constr()->isCountNormalized (lvs) == false) {
|
||||||
|
Parfactors pfs = LiftedVe::countNormalize (*it, lvs);
|
||||||
|
it = pfList_.removeAndDelete (it);
|
||||||
|
pfList_.add (pfs);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<PrvGroup>
|
||||||
|
LiftedBp::getQueryGroups (const Grounds& query)
|
||||||
|
{
|
||||||
|
vector<PrvGroup> queryGroups;
|
||||||
|
for (unsigned i = 0; i < query.size(); i++) {
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
for (; it != pfList_.end(); ++it) {
|
||||||
|
if ((*it)->containsGround (query[i])) {
|
||||||
|
queryGroups.push_back ((*it)->findGroup (query[i]));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (queryGroups.size() == query.size());
|
||||||
|
return queryGroups;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph*
|
||||||
|
LiftedBp::getFactorGraph (void)
|
||||||
|
{
|
||||||
|
FactorGraph* fg = new FactorGraph();
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
for (; it != pfList_.end(); ++it) {
|
||||||
|
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||||
|
VarIds varIds;
|
||||||
|
for (size_t i = 0; i < groups.size(); i++) {
|
||||||
|
varIds.push_back (groups[i]);
|
||||||
|
}
|
||||||
|
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
|
||||||
|
}
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<vector<unsigned>>
|
||||||
|
LiftedBp::getWeights (void) const
|
||||||
|
{
|
||||||
|
vector<vector<unsigned>> weights;
|
||||||
|
weights.reserve (pfList_.size());
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
for (; it != pfList_.end(); ++it) {
|
||||||
|
const ProbFormulas& args = (*it)->arguments();
|
||||||
|
weights.push_back ({ });
|
||||||
|
weights.back().reserve (args.size());
|
||||||
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
|
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
||||||
|
weights.back().push_back ((*it)->constr()->getConditionalCount (lvs));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
LiftedBp::rangeOfGround (const Ground& gr)
|
||||||
|
{
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
if ((*it)->containsGround (gr)) {
|
||||||
|
PrvGroup prvGroup = (*it)->findGroup (gr);
|
||||||
|
return (*it)->range ((*it)->indexOfGroup (prvGroup));
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
return std::numeric_limits<unsigned>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
LiftedBp::getJointByConditioning (
|
||||||
|
const ParfactorList& pfList,
|
||||||
|
const Grounds& grounds)
|
||||||
|
{
|
||||||
|
LiftedBp solver (pfList);
|
||||||
|
Params prevBeliefs = solver.solveQuery ({grounds[0]});
|
||||||
|
Grounds obsGrounds = {grounds[0]};
|
||||||
|
for (size_t i = 1; i < grounds.size(); i++) {
|
||||||
|
Params newBeliefs;
|
||||||
|
vector<ObservedFormula> obsFs;
|
||||||
|
Ranges obsRanges;
|
||||||
|
for (size_t j = 0; j < obsGrounds.size(); j++) {
|
||||||
|
obsFs.push_back (ObservedFormula (
|
||||||
|
obsGrounds[j].functor(), 0, obsGrounds[j].args()));
|
||||||
|
obsRanges.push_back (rangeOfGround (obsGrounds[j]));
|
||||||
|
}
|
||||||
|
Indexer indexer (obsRanges, false);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
for (size_t j = 0; j < obsFs.size(); j++) {
|
||||||
|
obsFs[j].setEvidence (indexer[j]);
|
||||||
|
}
|
||||||
|
ParfactorList tempPfList (pfList);
|
||||||
|
LiftedVe::absorveEvidence (tempPfList, obsFs);
|
||||||
|
LiftedBp solver (tempPfList);
|
||||||
|
Params beliefs = solver.solveQuery ({grounds[i]});
|
||||||
|
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||||
|
newBeliefs.push_back (beliefs[k]);
|
||||||
|
}
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
int count = -1;
|
||||||
|
unsigned range = rangeOfGround (grounds[i]);
|
||||||
|
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||||
|
if (j % range == 0) {
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
newBeliefs[j] *= prevBeliefs[count];
|
||||||
|
}
|
||||||
|
prevBeliefs = newBeliefs;
|
||||||
|
obsGrounds.push_back (grounds[i]);
|
||||||
|
}
|
||||||
|
return prevBeliefs;
|
||||||
|
}
|
||||||
|
|
@ -1,15 +1,17 @@
|
|||||||
#ifndef HORUS_LIFTEDBPSOLVER_H
|
#ifndef HORUS_LIFTEDBP_H
|
||||||
#define HORUS_LIFTEDBPSOLVER_H
|
#define HORUS_LIFTEDBP_H
|
||||||
|
|
||||||
#include "ParfactorList.h"
|
#include "ParfactorList.h"
|
||||||
|
|
||||||
class FactorGraph;
|
class FactorGraph;
|
||||||
class WeightedBpSolver;
|
class WeightedBp;
|
||||||
|
|
||||||
class LiftedBpSolver
|
class LiftedBp
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
LiftedBpSolver (const ParfactorList& pfList);
|
LiftedBp (const ParfactorList& pfList);
|
||||||
|
|
||||||
|
~LiftedBp (void);
|
||||||
|
|
||||||
Params solveQuery (const Grounds&);
|
Params solveQuery (const Grounds&);
|
||||||
|
|
||||||
@ -26,9 +28,13 @@ class LiftedBpSolver
|
|||||||
|
|
||||||
vector<vector<unsigned>> getWeights (void) const;
|
vector<vector<unsigned>> getWeights (void) const;
|
||||||
|
|
||||||
|
unsigned rangeOfGround (const Ground&);
|
||||||
|
|
||||||
|
Params getJointByConditioning (const ParfactorList&, const Grounds&);
|
||||||
|
|
||||||
ParfactorList pfList_;
|
ParfactorList pfList_;
|
||||||
WeightedBpSolver* solver_;
|
WeightedBp* solver_;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_LIFTEDBPSOLVER_H
|
#endif // HORUS_LIFTEDBP_H
|
@ -1,148 +0,0 @@
|
|||||||
#include "LiftedBpSolver.h"
|
|
||||||
#include "WeightedBpSolver.h"
|
|
||||||
#include "FactorGraph.h"
|
|
||||||
#include "FoveSolver.h"
|
|
||||||
|
|
||||||
|
|
||||||
LiftedBpSolver::LiftedBpSolver (const ParfactorList& pfList)
|
|
||||||
: pfList_(pfList)
|
|
||||||
{
|
|
||||||
refineParfactors();
|
|
||||||
solver_ = new WeightedBpSolver (*getFactorGraph(), getWeights());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
LiftedBpSolver::solveQuery (const Grounds& query)
|
|
||||||
{
|
|
||||||
assert (query.empty() == false);
|
|
||||||
Params res;
|
|
||||||
vector<PrvGroup> groups = getQueryGroups (query);
|
|
||||||
if (query.size() == 1) {
|
|
||||||
res = solver_->getPosterioriOf (groups[0]);
|
|
||||||
} else {
|
|
||||||
VarIds queryVids;
|
|
||||||
for (unsigned i = 0; i < groups.size(); i++) {
|
|
||||||
queryVids.push_back (groups[i]);
|
|
||||||
}
|
|
||||||
res = solver_->getJointDistributionOf (queryVids);
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
LiftedBpSolver::printSolverFlags (void) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss << "lifted bp [" ;
|
|
||||||
ss << "schedule=" ;
|
|
||||||
typedef BpOptions::Schedule Sch;
|
|
||||||
switch (BpOptions::schedule) {
|
|
||||||
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
|
||||||
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
|
||||||
case Sch::PARALLEL: ss << "parallel"; break;
|
|
||||||
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
|
||||||
}
|
|
||||||
ss << ",max_iter=" << BpOptions::maxIter;
|
|
||||||
ss << ",accuracy=" << BpOptions::accuracy;
|
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
|
||||||
ss << "]" ;
|
|
||||||
cout << ss.str() << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
LiftedBpSolver::refineParfactors (void)
|
|
||||||
{
|
|
||||||
while (iterate() == false);
|
|
||||||
|
|
||||||
if (Globals::verbosity > 2) {
|
|
||||||
Util::printHeader ("AFTER REFINEMENT");
|
|
||||||
pfList_.print();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
LiftedBpSolver::iterate (void)
|
|
||||||
{
|
|
||||||
ParfactorList::iterator it = pfList_.begin();
|
|
||||||
while (it != pfList_.end()) {
|
|
||||||
const ProbFormulas& args = (*it)->arguments();
|
|
||||||
for (size_t i = 0; i < args.size(); i++) {
|
|
||||||
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
|
||||||
if ((*it)->constr()->isCountNormalized (lvs) == false) {
|
|
||||||
Parfactors pfs = FoveSolver::countNormalize (*it, lvs);
|
|
||||||
it = pfList_.removeAndDelete (it);
|
|
||||||
pfList_.add (pfs);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
++ it;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<PrvGroup>
|
|
||||||
LiftedBpSolver::getQueryGroups (const Grounds& query)
|
|
||||||
{
|
|
||||||
vector<PrvGroup> queryGroups;
|
|
||||||
for (unsigned i = 0; i < query.size(); i++) {
|
|
||||||
ParfactorList::const_iterator it = pfList_.begin();
|
|
||||||
for (; it != pfList_.end(); ++it) {
|
|
||||||
if ((*it)->containsGround (query[i])) {
|
|
||||||
queryGroups.push_back ((*it)->findGroup (query[i]));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert (queryGroups.size() == query.size());
|
|
||||||
return queryGroups;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph*
|
|
||||||
LiftedBpSolver::getFactorGraph (void)
|
|
||||||
{
|
|
||||||
FactorGraph* fg = new FactorGraph();
|
|
||||||
ParfactorList::const_iterator it = pfList_.begin();
|
|
||||||
for (; it != pfList_.end(); ++it) {
|
|
||||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
|
||||||
VarIds varIds;
|
|
||||||
for (size_t i = 0; i < groups.size(); i++) {
|
|
||||||
varIds.push_back (groups[i]);
|
|
||||||
}
|
|
||||||
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
|
|
||||||
}
|
|
||||||
return fg;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<vector<unsigned>>
|
|
||||||
LiftedBpSolver::getWeights (void) const
|
|
||||||
{
|
|
||||||
vector<vector<unsigned>> weights;
|
|
||||||
weights.reserve (pfList_.size());
|
|
||||||
ParfactorList::const_iterator it = pfList_.begin();
|
|
||||||
for (; it != pfList_.end(); ++it) {
|
|
||||||
const ProbFormulas& args = (*it)->arguments();
|
|
||||||
weights.push_back ({ });
|
|
||||||
weights.back().reserve (args.size());
|
|
||||||
for (size_t i = 0; i < args.size(); i++) {
|
|
||||||
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
|
||||||
weights.back().push_back ((*it)->constr()->getConditionalCount (lvs));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return weights;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
#include "FoveSolver.h"
|
#include "LiftedVe.h"
|
||||||
#include "Histogram.h"
|
#include "Histogram.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
@ -222,7 +221,7 @@ SumOutOperator::apply (void)
|
|||||||
product->sumOutIndex (fIdx);
|
product->sumOutIndex (fIdx);
|
||||||
pfList_.addShattered (product);
|
pfList_.addShattered (product);
|
||||||
} else {
|
} else {
|
||||||
Parfactors pfs = FoveSolver::countNormalize (product, excl);
|
Parfactors pfs = LiftedVe::countNormalize (product, excl);
|
||||||
for (size_t i = 0; i < pfs.size(); i++) {
|
for (size_t i = 0; i < pfs.size(); i++) {
|
||||||
pfs[i]->sumOutIndex (fIdx);
|
pfs[i]->sumOutIndex (fIdx);
|
||||||
pfList_.add (pfs[i]);
|
pfList_.add (pfs[i]);
|
||||||
@ -376,7 +375,7 @@ CountingOperator::apply (void)
|
|||||||
} else {
|
} else {
|
||||||
Parfactor* pf = *pfIter_;
|
Parfactor* pf = *pfIter_;
|
||||||
pfList_.remove (pfIter_);
|
pfList_.remove (pfIter_);
|
||||||
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
|
Parfactors pfs = LiftedVe::countNormalize (pf, X_);
|
||||||
for (size_t i = 0; i < pfs.size(); i++) {
|
for (size_t i = 0; i < pfs.size(); i++) {
|
||||||
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
||||||
bool cartProduct = pfs[i]->constr()->isCartesianProduct (
|
bool cartProduct = pfs[i]->constr()->isCartesianProduct (
|
||||||
@ -420,7 +419,7 @@ CountingOperator::toString (void)
|
|||||||
ss << "count convert " << X_ << " in " ;
|
ss << "count convert " << X_ << " in " ;
|
||||||
ss << (*pfIter_)->getLabel();
|
ss << (*pfIter_)->getLabel();
|
||||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
Parfactors pfs = LiftedVe::countNormalize (*pfIter_, X_);
|
||||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||||
for (size_t i = 0; i < pfs.size(); i++) {
|
for (size_t i = 0; i < pfs.size(); i++) {
|
||||||
ss << " º " << pfs[i]->getLabel() << endl;
|
ss << " º " << pfs[i]->getLabel() << endl;
|
||||||
@ -501,7 +500,7 @@ GroundOperator::getLogCost (void)
|
|||||||
++ pflIt;
|
++ pflIt;
|
||||||
}
|
}
|
||||||
// cout << endl;
|
// cout << endl;
|
||||||
return totalCost;
|
return totalCost + 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -610,7 +609,7 @@ GroundOperator::getAffectedFormulas (void)
|
|||||||
LogVar X = f.logVars()[front.second];
|
LogVar X = f.logVars()[front.second];
|
||||||
const ProbFormulas& fs = (*pflIt)->arguments();
|
const ProbFormulas& fs = (*pflIt)->arguments();
|
||||||
for (size_t i = 0; i < fs.size(); i++) {
|
for (size_t i = 0; i < fs.size(); i++) {
|
||||||
if ((int)i != idx && fs[i].contains (X)) {
|
if (i != idx && fs[i].contains (X)) {
|
||||||
pair<PrvGroup, unsigned> pair = make_pair (
|
pair<PrvGroup, unsigned> pair = make_pair (
|
||||||
fs[i].group(), fs[i].indexOf (X));
|
fs[i].group(), fs[i].indexOf (X));
|
||||||
if (Util::contains (affectedFormulas, pair) == false) {
|
if (Util::contains (affectedFormulas, pair) == false) {
|
||||||
@ -630,7 +629,7 @@ GroundOperator::getAffectedFormulas (void)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FoveSolver::solveQuery (const Grounds& query)
|
LiftedVe::solveQuery (const Grounds& query)
|
||||||
{
|
{
|
||||||
assert (query.empty() == false);
|
assert (query.empty() == false);
|
||||||
runSolver (query);
|
runSolver (query);
|
||||||
@ -645,7 +644,7 @@ FoveSolver::solveQuery (const Grounds& query)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FoveSolver::printSolverFlags (void) const
|
LiftedVe::printSolverFlags (void) const
|
||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "fove [" ;
|
ss << "fove [" ;
|
||||||
@ -657,7 +656,7 @@ FoveSolver::printSolverFlags (void) const
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FoveSolver::absorveEvidence (
|
LiftedVe::absorveEvidence (
|
||||||
ParfactorList& pfList,
|
ParfactorList& pfList,
|
||||||
ObservedFormulas& obsFormulas)
|
ObservedFormulas& obsFormulas)
|
||||||
{
|
{
|
||||||
@ -696,7 +695,7 @@ FoveSolver::absorveEvidence (
|
|||||||
|
|
||||||
|
|
||||||
Parfactors
|
Parfactors
|
||||||
FoveSolver::countNormalize (
|
LiftedVe::countNormalize (
|
||||||
Parfactor* g,
|
Parfactor* g,
|
||||||
const LogVarSet& set)
|
const LogVarSet& set)
|
||||||
{
|
{
|
||||||
@ -715,7 +714,7 @@ FoveSolver::countNormalize (
|
|||||||
|
|
||||||
|
|
||||||
Parfactor
|
Parfactor
|
||||||
FoveSolver::calcGroundMultiplication (Parfactor pf)
|
LiftedVe::calcGroundMultiplication (Parfactor pf)
|
||||||
{
|
{
|
||||||
LogVarSet lvs = pf.constr()->logVarSet();
|
LogVarSet lvs = pf.constr()->logVarSet();
|
||||||
lvs -= pf.constr()->singletons();
|
lvs -= pf.constr()->singletons();
|
||||||
@ -748,7 +747,7 @@ FoveSolver::calcGroundMultiplication (Parfactor pf)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FoveSolver::runSolver (const Grounds& query)
|
LiftedVe::runSolver (const Grounds& query)
|
||||||
{
|
{
|
||||||
largestCost_ = std::log (0);
|
largestCost_ = std::log (0);
|
||||||
shatterAgainstQuery (query);
|
shatterAgainstQuery (query);
|
||||||
@ -794,7 +793,7 @@ FoveSolver::runSolver (const Grounds& query)
|
|||||||
|
|
||||||
|
|
||||||
LiftedOperator*
|
LiftedOperator*
|
||||||
FoveSolver::getBestOperation (const Grounds& query)
|
LiftedVe::getBestOperation (const Grounds& query)
|
||||||
{
|
{
|
||||||
double bestCost = 0.0;
|
double bestCost = 0.0;
|
||||||
LiftedOperator* bestOp = 0;
|
LiftedOperator* bestOp = 0;
|
||||||
@ -821,7 +820,7 @@ FoveSolver::getBestOperation (const Grounds& query)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FoveSolver::runWeakBayesBall (const Grounds& query)
|
LiftedVe::runWeakBayesBall (const Grounds& query)
|
||||||
{
|
{
|
||||||
queue<PrvGroup> todo; // groups to process
|
queue<PrvGroup> todo; // groups to process
|
||||||
set<PrvGroup> done; // processed or in queue
|
set<PrvGroup> done; // processed or in queue
|
||||||
@ -880,7 +879,7 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FoveSolver::shatterAgainstQuery (const Grounds& query)
|
LiftedVe::shatterAgainstQuery (const Grounds& query)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < query.size(); i++) {
|
for (size_t i = 0; i < query.size(); i++) {
|
||||||
if (query[i].isAtom()) {
|
if (query[i].isAtom()) {
|
||||||
@ -931,7 +930,7 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
|
|||||||
|
|
||||||
|
|
||||||
Parfactors
|
Parfactors
|
||||||
FoveSolver::absorve (
|
LiftedVe::absorve (
|
||||||
ObservedFormula& obsFormula,
|
ObservedFormula& obsFormula,
|
||||||
Parfactor* g)
|
Parfactor* g)
|
||||||
{
|
{
|
@ -1,5 +1,5 @@
|
|||||||
#ifndef HORUS_FOVESOLVER_H
|
#ifndef HORUS_LIFTEDVE_H
|
||||||
#define HORUS_FOVESOLVER_H
|
#define HORUS_LIFTEDVE_H
|
||||||
|
|
||||||
|
|
||||||
#include "ParfactorList.h"
|
#include "ParfactorList.h"
|
||||||
@ -130,10 +130,10 @@ class GroundOperator : public LiftedOperator
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FoveSolver
|
class LiftedVe
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
|
LiftedVe (const ParfactorList& pfList) : pfList_(pfList) { }
|
||||||
|
|
||||||
Params solveQuery (const Grounds&);
|
Params solveQuery (const Grounds&);
|
||||||
|
|
||||||
@ -162,5 +162,5 @@ class FoveSolver
|
|||||||
double largestCost_;
|
double largestCost_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_FOVESOLVER_H
|
#endif // HORUS_LIFTEDVE_H
|
||||||
|
|
@ -23,10 +23,10 @@ CC=@CC@
|
|||||||
CXX=@CXX@
|
CXX=@CXX@
|
||||||
|
|
||||||
# normal
|
# normal
|
||||||
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG
|
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG
|
||||||
|
|
||||||
# debug
|
# debug
|
||||||
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra
|
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -45,98 +45,91 @@ CWD=$(PWD)
|
|||||||
|
|
||||||
|
|
||||||
HEADERS = \
|
HEADERS = \
|
||||||
$(srcdir)/BayesNet.h \
|
|
||||||
$(srcdir)/BayesBall.h \
|
$(srcdir)/BayesBall.h \
|
||||||
$(srcdir)/ElimGraph.h \
|
$(srcdir)/BayesBallGraph.h \
|
||||||
$(srcdir)/FactorGraph.h \
|
$(srcdir)/BeliefProp.h \
|
||||||
$(srcdir)/Factor.h \
|
|
||||||
$(srcdir)/ConstraintTree.h \
|
$(srcdir)/ConstraintTree.h \
|
||||||
$(srcdir)/Solver.h \
|
$(srcdir)/CountingBp.h \
|
||||||
$(srcdir)/VarElimSolver.h \
|
$(srcdir)/ElimGraph.h \
|
||||||
$(srcdir)/BpSolver.h \
|
$(srcdir)/Factor.h \
|
||||||
$(srcdir)/CbpSolver.h \
|
$(srcdir)/FactorGraph.h \
|
||||||
$(srcdir)/FoveSolver.h \
|
|
||||||
$(srcdir)/Var.h \
|
|
||||||
$(srcdir)/Indexer.h \
|
|
||||||
$(srcdir)/Parfactor.h \
|
|
||||||
$(srcdir)/ProbFormula.h \
|
|
||||||
$(srcdir)/Histogram.h \
|
$(srcdir)/Histogram.h \
|
||||||
$(srcdir)/ParfactorList.h \
|
$(srcdir)/Horus.h \
|
||||||
|
$(srcdir)/Indexer.h \
|
||||||
|
$(srcdir)/LiftedBp.h \
|
||||||
$(srcdir)/LiftedUtils.h \
|
$(srcdir)/LiftedUtils.h \
|
||||||
|
$(srcdir)/LiftedVe.h \
|
||||||
|
$(srcdir)/Parfactor.h \
|
||||||
|
$(srcdir)/ParfactorList.h \
|
||||||
|
$(srcdir)/ProbFormula.h \
|
||||||
|
$(srcdir)/Solver.h \
|
||||||
$(srcdir)/TinySet.h \
|
$(srcdir)/TinySet.h \
|
||||||
$(srcdir)/LiftedBpSolver.h \
|
|
||||||
$(srcdir)/WeightedBpSolver.h \
|
|
||||||
$(srcdir)/Util.h \
|
$(srcdir)/Util.h \
|
||||||
$(srcdir)/Horus.h
|
$(srcdir)/Var.h \
|
||||||
|
$(srcdir)/VarElim.h \
|
||||||
|
$(srcdir)/WeightedBp.h
|
||||||
|
|
||||||
CPP_SOURCES = \
|
CPP_SOURCES = \
|
||||||
$(srcdir)/BayesNet.cpp \
|
|
||||||
$(srcdir)/BayesBall.cpp \
|
$(srcdir)/BayesBall.cpp \
|
||||||
$(srcdir)/ElimGraph.cpp \
|
$(srcdir)/BayesBallGraph.cpp \
|
||||||
$(srcdir)/FactorGraph.cpp \
|
$(srcdir)/BeliefProp.cpp \
|
||||||
$(srcdir)/Factor.cpp \
|
|
||||||
$(srcdir)/ConstraintTree.cpp \
|
$(srcdir)/ConstraintTree.cpp \
|
||||||
$(srcdir)/Var.cpp \
|
$(srcdir)/CountingBp.cpp \
|
||||||
$(srcdir)/Solver.cpp \
|
$(srcdir)/ElimGraph.cpp \
|
||||||
$(srcdir)/VarElimSolver.cpp \
|
$(srcdir)/Factor.cpp \
|
||||||
$(srcdir)/BpSolver.cpp \
|
$(srcdir)/FactorGraph.cpp \
|
||||||
$(srcdir)/CbpSolver.cpp \
|
|
||||||
$(srcdir)/FoveSolver.cpp \
|
|
||||||
$(srcdir)/Parfactor.cpp \
|
|
||||||
$(srcdir)/ProbFormula.cpp \
|
|
||||||
$(srcdir)/Histogram.cpp \
|
$(srcdir)/Histogram.cpp \
|
||||||
$(srcdir)/ParfactorList.cpp \
|
$(srcdir)/HorusCli.cpp \
|
||||||
$(srcdir)/LiftedUtils.cpp \
|
|
||||||
$(srcdir)/Util.cpp \
|
|
||||||
$(srcdir)/LiftedBpSolver.cpp \
|
|
||||||
$(srcdir)/WeightedBpSolver.cpp \
|
|
||||||
$(srcdir)/HorusYap.cpp \
|
$(srcdir)/HorusYap.cpp \
|
||||||
$(srcdir)/HorusCli.cpp
|
$(srcdir)/LiftedBp.cpp \
|
||||||
|
$(srcdir)/LiftedUtils.cpp \
|
||||||
|
$(srcdir)/LiftedVe.cpp \
|
||||||
|
$(srcdir)/Parfactor.cpp \
|
||||||
|
$(srcdir)/ParfactorList.cpp \
|
||||||
|
$(srcdir)/ProbFormula.cpp \
|
||||||
|
$(srcdir)/Solver.cpp \
|
||||||
|
$(srcdir)/Util.cpp \
|
||||||
|
$(srcdir)/Var.cpp \
|
||||||
|
$(srcdir)/VarElim.cpp \
|
||||||
|
$(srcdir)/WeightedBp.cpp \
|
||||||
|
|
||||||
OBJS = \
|
OBJS = \
|
||||||
BayesNet.o \
|
|
||||||
BayesBall.o \
|
BayesBall.o \
|
||||||
ElimGraph.o \
|
BayesBallGraph.o \
|
||||||
FactorGraph.o \
|
BeliefProp.o \
|
||||||
Factor.o \
|
|
||||||
ConstraintTree.o \
|
ConstraintTree.o \
|
||||||
Var.o \
|
CountingBp.o \
|
||||||
Solver.o \
|
ElimGraph.o \
|
||||||
VarElimSolver.o \
|
Factor.o \
|
||||||
BpSolver.o \
|
FactorGraph.o \
|
||||||
CbpSolver.o \
|
|
||||||
FoveSolver.o \
|
|
||||||
Parfactor.o \
|
|
||||||
ProbFormula.o \
|
|
||||||
Histogram.o \
|
Histogram.o \
|
||||||
ParfactorList.o \
|
HorusYap.o \
|
||||||
|
LiftedBp.o \
|
||||||
LiftedUtils.o \
|
LiftedUtils.o \
|
||||||
|
LiftedVe.o \
|
||||||
|
ProbFormula.o \
|
||||||
|
Parfactor.o \
|
||||||
|
ParfactorList.o \
|
||||||
|
Solver.o \
|
||||||
Util.o \
|
Util.o \
|
||||||
LiftedBpSolver.o \
|
Var.o \
|
||||||
WeightedBpSolver.o \
|
VarElim.o \
|
||||||
HorusYap.o
|
WeightedBp.o
|
||||||
|
|
||||||
HCLI_OBJS = \
|
HCLI_OBJS = \
|
||||||
BayesNet.o \
|
|
||||||
BayesBall.o \
|
BayesBall.o \
|
||||||
|
BayesBallGraph.o \
|
||||||
|
BeliefProp.o \
|
||||||
|
CountingBp.o \
|
||||||
ElimGraph.o \
|
ElimGraph.o \
|
||||||
FactorGraph.o \
|
|
||||||
Factor.o \
|
Factor.o \
|
||||||
ConstraintTree.o \
|
FactorGraph.o \
|
||||||
Var.o \
|
HorusCli.o \
|
||||||
Solver.o \
|
Solver.o \
|
||||||
VarElimSolver.o \
|
|
||||||
BpSolver.o \
|
|
||||||
CbpSolver.o \
|
|
||||||
FoveSolver.o \
|
|
||||||
Parfactor.o \
|
|
||||||
ProbFormula.o \
|
|
||||||
Histogram.o \
|
|
||||||
ParfactorList.o \
|
|
||||||
WeightedBpSolver.o \
|
|
||||||
LiftedUtils.o \
|
|
||||||
Util.o \
|
Util.o \
|
||||||
HorusCli.o
|
Var.o \
|
||||||
|
VarElim.o \
|
||||||
|
WeightedBp.o
|
||||||
|
|
||||||
SOBJS=horus.@SO@
|
SOBJS=horus.@SO@
|
||||||
|
|
||||||
|
@ -402,21 +402,32 @@ Parfactor::applySubstitution (const Substitution& theta)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
PrvGroup
|
size_t
|
||||||
Parfactor::findGroup (const Ground& ground) const
|
Parfactor::indexOfGround (const Ground& ground) const
|
||||||
{
|
{
|
||||||
PrvGroup group = numeric_limits<PrvGroup>::max();
|
size_t idx = args_.size();
|
||||||
for (size_t i = 0; i < args_.size(); i++) {
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
if (args_[i].functor() == ground.functor() &&
|
if (args_[i].functor() == ground.functor() &&
|
||||||
args_[i].arity() == ground.arity()) {
|
args_[i].arity() == ground.arity()) {
|
||||||
constr_->moveToTop (args_[i].logVars());
|
constr_->moveToTop (args_[i].logVars());
|
||||||
if (constr_->containsTuple (ground.args())) {
|
if (constr_->containsTuple (ground.args())) {
|
||||||
group = args_[i].group();
|
idx = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return group;
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
PrvGroup
|
||||||
|
Parfactor::findGroup (const Ground& ground) const
|
||||||
|
{
|
||||||
|
size_t idx = indexOfGround (ground);
|
||||||
|
return idx == args_.size()
|
||||||
|
? numeric_limits<PrvGroup>::max()
|
||||||
|
: args_[idx].group();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -429,6 +440,30 @@ Parfactor::containsGround (const Ground& ground) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Parfactor::containsGrounds (const Grounds& grounds) const
|
||||||
|
{
|
||||||
|
Tuple tuple;
|
||||||
|
LogVars tupleLvs;
|
||||||
|
for (size_t i = 0; i < grounds.size(); i++) {
|
||||||
|
size_t idx = indexOfGround (grounds[i]);
|
||||||
|
if (idx == args_.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
LogVars lvs = args_[idx].logVars();
|
||||||
|
for (size_t j = 0; j < lvs.size(); j++) {
|
||||||
|
if (Util::contains (tupleLvs, lvs[j]) == false) {
|
||||||
|
tuple.push_back (grounds[i].args()[j]);
|
||||||
|
tupleLvs.push_back (lvs[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
constr_->moveToTop (tupleLvs);
|
||||||
|
return constr_->containsTuple (tuple);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
Parfactor::containsGroup (PrvGroup group) const
|
Parfactor::containsGroup (PrvGroup group) const
|
||||||
{
|
{
|
||||||
@ -442,6 +477,19 @@ Parfactor::containsGroup (PrvGroup group) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Parfactor::containsGroups (vector<PrvGroup> groups) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < groups.size(); i++) {
|
||||||
|
if (containsGroup (groups[i]) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
unsigned
|
||||||
Parfactor::nrFormulas (LogVar X) const
|
Parfactor::nrFormulas (LogVar X) const
|
||||||
{
|
{
|
||||||
|
@ -64,12 +64,18 @@ class Parfactor : public TFactor<ProbFormula>
|
|||||||
|
|
||||||
void applySubstitution (const Substitution&);
|
void applySubstitution (const Substitution&);
|
||||||
|
|
||||||
|
size_t indexOfGround (const Ground&) const;
|
||||||
|
|
||||||
PrvGroup findGroup (const Ground&) const;
|
PrvGroup findGroup (const Ground&) const;
|
||||||
|
|
||||||
bool containsGround (const Ground&) const;
|
bool containsGround (const Ground&) const;
|
||||||
|
|
||||||
|
bool containsGrounds (const Grounds&) const;
|
||||||
|
|
||||||
bool containsGroup (PrvGroup) const;
|
bool containsGroup (PrvGroup) const;
|
||||||
|
|
||||||
|
bool containsGroups (vector<PrvGroup>) const;
|
||||||
|
|
||||||
unsigned nrFormulas (LogVar) const;
|
unsigned nrFormulas (LogVar) const;
|
||||||
|
|
||||||
int indexOfLogVar (LogVar) const;
|
int indexOfLogVar (LogVar) const;
|
||||||
|
@ -91,6 +91,8 @@ class ObservedFormula
|
|||||||
|
|
||||||
unsigned evidence (void) const { return evidence_; }
|
unsigned evidence (void) const { return evidence_; }
|
||||||
|
|
||||||
|
void setEvidence (unsigned ev) { evidence_ = ev; }
|
||||||
|
|
||||||
ConstraintTree& constr (void) { return constr_; }
|
ConstraintTree& constr (void) { return constr_; }
|
||||||
|
|
||||||
bool isAtom (void) const { return arity_ == 0; }
|
bool isAtom (void) const { return arity_ == 0; }
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
#include "Solver.h"
|
#include "Solver.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
#include "BeliefProp.h"
|
||||||
|
#include "CountingBp.h"
|
||||||
|
#include "VarElim.h"
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
@ -38,3 +41,67 @@ Solver::printAllPosterioris (void)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
Solver::getJointByConditioning (
|
||||||
|
GroundSolver solverType,
|
||||||
|
FactorGraph fg,
|
||||||
|
const VarIds& jointVarIds) const
|
||||||
|
{
|
||||||
|
VarNodes jointVars;
|
||||||
|
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
||||||
|
assert (fg.getVarNode (jointVarIds[i]));
|
||||||
|
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
Solver* solver = 0;
|
||||||
|
switch (solverType) {
|
||||||
|
case GroundSolver::BP: solver = new BeliefProp (fg); break;
|
||||||
|
case GroundSolver::CBP: solver = new CountingBp (fg); break;
|
||||||
|
case GroundSolver::VE: solver = new VarElim (fg); break;
|
||||||
|
}
|
||||||
|
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
||||||
|
VarIds observedVids = {jointVars[0]->varId()};
|
||||||
|
|
||||||
|
for (size_t i = 1; i < jointVarIds.size(); i++) {
|
||||||
|
assert (jointVars[i]->hasEvidence() == false);
|
||||||
|
Params newBeliefs;
|
||||||
|
Vars observedVars;
|
||||||
|
Ranges observedRanges;
|
||||||
|
for (size_t j = 0; j < observedVids.size(); j++) {
|
||||||
|
observedVars.push_back (fg.getVarNode (observedVids[j]));
|
||||||
|
observedRanges.push_back (observedVars.back()->range());
|
||||||
|
}
|
||||||
|
Indexer indexer (observedRanges, false);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
for (size_t j = 0; j < observedVars.size(); j++) {
|
||||||
|
observedVars[j]->setEvidence (indexer[j]);
|
||||||
|
}
|
||||||
|
delete solver;
|
||||||
|
switch (solverType) {
|
||||||
|
case GroundSolver::BP: solver = new BeliefProp (fg); break;
|
||||||
|
case GroundSolver::CBP: solver = new CountingBp (fg); break;
|
||||||
|
case GroundSolver::VE: solver = new VarElim (fg); break;
|
||||||
|
}
|
||||||
|
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
||||||
|
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||||
|
newBeliefs.push_back (beliefs[k]);
|
||||||
|
}
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
|
||||||
|
int count = -1;
|
||||||
|
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||||
|
if (j % jointVars[i]->range() == 0) {
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
newBeliefs[j] *= prevBeliefs[count];
|
||||||
|
}
|
||||||
|
prevBeliefs = newBeliefs;
|
||||||
|
observedVids.push_back (jointVars[i]->varId());
|
||||||
|
}
|
||||||
|
delete solver;
|
||||||
|
return prevBeliefs;
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -3,8 +3,9 @@
|
|||||||
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
||||||
#include "Var.h"
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
|
#include "Var.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -24,6 +25,9 @@ class Solver
|
|||||||
|
|
||||||
void printAllPosterioris (void);
|
void printAllPosterioris (void);
|
||||||
|
|
||||||
|
Params getJointByConditioning (GroundSolver,
|
||||||
|
FactorGraph, const VarIds& jointVarIds) const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const FactorGraph& fg;
|
const FactorGraph& fg;
|
||||||
};
|
};
|
||||||
|
@ -13,9 +13,9 @@ bool logDomain = false;
|
|||||||
|
|
||||||
unsigned verbosity = 0;
|
unsigned verbosity = 0;
|
||||||
|
|
||||||
LiftedSolvers liftedSolver = LiftedSolvers::FOVE;
|
LiftedSolver liftedSolver = LiftedSolver::FOVE;
|
||||||
|
|
||||||
GroundSolvers groundSolver = GroundSolvers::VE;
|
GroundSolver groundSolver = GroundSolver::VE;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -211,9 +211,9 @@ setHorusFlag (string key, string value)
|
|||||||
ss >> Globals::verbosity;
|
ss >> Globals::verbosity;
|
||||||
} else if (key == "lifted_solver") {
|
} else if (key == "lifted_solver") {
|
||||||
if ( value == "fove") {
|
if ( value == "fove") {
|
||||||
Globals::liftedSolver = LiftedSolvers::FOVE;
|
Globals::liftedSolver = LiftedSolver::FOVE;
|
||||||
} else if (value == "lbp") {
|
} else if (value == "lbp") {
|
||||||
Globals::liftedSolver = LiftedSolvers::LBP;
|
Globals::liftedSolver = LiftedSolver::LBP;
|
||||||
} else {
|
} else {
|
||||||
cerr << "warning: invalid value `" << value << "' " ;
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
cerr << "for `" << key << "'" << endl;
|
cerr << "for `" << key << "'" << endl;
|
||||||
@ -221,11 +221,11 @@ setHorusFlag (string key, string value)
|
|||||||
}
|
}
|
||||||
} else if (key == "ground_solver") {
|
} else if (key == "ground_solver") {
|
||||||
if ( value == "ve") {
|
if ( value == "ve") {
|
||||||
Globals::groundSolver = GroundSolvers::VE;
|
Globals::groundSolver = GroundSolver::VE;
|
||||||
} else if (value == "bp") {
|
} else if (value == "bp") {
|
||||||
Globals::groundSolver = GroundSolvers::BP;
|
Globals::groundSolver = GroundSolver::BP;
|
||||||
} else if (value == "cbp") {
|
} else if (value == "cbp") {
|
||||||
Globals::groundSolver = GroundSolvers::CBP;
|
Globals::groundSolver = GroundSolver::CBP;
|
||||||
} else {
|
} else {
|
||||||
cerr << "warning: invalid value `" << value << "' " ;
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
cerr << "for `" << key << "'" << endl;
|
cerr << "for `" << key << "'" << endl;
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "VarElimSolver.h"
|
#include "VarElim.h"
|
||||||
#include "ElimGraph.h"
|
#include "ElimGraph.h"
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
VarElimSolver::~VarElimSolver (void)
|
VarElim::~VarElim (void)
|
||||||
{
|
{
|
||||||
delete factorList_.back();
|
delete factorList_.back();
|
||||||
}
|
}
|
||||||
@ -14,7 +14,7 @@ VarElimSolver::~VarElimSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
VarElimSolver::solveQuery (VarIds queryVids)
|
VarElim::solveQuery (VarIds queryVids)
|
||||||
{
|
{
|
||||||
if (Globals::verbosity > 1) {
|
if (Globals::verbosity > 1) {
|
||||||
cout << "Solving query on " ;
|
cout << "Solving query on " ;
|
||||||
@ -41,7 +41,7 @@ VarElimSolver::solveQuery (VarIds queryVids)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
VarElimSolver::printSolverFlags (void) const
|
VarElim::printSolverFlags (void) const
|
||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "variable elimination [" ;
|
ss << "variable elimination [" ;
|
||||||
@ -62,7 +62,7 @@ VarElimSolver::printSolverFlags (void) const
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
VarElimSolver::createFactorList (void)
|
VarElim::createFactorList (void)
|
||||||
{
|
{
|
||||||
const FacNodes& facNodes = fg.facNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
factorList_.reserve (facNodes.size() * 2);
|
factorList_.reserve (facNodes.size() * 2);
|
||||||
@ -84,7 +84,7 @@ VarElimSolver::createFactorList (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
VarElimSolver::absorveEvidence (void)
|
VarElim::absorveEvidence (void)
|
||||||
{
|
{
|
||||||
if (Globals::verbosity > 2) {
|
if (Globals::verbosity > 2) {
|
||||||
Util::printDashedLine();
|
Util::printDashedLine();
|
||||||
@ -117,7 +117,7 @@ VarElimSolver::absorveEvidence (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
VarElimSolver::findEliminationOrder (const VarIds& vids)
|
VarElim::findEliminationOrder (const VarIds& vids)
|
||||||
{
|
{
|
||||||
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
|
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
|
||||||
}
|
}
|
||||||
@ -125,7 +125,7 @@ VarElimSolver::findEliminationOrder (const VarIds& vids)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
VarElimSolver::processFactorList (const VarIds& vids)
|
VarElim::processFactorList (const VarIds& vids)
|
||||||
{
|
{
|
||||||
totalFactorSize_ = 0;
|
totalFactorSize_ = 0;
|
||||||
largestFactorSize_ = 0;
|
largestFactorSize_ = 0;
|
||||||
@ -170,7 +170,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
VarElimSolver::eliminate (VarId elimVar)
|
VarElim::eliminate (VarId elimVar)
|
||||||
{
|
{
|
||||||
Factor* result = 0;
|
Factor* result = 0;
|
||||||
vector<size_t>& idxs = varFactors_.find (elimVar)->second;
|
vector<size_t>& idxs = varFactors_.find (elimVar)->second;
|
||||||
@ -205,7 +205,7 @@ VarElimSolver::eliminate (VarId elimVar)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
VarElimSolver::printActiveFactors (void)
|
VarElim::printActiveFactors (void)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < factorList_.size(); i++) {
|
for (size_t i = 0; i < factorList_.size(); i++) {
|
||||||
if (factorList_[i] != 0) {
|
if (factorList_[i] != 0) {
|
@ -1,5 +1,5 @@
|
|||||||
#ifndef HORUS_VARELIMSOLVER_H
|
#ifndef HORUS_VARELIM_H
|
||||||
#define HORUS_VARELIMSOLVER_H
|
#define HORUS_VARELIM_H
|
||||||
|
|
||||||
#include "unordered_map"
|
#include "unordered_map"
|
||||||
|
|
||||||
@ -11,12 +11,12 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
class VarElimSolver : public Solver
|
class VarElim : public Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
VarElimSolver (const FactorGraph& fg) : Solver (fg) { }
|
VarElim (const FactorGraph& fg) : Solver (fg) { }
|
||||||
|
|
||||||
~VarElimSolver (void);
|
~VarElim (void);
|
||||||
|
|
||||||
Params solveQuery (VarIds);
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
@ -42,5 +42,5 @@ class VarElimSolver : public Solver
|
|||||||
unordered_map<VarId, vector<size_t>> varFactors_;
|
unordered_map<VarId, vector<size_t>> varFactors_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_VARELIMSOLVER_H
|
#endif // HORUS_VARELIM_H
|
||||||
|
|
@ -1,7 +1,7 @@
|
|||||||
#include "WeightedBpSolver.h"
|
#include "WeightedBp.h"
|
||||||
|
|
||||||
|
|
||||||
WeightedBpSolver::~WeightedBpSolver (void)
|
WeightedBp::~WeightedBp (void)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
delete links_[i];
|
delete links_[i];
|
||||||
@ -12,7 +12,7 @@ WeightedBpSolver::~WeightedBpSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
WeightedBpSolver::getPosterioriOf (VarId vid)
|
WeightedBp::getPosterioriOf (VarId vid)
|
||||||
{
|
{
|
||||||
if (runned_ == false) {
|
if (runned_ == false) {
|
||||||
runSolver();
|
runSolver();
|
||||||
@ -47,7 +47,7 @@ WeightedBpSolver::getPosterioriOf (VarId vid)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
WeightedBpSolver::createLinks (void)
|
WeightedBp::createLinks (void)
|
||||||
{
|
{
|
||||||
if (Globals::verbosity > 0) {
|
if (Globals::verbosity > 0) {
|
||||||
cout << "compressed factor graph contains " ;
|
cout << "compressed factor graph contains " ;
|
||||||
@ -78,7 +78,7 @@ WeightedBpSolver::createLinks (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
WeightedBpSolver::maxResidualSchedule (void)
|
WeightedBp::maxResidualSchedule (void)
|
||||||
{
|
{
|
||||||
if (nIters_ == 1) {
|
if (nIters_ == 1) {
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
@ -151,7 +151,7 @@ WeightedBpSolver::maxResidualSchedule (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
|
WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
||||||
{
|
{
|
||||||
WeightedLink* link = static_cast<WeightedLink*> (_link);
|
WeightedLink* link = static_cast<WeightedLink*> (_link);
|
||||||
FacNode* src = link->facNode();
|
FacNode* src = link->facNode();
|
||||||
@ -223,7 +223,7 @@ WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const
|
WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
||||||
{
|
{
|
||||||
const WeightedLink* link = static_cast<const WeightedLink*> (_link);
|
const WeightedLink* link = static_cast<const WeightedLink*> (_link);
|
||||||
const VarNode* src = link->varNode();
|
const VarNode* src = link->varNode();
|
||||||
@ -272,7 +272,7 @@ WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
WeightedBpSolver::printLinkInformation (void) const
|
WeightedBp::printLinkInformation (void) const
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
WeightedLink* l = static_cast<WeightedLink*> (links_[i]);
|
WeightedLink* l = static_cast<WeightedLink*> (links_[i]);
|
@ -1,7 +1,7 @@
|
|||||||
#ifndef HORUS_WEIGHTEDBPSOLVER_H
|
#ifndef HORUS_WEIGHTEDBP_H
|
||||||
#define HORUS_WEIGHTEDBPSOLVER_H
|
#define HORUS_WEIGHTEDBP_H
|
||||||
|
|
||||||
#include "BpSolver.h"
|
#include "BeliefProp.h"
|
||||||
|
|
||||||
class WeightedLink : public BpLink
|
class WeightedLink : public BpLink
|
||||||
{
|
{
|
||||||
@ -31,14 +31,14 @@ class WeightedLink : public BpLink
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WeightedBpSolver : public BpSolver
|
class WeightedBp : public BeliefProp
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
WeightedBpSolver (const FactorGraph& fg,
|
WeightedBp (const FactorGraph& fg,
|
||||||
const vector<vector<unsigned>>& weights)
|
const vector<vector<unsigned>>& weights)
|
||||||
: BpSolver (fg), weights_(weights) { }
|
: BeliefProp (fg), weights_(weights) { }
|
||||||
|
|
||||||
~WeightedBpSolver (void);
|
~WeightedBp (void);
|
||||||
|
|
||||||
Params getPosterioriOf (VarId);
|
Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
@ -57,5 +57,5 @@ class WeightedBpSolver : public BpSolver
|
|||||||
vector<vector<unsigned>> weights_;
|
vector<vector<unsigned>> weights_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_WEIGHTEDBPSOLVER_H
|
#endif // HORUS_WEIGHTEDBP_H
|
||||||
|
|
Reference in New Issue
Block a user