Don't pass around the Solver for EM

This commit is contained in:
Tiago Gomes 2012-12-15 16:11:03 +00:00
parent 81ac6f1913
commit 9ff9be2f49
3 changed files with 31 additions and 26 deletions

View File

@ -6,11 +6,9 @@
clpbn_key/2, clpbn_key/2,
clpbn_init_solver/4, clpbn_init_solver/4,
clpbn_run_solver/3, clpbn_run_solver/3,
pfl_init_solver/6, pfl_init_solver/5,
pfl_run_solver/4, pfl_run_solver/3,
clpbn_finalize_solver/1, clpbn_finalize_solver/1,
clpbn_init_solver/5,
clpbn_run_solver/4,
clpbn_init_graph/1, clpbn_init_graph/1,
probability/2, probability/2,
conditional_probability/3, conditional_probability/3,
@ -589,21 +587,26 @@ clpbn_run_solver(pcg, LVs, LPs, State) :-
% %
% This is a routine to start a solver, called by the learning procedures (ie, em). % This is a routine to start a solver, called by the learning procedures (ie, em).
% %
pfl_init_solver(QueryKeys, AllKeys, Factors, Evidence, State, ve) :-
pfl_init_solver(QueryKeys, AllKeys, Factors, Evidence, State) :-
solver(Solver),
pfl_init_solver(QueryKeys, AllKeys, Factors, Evidence, State, Solver).
pfl_init_solver(QueryKeys, AllKeys, Factors, Evidence, State, ve) :- !,
init_ve_ground_solver(QueryKeys, AllKeys, Factors, Evidence, State). init_ve_ground_solver(QueryKeys, AllKeys, Factors, Evidence, State).
pfl_init_solver(QueryKeys, AllKeys, Factors, Evidence, State, bdd) :- pfl_init_solver(QueryKeys, AllKeys, Factors, Evidence, State, bdd) :- !,
init_bdd_ground_solver(QueryKeys, AllKeys, Factors, Evidence, State). init_bdd_ground_solver(QueryKeys, AllKeys, Factors, Evidence, State).
pfl_init_solver(QueryKeys, AllKeys, Factors, Evidence, State, hve) :- pfl_init_solver(QueryKeys, AllKeys, Factors, Evidence, State, hve) :- !,
clpbn_horus:set_horus_flag(ground_solver, ve), clpbn_horus:set_horus_flag(ground_solver, ve),
init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence, State). init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence, State).
pfl_init_solver(QueryKeys, AllKeys, Factors, Evidence, State, bp) :- pfl_init_solver(QueryKeys, AllKeys, Factors, Evidence, State, bp) :- !,
clpbn_horus:set_horus_flag(ground_solver, bp), clpbn_horus:set_horus_flag(ground_solver, bp),
init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence, State). init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence, State).
pfl_init_solver(QueryKeys, AllKeys, Factors, Evidence, State, cbp) :- pfl_init_solver(QueryKeys, AllKeys, Factors, Evidence, State, cbp) :- !,
clpbn_horus:set_horus_flag(ground_solver, cbp), clpbn_horus:set_horus_flag(ground_solver, cbp),
init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence, State). init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence, State).
@ -612,19 +615,23 @@ pfl_init_solver(_, _, _, _, _, Solver) :-
write(Solver), write(Solver),
write('\' cannot be used for learning'). write('\' cannot be used for learning').
pfl_run_solver(LVs, LPs, State, ve) :- pfl_run_solver(LVs, LPs, State) :-
solver(Solver),
pfl_run_solver(LVs, LPs, State, Solver).
pfl_run_solver(LVs, LPs, State, ve) :- !,
run_ve_ground_solver(LVs, LPs, State). run_ve_ground_solver(LVs, LPs, State).
pfl_run_solver(LVs, LPs, State, bdd) :- pfl_run_solver(LVs, LPs, State, bdd) :- !,
run_bdd_ground_solver(LVs, LPs, State). run_bdd_ground_solver(LVs, LPs, State).
pfl_run_solver(LVs, LPs, State, hve) :- pfl_run_solver(LVs, LPs, State, hve) :- !,
run_horus_ground_solver(LVs, LPs, State). run_horus_ground_solver(LVs, LPs, State).
pfl_run_solver(LVs, LPs, State, bp) :- pfl_run_solver(LVs, LPs, State, bp) :- !,
run_horus_ground_solver(LVs, LPs, State). run_horus_ground_solver(LVs, LPs, State).
pfl_run_solver(LVs, LPs, State, cbp) :- pfl_run_solver(LVs, LPs, State, cbp) :- !,
run_horus_ground_solver(LVs, LPs, State). run_horus_ground_solver(LVs, LPs, State).

