more fixes to improve aggregate performance:

handle evidence on children and root nodes.
make graphviz call on top-level.
This commit is contained in:
Vitor Santos Costa 2008-11-04 00:46:17 +00:00
parent f6c5d16f63
commit 8a9d66d001
7 changed files with 67 additions and 24 deletions

View File

@ -77,14 +77,17 @@
sort_vars_by_key/3
]).
:- use_module('clpbn/graphviz',
[clpbn2gviz/4]).
:- dynamic solver/1,output/1,use/1,suppress_attribute_display/1, parameter_softening/1, em_solver/1.
solver(jt).
em_solver(vel).
%output(xbif(user_error)).
%output(gviz(user_error)).
output(no).
output(gviz(user_error)).
%output(no).
suppress_attribute_display(false).
parameter_softening(laplace).
@ -166,6 +169,8 @@ project_attributes(GVars, AVars) :-
clpbn_vars(AVars, DiffVars, AllVars),
get_clpbn_vars(GVars,CLPBNGVars0),
simplify_query_vars(CLPBNGVars0, CLPBNGVars),
(output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,AllVars) ; true),
(output(gviz(XBifStream)) -> clpbn2gviz(XBifStream,sort,AllVars,GVars) ; true),
(
Solver = graphs
->

View File

@ -15,7 +15,8 @@
sumlist/2,
sum_list/3,
max_list/2,
min_list/2
min_list/2,
nth0/3
]).
:- use_module(library(matrix),
@ -23,17 +24,22 @@
matrix_to_list/2,
matrix_set/3]).
:- use_module(library('clpbn/matrix_cpt_utils'),
[normalise_CPT_on_lines/3]).
:- use_module(dists, [get_dist_domain_size/2]).
cpt_average(AllVars, Key, Els0, Tab, Vs, NewVs) :-
cpt_average(AllVars, Key, Els0, 1.0, Tab, Vs, NewVs).
% support variables with evidence from domain. This should make everyone's life easier.
cpt_average([_|Vars], Key, Els0, Softness, p(Els0, CPT, NewEls), Vs, NewVs) :-
cpt_average([Ev|Vars], Key, Els0, Softness, p(Els0, CPT, NewParents), Vs, NewVs) :-
find_evidence(Vars, 0, TotEvidence, RVars),
build_avg_table(RVars, Vars, Els0, Key, TotEvidence, Softness, MAT, NewEls, Vs, NewVs),
matrix_to_list(MAT, CPT).
build_avg_table(RVars, Vars, Els0, Key, TotEvidence, Softness, MAT0, NewParents0, Vs, IVs),
include_qevidence(Ev, MAT0, MAT, NewParents0, NewParents, Vs, IVs, NewVs),
matrix_to_list(MAT, CPT), writeln(NewParents: Vs: NewVs: CPT).
% find all fixed kids, this simplifies significantly the function.
find_evidence([], TotEvidence, TotEvidence, []).
find_evidence([V|Vars], TotEvidence0, TotEvidence, RVars) :-
clpbn:get_atts(V,[evidence(Ev)]), !,
@ -139,6 +145,35 @@ list_split(I, [H|L], [H|L1], L2) :-
I1 is I-1,
list_split(I1, L, L1, L2).
%
% if we have evidence, we need to check if we are always consistent, never consistent, or can be consistent
%
include_qevidence(V, MAT0, MAT, NewParents0, NewParents, Vs, IVs, NewVs) :-
clpbn:get_atts(V,[evidence(Ev)]), !,
normalise_CPT_on_lines(MAT0, MAT1, L1),
check_consistency(L1, Ev, MAT0, MAT1, L1, MAT, NewParents0, NewParents, Vs, IVs, NewVs).
include_qevidence(_, MAT, MAT, NewParents, NewParents, _, Vs, Vs).
check_consistency(L1, Ev, MAT0, MAT1, L1, MAT, NewParents0, NewParents, Vs, IVs, NewVs) :-
sumlist(L1, Tot),
nth0(Ev, L1, Val),
(Val == Tot ->
writeln(Ev:L1:Val:1),
MAT1 = MAT,
NewParents = [],
Vs = NewVs
;
Val == 0.0 ->
writeln(Ev:L1:Val:2),
throw(error(domain_error(incompatible_evidence),evidence(Ev)))
;
writeln(Ev:L1:Val:3),
MAT0 = MAT,
NewParents = NewParents0,
IVs = NewVs
).
%
% generate actual table, instead of trusting the solver
%
@ -147,19 +182,17 @@ average_cpt(Vs, OVars, Vals, Base, _, MCPT) :-
get_ds_lengths(Vs,Lengs),
length(OVars, N),
length(Vals, SVals),
Tot is (N-1)*SVals,
Factor is SVals/Tot,
matrix_new(floats,[SVals|Lengs],MCPT),
fill_in_average(Lengs,Factor,Base,MCPT).
fill_in_average(Lengs,N,Base,MCPT).
get_ds_lengths([],[]).
get_ds_lengths([V|Vs],[Sz|Lengs]) :-
get_vdist_size(V, Sz),
get_ds_lengths(Vs,Lengs).
fill_in_average(Lengs, SVals, Base, MCPT) :-
fill_in_average(Lengs, N, Base, MCPT) :-
generate(Lengs, Case),
average(Case, SVals, Base, Val),
average(Case, N, Base, Val),
matrix_set(MCPT,[Val|Case],1.0),
fail.
fill_in_average(_,_,_,_).
@ -175,9 +208,9 @@ from(I1,M,J) :-
I < M,
from(I,M,J).
average(Case, SVals, Base, Val) :-
average(Case, N, Base, Val) :-
sum_list(Case, Base, Tot),
Val is integer(round(Tot*SVals)).
Val is integer(round(Tot/N)).
sum_cpt(Vs,Vals,_,CPT) :-

