Refactor the way we set the inference solver

This commit is contained in:
Tiago Gomes 2012-12-17 23:14:54 +00:00
parent f1499f99f3
commit e1c862ebbe
4 changed files with 116 additions and 82 deletions

View File

@ -3,6 +3,7 @@
[{}/1,
clpbn_flag/2,
set_clpbn_flag/2,
set_solver/1,
set_em_solver/1,
clpbn_flag/3,
clpbn_key/2,
@ -68,9 +69,9 @@
%% ]).
:- use_module('clpbn/pgrammar',
[init_pcg_solver/4,
run_pcg_solver/3,
pcg_init_graph/0
[pcg_init_graph/0,
init_pcg_solver/4,
run_pcg_solver/3
]).
:- use_module('clpbn/horus_ground',
@ -126,66 +127,89 @@
:- dynamic
solver/1,
output/1,
use/1,
em_solver/1,
suppress_attribute_display/1,
parameter_softening/1,
em_solver/1,
use_parfactors/1.
use_parfactors/1,
output/1,
use/1.
:- meta_predicate probability(:,-), conditional_probability(:,:,-).
solver(ve).
em_solver(bp).
%output(xbif(user_error)).
%output(gviz(user_error)).
output(no).
suppress_attribute_display(false).
parameter_softening(m_estimate(10)).
use_parfactors(off).
output(no).
%output(xbif(user_error)).
%output(gviz(user_error)).
ground_solver(ve).
ground_solver(hve).
ground_solver(jt).
ground_solver(bdd).
ground_solver(bp).
ground_solver(cbp).
ground_solver(gibbs).
lifted_solver(lve).
lifted_solver(lkc).
lifted_solver(lbp).
clpbn_flag(Flag,Option) :-
clpbn_flag(Flag, Option) :-
clpbn_flag(Flag, Option, Option).
set_clpbn_flag(Flag,Option) :-
clpbn_flag(Flag, _, Option).
clpbn_flag(output,Before,After) :-
retract(output(Before)),
assert(output(After)).
clpbn_flag(solver,Before,After) :-
retract(solver(Before)),
assert(solver(After)).
clpbn_flag(em_solver,Before,After) :-
retract(em_solver(Before)),
assert(em_solver(After)).
clpbn_flag(bnt_solver,Before,After) :-
retract(bnt:bnt_solver(Before)),
assert(bnt:bnt_solver(After)).
clpbn_flag(bnt_path,Before,After) :-
retract(bnt:bnt_path(Before)),
assert(bnt:bnt_path(After)).
clpbn_flag(bnt_model,Before,After) :-
retract(bnt:bnt_model(Before)),
assert(bnt:bnt_model(After)).
clpbn_flag(suppress_attribute_display,Before,After) :-
retract(suppress_attribute_display(Before)),
assert(suppress_attribute_display(After)).
clpbn_flag(parameter_softening,Before,After) :-
retract(parameter_softening(Before)),
assert(parameter_softening(After)).
clpbn_flag(use_factors,Before,After) :-
retract(use_parfactors(Before)),
assert(use_parfactors(After)).
clpbn_flag(output,Before,After) :-
retract(output(Before)),
assert(output(After)).
set_solver(Solver) :-
set_clpbn_flag(solver,Solver).
set_em_solver(Solver) :-
set_clpbn_flag(em_solver, Solver).
set_clpbn_flag(em_solver,Solver).
{_} :-
solver(none), !.
{Var = Key with Dist} :-
{ Var = Key with Dist } :-
put_atts(El,[key(Key),dist(DistInfo,Parents)]),
dist(Dist, DistInfo, Key, Parents),
add_evidence(Var,Key,DistInfo,El)
@ -209,8 +233,10 @@ init_clpbn_vars(El) :-
create_mutable(El, Mutable),
b_setval(clpbn_qvars, Mutable).
check_constraint(Constraint, _, _, Constraint) :- var(Constraint), !.
check_constraint((A->D), _, _, (A->D)) :- var(A), !.
check_constraint(Constraint, _, _, Constraint) :-
var(Constraint), !.
check_constraint((A->D), _, _, (A->D)) :-
var(A), !.
check_constraint((([A|B].L)->D), Vars, NVars, (([A|B].NL)->D)) :- !,
check_cpt_input_vars(L, Vars, NVars, NL).
check_constraint(Dist, _, _, Dist).
@ -246,7 +272,8 @@ clpbn_marginalise(V, Dist) :-
%
project_attributes(GVars0, _AVars0) :-
use_parfactors(on),
clpbn_flag(solver, Solver), Solver \= fove, !,
clpbn_flag(solver, Solver),
ground_solver(Solver),
generate_network(GVars0, GKeys, Keys, Factors, Evidence),
b_setval(clpbn_query_variables, f(GVars0,Evidence)),
simplify_query(GVars0, GVars),
@ -329,34 +356,27 @@ get_rid_of_ev_vars([V|LVs0],[V|LVs]) :-
get_rid_of_ev_vars(LVs0,LVs).
% do nothing if we don't have query variables to compute.
write_out(_, [], _, _) :- !.
write_out(graphs, _, AVars, _) :-
clpbn2graph(AVars).
write_out(ve, GVars, AVars, DiffVars) :-
ve(GVars, AVars, DiffVars).
write_out(jt, GVars, AVars, DiffVars) :-
jt(GVars, AVars, DiffVars).
write_out(bdd, GVars, AVars, DiffVars) :-
bdd(GVars, AVars, DiffVars).
write_out(bp, _GVars, _AVars, _DiffVars) :-
writeln('interface not supported any longer').
write_out(gibbs, GVars, AVars, DiffVars) :-
gibbs(GVars, AVars, DiffVars).
write_out(bnt, GVars, AVars, DiffVars) :-
do_bnt(GVars, AVars, DiffVars).
write_out(fove, GVars, AVars, DiffVars) :-
call_horus_lifted_solver(GVars, AVars, DiffVars).
% call a solver with keys, not actual variables
call_ground_solver(bp, GVars, GoalKeys, Keys, Factors, Evidence) :- !,
call_horus_ground_solver(GVars, GoalKeys, Keys, Factors, Evidence, _Answ).
call_ground_solver(bdd, GVars, GoalKeys, Keys, Factors, Evidence) :- !,
call_bdd_ground_solver(GVars, GoalKeys, Keys, Factors, Evidence, _Answ).
% Call a solver with keys, not actual variables
call_ground_solver(ve, GVars, GoalKeys, Keys, Factors, Evidence) :- !,
call_ve_ground_solver(GVars, GoalKeys, Keys, Factors, Evidence, _Answ).
call_ground_solver(hve, GVars, GoalKeys, Keys, Factors, Evidence) :- !,
clpbn_horus:set_horus_flag(ground_solver, ve),
call_horus_ground_solver(GVars, GoalKeys, Keys, Factors, Evidence, _Answ).
call_ground_solver(bdd, GVars, GoalKeys, Keys, Factors, Evidence) :- !,
call_bdd_ground_solver(GVars, GoalKeys, Keys, Factors, Evidence, _Answ).
call_ground_solver(bp, GVars, GoalKeys, Keys, Factors, Evidence) :- !,
clpbn_horus:set_horus_flag(ground_solver, bp),
call_horus_ground_solver(GVars, GoalKeys, Keys, Factors, Evidence, _Answ).
call_ground_solver(cbp, GVars, GoalKeys, Keys, Factors, Evidence) :- !,
clpbn_horus:set_horus_flag(ground_solver, cbp),
call_horus_ground_solver(GVars, GoalKeys, Keys, Factors, Evidence, _Answ).
call_ground_solver(Solver, GVars, _GoalKeys, Keys, Factors, Evidence) :-
% traditional solver
% fall back to traditional solver
b_hash_new(Hash0),
foldl(gvar_in_hash, GVars, Hash0, HashI),
foldl(key_to_var, Keys, AllVars, HashI, Hash1),
@ -368,6 +388,44 @@ call_ground_solver(Solver, GVars, _GoalKeys, Keys, Factors, Evidence) :-
write_out(Solver, [GVars], AllVars, _),
assert(use_parfactors(on)).
% do nothing if we don't have query variables to compute.
write_out(_, [], _, _) :- !.
write_out(graphs, _, AVars, _) :- !,
clpbn2graph(AVars).
write_out(ve, GVars, AVars, DiffVars) :- !,
ve(GVars, AVars, DiffVars).
write_out(jt, GVars, AVars, DiffVars) :- !,
jt(GVars, AVars, DiffVars).
write_out(bdd, GVars, AVars, DiffVars) :- !,
bdd(GVars, AVars, DiffVars).
write_out(gibbs, GVars, AVars, DiffVars) :- !,
gibbs(GVars, AVars, DiffVars).
write_out(lve, GVars, AVars, DiffVars) :- !,
clpbn_horus:set_horus_flag(lifted_solver, lve),
call_horus_lifted_solver(GVars, AVars, DiffVars).
write_out(lkc, GVars, AVars, DiffVars) :- !,
clpbn_horus:set_horus_flag(lifted_solver, lkc),
call_horus_lifted_solver(GVars, AVars, DiffVars).
write_out(lbp, GVars, AVars, DiffVars) :- !,
clpbn_horus:set_horus_flag(lifted_solver, lbp),
call_horus_lifted_solver(GVars, AVars, DiffVars).
write_out(bnt, GVars, AVars, DiffVars) :- !,
do_bnt(GVars, AVars, DiffVars).
write_out(Solver, _, _, _) :-
format("Error: solver `~w' is unknown", [Solver]),
fail.
%
% convert a PFL network (without constraints)
% into CLP(BN) for evaluation
@ -472,14 +530,11 @@ bind_clpbn(T, Var, Key, Dist, Parents, []) :- var(T),
;
fail
).
bind_clpbn(_, Var, _, _, _, _, []) :-
use(bnt),
check_if_bnt_done(Var), !.
bind_clpbn(_, Var, _, _, _, _, []) :-
use(ve),
check_if_ve_done(Var), !.
bind_clpbn(_, Var, _, _, _, _, []) :-
use(bp),
use(hve),
check_if_horus_ground_solver_done(Var), !.
bind_clpbn(_, Var, _, _, _, _, []) :-
use(jt),
@ -487,6 +542,15 @@ bind_clpbn(_, Var, _, _, _, _, []) :-
bind_clpbn(_, Var, _, _, _, _, []) :-
use(bdd),
check_if_bdd_done(Var), !.
bind_clpbn(_, Var, _, _, _, _, []) :-
use(bp),
check_if_horus_ground_solver_done(Var), !.
bind_clpbn(_, Var, _, _, _, _, []) :-
use(cbp),
check_if_horus_ground_solver_done(Var), !.
bind_clpbn(_, Var, _, _, _, _, []) :-
use(bnt),
check_if_bnt_done(Var), !.
bind_clpbn(T, Var, Key0, _, _, _, []) :-
get_atts(Var, [key(Key)]), !,
(
@ -501,7 +565,7 @@ fresh_attvar(Var, NVar) :-
put_atts(NVar, LAtts).
% I will now allow two CLPBN variables to be bound together.
%bind_clpbns(Key, Dist, Parents, Key, Dist, Parents).
% bind_clpbns(Key, Dist, Parents, Key, Dist, Parents).
bind_clpbns(Key, Dist, Parents, Key1, Dist1, Parents1) :-
Key == Key1, !,
get_dist(Dist,_Type,_Domain,_Table),
@ -611,19 +675,6 @@ clpbn_finalize_solver(State) :-
finalize_horus_ground_solver(Info).
clpbn_finalize_solver(_State).
ground_solver(ve).
ground_solver(hve).
ground_solver(jt).
ground_solver(bdd).
ground_solver(bp).
ground_solver(cbp).
ground_solver(gibbs).
lifted_solver(lve).
lifted_solver(lkc).
lifted_solver(lbp).
%
% This is a routine to start a solver, called by the learning procedures (ie, em).
%

View File

@ -5,8 +5,7 @@
********************************************************/
:- module(clpbn_horus,
[set_solver/1,
set_horus_flag/1,
[set_horus_flag/1,
cpp_create_lifted_network/3,
cpp_create_ground_network/4,
cpp_set_parfactors_params/2,
@ -35,19 +34,6 @@ warning :-
-> true ; warning.
set_solver(ve) :- !, set_clpbn_flag(solver,ve).
set_solver(bdd) :- !, set_clpbn_flag(solver,bdd).
set_solver(jt) :- !, set_clpbn_flag(solver,jt).
set_solver(gibbs) :- !, set_clpbn_flag(solver,gibbs).
set_solver(lve) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lve).
set_solver(lbp) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lbp).
set_solver(lkc) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lkc).
set_solver(hve) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, ve).
set_solver(bp) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, bp).
set_solver(cbp) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, cbp).
set_solver(S) :- throw(error('unknown solver ', S)).
set_horus_flag(K,V) :- cpp_set_horus_flag(K,V).

View File

@ -20,8 +20,7 @@
cpp_set_factors_params/2,
cpp_run_ground_solver/3,
cpp_set_vars_information/2,
cpp_free_ground_network/1,
set_solver/1
cpp_free_ground_network/1
]).
:- use_module(library('clpbn/dists'),

View File

@ -22,15 +22,13 @@
:- reexport(library(clpbn),
[clpbn_flag/2 as pfl_flag,
set_clpbn_flag/2 as set_pfl_flag,
set_solver/1,
set_em_solver/1,
conditional_probability/3,
pfl_init_solver/5,
pfl_run_solver/3
]).
:- reexport(library(clpbn/horus),
[set_solver/1]).
:- reexport(library(clpbn/aggregates),
[avg_factors/5]).