View File

@ -15,10 +15,10 @@
:- use_module(library(clpbn), :- use_module(library(clpbn),
[clpbn_init_graph/1, [clpbn_init_graph/1,
clpbn_init_solver/5, clpbn_init_solver/4,
clpbn_run_solver/4, clpbn_run_solver/3,
pfl_init_solver/6, pfl_init_solver/5,
pfl_run_solver/4, pfl_run_solver/3,
clpbn_finalize_solver/1, clpbn_finalize_solver/1,
conditional_probability/3, conditional_probability/3,
clpbn_flag/2]). clpbn_flag/2]).
@ -110,7 +110,7 @@ setup_em_network(Items, Solver, state( AllDists, AllDistInstances, MargKeys, Sol
% get the EM CPT connections info from the factors % get the EM CPT connections info from the factors
generate_dists(Factors, EList, AllDists, AllDistInstances, MargKeys), generate_dists(Factors, EList, AllDists, AllDistInstances, MargKeys),
% setup solver, if necessary % setup solver, if necessary
pfl_init_solver(MargKeys, Keys, Factors, EList, SolverState, Solver). pfl_init_solver(MargKeys, Keys, Factors, EList, SolverState).
setup_em_network(Items, Solver, state( AllDists, AllDistInstances, MargVars, SolverVars)) :- setup_em_network(Items, Solver, state( AllDists, AllDistInstances, MargVars, SolverVars)) :-
% create the ground network % create the ground network
call_run_all(Items), call_run_all(Items),
@ -121,7 +121,7 @@ setup_em_network(Items, Solver, state( AllDists, AllDistInstances, MargVars, Sol
% remove variables that do not have to do with this query. % remove variables that do not have to do with this query.
different_dists(AllVars, AllDists, AllDistInstances, MargVars), different_dists(AllVars, AllDists, AllDistInstances, MargVars),
% setup solver by doing parameter independent work. % setup solver by doing parameter independent work.
clpbn_init_solver(Solver, MargVars, AllVars, _, SolverVars). clpbn_init_solver(MargVars, AllVars, _, SolverVars).
run_examples(user:Exs, Keys, Factors, EList) :- run_examples(user:Exs, Keys, Factors, EList) :-
Exs = [_:_|_], !, Exs = [_:_|_], !,
@ -297,11 +297,9 @@ compact_mvars([X|MargVars], [X|CMVars]) :- !,
estimate(state(_, _, Margs, SolverState), LPs) :- estimate(state(_, _, Margs, SolverState), LPs) :-
clpbn:use_parfactors(on), !, clpbn:use_parfactors(on), !,
clpbn_flag(em_solver, Solver), pfl_run_solver(Margs, LPs, SolverState).
pfl_run_solver(Margs, LPs, SolverState, Solver).
estimate(state(_, _, Margs, SolverState), LPs) :- estimate(state(_, _, Margs, SolverState), LPs) :-
clpbn_flag(em_solver, Solver), clpbn_run_solver(Margs, LPs, SolverState).
clpbn_run_solver(Solver, Margs, LPs, SolverState).
maximise(state(_,DistInstances,MargVars,_), Tables, LPs, Likelihood) :- maximise(state(_,DistInstances,MargVars,_), Tables, LPs, Likelihood) :-
rb_new(MDistTable0), rb_new(MDistTable0),

View File

@ -23,8 +23,8 @@
[clpbn_flag/2 as pfl_flag, [clpbn_flag/2 as pfl_flag,
set_clpbn_flag/2 as set_pfl_flag, set_clpbn_flag/2 as set_pfl_flag,
conditional_probability/3, conditional_probability/3,
pfl_init_solver/6, pfl_init_solver/5,
pfl_run_solver/4]). pfl_run_solver/3]).
:- reexport(library(clpbn/horus), :- reexport(library(clpbn/horus),
[set_solver/1]). [set_solver/1]).