fix underflows when computing marginals.

This commit is contained in:
Vitor Santos Costa 2008-11-04 03:55:49 +00:00
parent 0dc7d3492d
commit 13fb5d5156
5 changed files with 44 additions and 3 deletions

View File

@ -37,7 +37,7 @@
matrix_sum_logs_out/3, matrix_sum_logs_out/3,
matrix_sum_logs_out_several/3, matrix_sum_logs_out_several/3,
matrix_op_to_all/4, matrix_op_to_all/4,
matrix_to_exps/1, matrix_to_exps2/1,
matrix_to_logs/1, matrix_to_logs/1,
matrix_set_all_that_disagree/5, matrix_set_all_that_disagree/5,
matrix_to_list/2, matrix_to_list/2,
@ -163,7 +163,7 @@ expand_tabs([V1|Deps1], [S1|Sz1], [V2|Deps2], [S2|Sz2], Map1, Map2, NDeps) :-
). ).
normalise_CPT(MAT,NMAT) :- normalise_CPT(MAT,NMAT) :-
matrix_to_exps(MAT), matrix_to_exps2(MAT),
matrix_sum(MAT, Sum), matrix_sum(MAT, Sum),
matrix_op_to_all(MAT, /, Sum, NMAT). matrix_op_to_all(MAT, /, Sum, NMAT).

View File

@ -102,8 +102,19 @@ run_vel_solver([LVs|MoreLVs], [Ps|MorePs], [NVs0|MoreLVis]) :-
list_from_CPT(Dist, LPs), list_from_CPT(Dist, LPs),
normalise_CPT(Dist,MPs), normalise_CPT(Dist,MPs),
list_from_CPT(MPs, Ps), list_from_CPT(MPs, Ps),
lists:sumlist(Ps,SUM), ((SUM > 0.9 , SUM < 1.1) -> true ; writeln(LPs:Ps), get_els(LVs),writeln('--'), get_els(NVs0), abort),
run_vel_solver(MoreLVs, MorePs, MoreLVis). run_vel_solver(MoreLVs, MorePs, MoreLVis).
get_els([]).
get_els(V.NVs0) :-
clpbn:get_atts(V,[key(K),evidence(El)]), !,
writeln(K:El),
get_els(NVs0).
get_els(V.NVs0) :-
clpbn:get_atts(V,[key(K)]),
writeln(K),
get_els(NVs0).
% %
% just get a list of variables plus associated tables % just get a list of variables plus associated tables
% %

View File

@ -11,7 +11,7 @@ main :-
em(L,0.01,10,CPTs,Lik), em(L,0.01,10,CPTs,Lik),
writeln(Lik:CPTs). writeln(Lik:CPTs).
missing(0.1). missing(0.3).
% miss 30% of the examples. % miss 30% of the examples.
goal(professor_ability(P,V)) :- goal(professor_ability(P,V)) :-

View File

@ -71,6 +71,7 @@ typedef enum {
matrix_agg_cols/3, matrix_agg_cols/3,
matrix_to_logs/1, matrix_to_logs/1,
matrix_to_exps/1, matrix_to_exps/1,
matrix_to_exps2/1,
matrix_to_logs/2, matrix_to_logs/2,
matrix_to_exps/2, matrix_to_exps/2,
matrix_op/4, matrix_op/4,

View File

@ -1100,6 +1100,34 @@ matrix_exp_all(void)
return TRUE; return TRUE;
} }
static int
matrix_exp2_all(void)
{
int *mat;
mat = (int *)YAP_BlobOfTerm(YAP_ARG1);
if (!mat) {
/* Error */
return FALSE;
}
if (mat[MAT_TYPE] == INT_MATRIX) {
return FALSE;
} else {
double *data = matrix_double_data(mat, mat[MAT_NDIMS]);
int i;
double max = data[0];
for (i=1; i< mat[MAT_SIZE]; i++) {
if (data[i] > max) max = data[i];
}
for (i=0; i< mat[MAT_SIZE]; i++) {
data[i] = exp(data[i]-max);
}
}
return TRUE;
}
static int static int
matrix_exp_all2(void) matrix_exp_all2(void)
{ {
@ -2959,6 +2987,7 @@ init_matrix(void)
YAP_UserCPredicate("matrix_column", matrix_column, 3); YAP_UserCPredicate("matrix_column", matrix_column, 3);
YAP_UserCPredicate("matrix_to_logs", matrix_log_all,1); YAP_UserCPredicate("matrix_to_logs", matrix_log_all,1);
YAP_UserCPredicate("matrix_to_exps", matrix_exp_all, 1); YAP_UserCPredicate("matrix_to_exps", matrix_exp_all, 1);
YAP_UserCPredicate("matrix_to_exps2", matrix_exp2_all, 1);
YAP_UserCPredicate("matrix_to_logs", matrix_log_all2,2); YAP_UserCPredicate("matrix_to_logs", matrix_log_all2,2);
YAP_UserCPredicate("matrix_to_exps", matrix_exp_all2, 2); YAP_UserCPredicate("matrix_to_exps", matrix_exp_all2, 2);
YAP_UserCPredicate("matrix_sum_out", matrix_sum_out, 3); YAP_UserCPredicate("matrix_sum_out", matrix_sum_out, 3);