This commit is contained in:
Vítor Santos Costa
2018-08-31 20:03:00 +01:00
parent e0cc401381
commit f6d8304fcf
9 changed files with 103 additions and 98 deletions

View File

@@ -244,7 +244,7 @@ logger_define_variable(Name,Type) :-
->
write('redefining logger variable '),write(Name),write(' of type '), write(Type0), nl
;
throw(error(variable_redefined(logger_define_variable(Name,Type)))
throw(error(variable_redefined(logger_define_variable(Name,Type))))
).
logger_define_variable(Name,Type) :-
ground(Type),

View File

@@ -263,7 +263,7 @@
% this is a test to determine whether YAP provides the needed trie library
:- initialization(
( predicate_property(trie_disable_hash, imported_from(tries)) ->
( predicate_property(trie_disable_hash, imported_from(M)) ->
trie_disable_hash
; print_message(warning,'The predicate tries:trie_disable_hash/0 does not exist. Please update trie library.')
)

View File

@@ -389,8 +389,6 @@ reset_learning :-
do_learning(Iterations) :-
do_learning(Iterations,-1).
:- spy do_learning/2.
do_learning(Iterations,Epsilon) :-
current_predicate(user:example/4),
!,
@@ -514,32 +512,7 @@ init_learning :-
% if yes, switch to problog_exact
% continuous facts are not supported yet.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
problog_flag(continuous_facts, true),
!,
problog_flag(init_method,(_,_,_,_,OldCall)),
(
(
continuous_fact(_),
OldCall\=problog_exact_save(_,_,_,_,_)
)
->
(
format_learning(2,'Theory uses continuous facts.~nWill use problog_exact/3 as initalization method.~2n',[]),
set_problog_flag(init_method,(Query,Probability,BDDFile,ProbFile,problog_exact_save(Query,Probability,_Status,BDDFile,ProbFile)))
);
true
),
(
problog_tabled(_)
->
(
format_learning(2,'Theory uses tabling.~nWill use problog_exact/3 as initalization method.~2n',[]),
set_problog_flag(init_method,(Query,Probability,BDDFile,ProbFile,problog_exact_save(Query,Probability,_Status,BDDFile,ProbFile)))
);
true
),
set_default_gradient_method,
succeeds_n_times(user:test_example(_,_,_,_),TestExampleCount),
format_learning(3,'~q test examples~n',[TestExampleCount]),
@@ -597,6 +570,21 @@ init_learning :-
empty_bdd_directory.
set_default_gradient_method :-
problog_flag(continuous_facts, true),
!,
problog_flag(init_method,OldMethod),
format_learning(2,'Theory uses continuous facts.~nWill use problog_exact/3 as initalization method.~2n',[]),
set_problog_flag(init_method,(Query,Probability,BDDFile,ProbFile,problog_exact_save(Query,Probability,_Status,BDDFile,ProbFile))).
set_default_gradient_method :-
problog_tabled(_), problog_flag(fast_proofs,false),
!,
format_learning(2,'Theory uses tabling.~nWill use problog_exact/3 as initalization method.~2n',[]),
set_problog_flag(init_method,(Query,Probability,BDDFile,ProbFile,problog_exact_save(Query,Probability,_Status,BDDFile,ProbFile))).
set_default_gradient_method :-
set_problog_flag(init_method,(Query,1,BDD,
problog_kbest_as_bdd(user:Query,1,BDD))).
%========================================================================
%= This predicate goes over all training and test examples,
%= calls the inference method of ProbLog and stores the resulting
@@ -625,29 +613,23 @@ init_one_query(QueryID,Query,Type) :-
format_learning(3,' Reuse existing BDD ~q~n~n',[QueryID]);
(
b_setval(problog_required_keep_ground_ids,false),
problog_flag(libbdd_init_method,(Query,Bdd,Call)),
rb_new(H0),
problog_flag(init_method,(Query,NOf,Bdd,problog_kbest_as_bdd(Call,1,Bdd))),
strip_module(Call,_,gene(X,Y)),
!,
Bdd = bdd(Dir, Tree, MapList),
% trace,
once(Call),
rb_new(H0),
problog:problog_kbest_as_bdd(user:gene(X,Y),1,Bdd),
maplist_to_hash(MapList, H0, Hash),
Tree \= [],
%put_code(0'.),
writeln(QueryID),
tree_to_grad(Tree, Hash, [], Grad),
recordz(QueryID,bdd(Dir, Grad, MapList),_)
)
),
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% check wether this BDD is similar to another BDD
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
(
problog_flag(check_duplicate_bdds,true)
->
true /* ignore this flag for now */
;
true
),!.
init_one_query(_QueryID,_Query,_Type).
).
init_one_query(_QueryID,_Query,_Type) :-
throw(unsupported_init_method).
@@ -1006,6 +988,7 @@ inv_sigmoid(T,InvSig) :-
save_old_probabilities :-
problog_flag(continous_facts, true),
!,
forall(tunable_fact(FactID,_),
(
continuous_fact(FactID)
@@ -1024,6 +1007,14 @@ save_old_probabilities :-
)
)
).
save_old_probabilities :-
forall(tunable_fact(FactID,_),
(
get_fact_probability(FactID,OldProbability),
atomic_concat(['old_prob_',FactID],Key),
bb_put(Key,OldProbability)
)
).
@@ -1106,7 +1097,8 @@ gradient_descent :-
save_old_probabilities,
update_values,
reset_gradients.
reset_gradients,
compute_gradients(Handle).
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% start set gradient to zero
@@ -1130,7 +1122,14 @@ gradient_descent :-
bb_put(Key,0.0)
)
)
),
).
reset_gradients :-
forall(tunable_fact(FactID,_),
(
atomic_concat(['grad_',FactID],Key),
bb_put(Key,0.0)
)
).
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% stop gradient to zero
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
@@ -1138,6 +1137,7 @@ gradient_descent :-
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% start calculate gradient
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
compute_gradients(Handle) :-
bb_put(mse_train_sum, 0.0),
bb_put(mse_train_min, 0.0),
bb_put(mse_train_max, 0.0),
@@ -1485,18 +1485,18 @@ my_5_min(V1,V2,V3,V4,V5,F1,F2,F3,F4,F5,VMin,FMin) :-
%========================================================================
init_flags :-
prolog_file_name('queries',Queries_Folder), % get absolute file name for './queries'
prolog_file_name('output',Output_Folder), % get absolute file name for './output'
prolog_file_name(queries,Queries_Folder), % get absolute file name for './queries'
prolog_file_name(output,Output_Folder), % get absolute file name for './output'
problog_define_flag(bdd_directory, problog_flag_validate_directory, 'directory for BDD scripts', Queries_Folder,learning_general),
problog_define_flag(output_directory, problog_flag_validate_directory, 'directory for logfiles etc', Output_Folder,learning_general,flags:learning_output_dir_handler),
problog_define_flag(log_frequency, problog_flag_validate_posint, 'log results every nth iteration', 1, learning_general),
problog_define_flag(rebuild_bdds, problog_flag_validate_nonegint, 'rebuild BDDs every nth iteration', 0, learning_general),
problog_define_flag(reuse_initialized_bdds,problog_flag_validate_boolean, 'Reuse BDDs from previous runs',false, learning_general),
problog_define_flag(check_duplicate_bdds,problog_flag_validate_boolean,'Store intermediate results in hash table',true,learning_general),
problog_define_flag(libbdd_init_method,problog_flag_validate_dummy,'ProbLog predicate to search proofs',(Query,Tree,problog:problog_kbest_as_bdd(Query,100,Tree)),learning_general,flags:learning_libdd_init_handler),
problog_define_flag(init_method,problog_flag_validate_dummy,'ProbLog predicate to search proofs',(Query,Tree,problog:problog_kbest_as_bdd(Query,100,Tree)),learning_general,flags:learning_libdd_init_handler),
problog_define_flag(alpha,problog_flag_validate_number,'weight of negative examples (auto=n_p/n_n)',auto,learning_general,flags:auto_handler),
problog_define_flag(sigmoid_slope,problog_flag_validate_posnumber,'slope of sigmoid function',1.0,learning_general),
problog_define_flag(continuous_facts,problog_flag_validate_boolean,'support parameter learning of continuous distributions',1.0,learning_general),
problog_define_flag(continuous_facts,problog_flag_validate_boolean,'support parameter learning of continuous distributions',true,learning_general),
problog_define_flag(learning_rate,problog_flag_validate_posnumber,'Default learning rate (If line_search=false)',examples,learning_line_search,flags:examples_handler),
problog_define_flag(line_search, problog_flag_validate_boolean,'estimate learning rate by line search',false,learning_line_search),
@@ -1504,7 +1504,6 @@ init_flags :-
problog_define_flag(line_search_tau, problog_flag_validate_indomain_0_1_open,'tau value for line search',0.618033988749,learning_line_search),
problog_define_flag(line_search_tolerance,problog_flag_validate_posnumber,'tolerance value for line search',0.05,learning_line_search),
problog_define_flag(line_search_interval, problog_flag_validate_dummy,'interval for line search',(0,100),learning_line_search,flags:linesearch_interval_handler).
init_logger :-
logger_define_variable(iteration, int),
@@ -1527,5 +1526,7 @@ init_logger :-
logger_define_variable(llh_test_queries,float).
:- initialization(init_flags).
:- initialization(init_logger).