use avg CPT type.

how to deal with it is a solver problem, not an app issue.
This commit is contained in:
Vitor Santos Costa 2008-11-03 16:02:15 +00:00
parent 45df10e86d
commit f6c5d16f63
5 changed files with 91 additions and 54 deletions

View File

@ -1,9 +1,11 @@
%
% generate explicit CPTs
%
:- module(clpbn_aggregates, [ :- module(clpbn_aggregates, [
cpt_average/4, cpt_average/6,
cpt_average/5, cpt_average/7,
cpt_max/4, cpt_max/6,
cpt_min/4 cpt_min/6
]). ]).
:- use_module(library(clpbn), [{}/1]). :- use_module(library(clpbn), [{}/1]).
@ -11,6 +13,7 @@
:- use_module(library(lists), :- use_module(library(lists),
[last/2, [last/2,
sumlist/2, sumlist/2,
sum_list/3,
max_list/2, max_list/2,
min_list/2 min_list/2
]). ]).
@ -22,66 +25,77 @@
:- use_module(dists, [get_dist_domain_size/2]). :- use_module(dists, [get_dist_domain_size/2]).
cpt_average(Vars, Key, Els0, CPT) :- cpt_average(AllVars, Key, Els0, Tab, Vs, NewVs) :-
build_avg_table(Vars, Els0, Key, 1.0, CPT). cpt_average(AllVars, Key, Els0, 1.0, Tab, Vs, NewVs).
cpt_average(Vars, Key, Els0, Softness, CPT) :- % support variables with evidence from domain. This should make everyone's life easier.
build_avg_table(Vars, Els0, Key, Softness, CPT). cpt_average([_|Vars], Key, Els0, Softness, p(Els0, CPT, NewEls), Vs, NewVs) :-
find_evidence(Vars, 0, TotEvidence, RVars),
build_avg_table(RVars, Vars, Els0, Key, TotEvidence, Softness, MAT, NewEls, Vs, NewVs),
matrix_to_list(MAT, CPT).
cpt_max(Vars, Key, Els0, CPT) :- find_evidence([], TotEvidence, TotEvidence, []).
build_max_table(Vars, Els0, Els0, Key, 1.0, CPT). find_evidence([V|Vars], TotEvidence0, TotEvidence, RVars) :-
clpbn:get_atts(V,[evidence(Ev)]), !,
TotEvidenceI is TotEvidence0+Ev,
find_evidence(Vars, TotEvidenceI, TotEvidence, RVars).
find_evidence([V|Vars], TotEvidence0, TotEvidence, [V|RVars]) :-
find_evidence(Vars, TotEvidence0, TotEvidence, RVars).
cpt_min(Vars, Key, Els0, CPT) :- cpt_max([_|Vars], Key, Els0, CPT, Vs, NewVs) :-
build_min_table(Vars, Els0, Els0, Key, 1.0, CPT). build_max_table(Vars, Els0, Els0, Key, 1.0, CPT, Vs, NewVs).
build_avg_table(Vars, Domain, _, Softness, p(Domain, CPT, Vars)) :- cpt_min([_|Vars], Key, Els0, CPT, Vs, NewVs) :-
build_min_table(Vars, Els0, Els0, Key, 1.0, CPT, Vs, NewVs).
build_avg_table(Vars, OVars, Domain, _, TotEvidence, Softness, CPT, Vars, Vs, Vs) :-
length(Domain, SDomain), length(Domain, SDomain),
int_power(Vars, SDomain, 1, TabSize), int_power(Vars, SDomain, 1, TabSize),
TabSize =< 16, TabSize =< 256,
/* case gmp is not there !! */ /* case gmp is not there !! */
TabSize > 0, !, TabSize > 0, !,
average_cpt(Vars, Domain, Softness, CPT). average_cpt(Vars, OVars, Domain, TotEvidence, Softness, CPT).
build_avg_table(Vars, Domain, Key, Softness, p(Domain, CPT, [V1,V2])) :- build_avg_table(Vars, OVars, Domain, Key, TotEvidence, Softness, CPT, [V1,V2], Vs, [V1,V2|NewVs]) :-
length(Vars,L), length(Vars,L),
LL1 is L//2, LL1 is L//2,
LL2 is L-LL1, LL2 is L-LL1,
list_split(LL1, Vars, L1, L2), list_split(LL1, Vars, L1, L2),
Min = 0, Min = 0,
length(Domain,Max1), Max is Max1-1, length(Domain,Max1), Max is Max1-1,
build_intermediate_table(LL1, sum(Min,Max), L1, V1, Key, 1.0, 0, I1), build_intermediate_table(LL1, sum(Min,Max), L1, V1, Key, 1.0, 0, I1, Vs, Vs1),
build_intermediate_table(LL2, sum(Min,Max), L2, V2, Key, 1.0, I1, _), build_intermediate_table(LL2, sum(Min,Max), L2, V2, Key, 1.0, I1, _, Vs1, NewVs),
average_cpt([V1,V2], Domain, Softness, CPT). average_cpt([V1,V2], OVars, Domain, TotEvidence, Softness, CPT).
build_max_table(Vars, Domain, Softness, p(Domain, CPT, Vars)) :- build_max_table(Vars, Domain, Softness, p(Domain, CPT, Vars), Vs, Vs) :-
length(Domain, SDomain), length(Domain, SDomain),
int_power(Vars, SDomain, 1, TabSize), int_power(Vars, SDomain, 1, TabSize),
TabSize =< 16, TabSize =< 16,
/* case gmp is not there !! */ /* case gmp is not there !! */
TabSize > 0, !, TabSize > 0, !,
max_cpt(Vars, Domain, Softness, CPT). max_cpt(Vars, Domain, Softness, CPT).
build_max_table(Vars, Domain, Softness, p(Domain, CPT, [V1,V2])) :- build_max_table(Vars, Domain, Softness, p(Domain, CPT, [V1,V2]), Vs, [V1,V2|NewVs]) :-
length(Vars,L), length(Vars,L),
LL1 is L//2, LL1 is L//2,
LL2 is L-LL1, LL2 is L-LL1,
list_split(LL1, Vars, L1, L2), list_split(LL1, Vars, L1, L2),
build_intermediate_table(LL1, max(Domain,CPT), L1, V1, Key, 1.0, 0, I1), build_intermediate_table(LL1, max(Domain,CPT), L1, V1, Key, 1.0, 0, I1, Vs, Vs1),
build_intermediate_table(LL2, max(Domain,CPT), L2, V2, Key, 1.0, I1, _), build_intermediate_table(LL2, max(Domain,CPT), L2, V2, Key, 1.0, I1, _, Vs1, NewVs),
max_cpt([V1,V2], Domain, Softness, CPT). max_cpt([V1,V2], Domain, Softness, CPT).
build_min_table(Vars, Domain, Softness, p(Domain, CPT, Vars)) :- build_min_table(Vars, Domain, Softness, p(Domain, CPT, Vars), Vs, Vs) :-
length(Domain, SDomain), length(Domain, SDomain),
int_power(Vars, SDomain, 1, TabSize), int_power(Vars, SDomain, 1, TabSize),
TabSize =< 16, TabSize =< 16,
/* case gmp is not there !! */ /* case gmp is not there !! */
TabSize > 0, !, TabSize > 0, !,
min_cpt(Vars, Domain, Softness, CPT). min_cpt(Vars, Domain, Softness, CPT).
build_min_table(Vars, Domain, Softness, p(Domain, CPT, [V1,V2])) :- build_min_table(Vars, Domain, Softness, p(Domain, CPT, [V1,V2]), Vs, [V1,V2|NewVs]) :-
length(Vars,L), length(Vars,L),
LL1 is L//2, LL1 is L//2,
LL2 is L-LL1, LL2 is L-LL1,
list_split(LL1, Vars, L1, L2), list_split(LL1, Vars, L1, L2),
build_intermediate_table(LL1, min(Domain,CPT), L1, V1, Key, 1.0, 0, I1), build_intermediate_table(LL1, min(Domain,CPT), L1, V1, Key, 1.0, 0, I1, Vs, Vs1),
build_intermediate_table(LL2, min(Domain,CPT), L2, V2, Key, 1.0, I1, _), build_intermediate_table(LL2, min(Domain,CPT), L2, V2, Key, 1.0, I1, _, Vs1, NewVs),
min_cpt([V1,V2], Domain, Softness, CPT). min_cpt([V1,V2], Domain, Softness, CPT).
int_power([], _, TabSize, TabSize). int_power([], _, TabSize, TabSize).
@ -89,17 +103,17 @@ int_power([_|L], X, I0, TabSize) :-
I is I0*X, I is I0*X,
int_power(L, X, I, TabSize). int_power(L, X, I, TabSize).
build_intermediate_table(1,_,[V],V, _, _, I, I) :- !. build_intermediate_table(1,_,[V],V, _, _, I, I, Vs, Vs) :- !.
build_intermediate_table(2, Op, [V1,V2], V, Key, Softness, I0, If) :- !, build_intermediate_table(2, Op, [V1,V2], V, Key, Softness, I0, If, Vs, Vs) :- !,
If is I0+1, If is I0+1,
generate_tmp_random(Op, 2, [V1,V2], V, Key, Softness, I0). generate_tmp_random(Op, 2, [V1,V2], V, Key, Softness, I0).
build_intermediate_table(N, Op, L, V, Key, Softness, I0, If) :- build_intermediate_table(N, Op, L, V, Key, Softness, I0, If, Vs, [V1,V2|NewVs]) :-
LL1 is N//2, LL1 is N//2,
LL2 is N-LL1, LL2 is N-LL1,
list_split(LL1, L, L1, L2), list_split(LL1, L, L1, L2),
I1 is I0+1, I1 is I0+1,
build_intermediate_table(LL1, Op, L1, V1, Key, Softness, I1, I2), build_intermediate_table(LL1, Op, L1, V1, Key, Softness, I1, I2, Vs, Vs1),
build_intermediate_table(LL2, Op, L2, V2, Key, Softness, I2, If), build_intermediate_table(LL2, Op, L2, V2, Key, Softness, I2, If, Vs1, NewVs),
generate_tmp_random(Op, N, [V1,V2], V, Key, Softness, I0). generate_tmp_random(Op, N, [V1,V2], V, Key, Softness, I0).
% averages are transformed into sums. % averages are transformed into sums.
@ -129,26 +143,26 @@ list_split(I, [H|L], [H|L1], L2) :-
% generate actual table, instead of trusting the solver % generate actual table, instead of trusting the solver
% %
average_cpt(Vs,Vals,_,CPT) :- average_cpt(Vs, OVars, Vals, Base, _, MCPT) :-
get_ds_lengths(Vs,Lengs), get_ds_lengths(Vs,Lengs),
sumlist(Lengs, Tot), length(OVars, N),
length(Vals,SVals), length(Vals, SVals),
Tot is (N-1)*SVals,
Factor is SVals/Tot, Factor is SVals/Tot,
matrix_new(floats,[SVals|Lengs],MCPT), matrix_new(floats,[SVals|Lengs],MCPT),
fill_in_average(Lengs,Factor,MCPT), fill_in_average(Lengs,Factor,Base,MCPT).
matrix_to_list(MCPT,CPT).
get_ds_lengths([],[]). get_ds_lengths([],[]).
get_ds_lengths([V|Vs],[Sz|Lengs]) :- get_ds_lengths([V|Vs],[Sz|Lengs]) :-
get_vdist_size(V, Sz), get_vdist_size(V, Sz),
get_ds_lengths(Vs,Lengs). get_ds_lengths(Vs,Lengs).
fill_in_average(Lengs,SVals,MCPT) :- fill_in_average(Lengs, SVals, Base, MCPT) :-
generate(Lengs, Case), generate(Lengs, Case),
average(Case, SVals, Val), average(Case, SVals, Base, Val),
matrix_set(MCPT,[Val|Case],1.0), matrix_set(MCPT,[Val|Case],1.0),
fail. fail.
fill_in_average(_,_,_). fill_in_average(_,_,_,_).
generate([], []). generate([], []).
generate([N|Lengs], [C|Case]) :- generate([N|Lengs], [C|Case]) :-
@ -161,8 +175,8 @@ from(I1,M,J) :-
I < M, I < M,
from(I,M,J). from(I,M,J).
average(Case, SVals, Val) :- average(Case, SVals, Base, Val) :-
sumlist(Case, Tot), sum_list(Case, Base, Tot),
Val is integer(round(Tot*SVals)). Val is integer(round(Tot*SVals)).

