update ProbLog

This commit is contained in:
Vitor Santos Costa
2009-03-06 09:53:09 +00:00
parent afd979a246
commit f01fd0fbee
12 changed files with 643 additions and 3647 deletions

View File

@@ -3,7 +3,7 @@
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Parameter Learning for ProbLog
%
% 27.10.2008
% 28.11.2008
% bernd.gutmann@cs.kuleuven.be
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
@@ -61,6 +61,11 @@
:- dynamic log_frequency/1.
:- dynamic alpha/1.
:- dynamic sigmoid_slope/1.
:- dynamic line_search/1.
:- dynamic line_search_tolerance/1.
:- dynamic line_search_tau/1.
:- dynamic line_search_never_stop/1.
:- dynamic line_search_interval/2.
%==========================================================================
@@ -84,6 +89,7 @@ set_learning_flag(init_method,(Query,Probability,BDDFile,ProbFile,Call)) :-
set_learning_flag(rebuild_bdds,Flag) :-
(Flag=true;Flag=false),
!,
retractall(rebuild_bdds(_)),
assert(rebuild_bdds(Flag)).
@@ -95,11 +101,13 @@ set_learning_flag(rebuild_bdds_it,Flag) :-
set_learning_flag(reuse_initialized_bdds,Flag) :-
(Flag=true;Flag=false),
!,
retractall(reuse_initialized_bdds(_)),
assert(reuse_initialized_bdds(Flag)).
set_learning_flag(learning_rate,V) :-
(V=examples -> true;(number(V),V>=0)),
!,
retractall(learning_rate(_)),
assert(learning_rate(V)).
@@ -112,6 +120,7 @@ set_learning_flag(probability_initializer,(FactID,Probability,Query)) :-
set_learning_flag(check_duplicate_bdds,Flag) :-
(Flag=true;Flag=false),
!,
retractall(check_duplicate_bdds(_)),
assert(check_duplicate_bdds(Flag)).
@@ -160,6 +169,34 @@ set_learning_flag(sigmoid_slope,Slope) :-
assert(sigmoid_slope(Slope)).
set_learning_flag(line_search,Flag) :-
(Flag=true;Flag=false),
!,
retractall(line_search(_)),
assert(line_search(Flag)).
set_learning_flag(line_search_tolerance,Number) :-
number(Number),
Number>0,
retractall(line_search_tolerance(_)),
assert(line_search_tolerance(Number)).
set_learning_flag(line_search_interval,(L,R)) :-
number(L),
number(R),
L<R,
retractall(line_search_interval(_,_)),
assert(line_search_interval(L,R)).
set_learning_flag(line_search_tau,Number) :-
number(Number),
Number>0,
retractall(line_search_tau(_)),
assert(line_search_tau(Number)).
set_learning_flag(line_search_never_stop,Flag) :-
(Flag=true;Flag=false),
!,
retractall(line_search_nerver_stop(_)),
assert(line_search_never_stop(Flag)).
%========================================================================
%= store the facts with the learned probabilities to a file
%= if F is a variable, a filename based on the current iteration is used
@@ -358,7 +395,7 @@ do_learning_intern(Iterations,Epsilon) :-
assert(current_iteration(CurrentIteration)),
EndIteration is OldIteration+Iterations,
format(' Iteration ~d of ~d~n',[CurrentIteration,EndIteration]),
format('~n Iteration ~d of ~d~n',[CurrentIteration,EndIteration]),
logger_set_variable(iteration,CurrentIteration),
logger_start_timer(duration),
@@ -668,10 +705,14 @@ random_probability(_FactID,Probability) :-
%========================================================================
update_values :-
update_values(all).
update_values(_) :-
values_correct,
!.
update_values :-
update_values(What_To_Update) :-
\+ values_correct,
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
@@ -688,12 +729,17 @@ update_values :-
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
open(Input_Filename,'write',Handle),
( % go over all tunable facts
( % go over all probabilistic fact
get_fact_probability(ID,Prob),
inv_sigmoid(Prob,Value),
format(Handle,'@x~q~n~10f~n',[ID,Value]),
(
non_ground_fact(ID)
->
format(Handle,'@x~q_*~n~10f~n',[ID,Value]);
format(Handle,'@x~q~n~10f~n',[ID,Value])
),
fail; % go to next tunable fact
fail; % go to next probabilistic fact
true
),
@@ -710,7 +756,7 @@ update_values :-
( % go over all training examples
current_predicate(user:example/3),
user:example(QueryID,_Query,_QueryProb),
once(call_bdd_tool(QueryID,'.')),
once(call_bdd_tool(QueryID,'.',What_To_Update)),
fail; % go to next training example
true
),
@@ -723,13 +769,16 @@ update_values :-
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% start update values for all test examples
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
( % go over all training examples
current_predicate(user:test_example/3),
user:test_example(QueryID,_Query,_QueryProb),
once(call_bdd_tool(QueryID,'+')),
fail; % go to next training example
true
),
( What_To_Update = all
->
( % go over all training examples
current_predicate(user:test_example/3),
user:test_example(QueryID,_Query,_QueryProb),
once(call_bdd_tool(QueryID,'+',all)),
fail; % go to next training example
true
); true
),
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% stop update values for all test examples
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
@@ -748,7 +797,7 @@ update_values :-
%========================================================================
call_bdd_tool(QueryID,Symbol) :-
call_bdd_tool(QueryID,Symbol,What_To_Update) :-
output_directory(Output_Directory),
query_directory(Query_Directory),
(
@@ -759,15 +808,16 @@ call_bdd_tool(QueryID,Symbol) :-
(
sigmoid_slope(Slope),
problog_dir(PD),
(What_To_Update=all -> Method='g' ; Method='l'),
atomic_concat([PD,
'/ProblogBDD -i "',
'/ProblogBDD -i "',
Output_Directory,
'input.txt',
'" -l "',
Query_Directory,
'query_',
QueryID,
'" -m g -id ',
'" -m ',Method,' -id ',
QueryID,
' -sl ',Slope,
' > "',
@@ -894,6 +944,25 @@ ground_truth_difference :-
%= -Float
%========================================================================
mse_trainingset_only_for_linesearch(MSE) :-
(
current_predicate(user:example/3)
->
(
update_values(probabilities),
findall(SquaredError,
(user:example(QueryID,_Query,QueryProb),
query_probability(QueryID,CurrentProb),
SquaredError is (CurrentProb-QueryProb)**2),
AllSquaredErrors),
length(AllSquaredErrors,Length),
sum_list(AllSquaredErrors,SumAllSquaredErrors),
MSE is SumAllSquaredErrors/Length
); true
),
retractall(values_correct).
% calculate the mse of the training data
mse_trainingset :-
(
@@ -947,7 +1016,6 @@ mse_testset :-
%========================================================================
%= Calculates the sigmoid function respectivly the inverse of it
%= warning: applying inv_sigmoid to 0.0 or 1.0 will yield +/-inf
@@ -988,8 +1056,64 @@ secure_probability(Prob,Prob_Secure) :-
%= probabilities of the examples have to be recalculated
%========================================================================
save_old_probabilities :-
( % go over all tunable facts
tunable_fact(FactID,_),
get_fact_probability(FactID,OldProbability),
atomic_concat(['old_prob_',FactID],Key),
bb_put(Key,OldProbability),
fail; % go to next tunable fact
true
).
forget_old_values :-
( % go over all tunable facts
tunable_fact(FactID,_),
atomic_concat(['old_prob_',FactID],Key),
atomic_concat(['grad_',FactID],Key2),
bb_delete(Key,_),
bb_delete(Key2,_),
fail; % go to next tunable fact
true
).
add_gradient(Learning_Rate) :-
( % go over all tunable facts
tunable_fact(FactID,_),
atomic_concat(['old_prob_',FactID],Key),
atomic_concat(['grad_',FactID],Key2),
bb_get(Key,OldProbability),
bb_get(Key2,GradValue),
inv_sigmoid(OldProbability,OldValue),
NewValue is OldValue -Learning_Rate*GradValue,
sigmoid(NewValue,NewProbability),
% Prevent "inf" by using values too close to 1.0
secure_probability(NewProbability,NewProbabilityS),
set_fact_probability(FactID,NewProbabilityS),
fail; % go to next tunable fact
true
),
retractall(values_correct).
simulate :-
L = [0.6,1.0,2.0,3.0,10,50,100,200,300],
findall((X,Y),(member(X,L),line_search_evaluate_point(X,Y)),List),
write(List),nl.
gradient_descent :-
save_old_probabilities,
update_values,
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
@@ -998,7 +1122,8 @@ gradient_descent :-
( % go over all tunable facts
tunable_fact(FactID,_),
bb_put(FactID,0.0),
atomic_concat(['grad_',FactID],Key),
bb_put(Key,0.0),
fail; % go to next tunable fact
true
@@ -1029,15 +1154,18 @@ gradient_descent :-
( % go over all tunable facts
tunable_fact(FactID,_),
(
query_gradient(QueryID,FactID,GradValue)
->
true;
GradValue=0.0
),
bb_get(FactID,OldValue),
atomic_concat(['grad_',FactID],Key),
% if the following query fails,
% it means, the fact is not used in the proof
% of QueryID, and the gradient is 0.0 and will
% not contribute to NewValue either way
% DON'T FORGET THIS IF YOU CHANGE SOMETHING HERE!
query_gradient(QueryID,FactID,GradValue),
bb_get(Key,OldValue),
NewValue is OldValue + Y*GradValue,
bb_put(FactID,NewValue),
bb_put(Key,NewValue),
fail; % go to next fact
true
@@ -1054,7 +1182,7 @@ gradient_descent :-
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% start statistics on gradient
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
findall(V,(tunable_fact(FactID,_),bb_get(FactID,V)),GradientValues),
findall(V,(tunable_fact(FactID,_),atomic_concat(['grad_',FactID],Key),bb_get(Key,V)),GradientValues),
sum_list(GradientValues,GradSum),
max_list(GradientValues,GradMax),
@@ -1068,39 +1196,236 @@ gradient_descent :-
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% stop statistics on gradient
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% start add gradient to current probabilities
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
learning_rate(LearningRate),
( % go over all tunable facts
tunable_fact(FactID,_),
get_fact_probability(FactID,OldProbability),
bb_delete(FactID,GradValue),
inv_sigmoid(OldProbability,OldValue),
NewValue is OldValue -LearningRate*GradValue,
sigmoid(NewValue,NewProbability),
% Prevent "inf" by using values too close to 1.0
secure_probability(NewProbability,NewProbabilityS),
set_fact_probability(FactID,NewProbabilityS),
fail; % go to next tunable fact
true
(
line_search(false)
->
learning_rate(LearningRate);
lineSearch(LearningRate,_)
),
format('learning rate = ~12f~n',[LearningRate]),
add_gradient(LearningRate),
logger_set_variable(learning_rate,LearningRate),
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% stop add gradient to current probabilities
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
!,
forget_old_values.
% we're done, mark old values as incorrect
retractall(values_correct).
%========================================================================
%=
%=
%========================================================================
line_search_evaluate_point(Learning_Rate,MSE) :-
add_gradient(Learning_Rate),
mse_trainingset_only_for_linesearch(MSE).
lineSearch(Final_X,Final_Value) :-
% Get Parameters for line search
line_search_tolerance(Tol),
line_search_tau(Tau),
line_search_interval(A,B),
format(' Running line search in interval (~5f,~5f)~n',[A,B]),
% init values
Acc is Tol * (B-A),
InitRight is A + Tau*(B-A),
InitLeft is A + B - InitRight,
line_search_evaluate_point(A,Value_A),
line_search_evaluate_point(B,Value_B),
line_search_evaluate_point(InitRight,Value_InitRight),
line_search_evaluate_point(InitLeft,Value_InitLeft),
bb_put(line_search_a,A),
bb_put(line_search_b,B),
bb_put(line_search_left,InitLeft),
bb_put(line_search_right,InitRight),
bb_put(line_search_value_a,Value_A),
bb_put(line_search_value_b,Value_B),
bb_put(line_search_value_left,Value_InitLeft),
bb_put(line_search_value_right,Value_InitRight),
bb_put(line_search_iteration,1),
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% BEGIN BACK TRACKING
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
(
repeat,
bb_get(line_search_iteration,Iteration),
bb_get(line_search_a,Ak),
bb_get(line_search_b,Bk),
bb_get(line_search_left,Left),
bb_get(line_search_right,Right),
bb_get(line_search_value_a,Fl),
bb_get(line_search_value_b,Fr),
bb_get(line_search_value_left,FLeft),
bb_get(line_search_value_right,FRight),
write(lineSearch(Iteration,Ak,Fl,Bk,Fr,Left,FLeft,Right,FRight)),nl,
(
% check for infinity, if there is, go to the left
( FLeft >= FRight, \+ FLeft = (+inf), \+ FRight = (+inf) )
->
(
AkNew=Left,
FlNew=FLeft,
LeftNew=Right,
FLeftNew=FRight,
RightNew is AkNew + Bk - LeftNew,
line_search_evaluate_point(RightNew,FRightNew),
BkNew=Bk,
FrNew=Fr
);
(
BkNew=Right,
FrNew=FRight,
RightNew=Left,
FRightNew=FLeft,
LeftNew is Ak + BkNew - RightNew,
line_search_evaluate_point(LeftNew,FLeftNew),
AkNew=Ak,
FlNew=Fl
)
),
Next_Iteration is Iteration + 1,
ActAcc is BkNew -AkNew,
bb_put(line_search_iteration,Next_Iteration),
bb_put(line_search_a,AkNew),
bb_put(line_search_b,BkNew),
bb_put(line_search_left,LeftNew),
bb_put(line_search_right,RightNew),
bb_put(line_search_value_a,FlNew),
bb_put(line_search_value_b,FrNew),
bb_put(line_search_value_left,FLeftNew),
bb_put(line_search_value_right,FRightNew),
% is the search interval smaller than the tolerance level?
ActAcc < Acc,
% apperantly it is, so get me out of here and
% cut away the choice point from repeat
!
),
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% END BACK TRACKING
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% clean up the blackboard mess
bb_delete(line_search_iteration,_),
bb_delete(line_search_a,_),
bb_delete(line_search_b,_),
bb_delete(line_search_left,_),
bb_delete(line_search_right,_),
bb_delete(line_search_value_a,_),
bb_delete(line_search_value_b,_),
bb_delete(line_search_value_left,_),
bb_delete(line_search_value_right,_),
% it doesn't harm to check also the value in the middle
% of the current search interval
Middle is (AkNew + BkNew) / 2.0,
line_search_evaluate_point(Middle,Value_Middle),
% return the optimal value
my_5_min(Value_Middle,FlNew,FrNew,FLeftNew,FRightNew,
Middle,AkNew,BkNew,LeftNew,RightNew,
Optimal_Value,Optimal_X),
line_search_postcheck(Optimal_Value,Optimal_X,Final_Value,Final_X).
line_search_postcheck(V,X,V,X) :-
X>0,
!.
line_search_postcheck(V,X,V,X) :-
line_search_never_stop(false),
!.
line_search_postcheck(_,_, LLH, FinalPosition) :-
line_search_tolerance(Tolerance),
line_search_interval(Left,Right),
Offset is (Right - Left) * Tolerance,
bb_put(line_search_offset,Offset),
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
(
repeat,
bb_get(line_search_offset,OldOffset),
NewOffset is OldOffset * Tolerance,
bb_put(line_search_offset,NewOffset),
Position is Left + NewOffset,
set_linesearch_weights_calc_llh(Position,LLH),
bb_put(line_search_llh,LLH),
write(logAtom(lineSearchPostCheck(Position,LLH))),nl,
\+ LLH = (+inf),
!
), % cut away choice point from repeat
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
bb_delete(line_search_llh,LLH),
bb_delete(line_search_offset,FinalOffset),
FinalPosition is Left + FinalOffset.
my_5_min(V1,V2,V3,V4,V5,F1,F2,F3,F4,F5,VMin,FMin) :-
(
V1<V2
->
(VTemp1=V1,FTemp1=F1);
(VTemp1=V2,FTemp1=F2)
),
(
V3<V4
->
(VTemp2=V3,FTemp2=F3);
(VTemp2=V4,FTemp2=F4)
),
(
VTemp1<VTemp2
->
(VTemp3=VTemp1,FTemp3=FTemp1);
(VTemp3=VTemp2,FTemp3=FTemp2)
),
(
VTemp3<V5
->
(VMin=VTemp3,FMin=FTemp3);
(VMin=V5,FMin=F5)
).
%========================================================================
@@ -1123,6 +1448,11 @@ global_initialize :-
set_learning_flag(sigmoid_slope,1.0), % 1.0 gives standard sigmoid
set_learning_flag(init_method,(Query,Probability,BDDFile,ProbFile,
problog_kbest_save(Query,10,Probability,_Status,BDDFile,ProbFile))),
set_learning_flag(line_search,false),
set_learning_flag(line_search_never_stop,true),
set_learning_flag(line_search_tau,0.618033988749895),
set_learning_flag(line_search_tolerance,0.05),
set_learning_flag(line_search_interval,(0,100)),
logger_define_variable(iteration, int),
logger_define_variable(duration,time),
@@ -1137,7 +1467,8 @@ global_initialize :-
logger_define_variable(gradient_max,float),
logger_define_variable(ground_truth_diff,float),
logger_define_variable(ground_truth_mindiff,float),
logger_define_variable(ground_truth_maxdiff,float).
logger_define_variable(ground_truth_maxdiff,float),
logger_define_variable(learning_rate,float).
%========================================================================
%=