make EM work with PFL and BP.

This commit is contained in:
Costa Vitor 2012-08-15 16:01:45 -05:00
parent 020692635b
commit a76f4f34d5
8 changed files with 170 additions and 85 deletions

View File

@ -94,7 +94,8 @@ CLPBN_HMMER_EXAMPLES= \
$(CLPBN_EXDIR)/HMMer/score.yap
CLPBN_LEARNING_EXAMPLES= \
$(CLPBN_EXDIR)/learning/school_params.yap \
$(CLPBN_EXDIR)/learning/profz_params.pfl \
$(CLPBN_EXDIR)/learning/school_params.pfl \
$(CLPBN_EXDIR)/learning/sprinkler_params.yap \
$(CLPBN_EXDIR)/learning/train.yap

View File

@ -7,9 +7,10 @@
% [S \= s2])
:- module(clpbn_ground_factors, [
generate_networks/5,
generate_network/5]).
:- module(pfl_ground_factors, [
generate_network/5,
f/3
]).
:- use_module(library(bhash), [
b_hash_new/1,
@ -30,13 +31,13 @@
:- use_module(library(clpbn/dists), [
dist/4]).
:- dynamic currently_defined/1, f/4.
:- dynamic currently_defined/1, queue/1, f/4.
%
% as you add query vars the network grows
% until you reach the last variable.
%
generate_networks(QueryVars, QueryKeys, Keys, Factors, EList) :-
generate_network(QueryVars, QueryKeys, Keys, Factors, EList) :-
init_global_search,
attributes:all_attvars(AVars),
b_hash_new(Evidence0),
@ -44,14 +45,16 @@ generate_networks(QueryVars, QueryKeys, Keys, Factors, EList) :-
b_hash_to_list(Evidence, EList0), list_to_evlist(EList0, EList),
run_through_evidence(EList),
run_through_queries(QueryVars, QueryKeys, Evidence),
propagate,
collect(Keys, Factors).
%
% clean global stateq
%
init_global_search :-
retractall(queue(_)),
retractall(currently_defined(_)),
retractall(f(_,_,_,_)).
retractall(f(_,_,_)).
list_to_evlist([], []).
list_to_evlist([K-E|EList0], [K=E|EList]) :-
@ -90,17 +93,6 @@ run_through_queries([QVars|QueryVars], [GKs|GKeys], E) :-
run_through_queries(QueryVars, GKeys, E).
run_through_queries([], [], _).
generate_network(QueryVars0, QueryKeys, Keys, Factors, EList) :-
init_global_search,
attributes:all_attvars(AVars),
b_hash_new(Evidence0),
include_evidence(AVars, Evidence0, Evidence),
b_hash_to_list(Evidence, EList0), list_to_evlist(EList0, EList),
run_through_evidence(EList),
run_through_query(QueryVars0, QueryKeys, Evidence),
collect(Keys,Factors),
writeln(gn:Keys:QueryKeys:Factors:EList).
run_through_query([], [], _).
run_through_query([V|QueryVars], QueryKeys, Evidence) :-
clpbn:get_atts(V,[key(K)]),
@ -108,16 +100,16 @@ run_through_query([V|QueryVars], QueryKeys, Evidence) :-
run_through_query(QueryVars, QueryKeys, Evidence).
run_through_query([V|QueryVars], [K|QueryKeys], Evidence) :-
clpbn:get_atts(V,[key(K)]),
( find_factors(K), fail ; true ),
queue_in(K),
run_through_query(QueryVars, QueryKeys, Evidence).
collect(Keys, Factors) :-
findall(K, currently_defined(K), Keys),
findall(f(FType,FId,FKeys,FCPT), f(FType,FId,FKeys,FCPT), Factors).
findall(f(FType,FId,FKeys), f(FType,FId,FKeys), Factors).
run_through_evidence([]).
run_through_evidence([K=_|_]) :-
find_factors(K),
queue_in(K),
fail.
run_through_evidence([_|Ev]) :-
run_through_evidence(Ev).
@ -141,26 +133,48 @@ initialize_evidence([]).
initialize_evidence([V|EVars]) :-
clpbn:get_atts(V, [key(K)]),
ground(K),
assert(currently_defined(K)),
queue_in(K),
initialize_evidence(EVars).
%
% gets key K, and collects factors that define it
find_factors(K) :-
queue_in(K) :-
queue(K), !.
queue_in(K) :-
%writeln(+K),
assert(queue(K)).
propagate :-
retract(queue(K)),!,
do_propagate(K).
propagate.
do_propagate(K) :-
%writeln(-K),
\+ currently_defined(K),
( ground(K) -> assert(currently_defined(K)) ; true),
defined_in_factor(K, ParFactor),
add_factor(ParFactor, Ks),
(
defined_in_factor(K, ParFactor),
add_factor(ParFactor, Ks)
*->
true
;
throw(error(no_defining_factor(K)))
)
,
member(K1, Ks),
\+ currently_defined(K1),
find_factors(K1).
queue_in(K1),
fail.
do_propagate(K) :-
propagate.
add_factor(factor(Type, Id, Ks, _, Phi, Constraints), Ks) :-
F = f(Type, Id, Ks, CPT),
( is_list(Phi) -> CPT = Phi ; call(user:Phi, CPT) ),
run(Constraints),
\+ f(Type, Id, Ks, CPT),
assert(F).
run(Constraints), !,
\+ f(Type, Id, Ks),
assert(f(Type, Id, Ks)).
run([Goal|Goals]) :-
call(user:Goal),

View File

@ -31,7 +31,7 @@
]).
:- use_module(library('clpbn/ground_factors'),
[generate_networks/5
[generate_network/5
]).
:- use_module(library('clpbn/display'),
@ -66,7 +66,7 @@ call_horus_ground_solver_for_probabilities(QueryKeys, _AllKeys, Factors, Evidenc
keys_to_ids(AllKeys, 0, Id1, Hash0, Hash1),
get_factors_type(Factors, Type),
evidence_to_ids(Evidence, Hash1, Hash2, Id1, Id2, EvidenceIds),
%writeln(evidence:Evidence:EvidenceIds),
%writeln(evidence:Evidence:EvidenceIds),
factors_to_ids(Factors, Hash2, Hash3, Id2, Id3, FactorIds),
%writeln(queryKeys:QueryKeys), writeln(''),
%% writeln(type:Type), writeln(''),
@ -74,10 +74,10 @@ call_horus_ground_solver_for_probabilities(QueryKeys, _AllKeys, Factors, Evidenc
sort(AllKeys,SKeys), %% writeln(allSortedKeys:SKeys), writeln(''),
keys_to_ids(SKeys, Id3, Id4, Hash3, Hash4),
%b_hash:b_hash_to_list(Hash1,_L4), writeln(h1:_L4),
writeln(factors:Factors), writeln(''),
writeln(factorIds:FactorIds), writeln(''),
writeln(evidence:Evidence), writeln(''),
writeln(evidenceIds:EvidenceIds), writeln(''),
%writeln(factors:Factors), writeln(''),
%writeln(factorIds:FactorIds), writeln(''),
%writeln(evidence:Evidence), writeln(''),
%writeln(evidenceIds:EvidenceIds), writeln(''),
cpp_create_ground_network(Type, FactorIds, EvidenceIds, Network),
get_vars_information(AllKeys, StatesNames),
terms_to_atoms(AllKeys, KeysAtoms),
@ -119,8 +119,8 @@ keys_to_ids([Key|AllKeys], I0, I, Hash0, Hash) :-
get_factors_type([f(bayes, _, _, _)|_], bayes) :- ! .
get_factors_type([f(markov, _, _, _)|_], markov) :- ! .
get_factors_type([f(bayes, _, _)|_], bayes) :- ! .
get_factors_type([f(markov, _, _)|_], markov) :- ! .
list_of_keys_to_ids([], H, H, I, I, []).
@ -138,8 +138,9 @@ list_of_keys_to_ids([Key|QueryKeys], Hash0, Hash, I0, I, [I0|QueryIds]) :-
factors_to_ids([], H, H, I, I, []).
factors_to_ids([f(_, DistId, Keys, CPT)|Fs], Hash0, Hash, I0, I, [f(Ids, Ranges, CPT, DistId)|NFs]) :-
factors_to_ids([f(_, DistId, Keys)|Fs], Hash0, Hash, I0, I, [f(Ids, Ranges, CPT, DistId)|NFs]) :-
list_of_keys_to_ids(Keys, Hash0, Hash1, I0, I1, Ids),
pfl:get_pfl_parameters(DistId, CPT),
get_ranges(Keys, Ranges),
factors_to_ids(Fs, Hash1, Hash, I1, I, NFs).
@ -180,17 +181,18 @@ finalize_horus_ground_solver(bp(Network, _)) :-
% QVars: all query variables?
%
%
init_horus_ground_solver(QueryVars, _AllVars, _, horus(GKeys, Keys, Factors, Evidence)) :-
trace,
generate_networks(QueryVars, GKeys, Keys, Factors, Evidence), !.
% writeln(qvs:QueryVars),
% writeln(Keys), writeln(Factors), !.
init_horus_ground_solver(QueryVars, _AllVars, Ground, horus(GKeys, Keys, Factors, Evidence)) :-
(
var(GKeys) ->
Ground = ground(GKeys, Keys, Factors, Evidence)
;
generate_network(QueryVars, GKeys, Keys, Factors, Evidence)
).
%
% just call horus solver.
%
run_horus_ground_solver(_QueryVars, Solutions, horus(GKeys, Keys, Factors, Evidence) ) :- !,
trace,
call_horus_ground_solver_for_probabilities(GKeys, Keys, Factors, Evidence, Solutions).
%bp([[]],_,_) :- !.

View File

@ -16,20 +16,21 @@
% with \phi defined by abi_table(X) and whose domain and constraints
% is obtained from professor/1.
%
bayes abi(K)::[h,m,l] ; abi_table ; [professor(K)].
bayes pop(K)::[h,m,l], abi(K) ; pop_table ; [professor(K)].
bayes grade(C,S)::[a,b,c,d], int(S), diff(C) ; grade_table ; [registration(_,C,S)].
bayes sat(C,S,P)::[h,m,l], abi(P), grade(C,S) ; sat_table ; [reg_sat(C,S,P)].
bayes rat(C) :: [h,m,l], avg(Sats) ; avg ; [course_rating(C, Sats)].
bayes diff(C) :: [h,m,l] ; diff_table ; [course(C,_)].
bayes int(S) :: [h,m,l] ; int_table ; [student(S)].
bayes grade(C,S)::[a,b,c,d], int(S), diff(C) ; grade_table ; [registration(_,C,S)].
bayes satisfaction(C,S)::[h,m,l], abi(P), grade(C,S) ; sat_table ; [reg_satisfaction(C,S,P)].
bayes rat(C) :: [h,m,l], avg(Sats) ; avg ; [course_rating(C, Sats)].
bayes rank(S) :: [a,b,c,d], avg(Grades) ; avg ; [student_ranking(S,Grades)].
@ -37,14 +38,14 @@ grade(Key, Grade) :-
registration(Key, CKey, SKey),
grade(CKey, SKey, Grade).
reg_sat(CKey, SKey, PKey) :-
reg_satisfaction(CKey, SKey, PKey) :-
registration(_Key, CKey, SKey),
course(CKey, PKey).
course_rating(CKey, Sats) :-
course(CKey, _),
setof(sat(CKey,SKey,PKey),
reg_sat(CKey, SKey, PKey),
setof(satisfaction(CKey,SKey,PKey),
reg_satisfaction(CKey, SKey, PKey),
Sats).
student_ranking(SKey, Grades) :-
@ -53,13 +54,31 @@ student_ranking(SKey, Grades) :-
:- ensure_loaded(tables).
% convert to longer names
professor_ability(P,A) :- abi(P, A).
professor_popularity(P,A) :- pop(P, A).
registration_grade(R,A) :-
registration(R,C,S),
grade(C,S,A).
registration_satisfaction(R,A) :-
registration(R,C,S),
satisfaction(C,S,A).
student_intelligence(P,A) :- int(P, A).
course_difficulty(P,A) :- diff(P, A).
%
% evidence
%
abi(p0, h).
%abi(p0, h).
pop(p1, m).
pop(p2, h).
%pop(p1, m).
%pop(p2, h).
% Query
% ?- abi(p0, X).

View File

@ -7,7 +7,7 @@ total_students(256).
*/
:- use_module(library(clpbn)).
:- use_module(library(pfl)).
:- source.
@ -17,7 +17,7 @@ total_students(256).
:- yap_flag(write_strings,on).
:- ensure_loaded(schema).
:- ensure_loaded(parschema).
:- ensure_loaded(school32_data).

View File

@ -12,16 +12,17 @@ abi_table([0.3,0.3,0.4]).
pop_table([0.3,0.3,0.4,0.3,0.3,0.4,0.3,0.3,0.4]).
goal_list([abi(p0,h),
goal_list([/*abi(p0,h),
abi(p1,m),
abi(p2,m),
abi(p3,m),
abi(p3,m),*/
abi(p4,l),
pop(p5,h),
abi(p5,_),
abi(p6,_),
pop(p7,_)]).
professor(p0).
professor(p1).
professor(p2).
professor(p3).

View File

@ -24,6 +24,15 @@
randomise_all_dists/0,
uniformise_all_dists/0]).
:- use_module(library(clpbn/ground_factors),
[generate_network/5,
f/3]).
:- use_module(library(bhash), [
b_hash_new/1,
b_hash_lookup/3,
b_hash_insert/4]).
:- use_module(library('clpbn/learning/learn_utils'),
[run_all/1,
clpbn_vars/2,
@ -61,9 +70,11 @@ em(_, _, _, Tables, Likelihood) :-
retract(em_found(Tables, Likelihood)).
handle_em(error(repeated_parents)) :-
handle_em(error(repeated_parents)) :- !,
assert(em_found(_, -inf)),
fail.
handle_em(Error) :-
throw(Error).
% This gets you an initial configuration. If there is a lot of evidence
% tables may be filled in close to optimal, otherwise they may be
@ -74,7 +85,7 @@ handle_em(error(repeated_parents)) :-
% it includes the list of variables without evidence,
% the list of distributions for which we want to compute parameters,
% and more detailed info on distributions, namely with a list of all instances for the distribution.
init_em(Items, state( AllDists, AllDistInstances, MargVars, SolverVars)) :-
init_em(Items, State) :-
clpbn_flag(em_solver, Solver),
% only used for PCGs
clpbn_init_graph(Solver),
@ -83,12 +94,27 @@ init_em(Items, state( AllDists, AllDistInstances, MargVars, SolverVars)) :-
% randomise_all_dists,
% set initial values for distributions
uniformise_all_dists,
% get all variablews to marginalise
setup_em_network(Solver, State).
setup_em_network(Solver, state( AllDists, AllDistInstances, MargVars, SolverState)) :-
clpbn:use_parfactors(on), !,
% get all variables to marginalise
attributes:all_attvars(AllVars0),
% and order them
sort_vars_by_key(AllVars0,AllVars,[]),
% no, we are in trouble because we don't know the network yet.
% get the ground network
generate_network([AllVars], _, Keys, Factors, EList),
% get the EM CPT connections info from the factors
generate_dists(Factors, EList, AllDists, AllDistInstances, MargVars),
% setup solver, if necessary
clpbn_init_solver(Solver, MargVars, _AllVars, ground(MargVars, Keys, Factors, EList), SolverState).
setup_em_network(Solver, state( AllDists, AllDistInstances, MargVars, SolverVars)) :-
% get all variables to marginalise
attributes:all_attvars(AllVars0),
% and order them
sort_vars_by_key(AllVars0,AllVars,[]),
% remove variables that do not have to do with this query.
% check_for_hidden_vars(AllVars1, AllVars1, AllVars),
different_dists(AllVars, AllDists, AllDistInstances, MargVars),
% setup solver by doing parameter independent work.
clpbn_init_solver(Solver, MargVars, AllVars, _, SolverVars).
@ -97,7 +123,8 @@ init_em(Items, state( AllDists, AllDistInstances, MargVars, SolverVars)) :-
em_loop(Its, Likelihood0, State, MaxError, MaxIts, LikelihoodF, FTables) :-
estimate(State, LPs),
maximise(State, Tables, LPs, Likelihood),
writeln(iteration:Its:Likelihood:Its:Likelihood0:Tables),
ltables(Tables, F0Tables),
writeln(iteration:Its:Likelihood:Its:Likelihood0:F0Tables),
(
(
abs((Likelihood - Likelihood0)/Likelihood) < MaxError
@ -118,6 +145,41 @@ ltables([Id-T|Tables], [Key-LTable|FTables]) :-
get_dist_key(Id, Key),
ltables(Tables, FTables).
generate_dists(Factors, EList, AllDists, AllInfo, MargVars) :-
b_hash_new(Ev0),
elist_to_hash(EList, Ev0, Ev),
process_factors(Factors, Ev, Dists0),
sort(Dists0, Dists1),
group(Dists1, AllDists, AllInfo, MargVars0, []),
sort(MargVars0, MargVars).
elist_to_hash([], Ev, Ev).
elist_to_hash([K=V|EList], Ev0, Ev) :-
b_hash_insert(Ev0, K, V, Evi),
elist_to_hash(EList, Evi, Ev).
process_factors([], _Ev, []).
process_factors([f(bayes,Id,Ks)|Factors], Ev, [i(Id, Ks, Cases, NonEvs)|AllDistInstances]) :-
fetch_evidence(Ks, Ev, CompactCases, NonEvs),
uncompact_cases(CompactCases, Cases),
process_factors(Factors, Ev, AllDistInstances).
fetch_evidence([], _Ev, [], []).
fetch_evidence([K|Ks], Ev, [E|CompactCases], NonEvs) :-
b_hash_lookup(K, E, Ev), !,
fetch_evidence(Ks, Ev, CompactCases, NonEvs).
fetch_evidence([K|Ks], Ev, [Ns|CompactCases], [K|NonEvs]) :-
pfl:skolem(K,D),
domain_to_numbers(D,0,Ns),
fetch_evidence(Ks, Ev, CompactCases, NonEvs).
domain_to_numbers([],_,[]).
domain_to_numbers([_|D],I0,[I0|Ns]) :-
I is I0+1,
domain_to_numbers(D,I,Ns).
% collect the different dists we are going to learn next.
different_dists(AllVars, AllDists, AllInfo, MargVars) :-
all_dists(AllVars, AllVars, Dists0),
@ -134,14 +196,6 @@ different_dists(AllVars, AllDists, AllInfo, MargVars) :-
% Hiddens will be C
%
all_dists([], _, []).
all_dists([V|AllVars], AllVars0, [i(Id, [V|Parents], Cases, Hiddens)|Dists]) :-
clpbn:use_parfactors(on), !,
clpbn:get_atts(V, [key(K)]),
pfl:factor(bayes,Id,[K|PKeys],_,_,_),
find_variables(PKeys, AllVars0, Parents),
generate_hidden_cases([V|Parents], CompactCases, Hiddens),
uncompact_cases(CompactCases, Cases),
all_dists(AllVars, AllVars0, Dists).
all_dists([V|AllVars], AllVars0, [i(Id, [V|Parents], Cases, Hiddens)|Dists]) :-
% V is an instance of Id
clpbn:get_atts(V, [dist(Id,Parents)]),
@ -181,18 +235,12 @@ generate_hidden_cases([], [], []).
generate_hidden_cases([V|Parents], [P|Cases], Hiddens) :-
clpbn:get_atts(V, [evidence(P)]), !,
generate_hidden_cases(Parents, Cases, Hiddens).
generate_hidden_cases([V|Parents], [Cases|MoreCases], [V|Hiddens]) :-
clpbn:use_parfactors(on), !,
clpbn:get_atts(V, [key(K)]),
pfl:skolem(K,D), length(D,Sz),
gen_cases(0, Sz, Cases),
generate_hidden_cases(Parents, MoreCases, Hiddens).
generate_hidden_cases([V|Parents], [Cases|MoreCases], [V|Hiddens]) :-
clpbn:get_atts(V, [dist(Id,_)]),
get_dist_domain_size(Id, Sz),
gen_cases(0, Sz, Cases),
generate_hidden_cases(Parents, MoreCases, Hiddens).
gen_cases(Sz, Sz, []) :- !.
gen_cases(I, Sz, [I|Cases]) :-
I1 is I+1,