View File

@ -68,6 +68,9 @@ where Id is the id,
DSize is the domain size, DSize is the domain size,
Type is Type is
tab for tabular tab for tabular
avg for average
max for maximum
min for minimum
trans for HMMs trans for HMMs
continuous continuous
Domain is Domain is
@ -98,6 +101,9 @@ dist(V, Id, Key, Parents) :-
dist(V, Id, Key, Parents) :- dist(V, Id, Key, Parents) :-
var(Key), !, var(Key), !,
when(Key, dist(V, Id, Key, Parents)). when(Key, dist(V, Id, Key, Parents)).
dist(avg(Domain, Parents), avg(Domain), _, Parents).
dist(max(Domain, Parents), max(Domain), _, Parents).
dist(min(Domain, Parents), min(Domain), _, Parents).
dist(p(Type, CPT), Id, Key, FParents) :- dist(p(Type, CPT), Id, Key, FParents) :-
copy_structure(Key, Key0), copy_structure(Key, Key0),
distribution(Type, CPT, Id, Key0, [], FParents). distribution(Type, CPT, Id, Key0, [], FParents).
@ -207,6 +213,8 @@ get_dsizes([P|Parents], [Sz|Sizes], Sizes0) :-
get_dist_params(Id, Parms) :- get_dist_params(Id, Parms) :-
recorded(clpbn_dist_db, db(Id, _, Parms, _, _, _, _), _). recorded(clpbn_dist_db, db(Id, _, Parms, _, _, _, _), _).
get_dist_domain_size(avg(D,_), DSize) :- !,
length(D, DSize).
get_dist_domain_size(Id, DSize) :- get_dist_domain_size(Id, DSize) :-
recorded(clpbn_dist_db, db(Id, _, _, _, _, _, DSize), _). recorded(clpbn_dist_db, db(Id, _, _, _, _, _, DSize), _).

