This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
Files
yap-6.3/packages/prism/src/prolog/up/learn.pl
2011-11-10 12:24:47 +00:00

436 lines
16 KiB
Prolog

learn :-
get_prism_flag(learn_mode,Mode),
$pp_learn_main(Mode).
learn(Goals) :-
get_prism_flag(learn_mode,Mode),
$pp_learn_main(Mode,Goals).
learn_p :-
$pp_learn_main(params).
learn_p(Goals) :-
$pp_learn_main(params,Goals).
learn_h :-
$pp_learn_main(hparams).
learn_h(Goals) :-
$pp_learn_main(hparams,Goals).
learn_b :-
$pp_learn_main(both).
learn_b(Goals) :-
$pp_learn_main(both,Goals).
%% for the parallel version
$pp_learn_main(Mode) :- call($pp_learn_core(Mode)).
$pp_learn_main(Mode,Goals) :- call($pp_learn_core(Mode,Goals)).
$pp_learn_data_file(FileName) :-
get_prism_flag(data_source,Source),
( Source == none ->
$pp_raise_runtime_error($msg(1300),data_source_not_found,
$pp_learn_data_file/1)
; Source == data/1 ->
( current_predicate(data/1) -> data(FileName)
; $pp_raise_runtime_error($msg(1301),data_source_not_found,
$pp_learn_data_file/1)
)
; Source = file(FileName)
; $pp_raise_unmatched_branches($pp_learn_data_file/1)
),!.
$pp_learn_check_goals(Goals) :-
$pp_require_observed_data(Goals,$msg(1302),$pp_learn_core/1),
$pp_learn_check_goals1(Goals),
( get_prism_flag(daem,on),
membchk(failure,Goals)
-> $pp_raise_runtime_error($msg(1305),daem_with_failure,
$pp_learn_core/1)
; true
).
$pp_learn_check_goals1([]).
$pp_learn_check_goals1([G0|Gs]) :-
( (G0 = goal(G,Count) ; G0 = count(G,Count) ; G0 = (Count times G) ) ->
$pp_require_positive_integer(Count,$msg(1306),$pp_learn_core/1)
; G = G0
),
$pp_require_tabled_probabilistic_atom(G,$msg(1303),$pp_learn_core/1),!,
$pp_learn_check_goals1(Gs).
$pp_learn_core(Mode) :-
$pp_learn_data_file(FileName),
load_clauses(FileName,Goals,[]),!,
$pp_learn_core(Mode,Goals).
$pp_learn_core(Mode,Goals) :-
$pp_learn_check_goals(Goals),
$pp_learn_message(MsgS,MsgE,MsgT,MsgM),
$pc_set_em_message(MsgE),
cputime(Start),
$pp_learn_clean_info,
$pp_learn_reset_hparams(Mode),
$pp_trans_goals(Goals,GoalCountPairs,AllGoals),!,
global_set($pg_observed_facts,GoalCountPairs),
cputime(StartExpl),
global_set($pg_num_goals,0),
$pp_find_explanations(AllGoals),!,
$pp_print_num_goals(MsgS),
cputime(EndExpl),
% vsc statistics(table,[TableSpace,_]),
TableSpace = 0, % not supported in YAP (it should be).
( MsgM == 0 -> true
; format("Exporting switch information to the EM routine ... ",[])
),
flush_output,
$pp_collect_init_switches(Sws),
$pc_export_sw_info(Sws),
( MsgM == 0 -> true ; format("done~n",[]) ),
$pp_observed_facts(GoalCountPairs,GidCountPairs,
0,Len,0,NGoals,-1,FailRootIndex),
$pc_prism_prepare(GidCountPairs,Len,NGoals,FailRootIndex),
cputime(StartEM),
$pp_em(Mode,Output),
cputime(EndEM),
$pc_import_occ_switches(NewSws,NSwitches,NSwVals),
$pp_decode_update_switches(Mode,NewSws),
$pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared),
$pp_delete_tmp_out,
cputime(End),
$pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared),
$pp_assert_learn_stats(Mode,Output,NSwitches,NSwVals,TableSpace,
Start,End,StartExpl,EndExpl,StartEM,EndEM,1000),
( MsgT == 0 -> true ; $pp_print_learn_stats_message ),
( MsgM == 0 -> true ; $pp_print_learn_end_message(Mode) ),!.
$pp_learn_clean_info :-
$pp_clean_dummy_goal_table,
$pp_clean_graph_stats,
$pp_clean_learn_stats,
$pp_init_tables_aux,
$pp_init_tables_if_necessary,!.
$pp_learn_reset_hparams(Mode) :-
( Mode == params -> true
; get_prism_flag(reset_hparams,on) -> set_sw_all_a(_)
; true
).
$pp_print_num_goals(MsgS) :-
( MsgS == 0 -> true
; global_get($pg_num_goals,N),format("(~w)~n",[N]),flush_output
).
$pp_em(params,Output) :-
$pc_prism_em(Iterate,LogPost,LogLike,BIC,CS,ModeSmooth),
Output = [Iterate,LogPost,LogLike,BIC,CS,ModeSmooth].
$pp_em(hparams,Output) :-
$pc_prism_vbem(IterateVB,FreeEnergy),
Output = [IterateVB,FreeEnergy].
$pp_em(both,Output) :-
$pc_prism_both_em(IterateVB,FreeEnergy),
Output = [IterateVB,FreeEnergy].
$pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared) :-
NNodes is NGoalNodes + NSwNodes,
assertz($ps_num_subgraphs(NSubgraphs)),
assertz($ps_num_nodes(NNodes)),
assertz($ps_num_goal_nodes(NGoalNodes)),
assertz($ps_num_switch_nodes(NSwNodes)),
assertz($ps_avg_shared(AvgShared)),!.
$pp_assert_learn_stats(Mode,Output,NSwitches,NSwVals,TableSpace,
Start,End,StartExpl,EndExpl,StartEM,EndEM,UnitsPerSec) :-
assertz($ps_num_switches(NSwitches)),
assertz($ps_num_switch_values(NSwVals)),
( integer(TableSpace) -> assertz($ps_learn_table_space(TableSpace)) ; true ),
Time is (End - Start) / UnitsPerSec,
assertz($ps_learn_time(Time)),
TimeExpl is (EndExpl - StartExpl) / UnitsPerSec,
assertz($ps_learn_search_time(TimeExpl)),
TimeEM is (EndEM - StartEM) / UnitsPerSec,
assertz($ps_em_time(TimeEM)),
$pp_assert_learn_stats_sub(Mode,Output),!.
$pp_assert_learn_stats_sub(params,Output) :-
Output = [Iterate,LogPost,LogLike,BIC,CS,ModeSmooth],
assertz($ps_num_iterations(Iterate)),
( ModeSmooth > 0 -> assertz($ps_log_post(LogPost)) ; true ),
assertz($ps_log_likelihood(LogLike)),
assertz($ps_bic_score(BIC)),
( ModeSmooth > 0 -> assertz($ps_cs_score(CS)) ; true ),!.
$pp_assert_learn_stats_sub(hparams,Output) :-
Output = [IterateVB,FreeEnergy],
assertz($ps_num_iterations_vb(IterateVB)),
assertz($ps_free_energy(FreeEnergy)),!.
$pp_assert_learn_stats_sub(both,Output) :-
Output = [IterateVB,FreeEnergy],
assertz($ps_num_iterations_vb(IterateVB)),
assertz($ps_free_energy(FreeEnergy)),!.
$pp_print_learn_stats_message :-
format("Statistics on learning:~n",[]),
( $pp_print_learn_stats_message_sub,fail ; true ),!.
$pp_print_learn_stats_message_sub :-
( $ps_num_nodes(L),
format("~tGraph size: ~w~n",[L])
; $ps_num_switches(L),
format("~tNumber of switches: ~w~n",[L])
; $ps_num_switch_values(L),
format("~tNumber of switch instances: ~w~n",[L])
; $ps_num_iterations_vb(L),
format("~tNumber of iterations: ~w~n",[L])
; $ps_num_iterations(L),
format("~tNumber of iterations: ~w~n",[L])
; $ps_free_energy(L),
format("~tFinal variational free energy: ~9f~n",[L])
; $ps_log_post(L),
format("~tFinal log of a posteriori prob: ~9f~n",[L])
; $ps_log_likelihood(L), \+ $ps_log_post(_),
format("~tFinal log likelihood: ~9f~n",[L])
; $ps_learn_time(L),
format("~tTotal learning time: ~3f seconds~n",[L])
; $ps_learn_search_time(L),
format("~tExplanation search time: ~3f seconds~n",[L])
; $ps_learn_table_space(L),
format("~tTotal table space used: ~w bytes~n",[L])
).
$pp_print_learn_end_message(Mode) :-
( Mode == params ->
format("Type show_sw to show the probability distributions.~n",[])
; Mode == hparams ->
format("Type show_sw_a/show_sw_d to show the probability distributions.~n",[])
; Mode == both ->
format("Type show_sw_pa/show_sw_pd to show the probability distributions.~n",[])
).
$pp_clean_graph_stats :-
retractall($ps_num_subgraphs(_)),
retractall($ps_num_nodes(_)),
retractall($ps_num_goal_nodes(_)),
retractall($ps_num_switch_nodes(_)),
retractall($ps_avg_shared(_)),!.
$pp_clean_learn_stats :-
retractall($ps_log_likelihood(_)),
retractall($ps_log_post(_)),
retractall($ps_num_switches(_)),
retractall($ps_num_switch_values(_)),
retractall($ps_num_iterations(_)),
retractall($ps_num_iterations_vb(_)),
retractall($ps_bic_score(_)),
retractall($ps_cs_score(_)),
retractall($ps_free_energy(_)),
retractall($ps_learn_time(_)),
retractall($ps_learn_search_time(_)),
retractall($ps_em_time(_)),
retractall($ps_learn_table_space(_)),!.
$pp_collect_init_switches(Sws) :-
$pc_prism_sw_count(N),
$pp_collect_init_switches(0,N,Sws).
$pp_collect_init_switches(Sid,N,SwInsList) :- Sid >= N,!,
SwInsList = [].
$pp_collect_init_switches(Sid,N,SwInsList) :-
$pc_prism_sw_term(Sid,Sw),
SwInsList = [sw(Sid,Instances,Pbs,Deltas,FixedP,FixedH)|SwInsList1],
$pp_get_parameters(Sw,Values,Pbs),!,
$pp_get_hyperparameters(Sw,Values,_,Deltas),!,
( $pd_fixed_parameters(Sw) -> FixedP = 1 ; FixedP = 0 ),
( $pd_fixed_hyperparameters(Sw) -> FixedH = 1 ; FixedH = 0 ),
$pp_collect_sw_ins_ids(Sw,Values,Instances),
Sid1 is Sid + 1,!,
$pp_collect_init_switches(Sid1,N,SwInsList1).
$pp_collect_sw_ins_ids(_Sw,[],[]).
$pp_collect_sw_ins_ids(Sw,[V|Vs],[I|Is]) :-
$pc_prism_sw_ins_id_get(msw(Sw,V),I),!,
$pp_collect_sw_ins_ids(Sw,Vs,Is).
$pp_decode_update_switches(params,Sws) :-
$pp_decode_update_switches_p(Sws).
$pp_decode_update_switches(hparams,Sws) :-
$pp_decode_update_switches_h(Sws).
$pp_decode_update_switches(both,Sws) :-
$pp_decode_update_switches_b(Sws).
$pp_decode_update_switches_p([]).
$pp_decode_update_switches_p([sw(_,SwInstances)|Sws]) :-
$pp_decode_switch_name(SwInstances,Sw),
$pp_decode_switch_instances(SwInstances,Updates),
get_values1(Sw,Values),
$pp_separate_updates(Values,Probs,_Deltas,Es,Updates),
( retract($pd_parameters(Sw,_,_)) -> true ; true ),
assert($pd_parameters(Sw,Values,Probs)),
( retract($pd_expectations(Sw,_,_)) -> true ; true),
( retract($pd_hyperexpectations(Sw,_,_)) -> true ; true),
assert($pd_expectations(Sw,Values,Es)),!,
$pp_decode_update_switches_p(Sws).
$pp_decode_update_switches_h([]).
$pp_decode_update_switches_h([sw(_,SwInstances)|Sws]) :-
$pp_decode_switch_name(SwInstances,Sw),
$pp_decode_switch_instances(SwInstances,Updates),
get_values1(Sw,Values),
$pp_separate_updates(Values,_Probs,Deltas,Es,Updates),
( retract($pd_hyperparameters(Sw,_,_,_)) -> true ; true ),
$pp_delta_to_alpha(Deltas,Alphas),
assert($pd_hyperparameters(Sw,Values,Alphas,Deltas)),
( retract($pd_expectations(Sw,_,_)) -> true ; true),
( retract($pd_hyperexpectations(Sw,_,_)) -> true ; true),
assert($pd_hyperexpectations(Sw,Values,Es)),!,
$pp_decode_update_switches_h(Sws).
$pp_decode_update_switches_b([]).
$pp_decode_update_switches_b([sw(_,SwInstances)|Sws]) :-
$pp_decode_switch_name(SwInstances,Sw),
$pp_decode_switch_instances(SwInstances,Updates),
get_values1(Sw,Values),
$pp_separate_updates(Values,Probs,Deltas,Es,Updates),
( retract($pd_parameters(Sw,_,_)) -> true ; true ),
assert($pd_parameters(Sw,Values,Probs)),
( retract($pd_hyperparameters(Sw,_,_,_)) -> true ; true ),
$pp_delta_to_alpha(Deltas,Alphas),
assert($pd_hyperparameters(Sw,Values,Alphas,Deltas)),
( retract($pd_hyperexpectations(Sw,_,_)) -> true ; true),
( retract($pd_expectations(Sw,_,_)) -> true ; true),
assert($pd_hyperexpectations(Sw,Values,Es)),!,
$pp_decode_update_switches_b(Sws).
$pp_decode_switch_name([sw_ins(Sid,_,_,_)|_SwInstances],Sw) :-
$pc_prism_sw_ins_term(Sid,msw(Sw,_)). % only uses the first element
$pp_decode_switch_instances([],[]).
$pp_decode_switch_instances([sw_ins(Sid,Prob,Delta,Expect)|SwInstances],
[(V,Prob,Delta,Expect)|Updates]) :-
$pc_prism_sw_ins_term(Sid,msw(_,V)),!,
$pp_decode_switch_instances(SwInstances,Updates).
$pp_separate_updates([],[],[],[],_Updates).
$pp_separate_updates([V|Vs],[Prob|Probs],[Delta|Deltas],[E|Es],Updates) :-
member((V,Prob,Delta,E),Updates),!,
$pp_separate_updates(Vs,Probs,Deltas,Es,Updates).
%% [NOTE] Non-ground goals has already been replaced by dummy goals, so all
%% goals are ground here.
$pp_observed_facts([],[],Len,Len,NGoals,NGoals,FailRootIndex,FailRootIndex).
$pp_observed_facts([goal(Goal,Count)|GoalCountPairs],GidCountPairs,
Len0,Len,NGoals0,NGoals,FailRootIndex0,FailRootIndex) :-
% fails if the goal is ground but has no proof
( $pc_prism_goal_id_get(Goal,Gid) ->
( Goal == failure ->
NGoals1 = NGoals0,
FailRootIndex1 = Len0
; NGoals1 is NGoals0 + Count,
FailRootIndex1 = FailRootIndex0
),
GidCountPairs = [goal(Gid,Count)|GidCountPairs1],
Len1 is Len0 + 1
; $pp_raise_unexpected_failure($pp_observed_facts/8)
),!,
$pp_observed_facts(GoalCountPairs,GidCountPairs1,
Len1,Len,NGoals1,NGoals,FailRootIndex1,FailRootIndex).
%% Assumption: for any pair of terms F and F' (F's variant), hash codes for
%% F and F' are equal.
%%
%% For convenience on implementation of parallel learning, $pp_trans_goals/3
%% is (internally) split into two predicates $pp_build_count_pairs/2 and
%% $pp_trans_count_pairs/3.
%%
%% The order of goal-count pairs may differ at every run due to the way of
%% implemention of hashtables.
$pp_trans_goals(Goals,GoalCountPairs,AllGoals) :-
$pp_build_count_pairs(Goals,Pairs),
$pp_trans_count_pairs(Pairs,GoalCountPairs,AllGoals).
$pp_build_count_pairs(Goals,Pairs) :-
new_hashtable(Table),
$pp_count_goals(Goals,Table),
hashtable_to_list(Table,Pairs0),
sort(Pairs0,Pairs).
$pp_count_goals([],_).
$pp_count_goals([G0|Goals],Table) :-
( G0 = goal(Goal,Count) -> true
; G0 = count(Goal,Count) -> true
; G0 = (Count times Goal) -> true
; Goal = G0, Count = 1
),
( ground(Goal) -> GoalCp = Goal
; copy_term(Goal,GoalCp)
),
( $pp_hashtable_get(Table,GoalCp,Count0) ->
Count1 is Count0 + Count,
$pp_hashtable_put(Table,GoalCp,Count1)
; $pp_hashtable_put(Table,GoalCp,Count)
),!,
$pp_count_goals(Goals,Table).
$pp_trans_count_pairs([],[],[]).
$pp_trans_count_pairs([Goal=Count|Pairs],GoalCountPairs,AllGoals) :-
$pp_build_dummy_goal(Goal,DummyGoal),
GoalCountPairs = [goal(DummyGoal,Count)|GoalCountPairs1],
AllGoals = [DummyGoal|AllGoals1],!,
$pp_trans_count_pairs(Pairs,GoalCountPairs1,AllGoals1).
$pp_build_dummy_goal(Goal,DummyGoal) :-
( Goal = msw(I,V) ->
( ground(I) -> I = ICp ; copy_term(I,ICp) ),
( ground(V) -> V = VCp ; copy_term(V,VCp) ),
$pp_create_dummy_goal(DummyGoal),
$pp_assert_dummy_goal(DummyGoal,Goal),
Clause = (DummyGoal :- $prism_expl_msw(ICp,VCp,Sid),
$pc_prism_goal_id_register(DummyGoal,Hid),
$prism_eg_path(Hid,[],[Sid])),
Prog = [pred(DummyGoal,0,_,_,tabled(_,_,_,_),[Clause]),
pred($damon_load,0,_,_,_,[($damon_load:-true)])],
consult_preds([],Prog)
; ground(Goal) ->
DummyGoal = Goal % don't create dummy goals (wrappers) for
; % ground goals to save memory.
$pp_create_dummy_goal(DummyGoal),
$pp_assert_dummy_goal(DummyGoal,Goal),
( $pp_trans_one_goal(Goal,CompGoal) -> BodyGoal = CompGoal
; BodyGoal = (savecp(CP),Depth=0,
$pp_expl_interp_goal(Goal,Depth,CP,[],_,[],_,[],_,[],_))
),
Clause = (DummyGoal:-BodyGoal,
$pc_prism_goal_id_register(Goal,GId),
$pc_prism_goal_id_register(DummyGoal,HId),
$prism_eg_path(HId,[GId],[])),
Prog = [pred(DummyGoal,0,_Mode,_Delay,tabled(_,_,_,_),[Clause]),
pred($damon_load,0,_,_,_,[($damon_load:-true)])],
consult_preds([],Prog)
),!.
$pp_assert_dummy_goal(DummyGoal,OrigGoal) :-
assertz($pd_dummy_goal_table(DummyGoal,OrigGoal)),!.
$pp_clean_dummy_goal_table :-
retractall($pd_dummy_goal_table(_,_)).
%%----------------------------------------
% just make a simple check
$pp_require_observed_data(Gs,MsgID,Source) :-
( $pp_test_observed_data(Gs) -> true
; $pp_raise_on_require([Gs],MsgID,Source,$pp_error_observed_data)
).
$pp_test_observed_data(Gs) :-
nonvar(Gs),
( Gs = [failure] -> fail
; Gs = [_|_]
).
$pp_error_observed_data(Gs,Error) :-
$pp_error_nonvar(Gs,Error), !.
$pp_error_observed_data(Gs,domain_error(observed_data,Gs)) :-
( Gs = [failure] ; Gs \= [_|_] ), !.