more fixes to use matrices in gibbs sampling.

This commit is contained in:
Vitor Santos Costa 2008-10-31 09:41:52 +00:00
parent 6408ae52ac
commit e66e7c86bc

View File

@ -28,15 +28,19 @@
:- use_module(library(ordsets), :- use_module(library(ordsets),
[ord_subtract/3]). [ord_subtract/3]).
:- use_module(library('clpbn/discrete_utils'), [ :- use_module(library('clpbn/matrix_cpt_utils'), [
project_from_CPT/3, project_from_CPT/3,
reorder_CPT/5]). reorder_CPT/5,
multiply_possibly_deterministic_factors/3,
row_from_possibly_deterministic_CPT/3,
normalise_possibly_deterministic_CPT/2,
list_from_CPT/2]).
:- use_module(library('clpbn/utils'), [ :- use_module(library('clpbn/utils'), [
check_for_hidden_vars/3]). check_for_hidden_vars/3]).
:- use_module(library('clpbn/dists'), [ :- use_module(library('clpbn/dists'), [
get_dist/4, get_possibly_deterministic_dist_matrix/5,
get_dist_domain_size/2]). get_dist_domain_size/2]).
:- use_module(library('clpbn/topsort'), [ :- use_module(library('clpbn/topsort'), [
@ -47,14 +51,13 @@
:- dynamic gibbs_params/3. :- dynamic gibbs_params/3.
:- dynamic implicit/1. :- dynamic explicit/1.
% arguments: % arguments:
% %
% list of output variables % list of output variables
% list of attributed variables % list of attributed variables
% %
gibbs([[]],_,_) :- !.
gibbs(LVs,Vs0,AllDiffs) :- gibbs(LVs,Vs0,AllDiffs) :-
init_gibbs_solver(LVs, Vs0, Vs), init_gibbs_solver(LVs, Vs0, Vs),
(clpbn:output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,Vs) ; true), (clpbn:output(xbif(XBifStream)) -> clpbn2xbif(XBifStream,vel,Vs) ; true),
@ -104,7 +107,8 @@ graph_representation([],_,_,_,[]).
graph_representation([V|Vs], Graph, I0, Keys, TGraph) :- graph_representation([V|Vs], Graph, I0, Keys, TGraph) :-
clpbn:get_atts(V,[evidence(_)]), !, clpbn:get_atts(V,[evidence(_)]), !,
clpbn:get_atts(V, [dist(Id,Parents)]), clpbn:get_atts(V, [dist(Id,Parents)]),
get_dist(Id, _, Vals, Table), get_possibly_deterministic_dist_matrix(Id, Parents, _, Vals, Table),
matrix:matrix_to_list(Table,T),writeln(T),
get_sizes(Parents, Szs), get_sizes(Parents, Szs),
length(Vals,Sz), length(Vals,Sz),
project_evidence_out([V|Parents],[V|Parents],Table,[Sz|Szs],Variables,NewTable), project_evidence_out([V|Parents],[V|Parents],Table,[Sz|Szs],Variables,NewTable),
@ -114,7 +118,7 @@ graph_representation([V|Vs], Graph, I0, Keys, TGraph) :-
graph_representation([V|Vs], Graph, I0, Keys, [I-IParents|TGraph]) :- graph_representation([V|Vs], Graph, I0, Keys, [I-IParents|TGraph]) :-
I is I0+1, I is I0+1,
clpbn:get_atts(V, [dist(Id,Parents)]), clpbn:get_atts(V, [dist(Id,Parents)]),
get_dist(Id, _, Vals, Table), get_possibly_deterministic_dist_matrix(Id, Parents, _, Vals, Table),
get_sizes(Parents, Szs), get_sizes(Parents, Szs),
length(Vals,Sz), length(Vals,Sz),
project_evidence_out([V|Parents],[V|Parents],Table,[Sz|Szs],Variables,NewTable), project_evidence_out([V|Parents],[V|Parents],Table,[Sz|Szs],Variables,NewTable),
@ -152,10 +156,8 @@ parent_indices([V|Parents], Keys, [I|IParents]) :-
project_evidence_out([],Deps,Table,_,Deps,Table). project_evidence_out([],Deps,Table,_,Deps,Table).
project_evidence_out([V|Parents],Deps,Table,Szs,NewDeps,NewTable) :- project_evidence_out([V|Parents],Deps,Table,Szs,NewDeps,NewTable) :-
clpbn:get_atts(V,[evidence(_)]), !, clpbn:get_atts(V,[evidence(_)]), !,
NTab =.. [t|Table], project_from_CPT(V,tab(Table,Deps,Szs),tab(ITable,IDeps,ISzs)),
project_from_CPT(V,tab(NTab,Deps,Szs),tab(ITable,IDeps,ISzs)), project_evidence_out(Parents,IDeps,ITable,ISzs,NewDeps,NewTable).
ITable =.. [_|LITable],
project_evidence_out(Parents,IDeps,LITable,ISzs,NewDeps,NewTable).
project_evidence_out([_Par|Parents],Deps,Table,Szs,NewDeps,NewTable) :- project_evidence_out([_Par|Parents],Deps,Table,Szs,NewDeps,NewTable) :-
project_evidence_out(Parents,Deps,Table,Szs,NewDeps,NewTable). project_evidence_out(Parents,Deps,Table,Szs,NewDeps,NewTable).
@ -188,10 +190,6 @@ vars2indices([V|Parents],Keys,[I-V|IParents]) :-
rb_lookup(V, I, Keys), rb_lookup(V, I, Keys),
vars2indices(Parents,Keys,IParents). vars2indices(Parents,Keys,IParents).
compact_table(NewTable, RepTable) :-
NewTable = [_|_], !,
RepTable =.. [t|NewTable].
% %
% This is the really cool bit. % This is the really cool bit.
% %
@ -202,7 +200,6 @@ compile_graph(Graph) :-
compile_vars([],_). compile_vars([],_).
compile_vars([var(_,I,_,Vals,Sz,VarSlot,Parents,_,_)|VarsInfo],Graph) compile_vars([var(_,I,_,Vals,Sz,VarSlot,Parents,_,_)|VarsInfo],Graph)
:- :-
compile_var(I,Vals,Sz,VarSlot,Parents,Graph), compile_var(I,Vals,Sz,VarSlot,Parents,Graph),
compile_vars(VarsInfo,Graph). compile_vars(VarsInfo,Graph).
@ -211,7 +208,7 @@ compile_var(I,Vals,Sz,VarSlot,Parents,Graph) :-
mult_list(Sizes,1,TotSize), mult_list(Sizes,1,TotSize),
compile_var(TotSize,I,Vals,Sz,VarSlot,Parents,Sizes,Graph). compile_var(TotSize,I,Vals,Sz,VarSlot,Parents,Sizes,Graph).
fetch_all_parents([],_,Parents,Parents,Sizes,Sizes). fetch_all_parents([],_,Parents,Parents,Sizes,Sizes) :- !.
fetch_all_parents([tabular(_,_,Ps)|CPTs],Graph,Parents0,ParentsF,Sizes0,SizesF) :- fetch_all_parents([tabular(_,_,Ps)|CPTs],Graph,Parents0,ParentsF,Sizes0,SizesF) :-
merge_these_parents(Ps,Graph,Parents0,ParentsI,Sizes0,SizesI), merge_these_parents(Ps,Graph,Parents0,ParentsI,Sizes0,SizesI),
fetch_all_parents(CPTs,Graph,ParentsI,ParentsF,SizesI,SizesF). fetch_all_parents(CPTs,Graph,ParentsI,ParentsF,SizesI,SizesF).
@ -241,16 +238,24 @@ mult_list([Sz|Sizes],Mult0,Mult) :-
% compile node as set of facts, faster execution % compile node as set of facts, faster execution
compile_var(TotSize,I,_Vals,Sz,CPTs,Parents,_Sizes,Graph) :- compile_var(TotSize,I,_Vals,Sz,CPTs,Parents,_Sizes,Graph) :-
TotSize < 1024*64, TotSize > 0, !, TotSize < 1024*64, TotSize > 0, !,
writeln(I), (I=55->assert(a); retractall(a)),
multiply_all(I,Parents,CPTs,Sz,Graph). multiply_all(I,Parents,CPTs,Sz,Graph).
compile_var(_,I,_,_,_,_,_,_) :- % do it dynamically
assert(implicit(I)). compile_var(_,_,_,_,_,_,_,_).
multiply_all(I,Parents,CPTs,Sz,Graph) :- multiply_all(I,Parents,CPTs,Sz,Graph) :-
markov_blanket_instance(Parents,Graph,Values), markov_blanket_instance(Parents,Graph,Values),
multiply_all(CPTs,Sz,Graph,Probs), (
store_mblanket(I,Values,Probs), multiply_all(CPTs,Graph,Probs)
, (a->writeln(Probs);true)
->
store_mblanket(I,Values,Probs)
;
throw(error(domain_error(bayesian_domain),gibbs_cpt(I,Parents,Values,Sz)))
),
fail. fail.
multiply_all(_,_,_,_,_). multiply_all(I,_,_,_,_) :-
assert(explicit(I)).
% note: what matters is how this predicate instantiates the temp % note: what matters is how this predicate instantiates the temp
% slot in the graph! % slot in the graph!
@ -269,56 +274,35 @@ fetch_val([_|Vals],I0,Pos) :-
:- dynamic a/0. :- dynamic a/0.
multiply_all(CPTs,Size,Graph,Probs) :- multiply_all([tabular(Table,_,Parents)|CPTs],Graph,Probs) :-
init_factors(Size,Factors0), fetch_parents(Parents, Graph, Vals),
mult_factors(CPTs,Size,Graph,Factors0,Factors), row_from_possibly_deterministic_CPT(Table,Vals,Probs0),
normalise_factors(Factors,Probs). (a -> list_from_CPT(Probs0,LProbs0), writeln(s:LProbs0) ; true),
multiply_more(CPTs,Graph,Probs0,Probs).
init_factors(0,[]) :- !. fetch_parents([], _, []).
init_factors(I0,[0.0|Factors]) :- fetch_parents([P|Parents], Graph, [Val|Vals]) :-
I is I0-1, arg(P,Graph,var(_,_,Val,_,_,_,_,_,_)),
init_factors(I,Factors). fetch_parents(Parents, Graph, Vals).
mult_factors([],_,_,Factors,Factors). multiply_more([],_,Probs0,LProbs) :-
mult_factors([tabular(Table,_,Parents)|CPTs],Size,Graph,Factors0,Factors) :- normalise_possibly_deterministic_CPT(Probs0, Probs),
functor(Table,_,CPTSize), list_from_CPT(Probs, LProbs0),
Off is CPTSize//Size, (a -> writeln(e:LProbs0) ; true),
factor(Parents,Table,Graph,0,Off,Indx0), accumulate_up_list(LProbs0, 0.0, LProbs).
Indx is Indx0+1, multiply_more([tabular(Table,_,Parents)|CPTs],Graph,Probs0,Probs) :-
mult_with_probs(Factors0,Indx,Off,Table,FactorsI), fetch_parents(Parents, Graph, Vals),
mult_factors(CPTs,Size,Graph,FactorsI,Factors). row_from_possibly_deterministic_CPT(Table, Vals, P0),
(a -> list_from_CPT(P0, L0), list_from_CPT(Probs0, LI), writeln(m:LI:L0) ; true),
multiply_possibly_deterministic_factors(Probs0, P0, ProbsI),
multiply_more(CPTs,Graph,ProbsI,Probs).
accumulate_up_list([], _, []).
accumulate_up_list([P|LProbs], P0, [P1|L]) :-
P1 is P0+P,
accumulate_up_list(LProbs, P1, L).
factor([],_,_,Arg,_,Arg).
factor([I|Parents],Table,Graph,Pos0,Weight0,Pos) :-
arg(I,Graph,var(_,I,CurPos,_,Sz,_,_,_,_)),
NWeight is Weight0 // Sz,
PosI is Pos0+(NWeight*CurPos),
factor(Parents,Table,Graph,PosI,NWeight,Pos).
mult_with_probs([],_,_,_,[]).
mult_with_probs([F0|Factors0],Indx,Off,Table,[F|Factors]) :-
arg(Indx,Table,P1),
F is F0+log(P1),
Indx1 is Indx+Off,
mult_with_probs(Factors0,Indx1,Off,Table,Factors).
normalise_factors(Factors,Probs) :-
max_list(Factors,Max),
logs2list(Factors,Max,NFactors),
normalise_factors(NFactors,0,_,Probs,_).
logs2list([],_,[]).
logs2list([Log|Factors],Max,[P|NFactors]) :-
P is exp(Log+Max),
logs2list(Factors,Max,NFactors).
normalise_factors([],Sum,Sum,[],1.0) :- Sum > 0.0.
normalise_factors([F|Factors],S0,S,[P0|Probs],PF) :-
Si is S0+F,
normalise_factors(Factors,Si,S,Probs,P0),
PF is P0-F/S.
store_mblanket(I,Values,Probs) :- store_mblanket(I,Values,Probs) :-
recordz(mblanket,m(I,Values,Probs),_). recordz(mblanket,m(I,Values,Probs),_).
@ -356,41 +340,11 @@ init_chain(VarOrder,Len,Graph,Chain) :-
gen_sample([],_,_) :- !. gen_sample([],_,_) :- !.
gen_sample([I|Vs],Graph,Chain) :- gen_sample([I|Vs],Graph,Chain) :-
arg(I,Graph,var(_,I,_,_,Sz,_,_,Table,IPars)), arg(I,Graph,var(_,I,_,_,Sz,_,_,_,_)),
functor(Table,_,CPTSize), Pos is integer(random*Sz),
Off is CPTSize//Sz,
iparents_pos_sz(IPars, Chain, IPos, Graph, ISz),
R is random,
project(IPos, ISz, Table,0,Off,Indx0),
Indx is Indx0+1,
fetch_from_dist(Table,R,Indx,Off,0,Pos),
arg(I,Chain,Pos), arg(I,Chain,Pos),
gen_sample(Vs,Graph,Chain). gen_sample(Vs,Graph,Chain).
project([],[],_,Arg,_,Arg).
project([CurPos|Parents],[Sz|Sizes],Table,Pos0,Weight0,Pos) :-
NWeight is Weight0 // Sz,
PosI is Pos0+(NWeight*CurPos),
project(Parents,Sizes,Table,PosI,NWeight,Pos).
fetch_from_dist(Table,R,Indx,Off,IPos,Pos) :-
arg(Indx,Table,P),
( P >= R ->
Pos = IPos
;
NR is R-P,
NIndx is Indx+Off,
NPos is IPos+1,
fetch_from_dist(Table,NR,NIndx,Off,NPos,Pos)
).
iparents_pos_sz([], _, [], _, []).
iparents_pos_sz([I|IPars], Chain, [P|IPos], Graph, [Sz|Sizes]) :-
arg(I,Chain,P),
arg(I,Graph, var(_,I,_,_,Sz,_,_,_,_)),
iparents_pos_sz(IPars, Chain, IPos, Graph, Sizes).
init_estimates(0,_,_,[]) :- !. init_estimates(0,_,_,[]) :- !.
init_estimates(NChains,OutputVars,Graph,[Est|Est0]) :- init_estimates(NChains,OutputVars,Graph,[Est|Est0]) :-
@ -425,7 +379,7 @@ gen_e0(Sz,[0|E0L]) :-
process_chains(0,_,F,F,_,_,Est,Est) :- !. process_chains(0,_,F,F,_,_,Est,Est) :- !.
process_chains(ToDo,VarOrder,End,Start,Graph,Len,Est0,Estf) :- process_chains(ToDo,VarOrder,End,Start,Graph,Len,Est0,Estf) :-
%format('ToDo = ~d~n',[ToDo]), format('ToDo = ~d~n',[ToDo]),
process_chains(Start,VarOrder,Int,Graph,Len,Est0,Esti), process_chains(Start,VarOrder,Int,Graph,Len,Est0,Esti),
% (ToDo mod 100 =:= 1 -> statistics,cvt2problist(Esti, Probs), Int =[S|_], format('did ~d: ~w~n ~w~n',[ToDo,Probs,S]) ; true), % (ToDo mod 100 =:= 1 -> statistics,cvt2problist(Esti, Probs), Int =[S|_], format('did ~d: ~w~n ~w~n',[ToDo,Probs,S]) ; true),
ToDo1 is ToDo-1, ToDo1 is ToDo-1,
@ -446,24 +400,26 @@ do_sample([I|VarOrder],Sample,Sample0,Graph) :-
do_sample(VarOrder,Sample,Sample0,Graph). do_sample(VarOrder,Sample,Sample0,Graph).
do_var(I,Sample,Sample0,Graph) :- do_var(I,Sample,Sample0,Graph) :-
( implicit(I) -> ( explicit(I) ->
arg(I,Graph,var(_,_,_,_,Sz,CPTs,Parents,_,_)),
fetch_parents(Parents,I,Sample,Sample0,Bindings),
multiply_all_in_context(Parents,Bindings,CPTs,Sz,Graph,Vals)
;
arg(I,Graph,var(_,_,_,_,_,_,Parents,_,_)), arg(I,Graph,var(_,_,_,_,_,_,Parents,_,_)),
fetch_parents(Parents,I,Sample,Sample0,Args), fetch_parents(Parents,I,Sample,Sample0,Args),
recorded(mblanket,m(I,Args,Vals),_) recorded(mblanket,m(I,Args,Vals),_)
;
arg(I,Graph,var(_,_,_,_,_,CPTs,Parents,_,_)),
fetch_parents(Parents,I,Sample,Sample0,Bindings),
CPTs=[tabular(T,_,_)|_], matrix:matrix_dims(T,Dims), writeln(I:1:Bindings:Dims),
multiply_all_in_context(Parents,Bindings,CPTs,Graph,Vals)
), ),
X is random, X is random,
writeln(I:X:Vals),
pick_new_value(Vals,X,0,Val), pick_new_value(Vals,X,0,Val),
arg(I,Sample,Val). arg(I,Sample,Val).
multiply_all_in_context(Parents,Args,CPTs,Sz,Graph,Vals) :- multiply_all_in_context(Parents,Args,CPTs,Graph,Vals) :-
set_pos(Parents,Args,Graph), set_pos(Parents,Args,Graph),
multiply_all(CPTs,Sz,Graph,Vals), multiply_all(CPTs,Graph,Vals),
assert(mall(Vals)), fail. assert(mall(Vals)), fail.
multiply_all_in_context(_,_,_,_,_,Vals) :- multiply_all_in_context(_,_,_,_,Vals) :-
retract(mall(Vals)). retract(mall(Vals)).
set_pos([],[],_). set_pos([],[],_).
@ -524,7 +480,7 @@ clean_up :-
eraseall(mblanket), eraseall(mblanket),
fail. fail.
clean_up :- clean_up :-
retractall(implicit(_)), retractall(explicit(_)),
fail. fail.
clean_up. clean_up.