View File

@ -47,8 +47,7 @@ course_professor(Key, PKey) :-
course_rating(CKey, Rat) :- course_rating(CKey, Rat) :-
setof(Sat, RKey^(registration_course(RKey,CKey), registration_satisfaction(RKey,Sat)), Sats), setof(Sat, RKey^(registration_course(RKey,CKey), registration_satisfaction(RKey,Sat)), Sats),
build_rating_table(Sats, rating(CKey), Table), { Rat = rating(CKey) with avg([h,m,l],Sats) }.
{ Rat = rating(CKey) with Table }.
course_difficulty(Key, Dif) :- course_difficulty(Key, Dif) :-
dif_table(Key, Dist), dif_table(Key, Dist),
@ -64,8 +63,7 @@ student_intelligence(Key, Int) :-
student_ranking(Key, Rank) :- student_ranking(Key, Rank) :-
setof(Grade, CKey^(registration_student(CKey,Key), setof(Grade, CKey^(registration_student(CKey,Key),
registration_grade(CKey, Grade)), Grades), registration_grade(CKey, Grade)), Grades),
build_grades_table(Grades, ranking(Key), GradesTable), { Rank = ranking(Key) with avg([a,b,c,d],Grades) }.
{ Rank = ranking(Key) with GradesTable }.
:- ensure_loaded(tables). :- ensure_loaded(tables).