View File

@ -227,6 +227,8 @@ get_dist_key(Id, Key) :-
get_dist_nparams(Id, NParms) :-
recorded(clpbn_dist_db, db(Id, _, _, _, _, NParms, _), _).
get_evidence_position(El, avg(Domain), Pos) :- !,
nth0(Pos, Domain, El), !.
get_evidence_position(El, Id, Pos) :-
recorded(clpbn_dist_db, db(Id, _, _, _, Domain, _, _), _),
nth0(Pos, Domain, El), !.

View File

@ -65,8 +65,6 @@
%
gibbs(LVs,Vs0,AllDiffs) :-
init_gibbs_solver(LVs, Vs0, AllDiffs, Vs),
(clpbn:output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,Vs) ; true),
(clpbn:output(gviz(XBifStream)) -> clpbn2gviz(XBifStream,vel,Vs,LVs) ; true),
run_gibbs_solver(LVs, LPs, Vs),
clpbn_bind_vals(LVs,LPs,AllDiffs),
clean_up.

View File

@ -1,4 +1,4 @@
:- module(gviz, [clpbn2gviz/4]).
:- module(clpbn_gviz, [clpbn2gviz/4]).
clpbn2gviz(Stream, Name, Network, Output) :-
format(Stream, 'digraph ~w {

View File

@ -16,7 +16,8 @@
column_from_possibly_deterministic_CPT/3,
multiply_possibly_deterministic_factors/3,
random_CPT/2,
uniform_CPT/2]).
uniform_CPT/2,
normalise_CPT_on_lines/3]).
:- use_module(dists,
[get_dist_domain_size/2,
@ -41,6 +42,7 @@
matrix_set_all_that_disagree/5,
matrix_to_list/2,
matrix_agg_lines/3,
matrix_agg_cols/3,
matrix_op_to_lines/4,
matrix_column/3]).
@ -252,4 +254,9 @@ uniform_CPT(Dims, M) :-
matrix_new_set(floats,Dims,1.0,M1),
normalise_possibly_deterministic_CPT(M1, M).
normalise_CPT_on_lines(MAT0, MAT2, L1) :-
matrix_agg_cols(MAT0, +, MAT1),
matrix_sum(MAT1, SUM),
matrix_op_to_all(MAT1, /, SUM, MAT2),
matrix:matrix_to_list(MAT2,L1).

View File

@ -83,10 +83,7 @@ init_vel_solver(Qs, Vs0, _, LVis) :-
% LVi will have a list of CLPBN variables
% Tables0 will have the full data on each variable
init_influences(Vs1, G, RG),
init_vel_solver_for_questions(Qs, G, RG, Vs0F, LVis),
term_variables(Vs0F, Vs),
(clpbn:output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,Vs) ; true),
(clpbn:output(gviz(XBifStream)) -> clpbn2gviz(XBifStream,vel,Vs,_) ; true).
init_vel_solver_for_questions(Qs, G, RG, _, LVis).
check_for_special_vars([], []).
check_for_special_vars([V|Vs0], [V|Vs1]) :-
@ -211,7 +208,7 @@ process(LV0, _, Out) :-
find_best([], V, _TF, V, _, [], _).
%:-
% clpbn:get_atts(V,[key(K)]), write(chosen:K:TF), nl.
% clpbn:get_atts(V,[key(K)]), writeln(chosen:K:_TF).
% root_with_single_child
%find_best([var(V,I,_,_,[],Ev,[Dep],K)|LV], _, _, V, [Dep], LVF, Inputs) :- !.
find_best([var(V,I,Sz,Vals,Parents,Ev,Deps,K)|LV], _, Threshold, VF, NWorktables, LVF, Inputs) :-
@ -226,8 +223,9 @@ find_best([var(V,I,Sz,Vals,Parents,Ev,Deps,K)|LV], _, Threshold, VF, NWorktables
find_best([V|LV], V0, Threshold, VF, WorkTables, [V|LVF], Inputs) :-
find_best(LV, V0, Threshold, VF, WorkTables, LVF, Inputs).
multiply_tables([Table], Table) :- !.
multiply_tables([Table], Table) :- !. %, Table = tab(T,D,S),matrix:matrix_to_list(T,L),writeln(D:S:L).
multiply_tables([TAB1, TAB2| Tables], Out) :-
%TAB1 = tab(T,_,_),matrix:matrix_to_list(T,L),writeln(doing:L),
multiply_CPTs(TAB1, TAB2, TAB, _),
multiply_tables([TAB| Tables], Out).