Merge branch 'master' of https://github.com/tacgomes/yap6.3
This commit is contained in:
commit
5fe052a3ef
110
packages/CLPBN/README.txt
Normal file
110
packages/CLPBN/README.txt
Normal file
@ -0,0 +1,110 @@
|
||||
Prolog Factor Language (PFL)
|
||||
|
||||
Prolog Factor Language (PFL) is a extension of the Prolog language that
|
||||
allows a natural representation of this first-order probabilistic models
|
||||
(either directed or undirected). PFL is also capable of solving probabilistic
|
||||
queries on this models through the implementation of several inference
|
||||
techniques: variable elimination, belief propagation, lifted variable
|
||||
elimination and lifted belief propagation.
|
||||
|
||||
Language
|
||||
-------------------------------------------------------------------------------
|
||||
A graphical model in PFL is represented using parfactors. A PFL parfactor
|
||||
has the following four components:
|
||||
|
||||
Type ; Formulas ; Phi ; Constraint .
|
||||
|
||||
- Type refers the type of the network over which the parfactor is defined.
|
||||
It can be bayes for directed networks, or markov for undirected ones.
|
||||
- Formulas is a sequence of Prolog terms that define sets of random variables
|
||||
under the constraint.
|
||||
- Phi is either a list of parameters or a call to a Prolog goal that will
|
||||
unify its last argument with a list of parameters.
|
||||
- Constraint is a list (possible empty) of Prolog goals that will impose
|
||||
bindings on the logical variables that appear in the formulas.
|
||||
|
||||
The "examples" directory contains some popular graphical models described
|
||||
using PFL.
|
||||
|
||||
Querying
|
||||
-------------------------------------------------------------------------------
|
||||
Now we show how to use PFL to solve probabilistic queries. We will
|
||||
use the burlgary alarm network as an example. First, we load the model:
|
||||
|
||||
$ yap -l examples/burglary-alarm.yap
|
||||
|
||||
Now let's suppose that we want to estimate the probability of a earthquake
|
||||
ocurred given that mary called. We can do it with the following query:
|
||||
|
||||
?- earthquake(X), mary_calls(t).
|
||||
|
||||
Suppose now that we want the joint distribution for john_calls and
|
||||
mary_calls. We can obtain this with the following query:
|
||||
|
||||
?- john_calls(X), mary_calls(Y).
|
||||
|
||||
|
||||
Inference Options
|
||||
-------------------------------------------------------------------------------
|
||||
PFL supports both ground and lifted inference. The inference algorithm
|
||||
can be chosen using the set_solver/1 predicate. The following algorithms
|
||||
are supported:
|
||||
- fove: lifted variable elimination with arbitrary constraints (GC-FOVE)
|
||||
- hve: (ground) variable elimination
|
||||
- lbp: lifted first-order belief propagation
|
||||
- cbp: counting belief propagation
|
||||
- bp: (ground) belief propagation
|
||||
|
||||
For example, if we want to use ground variable elimination to solve some
|
||||
query, we need to call first the following goal:
|
||||
|
||||
?- set_solver(hve).
|
||||
|
||||
It is possible to tweak several parameters of PFL through the
|
||||
set_horus_flag/2 predicate. The first argument is a key that
|
||||
identifies the parameter that we desire to tweak, while the second
|
||||
is some possible value for this key.
|
||||
|
||||
The verbosity key controls the level of log information that will be
|
||||
printed by the corresponding solver. Its possible values are positive
|
||||
integers. The bigger the number, more log information will be printed.
|
||||
For example, to view some basic log information we need to call the
|
||||
following goal:
|
||||
|
||||
?- set_horus_flag(verbosity, 1).
|
||||
|
||||
The use_logarithms key controls whether the calculations performed
|
||||
during inference should be done in the log domain or not. Its values
|
||||
can be true or false. By default is false.
|
||||
|
||||
There are also keys specific to the inference algorithm. For example,
|
||||
elim_heuristic key controls the elimination heuristic that will be
|
||||
used by ground variable elimination. The following heuristics are
|
||||
supported:
|
||||
- sequential
|
||||
- min_neighbors
|
||||
- min_weight
|
||||
- min_fill
|
||||
- weighted_min_fill
|
||||
|
||||
An explanation of this heuristics can be found in Probabilistic Graphical
|
||||
Models by Daphne Koller.
|
||||
|
||||
The schedule, accuracy and max_iter keys are specific for inference
|
||||
algorithms based on message passing, namely lbp, cbp and bp.
|
||||
The key schedule can be used to specify the order in which the messages
|
||||
are sent in belief propagation. The possible values are:
|
||||
- seq_fixed: at each iteration, all messages are sent in the same order
|
||||
- seq_random: at each iteration, the messages are sent with a random order
|
||||
- parallel: at each iteration, the messages are all calculated using the
|
||||
values of the previous iteration.
|
||||
- max_residual: the next message to be sent is the one with maximum residual,
|
||||
(Residual Belief Propagation:Informed Scheduling for Asynchronous Message
|
||||
Passing)
|
||||
|
||||
The max_iter key sets the maximum number of iterations. One iteration
|
||||
consists in sending all possible messages. The accuracy key indicate
|
||||
when we should stop sending messages. If the largest difference between
|
||||
a message sent in the current iteration and one message sent in the previous
|
||||
iteration is less that accuracy value given, we terminate belief propagation.
|
||||
|
@ -26,6 +26,8 @@ function run_solver
|
||||
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||
elif [ $SOLVER = cbp ]; then
|
||||
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||
elif [ $SOLVER = lbp ]; then
|
||||
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||
else
|
||||
echo "unknow flag $2"
|
||||
fi
|
||||
|
@ -23,7 +23,7 @@ function run_all_graphs
|
||||
run_solver city60000 $2
|
||||
run_solver city65000 $2
|
||||
run_solver city70000 $2
|
||||
|
||||
return
|
||||
run_solver city75000 $2
|
||||
run_solver city80000 $2
|
||||
run_solver city85000 $2
|
||||
|
36
packages/CLPBN/benchmarks/city/lbp_tests.sh
Executable file
36
packages/CLPBN/benchmarks/city/lbp_tests.sh
Executable file
@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
|
||||
source city.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="lbp"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
write_header $1
|
||||
run_solver city1000 $2
|
||||
run_solver city5000 $2
|
||||
run_solver city10000 $2
|
||||
run_solver city15000 $2
|
||||
run_solver city20000 $2
|
||||
run_solver city25000 $2
|
||||
run_solver city30000 $2
|
||||
run_solver city35000 $2
|
||||
run_solver city40000 $2
|
||||
run_solver city45000 $2
|
||||
run_solver city50000 $2
|
||||
run_solver city55000 $2
|
||||
run_solver city60000 $2
|
||||
run_solver city65000 $2
|
||||
run_solver city70000 $2
|
||||
run_solver city75000 $2
|
||||
run_solver city80000 $2
|
||||
run_solver city85000 $2
|
||||
run_solver city90000 $2
|
||||
run_solver city95000 $2
|
||||
run_solver city100000 $2
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed
|
||||
|
30
packages/CLPBN/benchmarks/comp_workshops/lbp_tests.sh
Executable file
30
packages/CLPBN/benchmarks/comp_workshops/lbp_tests.sh
Executable file
@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
source cw.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="lbp"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
write_header $1
|
||||
run_solver p1000w$N_WORKSHOPS $2
|
||||
run_solver p5000w$N_WORKSHOPS $2
|
||||
run_solver p10000w$N_WORKSHOPS $2
|
||||
run_solver p15000w$N_WORKSHOPS $2
|
||||
run_solver p20000w$N_WORKSHOPS $2
|
||||
run_solver p25000w$N_WORKSHOPS $2
|
||||
run_solver p30000w$N_WORKSHOPS $2
|
||||
run_solver p35000w$N_WORKSHOPS $2
|
||||
run_solver p40000w$N_WORKSHOPS $2
|
||||
run_solver p45000w$N_WORKSHOPS $2
|
||||
run_solver p50000w$N_WORKSHOPS $2
|
||||
run_solver p55000w$N_WORKSHOPS $2
|
||||
run_solver p60000w$N_WORKSHOPS $2
|
||||
run_solver p65000w$N_WORKSHOPS $2
|
||||
run_solver p70000w$N_WORKSHOPS $2
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed
|
||||
|
35
packages/CLPBN/benchmarks/run_all.sh
Executable file
35
packages/CLPBN/benchmarks/run_all.sh
Executable file
@ -0,0 +1,35 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd workshop_attrs
|
||||
source hve_tests.sh
|
||||
source bp_tests.sh
|
||||
source fove_tests.sh
|
||||
source lbp_tests.sh
|
||||
source cbp_tests.sh
|
||||
cd ..
|
||||
|
||||
cd comp_workshops
|
||||
source hve_tests.sh
|
||||
source bp_tests.sh
|
||||
source fove_tests.sh
|
||||
source lbp_tests.sh
|
||||
source cbp_tests.sh
|
||||
cd ..
|
||||
|
||||
cd city
|
||||
source hve_tests.sh
|
||||
source bp_tests.sh
|
||||
source fove_tests.sh
|
||||
source lbp_tests.sh
|
||||
source cbp_tests.sh
|
||||
cd ..
|
||||
|
||||
cd smokers
|
||||
source hve_tests.sh
|
||||
source bp_tests.sh
|
||||
source fove_tests.sh
|
||||
source lbp_tests.sh
|
||||
source cbp_tests.sh
|
||||
cd ..
|
||||
|
||||
|
30
packages/CLPBN/benchmarks/smokers/lbp_tests.sh
Executable file
30
packages/CLPBN/benchmarks/smokers/lbp_tests.sh
Executable file
@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
source sm.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="lbp"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
write_header $1
|
||||
run_solver pop100 $2
|
||||
run_solver pop200 $2
|
||||
run_solver pop300 $2
|
||||
run_solver pop400 $2
|
||||
run_solver pop500 $2
|
||||
run_solver pop600 $2
|
||||
run_solver pop700 $2
|
||||
run_solver pop800 $2
|
||||
run_solver pop900 $2
|
||||
run_solver pop1000 $2
|
||||
run_solver pop1100 $2
|
||||
run_solver pop1200 $2
|
||||
run_solver pop1300 $2
|
||||
run_solver pop1400 $2
|
||||
run_solver pop1500 $2
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed
|
||||
|
@ -2,5 +2,6 @@
|
||||
|
||||
NETWORK="'../../examples/social_domain2'"
|
||||
SHORTNAME="sm"
|
||||
QUERY="smokes(p1,t), smokes(p2,t), friends(p1,p2,X)"
|
||||
#QUERY="smokes(p1,t), smokes(p2,t), friends(p1,p2,X)"
|
||||
QUERY="friends(p1,p2,X)"
|
||||
|
||||
|
34
packages/CLPBN/benchmarks/smokers_evidence/bp_tests.sh
Executable file
34
packages/CLPBN/benchmarks/smokers_evidence/bp_tests.sh
Executable file
@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
|
||||
source sm.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="bp"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
write_header $1
|
||||
run_solver ev0p$POP $2
|
||||
run_solver ev5p$POP $2
|
||||
run_solver ev10p$POP $2
|
||||
run_solver ev15p$POP $2
|
||||
run_solver ev20p$POP $2
|
||||
run_solver ev25p$POP $2
|
||||
run_solver ev30p$POP $2
|
||||
run_solver ev35p$POP $2
|
||||
run_solver ev40p$POP $2
|
||||
run_solver ev45p$POP $2
|
||||
run_solver ev50p$POP $2
|
||||
run_solver ev55p$POP $2
|
||||
run_solver ev60p$POP $2
|
||||
run_solver ev65p$POP $2
|
||||
run_solver ev70p$POP $2
|
||||
run_solver ev75p$POP $2
|
||||
run_solver ev80p$POP $2
|
||||
run_solver ev85p$POP $2
|
||||
run_solver ev90p$POP $2
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
||||
|
34
packages/CLPBN/benchmarks/smokers_evidence/cbp_tests.sh
Executable file
34
packages/CLPBN/benchmarks/smokers_evidence/cbp_tests.sh
Executable file
@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
|
||||
source sm.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="cbp"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
write_header $1
|
||||
run_solver ev0p$POP $2
|
||||
run_solver ev5p$POP $2
|
||||
run_solver ev10p$POP $2
|
||||
run_solver ev15p$POP $2
|
||||
run_solver ev20p$POP $2
|
||||
run_solver ev25p$POP $2
|
||||
run_solver ev30p$POP $2
|
||||
run_solver ev35p$POP $2
|
||||
run_solver ev40p$POP $2
|
||||
run_solver ev45p$POP $2
|
||||
run_solver ev50p$POP $2
|
||||
run_solver ev55p$POP $2
|
||||
run_solver ev60p$POP $2
|
||||
run_solver ev65p$POP $2
|
||||
run_solver ev70p$POP $2
|
||||
run_solver ev75p$POP $2
|
||||
run_solver ev80p$POP $2
|
||||
run_solver ev85p$POP $2
|
||||
run_solver ev90p$POP $2
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
||||
|
35
packages/CLPBN/benchmarks/smokers_evidence/fove_tests.sh
Executable file
35
packages/CLPBN/benchmarks/smokers_evidence/fove_tests.sh
Executable file
@ -0,0 +1,35 @@
|
||||
#!/bin/bash
|
||||
|
||||
source sm.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="fove"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
write_header $1
|
||||
run_solver ev0p$POP $2
|
||||
run_solver ev5p$POP $2
|
||||
run_solver ev10p$POP $2
|
||||
run_solver ev15p$POP $2
|
||||
run_solver ev20p$POP $2
|
||||
run_solver ev25p$POP $2
|
||||
run_solver ev30p$POP $2
|
||||
run_solver ev35p$POP $2
|
||||
run_solver ev40p$POP $2
|
||||
run_solver ev45p$POP $2
|
||||
run_solver ev50p$POP $2
|
||||
run_solver ev55p$POP $2
|
||||
run_solver ev60p$POP $2
|
||||
run_solver ev65p$POP $2
|
||||
run_solver ev70p$POP $2
|
||||
run_solver ev75p$POP $2
|
||||
run_solver ev80p$POP $2
|
||||
run_solver ev85p$POP $2
|
||||
run_solver ev90p$POP $2
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "fove "
|
||||
|
||||
|
49
packages/CLPBN/benchmarks/smokers_evidence/gen_people.sh
Executable file
49
packages/CLPBN/benchmarks/smokers_evidence/gen_people.sh
Executable file
@ -0,0 +1,49 @@
|
||||
#!/home/tgomes/bin/yap -L --
|
||||
|
||||
:- use_module(library(lists)).
|
||||
:- use_module(library(random)).
|
||||
|
||||
|
||||
:- initialization(main).
|
||||
|
||||
main :-
|
||||
unix(argv(Args)),
|
||||
nth(1, Args, EV), % percentage of evidence
|
||||
nth(2, Args, NP), % number of individuals
|
||||
atomic_concat(['ev', EV, 'p', NP, '.yap'], FileName),
|
||||
open(FileName, 'write', S),
|
||||
atom_number(EV, EV2),
|
||||
atom_number(NP, NP2),
|
||||
EV3 is EV2 / 100.0,
|
||||
generate_people(S, NP2, 4),
|
||||
write(S, '\n'),
|
||||
write(S, 'query(X) :- '),
|
||||
generate_evidence(S, NP2, EV3, 4),
|
||||
write(S, 'friends(p1,p2,X).\n'),
|
||||
close(S).
|
||||
|
||||
|
||||
generate_people(S, N, Counting) :-
|
||||
Counting > N, !.
|
||||
generate_people(S, N, Counting) :-
|
||||
format(S, 'people(p~w).~n', [Counting]),
|
||||
Counting1 is Counting + 1,
|
||||
generate_people(S, N, Counting1).
|
||||
|
||||
|
||||
generate_evidence(S, N, Ev, Counting) :-
|
||||
Counting > N, !.
|
||||
generate_evidence(S, N, Ev, Counting) :-
|
||||
random(X),
|
||||
(
|
||||
X < Ev
|
||||
->
|
||||
random(Y),
|
||||
(Y > 0.5 -> Val = t ; Val = f),
|
||||
format(S, 'smokes(p~w,~w),', [Counting,Val])
|
||||
;
|
||||
true
|
||||
),
|
||||
Counting1 is Counting + 1,
|
||||
generate_evidence(S, N, Ev, Counting1).
|
||||
|
37
packages/CLPBN/benchmarks/smokers_evidence/hve_tests.sh
Executable file
37
packages/CLPBN/benchmarks/smokers_evidence/hve_tests.sh
Executable file
@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
|
||||
source sm.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="hve"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
write_header $1
|
||||
run_solver ev0p$POP $2
|
||||
run_solver ev5p$POP $2
|
||||
run_solver ev10p$POP $2
|
||||
run_solver ev15p$POP $2
|
||||
run_solver ev20p$POP $2
|
||||
run_solver ev25p$POP $2
|
||||
run_solver ev30p$POP $2
|
||||
run_solver ev35p$POP $2
|
||||
run_solver ev40p$POP $2
|
||||
run_solver ev45p$POP $2
|
||||
run_solver ev50p$POP $2
|
||||
run_solver ev55p$POP $2
|
||||
run_solver ev60p$POP $2
|
||||
run_solver ev65p$POP $2
|
||||
run_solver ev70p$POP $2
|
||||
run_solver ev75p$POP $2
|
||||
run_solver ev80p$POP $2
|
||||
run_solver ev85p$POP $2
|
||||
run_solver ev90p$POP $2
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
||||
#run_all_graphs "hve(elim_heuristic=min_weight) " min_weight
|
||||
#run_all_graphs "hve(elim_heuristic=min_fill) " min_fill
|
||||
#run_all_graphs "hve(elim_heuristic=weighted_min_fill) " weighted_min_fill
|
||||
|
34
packages/CLPBN/benchmarks/smokers_evidence/lbp_tests.sh
Executable file
34
packages/CLPBN/benchmarks/smokers_evidence/lbp_tests.sh
Executable file
@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
|
||||
source sm.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="lbp"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
write_header $1
|
||||
run_solver ev0p$POP $2
|
||||
run_solver ev5p$POP $2
|
||||
run_solver ev10p$POP $2
|
||||
run_solver ev15p$POP $2
|
||||
run_solver ev20p$POP $2
|
||||
run_solver ev25p$POP $2
|
||||
run_solver ev30p$POP $2
|
||||
run_solver ev35p$POP $2
|
||||
run_solver ev40p$POP $2
|
||||
run_solver ev45p$POP $2
|
||||
run_solver ev50p$POP $2
|
||||
run_solver ev55p$POP $2
|
||||
run_solver ev60p$POP $2
|
||||
run_solver ev65p$POP $2
|
||||
run_solver ev70p$POP $2
|
||||
run_solver ev75p$POP $2
|
||||
run_solver ev80p$POP $2
|
||||
run_solver ev85p$POP $2
|
||||
run_solver ev90p$POP $2
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed
|
||||
|
8
packages/CLPBN/benchmarks/smokers_evidence/sm.sh
Executable file
8
packages/CLPBN/benchmarks/smokers_evidence/sm.sh
Executable file
@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
NETWORK="'../../examples/social_domain2'"
|
||||
SHORTNAME="sm"
|
||||
QUERY="query(X)"
|
||||
|
||||
POP=500
|
||||
|
@ -15,8 +15,8 @@ function run_all_graphs
|
||||
run_solver p20000attrs$N_ATTRS $2
|
||||
run_solver p25000attrs$N_ATTRS $2
|
||||
run_solver p30000attrs$N_ATTRS $2
|
||||
run_solver p35000attrs$N_ATTRS $2
|
||||
return
|
||||
run_solver p35000attrs$N_ATTRS $2
|
||||
run_solver p40000attrs$N_ATTRS $2
|
||||
run_solver p45000attrs$N_ATTRS $2
|
||||
run_solver p50000attrs$N_ATTRS $2
|
||||
|
36
packages/CLPBN/benchmarks/workshop_attrs/lbp_tests.sh
Executable file
36
packages/CLPBN/benchmarks/workshop_attrs/lbp_tests.sh
Executable file
@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
|
||||
source wa.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="lbp"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
write_header $1
|
||||
run_solver p1000attrs$N_ATTRS $2
|
||||
run_solver p5000attrs$N_ATTRS $2
|
||||
run_solver p10000attrs$N_ATTRS $2
|
||||
run_solver p15000attrs$N_ATTRS $2
|
||||
run_solver p20000attrs$N_ATTRS $2
|
||||
run_solver p25000attrs$N_ATTRS $2
|
||||
run_solver p30000attrs$N_ATTRS $2
|
||||
run_solver p35000attrs$N_ATTRS $2
|
||||
run_solver p40000attrs$N_ATTRS $2
|
||||
run_solver p45000attrs$N_ATTRS $2
|
||||
run_solver p50000attrs$N_ATTRS $2
|
||||
run_solver p55000attrs$N_ATTRS $2
|
||||
run_solver p60000attrs$N_ATTRS $2
|
||||
run_solver p65000attrs$N_ATTRS $2
|
||||
run_solver p70000attrs$N_ATTRS $2
|
||||
run_solver p75000attrs$N_ATTRS $2
|
||||
run_solver p80000attrs$N_ATTRS $2
|
||||
run_solver p85000attrs$N_ATTRS $2
|
||||
run_solver p90000attrs$N_ATTRS $2
|
||||
run_solver p95000attrs$N_ATTRS $2
|
||||
run_solver p100000attrs$N_ATTRS $2
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "lbp(shedule=seq_fixed) " seq_fixed
|
||||
|
@ -15,7 +15,7 @@
|
||||
cpp_run_ground_solver/3,
|
||||
cpp_set_vars_information/2,
|
||||
cpp_set_horus_flag/2,
|
||||
cpp_free_parfactors/1,
|
||||
cpp_free_lifted_network/1,
|
||||
cpp_free_ground_network/1
|
||||
]).
|
||||
|
||||
|
@ -79,7 +79,6 @@ run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||
%writeln(queryKeys:QueryKeys), writeln(''),
|
||||
%writeln(queryIds:QueryIds), writeln(''),
|
||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||
cpp_run_ground_solver(Network, [QueryIds], Solutions).
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
[cpp_create_lifted_network/3,
|
||||
cpp_set_parfactors_params/2,
|
||||
cpp_run_lifted_solver/3,
|
||||
cpp_free_parfactors/1
|
||||
cpp_free_lifted_network/1
|
||||
]).
|
||||
|
||||
:- use_module(library('clpbn/display'),
|
||||
@ -144,5 +144,5 @@ run_horus_lifted_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds))
|
||||
|
||||
|
||||
finalize_horus_lifted_solver(fove(ParfactorList, _)) :-
|
||||
cpp_free_parfactors(ParfactorList).
|
||||
cpp_free_lifted_network(ParfactorList).
|
||||
|
||||
|
@ -7,17 +7,17 @@
|
||||
|
||||
:- yap_flag(write_strings, off).
|
||||
|
||||
bayes burglary::[b1,b2] ; [0.001, 0.999] ; [].
|
||||
bayes burglary::[t,f] ; [0.001, 0.999] ; [].
|
||||
|
||||
bayes earthquake::[e1,e2] ; [0.002, 0.998]; [].
|
||||
bayes earthquake::[t,f] ; [0.002, 0.998]; [].
|
||||
|
||||
bayes alarm::[a1,a2], burglary, earthquake ;
|
||||
bayes alarm::[t,f], burglary, earthquake ;
|
||||
[0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ;
|
||||
[].
|
||||
|
||||
bayes john_calls::[j1,j2], alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
|
||||
bayes john_calls::[t,f], alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
|
||||
|
||||
bayes mary_calls::[m1,m2], alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
|
||||
bayes mary_calls::[t,f], alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
|
||||
|
||||
% ?- john_calls(J), mary_calls(m1).
|
||||
% ?- john_calls(J), mary_calls(t).
|
||||
|
||||
|
@ -16,13 +16,13 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||
Scheduling scheduling;
|
||||
for (size_t i = 0; i < queryIds.size(); i++) {
|
||||
assert (dag_.getNode (queryIds[i]));
|
||||
DAGraphNode* n = dag_.getNode (queryIds[i]);
|
||||
BBNode* n = dag_.getNode (queryIds[i]);
|
||||
scheduling.push (ScheduleInfo (n, false, true));
|
||||
}
|
||||
|
||||
while (!scheduling.empty()) {
|
||||
ScheduleInfo& sch = scheduling.front();
|
||||
DAGraphNode* n = sch.node;
|
||||
BBNode* n = sch.node;
|
||||
n->setAsVisited();
|
||||
if (n->hasEvidence() == false && sch.visitedFromChild) {
|
||||
if (n->isMarkedOnTop() == false) {
|
||||
@ -59,7 +59,7 @@ BayesBall::constructGraph (FactorGraph* fg) const
|
||||
{
|
||||
const FacNodes& facNodes = fg_.facNodes();
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
const DAGraphNode* n = dag_.getNode (
|
||||
const BBNode* n = dag_.getNode (
|
||||
facNodes[i]->factor().argument (0));
|
||||
if (n->isMarkedOnTop()) {
|
||||
fg->addFactor (facNodes[i]->factor());
|
||||
|
@ -7,7 +7,7 @@
|
||||
#include <map>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "BayesNet.h"
|
||||
#include "BayesBallGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
@ -15,12 +15,12 @@ using namespace std;
|
||||
|
||||
struct ScheduleInfo
|
||||
{
|
||||
ScheduleInfo (DAGraphNode* n, bool vfp, bool vfc) :
|
||||
ScheduleInfo (BBNode* n, bool vfp, bool vfc) :
|
||||
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||
|
||||
DAGraphNode* node;
|
||||
bool visitedFromParent;
|
||||
bool visitedFromChild;
|
||||
BBNode* node;
|
||||
bool visitedFromParent;
|
||||
bool visitedFromChild;
|
||||
};
|
||||
|
||||
|
||||
@ -30,40 +30,40 @@ typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
|
||||
class BayesBall
|
||||
{
|
||||
public:
|
||||
BayesBall (FactorGraph& fg)
|
||||
: fg_(fg) , dag_(fg.getStructure())
|
||||
{
|
||||
dag_.clear();
|
||||
}
|
||||
BayesBall (FactorGraph& fg)
|
||||
: fg_(fg) , dag_(fg.getStructure())
|
||||
{
|
||||
dag_.clear();
|
||||
}
|
||||
|
||||
FactorGraph* getMinimalFactorGraph (const VarIds&);
|
||||
FactorGraph* getMinimalFactorGraph (const VarIds&);
|
||||
|
||||
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
|
||||
{
|
||||
BayesBall bb (fg);
|
||||
return bb.getMinimalFactorGraph (vids);
|
||||
}
|
||||
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
|
||||
{
|
||||
BayesBall bb (fg);
|
||||
return bb.getMinimalFactorGraph (vids);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
void constructGraph (FactorGraph* fg) const;
|
||||
|
||||
void scheduleParents (const DAGraphNode* n, Scheduling& sch) const;
|
||||
void scheduleParents (const BBNode* n, Scheduling& sch) const;
|
||||
|
||||
void scheduleChilds (const DAGraphNode* n, Scheduling& sch) const;
|
||||
void scheduleChilds (const BBNode* n, Scheduling& sch) const;
|
||||
|
||||
FactorGraph& fg_;
|
||||
|
||||
DAGraph& dag_;
|
||||
BayesBallGraph& dag_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
|
||||
BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const
|
||||
{
|
||||
const vector<DAGraphNode*>& ps = n->parents();
|
||||
for (vector<DAGraphNode*>::const_iterator it = ps.begin();
|
||||
const vector<BBNode*>& ps = n->parents();
|
||||
for (vector<BBNode*>::const_iterator it = ps.begin();
|
||||
it != ps.end(); ++it) {
|
||||
sch.push (ScheduleInfo (*it, false, true));
|
||||
}
|
||||
@ -72,10 +72,10 @@ BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
|
||||
|
||||
|
||||
inline void
|
||||
BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const
|
||||
BayesBall::scheduleChilds (const BBNode* n, Scheduling& sch) const
|
||||
{
|
||||
const vector<DAGraphNode*>& cs = n->childs();
|
||||
for (vector<DAGraphNode*>::const_iterator it = cs.begin();
|
||||
const vector<BBNode*>& cs = n->childs();
|
||||
for (vector<BBNode*>::const_iterator it = cs.begin();
|
||||
it != cs.end(); ++it) {
|
||||
sch.push (ScheduleInfo (*it, true, false));
|
||||
}
|
||||
|
@ -5,12 +5,12 @@
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "BayesNet.h"
|
||||
#include "BayesBallGraph.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
void
|
||||
DAGraph::addNode (DAGraphNode* n)
|
||||
BayesBallGraph::addNode (BBNode* n)
|
||||
{
|
||||
assert (Util::contains (varMap_, n->varId()) == false);
|
||||
nodes_.push_back (n);
|
||||
@ -20,10 +20,10 @@ DAGraph::addNode (DAGraphNode* n)
|
||||
|
||||
|
||||
void
|
||||
DAGraph::addEdge (VarId vid1, VarId vid2)
|
||||
BayesBallGraph::addEdge (VarId vid1, VarId vid2)
|
||||
{
|
||||
unordered_map<VarId, DAGraphNode*>::iterator it1;
|
||||
unordered_map<VarId, DAGraphNode*>::iterator it2;
|
||||
unordered_map<VarId, BBNode*>::iterator it1;
|
||||
unordered_map<VarId, BBNode*>::iterator it2;
|
||||
it1 = varMap_.find (vid1);
|
||||
it2 = varMap_.find (vid2);
|
||||
assert (it1 != varMap_.end());
|
||||
@ -34,20 +34,20 @@ DAGraph::addEdge (VarId vid1, VarId vid2)
|
||||
|
||||
|
||||
|
||||
const DAGraphNode*
|
||||
DAGraph::getNode (VarId vid) const
|
||||
const BBNode*
|
||||
BayesBallGraph::getNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||
unordered_map<VarId, BBNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
DAGraphNode*
|
||||
DAGraph::getNode (VarId vid)
|
||||
BBNode*
|
||||
BayesBallGraph::getNode (VarId vid)
|
||||
{
|
||||
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||
unordered_map<VarId, BBNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
@ -55,7 +55,7 @@ DAGraph::getNode (VarId vid)
|
||||
|
||||
|
||||
void
|
||||
DAGraph::setIndexes (void)
|
||||
BayesBallGraph::setIndexes (void)
|
||||
{
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->setIndex (i);
|
||||
@ -65,7 +65,7 @@ DAGraph::setIndexes (void)
|
||||
|
||||
|
||||
void
|
||||
DAGraph::clear (void)
|
||||
BayesBallGraph::clear (void)
|
||||
{
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->clear();
|
||||
@ -75,12 +75,12 @@ DAGraph::clear (void)
|
||||
|
||||
|
||||
void
|
||||
DAGraph::exportToGraphViz (const char* fileName)
|
||||
BayesBallGraph::exportToGraphViz (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "DAGraph::exportToDotFile()" << endl;
|
||||
cerr << "BayesBallGraph::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
out << "digraph {" << endl;
|
||||
@ -95,7 +95,7 @@ DAGraph::exportToGraphViz (const char* fileName)
|
||||
out << "]" << endl;
|
||||
}
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
|
||||
const vector<BBNode*>& childs = nodes_[i]->childs();
|
||||
for (size_t j = 0; j < childs.size(); j++) {
|
||||
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
||||
out << " [style=bold]" << endl ;
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_BAYESNET_H
|
||||
#define HORUS_BAYESNET_H
|
||||
#ifndef HORUS_BAYESBALLGRAPH_H
|
||||
#define HORUS_BAYESBALLGRAPH_H
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
@ -9,29 +9,25 @@
|
||||
#include "Var.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class Var;
|
||||
|
||||
class DAGraphNode : public Var
|
||||
class BBNode : public Var
|
||||
{
|
||||
public:
|
||||
DAGraphNode (Var* v) : Var (v) , visited_(false),
|
||||
BBNode (Var* v) : Var (v) , visited_(false),
|
||||
markedOnTop_(false), markedOnBottom_(false) { }
|
||||
|
||||
const vector<DAGraphNode*>& childs (void) const { return childs_; }
|
||||
const vector<BBNode*>& childs (void) const { return childs_; }
|
||||
|
||||
vector<DAGraphNode*>& childs (void) { return childs_; }
|
||||
vector<BBNode*>& childs (void) { return childs_; }
|
||||
|
||||
const vector<DAGraphNode*>& parents (void) const { return parents_; }
|
||||
const vector<BBNode*>& parents (void) const { return parents_; }
|
||||
|
||||
vector<DAGraphNode*>& parents (void) { return parents_; }
|
||||
vector<BBNode*>& parents (void) { return parents_; }
|
||||
|
||||
void addParent (DAGraphNode* p) { parents_.push_back (p); }
|
||||
void addParent (BBNode* p) { parents_.push_back (p); }
|
||||
|
||||
void addChild (DAGraphNode* c) { childs_.push_back (c); }
|
||||
void addChild (BBNode* c) { childs_.push_back (c); }
|
||||
|
||||
bool isVisited (void) const { return visited_; }
|
||||
|
||||
@ -52,23 +48,23 @@ class DAGraphNode : public Var
|
||||
bool markedOnTop_;
|
||||
bool markedOnBottom_;
|
||||
|
||||
vector<DAGraphNode*> childs_;
|
||||
vector<DAGraphNode*> parents_;
|
||||
vector<BBNode*> childs_;
|
||||
vector<BBNode*> parents_;
|
||||
};
|
||||
|
||||
|
||||
class DAGraph
|
||||
class BayesBallGraph
|
||||
{
|
||||
public:
|
||||
DAGraph (void) { }
|
||||
BayesBallGraph (void) { }
|
||||
|
||||
void addNode (DAGraphNode* n);
|
||||
void addNode (BBNode* n);
|
||||
|
||||
void addEdge (VarId vid1, VarId vid2);
|
||||
|
||||
const DAGraphNode* getNode (VarId vid) const;
|
||||
const BBNode* getNode (VarId vid) const;
|
||||
|
||||
DAGraphNode* getNode (VarId vid);
|
||||
BBNode* getNode (VarId vid);
|
||||
|
||||
bool empty (void) const { return nodes_.empty(); }
|
||||
|
||||
@ -79,10 +75,10 @@ class DAGraph
|
||||
void exportToGraphViz (const char*);
|
||||
|
||||
private:
|
||||
vector<DAGraphNode*> nodes_;
|
||||
vector<BBNode*> nodes_;
|
||||
|
||||
unordered_map<VarId, DAGraphNode*> varMap_;
|
||||
unordered_map<VarId, BBNode*> varMap_;
|
||||
};
|
||||
|
||||
#endif // HORUS_BAYESNET_H
|
||||
#endif // HORUS_BAYESBALLGRAPH_H
|
||||
|
@ -5,21 +5,21 @@
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "BpSolver.h"
|
||||
#include "BeliefProp.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Indexer.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
||||
BeliefProp::BeliefProp (const FactorGraph& fg) : Solver (fg)
|
||||
{
|
||||
runned_ = false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
BpSolver::~BpSolver (void)
|
||||
BeliefProp::~BeliefProp (void)
|
||||
{
|
||||
for (size_t i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
@ -35,7 +35,7 @@ BpSolver::~BpSolver (void)
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::solveQuery (VarIds queryVids)
|
||||
BeliefProp::solveQuery (VarIds queryVids)
|
||||
{
|
||||
assert (queryVids.empty() == false);
|
||||
return queryVids.size() == 1
|
||||
@ -46,7 +46,7 @@ BpSolver::solveQuery (VarIds queryVids)
|
||||
|
||||
|
||||
void
|
||||
BpSolver::printSolverFlags (void) const
|
||||
BeliefProp::printSolverFlags (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "belief propagation [" ;
|
||||
@ -68,7 +68,7 @@ BpSolver::printSolverFlags (void) const
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::getPosterioriOf (VarId vid)
|
||||
BeliefProp::getPosterioriOf (VarId vid)
|
||||
{
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
@ -101,7 +101,7 @@ BpSolver::getPosterioriOf (VarId vid)
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
{
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
@ -117,30 +117,43 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
}
|
||||
if (idx == facNodes.size()) {
|
||||
return getJointByConditioning (jointVarIds);
|
||||
} else {
|
||||
Factor res (facNodes[idx]->factor());
|
||||
const BpLinks& links = ninf(facNodes[idx])->getLinks();
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
Factor msg ({links[i]->varNode()->varId()},
|
||||
{links[i]->varNode()->range()},
|
||||
getVarToFactorMsg (links[i]));
|
||||
res.multiply (msg);
|
||||
}
|
||||
res.sumOutAllExcept (jointVarIds);
|
||||
res.reorderArguments (jointVarIds);
|
||||
res.normalize();
|
||||
Params jointDist = res.params();
|
||||
if (Globals::logDomain) {
|
||||
Util::exp (jointDist);
|
||||
}
|
||||
return jointDist;
|
||||
}
|
||||
return getFactorJoint (idx, jointVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BeliefProp::getFactorJoint (
|
||||
size_t fnIdx,
|
||||
const VarIds& jointVarIds)
|
||||
{
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
FacNode* fn = fg.facNodes()[fnIdx];
|
||||
Factor res (fn->factor());
|
||||
const BpLinks& links = ninf(fn)->getLinks();
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
Factor msg ({links[i]->varNode()->varId()},
|
||||
{links[i]->varNode()->range()},
|
||||
getVarToFactorMsg (links[i]));
|
||||
res.multiply (msg);
|
||||
}
|
||||
res.sumOutAllExcept (jointVarIds);
|
||||
res.reorderArguments (jointVarIds);
|
||||
res.normalize();
|
||||
Params jointDist = res.params();
|
||||
if (Globals::logDomain) {
|
||||
Util::exp (jointDist);
|
||||
}
|
||||
return jointDist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::runSolver (void)
|
||||
BeliefProp::runSolver (void)
|
||||
{
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
@ -173,7 +186,7 @@ BpSolver::runSolver (void)
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << "Belief propagation converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
@ -187,7 +200,7 @@ BpSolver::runSolver (void)
|
||||
|
||||
|
||||
void
|
||||
BpSolver::createLinks (void)
|
||||
BeliefProp::createLinks (void)
|
||||
{
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
@ -201,7 +214,7 @@ BpSolver::createLinks (void)
|
||||
|
||||
|
||||
void
|
||||
BpSolver::maxResidualSchedule (void)
|
||||
BeliefProp::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
@ -256,7 +269,7 @@ BpSolver::maxResidualSchedule (void)
|
||||
|
||||
|
||||
void
|
||||
BpSolver::calcFactorToVarMsg (BpLink* link)
|
||||
BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||
{
|
||||
FacNode* src = link->facNode();
|
||||
const VarNode* dst = link->varNode();
|
||||
@ -320,7 +333,7 @@ BpSolver::calcFactorToVarMsg (BpLink* link)
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::getVarToFactorMsg (const BpLink* link) const
|
||||
BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||
{
|
||||
const VarNode* src = link->varNode();
|
||||
Params msg;
|
||||
@ -361,61 +374,15 @@ BpSolver::getVarToFactorMsg (const BpLink* link) const
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
VarNodes jointVars;
|
||||
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
||||
assert (fg.getVarNode (jointVarIds[i]));
|
||||
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
FactorGraph* tempFg = new FactorGraph (fg);
|
||||
BpSolver solver (*tempFg);
|
||||
solver.runSolver();
|
||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
|
||||
for (size_t i = 1; i < jointVarIds.size(); i++) {
|
||||
assert (jointVars[i]->hasEvidence() == false);
|
||||
Params newBeliefs;
|
||||
Vars observedVars;
|
||||
Ranges observedRanges;
|
||||
for (size_t j = 0; j < observedVids.size(); j++) {
|
||||
observedVars.push_back (tempFg->getVarNode (observedVids[j]));
|
||||
observedRanges.push_back (observedVars.back()->range());
|
||||
}
|
||||
Indexer indexer (observedRanges, false);
|
||||
while (indexer.valid()) {
|
||||
for (size_t j = 0; j < observedVars.size(); j++) {
|
||||
observedVars[j]->setEvidence (indexer[j]);
|
||||
}
|
||||
BpSolver solver (*tempFg);
|
||||
solver.runSolver();
|
||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
|
||||
int count = -1;
|
||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->range() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
observedVids.push_back (jointVars[i]->varId());
|
||||
}
|
||||
return prevBeliefs;
|
||||
return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::initializeSolver (void)
|
||||
BeliefProp::initializeSolver (void)
|
||||
{
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
varsI_.reserve (varNodes.size());
|
||||
@ -439,7 +406,7 @@ BpSolver::initializeSolver (void)
|
||||
|
||||
|
||||
bool
|
||||
BpSolver::converged (void)
|
||||
BeliefProp::converged (void)
|
||||
{
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
@ -487,7 +454,7 @@ BpSolver::converged (void)
|
||||
|
||||
|
||||
void
|
||||
BpSolver::printLinkInformation (void) const
|
||||
BeliefProp::printLinkInformation (void) const
|
||||
{
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
BpLink* l = links_[i];
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_BPSOLVER_H
|
||||
#define HORUS_BPSOLVER_H
|
||||
#ifndef HORUS_BELIEFPROP_H
|
||||
#define HORUS_BELIEFPROP_H
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
@ -83,12 +83,12 @@ class SPNodeInfo
|
||||
};
|
||||
|
||||
|
||||
class BpSolver : public Solver
|
||||
class BeliefProp : public Solver
|
||||
{
|
||||
public:
|
||||
BpSolver (const FactorGraph&);
|
||||
BeliefProp (const FactorGraph&);
|
||||
|
||||
virtual ~BpSolver (void);
|
||||
virtual ~BeliefProp (void);
|
||||
|
||||
Params solveQuery (VarIds);
|
||||
|
||||
@ -111,6 +111,10 @@ class BpSolver : public Solver
|
||||
|
||||
virtual Params getJointByConditioning (const VarIds&) const;
|
||||
|
||||
public:
|
||||
Params getFactorJoint (size_t fnIdx, const VarIds&);
|
||||
|
||||
protected:
|
||||
SPNodeInfo* ninf (const VarNode* var) const
|
||||
{
|
||||
return varsI_[var->getIndex()];
|
||||
@ -180,5 +184,5 @@ class BpSolver : public Solver
|
||||
virtual void printLinkInformation (void) const;
|
||||
};
|
||||
|
||||
#endif // HORUS_BPSOLVER_H
|
||||
#endif // HORUS_BELIEFPROP_H
|
||||
|
@ -1,23 +1,23 @@
|
||||
#include "CbpSolver.h"
|
||||
#include "WeightedBpSolver.h"
|
||||
#include "CountingBp.h"
|
||||
#include "WeightedBp.h"
|
||||
|
||||
|
||||
bool CbpSolver::checkForIdenticalFactors = true;
|
||||
bool CountingBp::checkForIdenticalFactors = true;
|
||||
|
||||
|
||||
CbpSolver::CbpSolver (const FactorGraph& fg)
|
||||
CountingBp::CountingBp (const FactorGraph& fg)
|
||||
: Solver (fg), freeColor_(0)
|
||||
{
|
||||
findIdenticalFactors();
|
||||
setInitialColors();
|
||||
createGroups();
|
||||
compressedFg_ = getCompressedFactorGraph();
|
||||
solver_ = new WeightedBpSolver (*compressedFg_, getWeights());
|
||||
solver_ = new WeightedBp (*compressedFg_, getWeights());
|
||||
}
|
||||
|
||||
|
||||
|
||||
CbpSolver::~CbpSolver (void)
|
||||
CountingBp::~CountingBp (void)
|
||||
{
|
||||
delete solver_;
|
||||
delete compressedFg_;
|
||||
@ -32,7 +32,7 @@ CbpSolver::~CbpSolver (void)
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::printSolverFlags (void) const
|
||||
CountingBp::printSolverFlags (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "counting bp [" ;
|
||||
@ -48,7 +48,7 @@ CbpSolver::printSolverFlags (void) const
|
||||
ss << ",accuracy=" << BpOptions::accuracy;
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << ",chkif=" <<
|
||||
Util::toString (CbpSolver::checkForIdenticalFactors);
|
||||
Util::toString (CountingBp::checkForIdenticalFactors);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
}
|
||||
@ -56,7 +56,7 @@ CbpSolver::printSolverFlags (void) const
|
||||
|
||||
|
||||
Params
|
||||
CbpSolver::solveQuery (VarIds queryVids)
|
||||
CountingBp::solveQuery (VarIds queryVids)
|
||||
{
|
||||
assert (queryVids.empty() == false);
|
||||
Params res;
|
||||
@ -74,16 +74,15 @@ CbpSolver::solveQuery (VarIds queryVids)
|
||||
cout << endl;
|
||||
}
|
||||
if (idx == facNodes.size()) {
|
||||
cerr << "error: only joint distributions on variables of some " ;
|
||||
cerr << "clique are supported with the current solver" ;
|
||||
cerr << endl;
|
||||
exit (1);
|
||||
res = Solver::getJointByConditioning (
|
||||
GroundSolver::CBP, fg, queryVids);
|
||||
} else {
|
||||
VarIds reprArgs;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
reprArgs.push_back (getRepresentative (queryVids[i]));
|
||||
}
|
||||
res = solver_->getFactorJoint (idx, reprArgs);
|
||||
}
|
||||
VarIds representatives;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
representatives.push_back (getRepresentative (queryVids[i]));
|
||||
}
|
||||
res = solver_->getJointDistributionOf (representatives);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@ -91,7 +90,7 @@ CbpSolver::solveQuery (VarIds queryVids)
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::findIdenticalFactors()
|
||||
CountingBp::findIdenticalFactors()
|
||||
{
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
if (checkForIdenticalFactors == false ||
|
||||
@ -126,7 +125,7 @@ CbpSolver::findIdenticalFactors()
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::setInitialColors (void)
|
||||
CountingBp::setInitialColors (void)
|
||||
{
|
||||
varColors_.resize (fg.nrVarNodes());
|
||||
facColors_.resize (fg.nrFacNodes());
|
||||
@ -165,7 +164,7 @@ CbpSolver::setInitialColors (void)
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::createGroups (void)
|
||||
CountingBp::createGroups (void)
|
||||
{
|
||||
VarSignMap varGroups;
|
||||
FacSignMap facGroups;
|
||||
@ -227,7 +226,7 @@ CbpSolver::createGroups (void)
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::createClusters (
|
||||
CountingBp::createClusters (
|
||||
const VarSignMap& varGroups,
|
||||
const FacSignMap& facGroups)
|
||||
{
|
||||
@ -260,7 +259,7 @@ CbpSolver::createClusters (
|
||||
|
||||
|
||||
VarSignature
|
||||
CbpSolver::getSignature (const VarNode* varNode)
|
||||
CountingBp::getSignature (const VarNode* varNode)
|
||||
{
|
||||
const FacNodes& neighs = varNode->neighbors();
|
||||
VarSignature sign;
|
||||
@ -278,7 +277,7 @@ CbpSolver::getSignature (const VarNode* varNode)
|
||||
|
||||
|
||||
FacSignature
|
||||
CbpSolver::getSignature (const FacNode* facNode)
|
||||
CountingBp::getSignature (const FacNode* facNode)
|
||||
{
|
||||
const VarNodes& neighs = facNode->neighbors();
|
||||
FacSignature sign;
|
||||
@ -292,8 +291,31 @@ CbpSolver::getSignature (const FacNode* facNode)
|
||||
|
||||
|
||||
|
||||
VarId
|
||||
CountingBp::getRepresentative (VarId vid)
|
||||
{
|
||||
assert (Util::contains (vid2VarCluster_, vid));
|
||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||
return vc->representative()->varId();
|
||||
}
|
||||
|
||||
|
||||
|
||||
FacNode*
|
||||
CountingBp::getRepresentative (FacNode* fn)
|
||||
{
|
||||
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||
if (Util::contains (facClusters_[i]->members(), fn)) {
|
||||
return facClusters_[i]->representative();
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
CbpSolver::getCompressedFactorGraph (void)
|
||||
CountingBp::getCompressedFactorGraph (void)
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
for (size_t i = 0; i < varClusters_.size(); i++) {
|
||||
@ -322,7 +344,7 @@ CbpSolver::getCompressedFactorGraph (void)
|
||||
|
||||
|
||||
vector<vector<unsigned>>
|
||||
CbpSolver::getWeights (void) const
|
||||
CountingBp::getWeights (void) const
|
||||
{
|
||||
vector<vector<unsigned>> weights;
|
||||
weights.reserve (facClusters_.size());
|
||||
@ -341,7 +363,7 @@ CbpSolver::getWeights (void) const
|
||||
|
||||
|
||||
unsigned
|
||||
CbpSolver::getWeight (
|
||||
CountingBp::getWeight (
|
||||
const FacCluster* fc,
|
||||
const VarCluster* vc,
|
||||
size_t index) const
|
||||
@ -364,7 +386,7 @@ CbpSolver::getWeight (
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::printGroups (
|
||||
CountingBp::printGroups (
|
||||
const VarSignMap& varGroups,
|
||||
const FacSignMap& facGroups) const
|
||||
{
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_CBPSOLVER_H
|
||||
#define HORUS_CBPSOLVER_H
|
||||
#ifndef HORUS_COUNTINGBP_H
|
||||
#define HORUS_COUNTINGBP_H
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
@ -12,7 +12,7 @@ class VarCluster;
|
||||
class FacCluster;
|
||||
class VarSignHash;
|
||||
class FacSignHash;
|
||||
class WeightedBpSolver;
|
||||
class WeightedBp;
|
||||
|
||||
typedef long Color;
|
||||
typedef vector<Color> Colors;
|
||||
@ -100,12 +100,12 @@ class FacCluster
|
||||
};
|
||||
|
||||
|
||||
class CbpSolver : public Solver
|
||||
class CountingBp : public Solver
|
||||
{
|
||||
public:
|
||||
CbpSolver (const FactorGraph& fg);
|
||||
CountingBp (const FactorGraph& fg);
|
||||
|
||||
~CbpSolver (void);
|
||||
~CountingBp (void);
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
|
||||
@ -154,12 +154,9 @@ class CbpSolver : public Solver
|
||||
|
||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||
|
||||
VarId getRepresentative (VarId vid)
|
||||
{
|
||||
assert (Util::contains (vid2VarCluster_, vid));
|
||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||
return vc->representative()->varId();
|
||||
}
|
||||
VarId getRepresentative (VarId vid);
|
||||
|
||||
FacNode* getRepresentative (FacNode*);
|
||||
|
||||
FactorGraph* getCompressedFactorGraph (void);
|
||||
|
||||
@ -176,8 +173,8 @@ class CbpSolver : public Solver
|
||||
FacClusters facClusters_;
|
||||
VarId2VarCluster vid2VarCluster_;
|
||||
const FactorGraph* compressedFg_;
|
||||
WeightedBpSolver* solver_;
|
||||
WeightedBp* solver_;
|
||||
};
|
||||
|
||||
#endif // HORUS_CBPSOLVER_H
|
||||
#endif // HORUS_COUNTINGBP_H
|
||||
|
@ -8,7 +8,6 @@
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "BayesNet.h"
|
||||
#include "BayesBall.h"
|
||||
#include "Util.h"
|
||||
|
||||
@ -236,13 +235,13 @@ FactorGraph::isTree (void) const
|
||||
|
||||
|
||||
|
||||
DAGraph&
|
||||
BayesBallGraph&
|
||||
FactorGraph::getStructure (void)
|
||||
{
|
||||
assert (bayesFactors_);
|
||||
if (structure_.empty()) {
|
||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
||||
structure_.addNode (new BBNode (varNodes_[i]));
|
||||
}
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
const VarIds& vids = facNodes_[i]->factor().arguments();
|
||||
|
@ -4,7 +4,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "Factor.h"
|
||||
#include "BayesNet.h"
|
||||
#include "BayesBallGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
@ -103,7 +103,7 @@ class FactorGraph
|
||||
|
||||
bool isTree (void) const;
|
||||
|
||||
DAGraph& getStructure (void);
|
||||
BayesBallGraph& getStructure (void);
|
||||
|
||||
void print (void) const;
|
||||
|
||||
@ -129,7 +129,7 @@ class FactorGraph
|
||||
VarNodes varNodes_;
|
||||
FacNodes facNodes_;
|
||||
|
||||
DAGraph structure_;
|
||||
BayesBallGraph structure_;
|
||||
bool bayesFactors_;
|
||||
|
||||
typedef unordered_map<unsigned, VarNode*> VarMap;
|
||||
|
@ -28,14 +28,14 @@ typedef vector<unsigned> Ranges;
|
||||
typedef unsigned long long ullong;
|
||||
|
||||
|
||||
enum LiftedSolvers
|
||||
enum LiftedSolver
|
||||
{
|
||||
FOVE, // first order variable elimination
|
||||
LBP, // lifted belief propagation
|
||||
};
|
||||
|
||||
|
||||
enum GroundSolvers
|
||||
enum GroundSolver
|
||||
{
|
||||
VE, // variable elimination
|
||||
BP, // belief propagation
|
||||
@ -50,8 +50,8 @@ extern bool logDomain;
|
||||
// level of debug information
|
||||
extern unsigned verbosity;
|
||||
|
||||
extern LiftedSolvers liftedSolver;
|
||||
extern GroundSolvers groundSolver;
|
||||
extern LiftedSolver liftedSolver;
|
||||
extern GroundSolver groundSolver;
|
||||
|
||||
};
|
||||
|
||||
|
@ -4,9 +4,9 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "VarElimSolver.h"
|
||||
#include "BpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
#include "VarElim.h"
|
||||
#include "BeliefProp.h"
|
||||
#include "CountingBp.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
@ -162,14 +162,14 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||
{
|
||||
Solver* solver = 0;
|
||||
switch (Globals::groundSolver) {
|
||||
case GroundSolvers::VE:
|
||||
solver = new VarElimSolver (fg);
|
||||
case GroundSolver::VE:
|
||||
solver = new VarElim (fg);
|
||||
break;
|
||||
case GroundSolvers::BP:
|
||||
solver = new BpSolver (fg);
|
||||
case GroundSolver::BP:
|
||||
solver = new BeliefProp (fg);
|
||||
break;
|
||||
case GroundSolvers::CBP:
|
||||
solver = new CbpSolver (fg);
|
||||
case GroundSolver::CBP:
|
||||
solver = new CountingBp (fg);
|
||||
break;
|
||||
default:
|
||||
assert (false);
|
||||
|
@ -9,21 +9,19 @@
|
||||
|
||||
#include "ParfactorList.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "FoveSolver.h"
|
||||
#include "VarElimSolver.h"
|
||||
#include "LiftedBpSolver.h"
|
||||
#include "BpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
#include "LiftedVe.h"
|
||||
#include "VarElim.h"
|
||||
#include "LiftedBp.h"
|
||||
#include "CountingBp.h"
|
||||
#include "BeliefProp.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "BayesBall.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||
|
||||
|
||||
Params readParameters (YAP_Term);
|
||||
|
||||
vector<unsigned> readUnsignedList (YAP_Term);
|
||||
@ -32,14 +30,6 @@ void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
||||
|
||||
Parfactor* readParfactor (YAP_Term);
|
||||
|
||||
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
|
||||
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
|
||||
|
||||
|
||||
|
||||
vector<unsigned>
|
||||
readUnsignedList (YAP_Term list)
|
||||
@ -54,7 +44,8 @@ readUnsignedList (YAP_Term list)
|
||||
|
||||
|
||||
|
||||
int createLiftedNetwork (void)
|
||||
int
|
||||
createLiftedNetwork (void)
|
||||
{
|
||||
Parfactors parfactors;
|
||||
YAP_Term parfactorList = YAP_ARG1;
|
||||
@ -91,7 +82,8 @@ int createLiftedNetwork (void)
|
||||
|
||||
|
||||
|
||||
Parfactor* readParfactor (YAP_Term pfTerm)
|
||||
Parfactor*
|
||||
readParfactor (YAP_Term pfTerm)
|
||||
{
|
||||
// read dist id
|
||||
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
||||
@ -171,7 +163,8 @@ Parfactor* readParfactor (YAP_Term pfTerm)
|
||||
|
||||
|
||||
|
||||
void readLiftedEvidence (
|
||||
void
|
||||
readLiftedEvidence (
|
||||
YAP_Term observedList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
@ -237,7 +230,6 @@ createGroundNetwork (void)
|
||||
fg->addFactor (Factor (varIds, ranges, params, distId));
|
||||
factorList = YAP_TailOfTerm (factorList);
|
||||
}
|
||||
|
||||
unsigned nrObservedVars = 0;
|
||||
YAP_Term evidenceList = YAP_ARG3;
|
||||
while (evidenceList != YAP_TermNil()) {
|
||||
@ -285,7 +277,7 @@ runLiftedSolver (void)
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
vector<Params> results;
|
||||
ParfactorList pfListCopy (*network->first);
|
||||
FoveSolver::absorveEvidence (pfListCopy, *network->second);
|
||||
LiftedVe::absorveEvidence (pfListCopy, *network->second);
|
||||
while (taskList != YAP_TermNil()) {
|
||||
Grounds queryVars;
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
@ -311,15 +303,15 @@ runLiftedSolver (void)
|
||||
}
|
||||
jointList = YAP_TailOfTerm (jointList);
|
||||
}
|
||||
if (Globals::liftedSolver == LiftedSolvers::FOVE) {
|
||||
FoveSolver solver (pfListCopy);
|
||||
if (Globals::liftedSolver == LiftedSolver::FOVE) {
|
||||
LiftedVe solver (pfListCopy);
|
||||
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||
solver.printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
results.push_back (solver.solveQuery (queryVars));
|
||||
} else if (Globals::liftedSolver == LiftedSolvers::LBP) {
|
||||
LiftedBpSolver solver (pfListCopy);
|
||||
} else if (Globals::liftedSolver == LiftedSolver::LBP) {
|
||||
LiftedBp solver (pfListCopy);
|
||||
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||
solver.printSolverFlags();
|
||||
cout << endl;
|
||||
@ -361,11 +353,42 @@ runGroundSolver (void)
|
||||
taskList = YAP_TailOfTerm (taskList);
|
||||
}
|
||||
|
||||
vector<Params> results;
|
||||
if (Globals::groundSolver == GroundSolvers::VE) {
|
||||
runVeSolver (fg, tasks, results);
|
||||
std::set<VarId> vids;
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
Util::addToSet (vids, tasks[i]);
|
||||
}
|
||||
Solver* solver = 0;
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->bayesianFactors()) {
|
||||
mfg = BayesBall::getMinimalFactorGraph (
|
||||
*fg, VarIds (vids.begin(), vids.end()));
|
||||
}
|
||||
|
||||
if (Globals::groundSolver == GroundSolver::VE) {
|
||||
solver = new VarElim (*mfg);
|
||||
} else if (Globals::groundSolver == GroundSolver::BP) {
|
||||
solver = new BeliefProp (*mfg);
|
||||
} else if (Globals::groundSolver == GroundSolver::CBP) {
|
||||
CountingBp::checkForIdenticalFactors = false;
|
||||
solver = new CountingBp (*mfg);
|
||||
} else {
|
||||
runBpSolver (fg, tasks, results);
|
||||
assert (false);
|
||||
}
|
||||
|
||||
if (Globals::verbosity > 0) {
|
||||
solver->printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
vector<Params> results;
|
||||
results.reserve (tasks.size());
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
results.push_back (solver->solveQuery (tasks[i]));
|
||||
}
|
||||
|
||||
delete solver;
|
||||
if (fg->bayesianFactors()) {
|
||||
delete mfg;
|
||||
}
|
||||
|
||||
YAP_Term list = YAP_TermNil();
|
||||
@ -386,72 +409,6 @@ runGroundSolver (void)
|
||||
|
||||
|
||||
|
||||
void runVeSolver (
|
||||
FactorGraph* fg,
|
||||
const vector<VarIds>& tasks,
|
||||
vector<Params>& results)
|
||||
{
|
||||
results.reserve (tasks.size());
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->bayesianFactors()) {
|
||||
// mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
||||
}
|
||||
// VarElimSolver solver (*mfg);
|
||||
VarElimSolver solver (*fg); //FIXME
|
||||
if (Globals::verbosity > 0 && i == 0) {
|
||||
solver.printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
results.push_back (solver.solveQuery (tasks[i]));
|
||||
if (fg->bayesianFactors()) {
|
||||
// delete mfg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void runBpSolver (
|
||||
FactorGraph* fg,
|
||||
const vector<VarIds>& tasks,
|
||||
vector<Params>& results)
|
||||
{
|
||||
std::set<VarId> vids;
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
Util::addToSet (vids, tasks[i]);
|
||||
}
|
||||
Solver* solver = 0;
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->bayesianFactors()) {
|
||||
//mfg = BayesBall::getMinimalFactorGraph (
|
||||
// *fg, VarIds (vids.begin(),vids.end()));
|
||||
}
|
||||
if (Globals::groundSolver == GroundSolvers::BP) {
|
||||
solver = new BpSolver (*fg); // FIXME
|
||||
} else if (Globals::groundSolver == GroundSolvers::CBP) {
|
||||
CbpSolver::checkForIdenticalFactors = false;
|
||||
solver = new CbpSolver (*fg); // FIXME
|
||||
} else {
|
||||
cerr << "error: unknow solver" << endl;
|
||||
abort();
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
solver->printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
results.reserve (tasks.size());
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
results.push_back (solver->solveQuery (tasks[i]));
|
||||
}
|
||||
if (fg->bayesianFactors()) {
|
||||
//delete mfg;
|
||||
}
|
||||
delete solver;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setParfactorsParams (void)
|
||||
{
|
||||
@ -567,7 +524,7 @@ freeGroundNetwork (void)
|
||||
|
||||
|
||||
int
|
||||
freeParfactors (void)
|
||||
freeLiftedNetwork (void)
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
delete network->first;
|
||||
@ -589,7 +546,7 @@ init_predicates (void)
|
||||
YAP_UserCPredicate ("cpp_cpp_set_factors_params", setFactorsParams, 2);
|
||||
YAP_UserCPredicate ("cpp_set_vars_information", setVarsInformation, 2);
|
||||
YAP_UserCPredicate ("cpp_set_horus_flag", setHorusFlag, 2);
|
||||
YAP_UserCPredicate ("cpp_free_parfactors", freeParfactors, 1);
|
||||
YAP_UserCPredicate ("cpp_free_lifted_network", freeLiftedNetwork, 1);
|
||||
YAP_UserCPredicate ("cpp_free_ground_network", freeGroundNetwork, 1);
|
||||
}
|
||||
|
||||
|
232
packages/CLPBN/horus/LiftedBp.cpp
Normal file
232
packages/CLPBN/horus/LiftedBp.cpp
Normal file
@ -0,0 +1,232 @@
|
||||
#include "LiftedBp.h"
|
||||
#include "WeightedBp.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "LiftedVe.h"
|
||||
|
||||
|
||||
LiftedBp::LiftedBp (const ParfactorList& pfList)
|
||||
: pfList_(pfList)
|
||||
{
|
||||
refineParfactors();
|
||||
solver_ = new WeightedBp (*getFactorGraph(), getWeights());
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedBp::~LiftedBp (void)
|
||||
{
|
||||
delete solver_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
LiftedBp::solveQuery (const Grounds& query)
|
||||
{
|
||||
assert (query.empty() == false);
|
||||
Params res;
|
||||
vector<PrvGroup> groups = getQueryGroups (query);
|
||||
if (query.size() == 1) {
|
||||
res = solver_->getPosterioriOf (groups[0]);
|
||||
} else {
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
size_t idx = pfList_.size();
|
||||
size_t count = 0;
|
||||
while (it != pfList_.end()) {
|
||||
if ((*it)->containsGrounds (query)) {
|
||||
idx = count;
|
||||
break;
|
||||
}
|
||||
++ it;
|
||||
++ count;
|
||||
}
|
||||
if (idx == pfList_.size()) {
|
||||
res = getJointByConditioning (pfList_, query);
|
||||
} else {
|
||||
VarIds queryVids;
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
queryVids.push_back (groups[i]);
|
||||
}
|
||||
res = solver_->getFactorJoint (idx, queryVids);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedBp::printSolverFlags (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "lifted bp [" ;
|
||||
ss << "schedule=" ;
|
||||
typedef BpOptions::Schedule Sch;
|
||||
switch (BpOptions::schedule) {
|
||||
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case Sch::PARALLEL: ss << "parallel"; break;
|
||||
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",max_iter=" << BpOptions::maxIter;
|
||||
ss << ",accuracy=" << BpOptions::accuracy;
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedBp::refineParfactors (void)
|
||||
{
|
||||
while (iterate() == false);
|
||||
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printHeader ("AFTER REFINEMENT");
|
||||
pfList_.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
LiftedBp::iterate (void)
|
||||
{
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
const ProbFormulas& args = (*it)->arguments();
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
||||
if ((*it)->constr()->isCountNormalized (lvs) == false) {
|
||||
Parfactors pfs = LiftedVe::countNormalize (*it, lvs);
|
||||
it = pfList_.removeAndDelete (it);
|
||||
pfList_.add (pfs);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<PrvGroup>
|
||||
LiftedBp::getQueryGroups (const Grounds& query)
|
||||
{
|
||||
vector<PrvGroup> queryGroups;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
for (; it != pfList_.end(); ++it) {
|
||||
if ((*it)->containsGround (query[i])) {
|
||||
queryGroups.push_back ((*it)->findGroup (query[i]));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
assert (queryGroups.size() == query.size());
|
||||
return queryGroups;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
LiftedBp::getFactorGraph (void)
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
for (; it != pfList_.end(); ++it) {
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
VarIds varIds;
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
varIds.push_back (groups[i]);
|
||||
}
|
||||
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
|
||||
}
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<vector<unsigned>>
|
||||
LiftedBp::getWeights (void) const
|
||||
{
|
||||
vector<vector<unsigned>> weights;
|
||||
weights.reserve (pfList_.size());
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
for (; it != pfList_.end(); ++it) {
|
||||
const ProbFormulas& args = (*it)->arguments();
|
||||
weights.push_back ({ });
|
||||
weights.back().reserve (args.size());
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
||||
weights.back().push_back ((*it)->constr()->getConditionalCount (lvs));
|
||||
}
|
||||
}
|
||||
return weights;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
LiftedBp::rangeOfGround (const Ground& gr)
|
||||
{
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if ((*it)->containsGround (gr)) {
|
||||
PrvGroup prvGroup = (*it)->findGroup (gr);
|
||||
return (*it)->range ((*it)->indexOfGroup (prvGroup));
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
return std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
LiftedBp::getJointByConditioning (
|
||||
const ParfactorList& pfList,
|
||||
const Grounds& grounds)
|
||||
{
|
||||
LiftedBp solver (pfList);
|
||||
Params prevBeliefs = solver.solveQuery ({grounds[0]});
|
||||
Grounds obsGrounds = {grounds[0]};
|
||||
for (size_t i = 1; i < grounds.size(); i++) {
|
||||
Params newBeliefs;
|
||||
vector<ObservedFormula> obsFs;
|
||||
Ranges obsRanges;
|
||||
for (size_t j = 0; j < obsGrounds.size(); j++) {
|
||||
obsFs.push_back (ObservedFormula (
|
||||
obsGrounds[j].functor(), 0, obsGrounds[j].args()));
|
||||
obsRanges.push_back (rangeOfGround (obsGrounds[j]));
|
||||
}
|
||||
Indexer indexer (obsRanges, false);
|
||||
while (indexer.valid()) {
|
||||
for (size_t j = 0; j < obsFs.size(); j++) {
|
||||
obsFs[j].setEvidence (indexer[j]);
|
||||
}
|
||||
ParfactorList tempPfList (pfList);
|
||||
LiftedVe::absorveEvidence (tempPfList, obsFs);
|
||||
LiftedBp solver (tempPfList);
|
||||
Params beliefs = solver.solveQuery ({grounds[i]});
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
int count = -1;
|
||||
unsigned range = rangeOfGround (grounds[i]);
|
||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % range == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
obsGrounds.push_back (grounds[i]);
|
||||
}
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
@ -1,15 +1,17 @@
|
||||
#ifndef HORUS_LIFTEDBPSOLVER_H
|
||||
#define HORUS_LIFTEDBPSOLVER_H
|
||||
#ifndef HORUS_LIFTEDBP_H
|
||||
#define HORUS_LIFTEDBP_H
|
||||
|
||||
#include "ParfactorList.h"
|
||||
|
||||
class FactorGraph;
|
||||
class WeightedBpSolver;
|
||||
class WeightedBp;
|
||||
|
||||
class LiftedBpSolver
|
||||
class LiftedBp
|
||||
{
|
||||
public:
|
||||
LiftedBpSolver (const ParfactorList& pfList);
|
||||
LiftedBp (const ParfactorList& pfList);
|
||||
|
||||
~LiftedBp (void);
|
||||
|
||||
Params solveQuery (const Grounds&);
|
||||
|
||||
@ -25,10 +27,14 @@ class LiftedBpSolver
|
||||
FactorGraph* getFactorGraph (void);
|
||||
|
||||
vector<vector<unsigned>> getWeights (void) const;
|
||||
|
||||
unsigned rangeOfGround (const Ground&);
|
||||
|
||||
ParfactorList pfList_;
|
||||
WeightedBpSolver* solver_;
|
||||
Params getJointByConditioning (const ParfactorList&, const Grounds&);
|
||||
|
||||
ParfactorList pfList_;
|
||||
WeightedBp* solver_;
|
||||
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDBPSOLVER_H
|
||||
#endif // HORUS_LIFTEDBP_H
|
@ -1,148 +0,0 @@
|
||||
#include "LiftedBpSolver.h"
|
||||
#include "WeightedBpSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "FoveSolver.h"
|
||||
|
||||
|
||||
LiftedBpSolver::LiftedBpSolver (const ParfactorList& pfList)
|
||||
: pfList_(pfList)
|
||||
{
|
||||
refineParfactors();
|
||||
solver_ = new WeightedBpSolver (*getFactorGraph(), getWeights());
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
LiftedBpSolver::solveQuery (const Grounds& query)
|
||||
{
|
||||
assert (query.empty() == false);
|
||||
Params res;
|
||||
vector<PrvGroup> groups = getQueryGroups (query);
|
||||
if (query.size() == 1) {
|
||||
res = solver_->getPosterioriOf (groups[0]);
|
||||
} else {
|
||||
VarIds queryVids;
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
queryVids.push_back (groups[i]);
|
||||
}
|
||||
res = solver_->getJointDistributionOf (queryVids);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedBpSolver::printSolverFlags (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "lifted bp [" ;
|
||||
ss << "schedule=" ;
|
||||
typedef BpOptions::Schedule Sch;
|
||||
switch (BpOptions::schedule) {
|
||||
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case Sch::PARALLEL: ss << "parallel"; break;
|
||||
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",max_iter=" << BpOptions::maxIter;
|
||||
ss << ",accuracy=" << BpOptions::accuracy;
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedBpSolver::refineParfactors (void)
|
||||
{
|
||||
while (iterate() == false);
|
||||
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printHeader ("AFTER REFINEMENT");
|
||||
pfList_.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
LiftedBpSolver::iterate (void)
|
||||
{
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
const ProbFormulas& args = (*it)->arguments();
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
||||
if ((*it)->constr()->isCountNormalized (lvs) == false) {
|
||||
Parfactors pfs = FoveSolver::countNormalize (*it, lvs);
|
||||
it = pfList_.removeAndDelete (it);
|
||||
pfList_.add (pfs);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<PrvGroup>
|
||||
LiftedBpSolver::getQueryGroups (const Grounds& query)
|
||||
{
|
||||
vector<PrvGroup> queryGroups;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
for (; it != pfList_.end(); ++it) {
|
||||
if ((*it)->containsGround (query[i])) {
|
||||
queryGroups.push_back ((*it)->findGroup (query[i]));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
assert (queryGroups.size() == query.size());
|
||||
return queryGroups;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
LiftedBpSolver::getFactorGraph (void)
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
for (; it != pfList_.end(); ++it) {
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
VarIds varIds;
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
varIds.push_back (groups[i]);
|
||||
}
|
||||
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
|
||||
}
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<vector<unsigned>>
|
||||
LiftedBpSolver::getWeights (void) const
|
||||
{
|
||||
vector<vector<unsigned>> weights;
|
||||
weights.reserve (pfList_.size());
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
for (; it != pfList_.end(); ++it) {
|
||||
const ProbFormulas& args = (*it)->arguments();
|
||||
weights.push_back ({ });
|
||||
weights.back().reserve (args.size());
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
||||
weights.back().push_back ((*it)->constr()->getConditionalCount (lvs));
|
||||
}
|
||||
}
|
||||
return weights;
|
||||
}
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
|
||||
#include "FoveSolver.h"
|
||||
#include "LiftedVe.h"
|
||||
#include "Histogram.h"
|
||||
#include "Util.h"
|
||||
|
||||
@ -12,7 +11,7 @@ LiftedOperator::getValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<LiftedOperator*> validOps;
|
||||
vector<LiftedOperator*> validOps;
|
||||
vector<ProductOperator*> multOps;
|
||||
|
||||
multOps = ProductOperator::getValidOps (pfList);
|
||||
@ -222,7 +221,7 @@ SumOutOperator::apply (void)
|
||||
product->sumOutIndex (fIdx);
|
||||
pfList_.addShattered (product);
|
||||
} else {
|
||||
Parfactors pfs = FoveSolver::countNormalize (product, excl);
|
||||
Parfactors pfs = LiftedVe::countNormalize (product, excl);
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
pfs[i]->sumOutIndex (fIdx);
|
||||
pfList_.add (pfs[i]);
|
||||
@ -376,7 +375,7 @@ CountingOperator::apply (void)
|
||||
} else {
|
||||
Parfactor* pf = *pfIter_;
|
||||
pfList_.remove (pfIter_);
|
||||
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
|
||||
Parfactors pfs = LiftedVe::countNormalize (pf, X_);
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
||||
bool cartProduct = pfs[i]->constr()->isCartesianProduct (
|
||||
@ -420,7 +419,7 @@ CountingOperator::toString (void)
|
||||
ss << "count convert " << X_ << " in " ;
|
||||
ss << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
||||
Parfactors pfs = LiftedVe::countNormalize (*pfIter_, X_);
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
ss << " º " << pfs[i]->getLabel() << endl;
|
||||
@ -501,7 +500,7 @@ GroundOperator::getLogCost (void)
|
||||
++ pflIt;
|
||||
}
|
||||
// cout << endl;
|
||||
return totalCost;
|
||||
return totalCost + 3;
|
||||
}
|
||||
|
||||
|
||||
@ -610,7 +609,7 @@ GroundOperator::getAffectedFormulas (void)
|
||||
LogVar X = f.logVars()[front.second];
|
||||
const ProbFormulas& fs = (*pflIt)->arguments();
|
||||
for (size_t i = 0; i < fs.size(); i++) {
|
||||
if ((int)i != idx && fs[i].contains (X)) {
|
||||
if (i != idx && fs[i].contains (X)) {
|
||||
pair<PrvGroup, unsigned> pair = make_pair (
|
||||
fs[i].group(), fs[i].indexOf (X));
|
||||
if (Util::contains (affectedFormulas, pair) == false) {
|
||||
@ -630,7 +629,7 @@ GroundOperator::getAffectedFormulas (void)
|
||||
|
||||
|
||||
Params
|
||||
FoveSolver::solveQuery (const Grounds& query)
|
||||
LiftedVe::solveQuery (const Grounds& query)
|
||||
{
|
||||
assert (query.empty() == false);
|
||||
runSolver (query);
|
||||
@ -645,7 +644,7 @@ FoveSolver::solveQuery (const Grounds& query)
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::printSolverFlags (void) const
|
||||
LiftedVe::printSolverFlags (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "fove [" ;
|
||||
@ -657,7 +656,7 @@ FoveSolver::printSolverFlags (void) const
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::absorveEvidence (
|
||||
LiftedVe::absorveEvidence (
|
||||
ParfactorList& pfList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
@ -696,7 +695,7 @@ FoveSolver::absorveEvidence (
|
||||
|
||||
|
||||
Parfactors
|
||||
FoveSolver::countNormalize (
|
||||
LiftedVe::countNormalize (
|
||||
Parfactor* g,
|
||||
const LogVarSet& set)
|
||||
{
|
||||
@ -715,7 +714,7 @@ FoveSolver::countNormalize (
|
||||
|
||||
|
||||
Parfactor
|
||||
FoveSolver::calcGroundMultiplication (Parfactor pf)
|
||||
LiftedVe::calcGroundMultiplication (Parfactor pf)
|
||||
{
|
||||
LogVarSet lvs = pf.constr()->logVarSet();
|
||||
lvs -= pf.constr()->singletons();
|
||||
@ -748,7 +747,7 @@ FoveSolver::calcGroundMultiplication (Parfactor pf)
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::runSolver (const Grounds& query)
|
||||
LiftedVe::runSolver (const Grounds& query)
|
||||
{
|
||||
largestCost_ = std::log (0);
|
||||
shatterAgainstQuery (query);
|
||||
@ -794,7 +793,7 @@ FoveSolver::runSolver (const Grounds& query)
|
||||
|
||||
|
||||
LiftedOperator*
|
||||
FoveSolver::getBestOperation (const Grounds& query)
|
||||
LiftedVe::getBestOperation (const Grounds& query)
|
||||
{
|
||||
double bestCost = 0.0;
|
||||
LiftedOperator* bestOp = 0;
|
||||
@ -821,7 +820,7 @@ FoveSolver::getBestOperation (const Grounds& query)
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::runWeakBayesBall (const Grounds& query)
|
||||
LiftedVe::runWeakBayesBall (const Grounds& query)
|
||||
{
|
||||
queue<PrvGroup> todo; // groups to process
|
||||
set<PrvGroup> done; // processed or in queue
|
||||
@ -880,7 +879,7 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::shatterAgainstQuery (const Grounds& query)
|
||||
LiftedVe::shatterAgainstQuery (const Grounds& query)
|
||||
{
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
if (query[i].isAtom()) {
|
||||
@ -931,7 +930,7 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
|
||||
|
||||
|
||||
Parfactors
|
||||
FoveSolver::absorve (
|
||||
LiftedVe::absorve (
|
||||
ObservedFormula& obsFormula,
|
||||
Parfactor* g)
|
||||
{
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_FOVESOLVER_H
|
||||
#define HORUS_FOVESOLVER_H
|
||||
#ifndef HORUS_LIFTEDVE_H
|
||||
#define HORUS_LIFTEDVE_H
|
||||
|
||||
|
||||
#include "ParfactorList.h"
|
||||
@ -130,10 +130,10 @@ class GroundOperator : public LiftedOperator
|
||||
|
||||
|
||||
|
||||
class FoveSolver
|
||||
class LiftedVe
|
||||
{
|
||||
public:
|
||||
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
|
||||
LiftedVe (const ParfactorList& pfList) : pfList_(pfList) { }
|
||||
|
||||
Params solveQuery (const Grounds&);
|
||||
|
||||
@ -162,5 +162,5 @@ class FoveSolver
|
||||
double largestCost_;
|
||||
};
|
||||
|
||||
#endif // HORUS_FOVESOLVER_H
|
||||
#endif // HORUS_LIFTEDVE_H
|
||||
|
@ -23,10 +23,10 @@ CC=@CC@
|
||||
CXX=@CXX@
|
||||
|
||||
# normal
|
||||
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG
|
||||
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG
|
||||
|
||||
# debug
|
||||
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra
|
||||
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra
|
||||
|
||||
|
||||
#
|
||||
@ -45,98 +45,91 @@ CWD=$(PWD)
|
||||
|
||||
|
||||
HEADERS = \
|
||||
$(srcdir)/BayesNet.h \
|
||||
$(srcdir)/BayesBall.h \
|
||||
$(srcdir)/ElimGraph.h \
|
||||
$(srcdir)/FactorGraph.h \
|
||||
$(srcdir)/Factor.h \
|
||||
$(srcdir)/BayesBallGraph.h \
|
||||
$(srcdir)/BeliefProp.h \
|
||||
$(srcdir)/ConstraintTree.h \
|
||||
$(srcdir)/Solver.h \
|
||||
$(srcdir)/VarElimSolver.h \
|
||||
$(srcdir)/BpSolver.h \
|
||||
$(srcdir)/CbpSolver.h \
|
||||
$(srcdir)/FoveSolver.h \
|
||||
$(srcdir)/Var.h \
|
||||
$(srcdir)/Indexer.h \
|
||||
$(srcdir)/Parfactor.h \
|
||||
$(srcdir)/ProbFormula.h \
|
||||
$(srcdir)/CountingBp.h \
|
||||
$(srcdir)/ElimGraph.h \
|
||||
$(srcdir)/Factor.h \
|
||||
$(srcdir)/FactorGraph.h \
|
||||
$(srcdir)/Histogram.h \
|
||||
$(srcdir)/ParfactorList.h \
|
||||
$(srcdir)/Horus.h \
|
||||
$(srcdir)/Indexer.h \
|
||||
$(srcdir)/LiftedBp.h \
|
||||
$(srcdir)/LiftedUtils.h \
|
||||
$(srcdir)/LiftedVe.h \
|
||||
$(srcdir)/Parfactor.h \
|
||||
$(srcdir)/ParfactorList.h \
|
||||
$(srcdir)/ProbFormula.h \
|
||||
$(srcdir)/Solver.h \
|
||||
$(srcdir)/TinySet.h \
|
||||
$(srcdir)/LiftedBpSolver.h \
|
||||
$(srcdir)/WeightedBpSolver.h \
|
||||
$(srcdir)/Util.h \
|
||||
$(srcdir)/Horus.h
|
||||
$(srcdir)/Var.h \
|
||||
$(srcdir)/VarElim.h \
|
||||
$(srcdir)/WeightedBp.h
|
||||
|
||||
CPP_SOURCES = \
|
||||
$(srcdir)/BayesNet.cpp \
|
||||
$(srcdir)/BayesBall.cpp \
|
||||
$(srcdir)/ElimGraph.cpp \
|
||||
$(srcdir)/FactorGraph.cpp \
|
||||
$(srcdir)/Factor.cpp \
|
||||
$(srcdir)/BayesBallGraph.cpp \
|
||||
$(srcdir)/BeliefProp.cpp \
|
||||
$(srcdir)/ConstraintTree.cpp \
|
||||
$(srcdir)/Var.cpp \
|
||||
$(srcdir)/Solver.cpp \
|
||||
$(srcdir)/VarElimSolver.cpp \
|
||||
$(srcdir)/BpSolver.cpp \
|
||||
$(srcdir)/CbpSolver.cpp \
|
||||
$(srcdir)/FoveSolver.cpp \
|
||||
$(srcdir)/Parfactor.cpp \
|
||||
$(srcdir)/ProbFormula.cpp \
|
||||
$(srcdir)/CountingBp.cpp \
|
||||
$(srcdir)/ElimGraph.cpp \
|
||||
$(srcdir)/Factor.cpp \
|
||||
$(srcdir)/FactorGraph.cpp \
|
||||
$(srcdir)/Histogram.cpp \
|
||||
$(srcdir)/ParfactorList.cpp \
|
||||
$(srcdir)/LiftedUtils.cpp \
|
||||
$(srcdir)/Util.cpp \
|
||||
$(srcdir)/LiftedBpSolver.cpp \
|
||||
$(srcdir)/WeightedBpSolver.cpp \
|
||||
$(srcdir)/HorusCli.cpp \
|
||||
$(srcdir)/HorusYap.cpp \
|
||||
$(srcdir)/HorusCli.cpp
|
||||
$(srcdir)/LiftedBp.cpp \
|
||||
$(srcdir)/LiftedUtils.cpp \
|
||||
$(srcdir)/LiftedVe.cpp \
|
||||
$(srcdir)/Parfactor.cpp \
|
||||
$(srcdir)/ParfactorList.cpp \
|
||||
$(srcdir)/ProbFormula.cpp \
|
||||
$(srcdir)/Solver.cpp \
|
||||
$(srcdir)/Util.cpp \
|
||||
$(srcdir)/Var.cpp \
|
||||
$(srcdir)/VarElim.cpp \
|
||||
$(srcdir)/WeightedBp.cpp \
|
||||
|
||||
OBJS = \
|
||||
BayesNet.o \
|
||||
BayesBall.o \
|
||||
ElimGraph.o \
|
||||
FactorGraph.o \
|
||||
Factor.o \
|
||||
BayesBallGraph.o \
|
||||
BeliefProp.o \
|
||||
ConstraintTree.o \
|
||||
Var.o \
|
||||
Solver.o \
|
||||
VarElimSolver.o \
|
||||
BpSolver.o \
|
||||
CbpSolver.o \
|
||||
FoveSolver.o \
|
||||
Parfactor.o \
|
||||
ProbFormula.o \
|
||||
CountingBp.o \
|
||||
ElimGraph.o \
|
||||
Factor.o \
|
||||
FactorGraph.o \
|
||||
Histogram.o \
|
||||
ParfactorList.o \
|
||||
HorusYap.o \
|
||||
LiftedBp.o \
|
||||
LiftedUtils.o \
|
||||
LiftedVe.o \
|
||||
ProbFormula.o \
|
||||
Parfactor.o \
|
||||
ParfactorList.o \
|
||||
Solver.o \
|
||||
Util.o \
|
||||
LiftedBpSolver.o \
|
||||
WeightedBpSolver.o \
|
||||
HorusYap.o
|
||||
Var.o \
|
||||
VarElim.o \
|
||||
WeightedBp.o
|
||||
|
||||
HCLI_OBJS = \
|
||||
BayesNet.o \
|
||||
BayesBall.o \
|
||||
BayesBallGraph.o \
|
||||
BeliefProp.o \
|
||||
CountingBp.o \
|
||||
ElimGraph.o \
|
||||
FactorGraph.o \
|
||||
Factor.o \
|
||||
ConstraintTree.o \
|
||||
Var.o \
|
||||
FactorGraph.o \
|
||||
HorusCli.o \
|
||||
Solver.o \
|
||||
VarElimSolver.o \
|
||||
BpSolver.o \
|
||||
CbpSolver.o \
|
||||
FoveSolver.o \
|
||||
Parfactor.o \
|
||||
ProbFormula.o \
|
||||
Histogram.o \
|
||||
ParfactorList.o \
|
||||
WeightedBpSolver.o \
|
||||
LiftedUtils.o \
|
||||
Util.o \
|
||||
HorusCli.o
|
||||
Var.o \
|
||||
VarElim.o \
|
||||
WeightedBp.o
|
||||
|
||||
SOBJS=horus.@SO@
|
||||
|
||||
|
@ -402,21 +402,32 @@ Parfactor::applySubstitution (const Substitution& theta)
|
||||
|
||||
|
||||
|
||||
PrvGroup
|
||||
Parfactor::findGroup (const Ground& ground) const
|
||||
size_t
|
||||
Parfactor::indexOfGround (const Ground& ground) const
|
||||
{
|
||||
PrvGroup group = numeric_limits<PrvGroup>::max();
|
||||
size_t idx = args_.size();
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].functor() == ground.functor() &&
|
||||
args_[i].arity() == ground.arity()) {
|
||||
constr_->moveToTop (args_[i].logVars());
|
||||
if (constr_->containsTuple (ground.args())) {
|
||||
group = args_[i].group();
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return group;
|
||||
return idx;
|
||||
}
|
||||
|
||||
|
||||
|
||||
PrvGroup
|
||||
Parfactor::findGroup (const Ground& ground) const
|
||||
{
|
||||
size_t idx = indexOfGround (ground);
|
||||
return idx == args_.size()
|
||||
? numeric_limits<PrvGroup>::max()
|
||||
: args_[idx].group();
|
||||
}
|
||||
|
||||
|
||||
@ -429,6 +440,30 @@ Parfactor::containsGround (const Ground& ground) const
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGrounds (const Grounds& grounds) const
|
||||
{
|
||||
Tuple tuple;
|
||||
LogVars tupleLvs;
|
||||
for (size_t i = 0; i < grounds.size(); i++) {
|
||||
size_t idx = indexOfGround (grounds[i]);
|
||||
if (idx == args_.size()) {
|
||||
return false;
|
||||
}
|
||||
LogVars lvs = args_[idx].logVars();
|
||||
for (size_t j = 0; j < lvs.size(); j++) {
|
||||
if (Util::contains (tupleLvs, lvs[j]) == false) {
|
||||
tuple.push_back (grounds[i].args()[j]);
|
||||
tupleLvs.push_back (lvs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
constr_->moveToTop (tupleLvs);
|
||||
return constr_->containsTuple (tuple);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGroup (PrvGroup group) const
|
||||
{
|
||||
@ -442,6 +477,19 @@ Parfactor::containsGroup (PrvGroup group) const
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGroups (vector<PrvGroup> groups) const
|
||||
{
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
if (containsGroup (groups[i]) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
Parfactor::nrFormulas (LogVar X) const
|
||||
{
|
||||
|
@ -64,11 +64,17 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
void applySubstitution (const Substitution&);
|
||||
|
||||
size_t indexOfGround (const Ground&) const;
|
||||
|
||||
PrvGroup findGroup (const Ground&) const;
|
||||
|
||||
bool containsGround (const Ground&) const;
|
||||
|
||||
bool containsGrounds (const Grounds&) const;
|
||||
|
||||
bool containsGroup (PrvGroup) const;
|
||||
|
||||
bool containsGroups (vector<PrvGroup>) const;
|
||||
|
||||
unsigned nrFormulas (LogVar) const;
|
||||
|
||||
|
@ -91,6 +91,8 @@ class ObservedFormula
|
||||
|
||||
unsigned evidence (void) const { return evidence_; }
|
||||
|
||||
void setEvidence (unsigned ev) { evidence_ = ev; }
|
||||
|
||||
ConstraintTree& constr (void) { return constr_; }
|
||||
|
||||
bool isAtom (void) const { return arity_ == 0; }
|
||||
|
@ -1,5 +1,8 @@
|
||||
#include "Solver.h"
|
||||
#include "Util.h"
|
||||
#include "BeliefProp.h"
|
||||
#include "CountingBp.h"
|
||||
#include "VarElim.h"
|
||||
|
||||
|
||||
void
|
||||
@ -38,3 +41,67 @@ Solver::printAllPosterioris (void)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
Solver::getJointByConditioning (
|
||||
GroundSolver solverType,
|
||||
FactorGraph fg,
|
||||
const VarIds& jointVarIds) const
|
||||
{
|
||||
VarNodes jointVars;
|
||||
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
||||
assert (fg.getVarNode (jointVarIds[i]));
|
||||
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
Solver* solver = 0;
|
||||
switch (solverType) {
|
||||
case GroundSolver::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolver::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolver::VE: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
|
||||
for (size_t i = 1; i < jointVarIds.size(); i++) {
|
||||
assert (jointVars[i]->hasEvidence() == false);
|
||||
Params newBeliefs;
|
||||
Vars observedVars;
|
||||
Ranges observedRanges;
|
||||
for (size_t j = 0; j < observedVids.size(); j++) {
|
||||
observedVars.push_back (fg.getVarNode (observedVids[j]));
|
||||
observedRanges.push_back (observedVars.back()->range());
|
||||
}
|
||||
Indexer indexer (observedRanges, false);
|
||||
while (indexer.valid()) {
|
||||
for (size_t j = 0; j < observedVars.size(); j++) {
|
||||
observedVars[j]->setEvidence (indexer[j]);
|
||||
}
|
||||
delete solver;
|
||||
switch (solverType) {
|
||||
case GroundSolver::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolver::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolver::VE: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
|
||||
int count = -1;
|
||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->range() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
observedVids.push_back (jointVars[i]->varId());
|
||||
}
|
||||
delete solver;
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
|
@ -3,8 +3,9 @@
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
#include "Var.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Var.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
@ -23,6 +24,9 @@ class Solver
|
||||
void printAnswer (const VarIds& vids);
|
||||
|
||||
void printAllPosterioris (void);
|
||||
|
||||
Params getJointByConditioning (GroundSolver,
|
||||
FactorGraph, const VarIds& jointVarIds) const;
|
||||
|
||||
protected:
|
||||
const FactorGraph& fg;
|
||||
|
@ -13,9 +13,9 @@ bool logDomain = false;
|
||||
|
||||
unsigned verbosity = 0;
|
||||
|
||||
LiftedSolvers liftedSolver = LiftedSolvers::FOVE;
|
||||
LiftedSolver liftedSolver = LiftedSolver::FOVE;
|
||||
|
||||
GroundSolvers groundSolver = GroundSolvers::VE;
|
||||
GroundSolver groundSolver = GroundSolver::VE;
|
||||
|
||||
};
|
||||
|
||||
@ -211,9 +211,9 @@ setHorusFlag (string key, string value)
|
||||
ss >> Globals::verbosity;
|
||||
} else if (key == "lifted_solver") {
|
||||
if ( value == "fove") {
|
||||
Globals::liftedSolver = LiftedSolvers::FOVE;
|
||||
Globals::liftedSolver = LiftedSolver::FOVE;
|
||||
} else if (value == "lbp") {
|
||||
Globals::liftedSolver = LiftedSolvers::LBP;
|
||||
Globals::liftedSolver = LiftedSolver::LBP;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
@ -221,11 +221,11 @@ setHorusFlag (string key, string value)
|
||||
}
|
||||
} else if (key == "ground_solver") {
|
||||
if ( value == "ve") {
|
||||
Globals::groundSolver = GroundSolvers::VE;
|
||||
Globals::groundSolver = GroundSolver::VE;
|
||||
} else if (value == "bp") {
|
||||
Globals::groundSolver = GroundSolvers::BP;
|
||||
Globals::groundSolver = GroundSolver::BP;
|
||||
} else if (value == "cbp") {
|
||||
Globals::groundSolver = GroundSolvers::CBP;
|
||||
Globals::groundSolver = GroundSolver::CBP;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
|
@ -1,12 +1,12 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "VarElimSolver.h"
|
||||
#include "VarElim.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
VarElimSolver::~VarElimSolver (void)
|
||||
VarElim::~VarElim (void)
|
||||
{
|
||||
delete factorList_.back();
|
||||
}
|
||||
@ -14,7 +14,7 @@ VarElimSolver::~VarElimSolver (void)
|
||||
|
||||
|
||||
Params
|
||||
VarElimSolver::solveQuery (VarIds queryVids)
|
||||
VarElim::solveQuery (VarIds queryVids)
|
||||
{
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "Solving query on " ;
|
||||
@ -41,7 +41,7 @@ VarElimSolver::solveQuery (VarIds queryVids)
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::printSolverFlags (void) const
|
||||
VarElim::printSolverFlags (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "variable elimination [" ;
|
||||
@ -62,7 +62,7 @@ VarElimSolver::printSolverFlags (void) const
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::createFactorList (void)
|
||||
VarElim::createFactorList (void)
|
||||
{
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
factorList_.reserve (facNodes.size() * 2);
|
||||
@ -84,7 +84,7 @@ VarElimSolver::createFactorList (void)
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::absorveEvidence (void)
|
||||
VarElim::absorveEvidence (void)
|
||||
{
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printDashedLine();
|
||||
@ -117,7 +117,7 @@ VarElimSolver::absorveEvidence (void)
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::findEliminationOrder (const VarIds& vids)
|
||||
VarElim::findEliminationOrder (const VarIds& vids)
|
||||
{
|
||||
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
|
||||
}
|
||||
@ -125,7 +125,7 @@ VarElimSolver::findEliminationOrder (const VarIds& vids)
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::processFactorList (const VarIds& vids)
|
||||
VarElim::processFactorList (const VarIds& vids)
|
||||
{
|
||||
totalFactorSize_ = 0;
|
||||
largestFactorSize_ = 0;
|
||||
@ -170,7 +170,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::eliminate (VarId elimVar)
|
||||
VarElim::eliminate (VarId elimVar)
|
||||
{
|
||||
Factor* result = 0;
|
||||
vector<size_t>& idxs = varFactors_.find (elimVar)->second;
|
||||
@ -205,7 +205,7 @@ VarElimSolver::eliminate (VarId elimVar)
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::printActiveFactors (void)
|
||||
VarElim::printActiveFactors (void)
|
||||
{
|
||||
for (size_t i = 0; i < factorList_.size(); i++) {
|
||||
if (factorList_[i] != 0) {
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_VARELIMSOLVER_H
|
||||
#define HORUS_VARELIMSOLVER_H
|
||||
#ifndef HORUS_VARELIM_H
|
||||
#define HORUS_VARELIM_H
|
||||
|
||||
#include "unordered_map"
|
||||
|
||||
@ -11,12 +11,12 @@
|
||||
using namespace std;
|
||||
|
||||
|
||||
class VarElimSolver : public Solver
|
||||
class VarElim : public Solver
|
||||
{
|
||||
public:
|
||||
VarElimSolver (const FactorGraph& fg) : Solver (fg) { }
|
||||
VarElim (const FactorGraph& fg) : Solver (fg) { }
|
||||
|
||||
~VarElimSolver (void);
|
||||
~VarElim (void);
|
||||
|
||||
Params solveQuery (VarIds);
|
||||
|
||||
@ -42,5 +42,5 @@ class VarElimSolver : public Solver
|
||||
unordered_map<VarId, vector<size_t>> varFactors_;
|
||||
};
|
||||
|
||||
#endif // HORUS_VARELIMSOLVER_H
|
||||
#endif // HORUS_VARELIM_H
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include "WeightedBpSolver.h"
|
||||
#include "WeightedBp.h"
|
||||
|
||||
|
||||
WeightedBpSolver::~WeightedBpSolver (void)
|
||||
WeightedBp::~WeightedBp (void)
|
||||
{
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
@ -12,7 +12,7 @@ WeightedBpSolver::~WeightedBpSolver (void)
|
||||
|
||||
|
||||
Params
|
||||
WeightedBpSolver::getPosterioriOf (VarId vid)
|
||||
WeightedBp::getPosterioriOf (VarId vid)
|
||||
{
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
@ -47,7 +47,7 @@ WeightedBpSolver::getPosterioriOf (VarId vid)
|
||||
|
||||
|
||||
void
|
||||
WeightedBpSolver::createLinks (void)
|
||||
WeightedBp::createLinks (void)
|
||||
{
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "compressed factor graph contains " ;
|
||||
@ -78,7 +78,7 @@ WeightedBpSolver::createLinks (void)
|
||||
|
||||
|
||||
void
|
||||
WeightedBpSolver::maxResidualSchedule (void)
|
||||
WeightedBp::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
@ -151,7 +151,7 @@ WeightedBpSolver::maxResidualSchedule (void)
|
||||
|
||||
|
||||
void
|
||||
WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
|
||||
WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
||||
{
|
||||
WeightedLink* link = static_cast<WeightedLink*> (_link);
|
||||
FacNode* src = link->facNode();
|
||||
@ -223,7 +223,7 @@ WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
|
||||
|
||||
|
||||
Params
|
||||
WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const
|
||||
WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
||||
{
|
||||
const WeightedLink* link = static_cast<const WeightedLink*> (_link);
|
||||
const VarNode* src = link->varNode();
|
||||
@ -272,7 +272,7 @@ WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const
|
||||
|
||||
|
||||
void
|
||||
WeightedBpSolver::printLinkInformation (void) const
|
||||
WeightedBp::printLinkInformation (void) const
|
||||
{
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
WeightedLink* l = static_cast<WeightedLink*> (links_[i]);
|
@ -1,7 +1,7 @@
|
||||
#ifndef HORUS_WEIGHTEDBPSOLVER_H
|
||||
#define HORUS_WEIGHTEDBPSOLVER_H
|
||||
#ifndef HORUS_WEIGHTEDBP_H
|
||||
#define HORUS_WEIGHTEDBP_H
|
||||
|
||||
#include "BpSolver.h"
|
||||
#include "BeliefProp.h"
|
||||
|
||||
class WeightedLink : public BpLink
|
||||
{
|
||||
@ -31,14 +31,14 @@ class WeightedLink : public BpLink
|
||||
|
||||
|
||||
|
||||
class WeightedBpSolver : public BpSolver
|
||||
class WeightedBp : public BeliefProp
|
||||
{
|
||||
public:
|
||||
WeightedBpSolver (const FactorGraph& fg,
|
||||
WeightedBp (const FactorGraph& fg,
|
||||
const vector<vector<unsigned>>& weights)
|
||||
: BpSolver (fg), weights_(weights) { }
|
||||
: BeliefProp (fg), weights_(weights) { }
|
||||
|
||||
~WeightedBpSolver (void);
|
||||
~WeightedBp (void);
|
||||
|
||||
Params getPosterioriOf (VarId);
|
||||
|
||||
@ -57,5 +57,5 @@ class WeightedBpSolver : public BpSolver
|
||||
vector<vector<unsigned>> weights_;
|
||||
};
|
||||
|
||||
#endif // HORUS_WEIGHTEDBPSOLVER_H
|
||||
#endif // HORUS_WEIGHTEDBP_H
|
||||
|
Reference in New Issue
Block a user