View File

@ -31,12 +31,12 @@
:- use_module(library('clpbn/dists'), :- use_module(library('clpbn/dists'),
[ [
dist/4,
get_dist_domain_size/2, get_dist_domain_size/2,
get_dist_matrix/5]). get_dist_matrix/5]).
:- use_module(library('clpbn/utils'), [ :- use_module(library('clpbn/utils'), [
clpbn_not_var_member/2, clpbn_not_var_member/2]).
check_for_hidden_vars/3]).
:- use_module(library('clpbn/display'), [ :- use_module(library('clpbn/display'), [
clpbn_bind_vals/3]). clpbn_bind_vals/3]).
@ -60,6 +60,10 @@
append/3 append/3
]). ]).
:- use_module(library('clpbn/aggregates'),
[cpt_average/6]).
check_if_vel_done(Var) :- check_if_vel_done(Var) :-
get_atts(Var, [size(_)]), !. get_atts(Var, [size(_)]), !.
@ -70,14 +74,12 @@ vel([[]],_,_) :- !.
vel([LVs],Vs0,AllDiffs) :- vel([LVs],Vs0,AllDiffs) :-
init_vel_solver([LVs], Vs0, AllDiffs, State), init_vel_solver([LVs], Vs0, AllDiffs, State),
% variable elimination proper % variable elimination proper
run_vel_solver([LVs], [Ps], State), run_vel_solver([LVs], [LPs], State),
% from array to list
list_from_CPT(Ps, LPs),
% bind Probs back to variables so that they can be output. % bind Probs back to variables so that they can be output.
clpbn_bind_vals([LVs],[LPs],AllDiffs). clpbn_bind_vals([LVs],[LPs],AllDiffs).
init_vel_solver(Qs, Vs0, _, LVis) :- init_vel_solver(Qs, Vs0, _, LVis) :-
check_for_hidden_vars(Vs0, Vs0, Vs1), check_for_special_vars(Vs0, Vs1),
% LVi will have a list of CLPBN variables % LVi will have a list of CLPBN variables
% Tables0 will have the full data on each variable % Tables0 will have the full data on each variable
init_influences(Vs1, G, RG), init_influences(Vs1, G, RG),
@ -86,6 +88,21 @@ init_vel_solver(Qs, Vs0, _, LVis) :-
(clpbn:output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,Vs) ; true), (clpbn:output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,Vs) ; true),
(clpbn:output(gviz(XBifStream)) -> clpbn2gviz(XBifStream,vel,Vs,_) ; true). (clpbn:output(gviz(XBifStream)) -> clpbn2gviz(XBifStream,vel,Vs,_) ; true).
check_for_special_vars([], []).
check_for_special_vars([V|Vs0], [V|Vs1]) :-
clpbn:get_atts(V, [key(K), dist(Id,Parents)]), !,
simplify_dist(Id, V, K, Parents, Vs0, Vs00),
check_for_special_vars(Vs00, Vs1).
check_for_special_vars([_|Vs0], Vs1) :-
check_for_special_vars(Vs0, Vs1).
% transform aggregate distribution into tree
simplify_dist(avg(Domain), V, Key, Parents, Vs0, VsF) :- !,
cpt_average([V|Parents], Key, Domain, NewDist, Vs0, VsF),
dist(NewDist, Id, Key, ParentsF),
clpbn:put_atts(V, [dist(Id,ParentsF)]).
simplify_dist(_, _, _, _, Vs0, Vs0).
init_vel_solver_for_questions([], _, _, [], []). init_vel_solver_for_questions([], _, _, [], []).
init_vel_solver_for_questions([Vs|MVs], G, RG, [NVs|MNVs0], [NVs|LVis]) :- init_vel_solver_for_questions([Vs|MVs], G, RG, [NVs|MNVs0], [NVs|LVis]) :-
influences(Vs, _, NVs0, G, RG), influences(Vs, _, NVs0, G, RG),

View File

@ -11,7 +11,7 @@ main :-
em(L,0.01,10,CPTs,Lik), em(L,0.01,10,CPTs,Lik),
writeln(Lik:CPTs). writeln(Lik:CPTs).
missing(0.3). missing(0.1).
% miss 30% of the examples. % miss 30% of the examples.
goal(professor_ability(P,V)) :- goal(professor_ability(P,V)) :-