Some ProbLog improvements related with tabling efficiency, more to come soon

This commit is contained in:
Theofrastos Mantadelis
2010-11-03 19:22:11 +01:00
parent 362ecc2f16
commit c804d105b6
8 changed files with 1025 additions and 589 deletions

View File

@@ -2,8 +2,8 @@
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% $Date: 2010-10-05 16:52:13 +0200 (Tue, 05 Oct 2010) $
% $Revision: 4869 $
% $Date: 2010-10-20 18:06:47 +0200 (Wed, 20 Oct 2010) $
% $Revision: 4969 $
%
% This file is part of ProbLog
% http://dtai.cs.kuleuven.be/problog
@@ -225,6 +225,7 @@
:- use_module('problog/os').
:- use_module('problog/print_learning').
:- use_module('problog/utils_learning').
:- use_module('problog/utils').
% used to indicate the state of the system
:- dynamic(values_correct/0).
@@ -549,11 +550,12 @@ init_learning :-
!.
init_learning :-
check_examples,
empty_output_directory,
logger_write_header,
format_learning(1,'Initializing everything~n',[]),
empty_output_directory,
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Delete the BDDs from the previous run if they should
@@ -1031,6 +1033,7 @@ mse_testset :-
(
format_learning(2,'MSE_Test ',[]),
update_values,
bb_put(llh_test_queries,0.0),
findall(SquaredError,
(user:test_example(QueryID,_Query,QueryProb,Type),
once(update_query(QueryID,'+',probability)),
@@ -1041,7 +1044,10 @@ mse_testset :-
->
SquaredError is (CurrentProb-QueryProb)**2;
SquaredError = 0.0
)
),
bb_get(llh_test_queries,Old_LLH_Test_Queries),
New_LLH_Test_Queries is Old_LLH_Test_Queries+log(CurrentProb),
bb_put(llh_test_queries,New_LLH_Test_Queries)
),
AllSquaredErrors),
@@ -1050,10 +1056,12 @@ mse_testset :-
min_list(AllSquaredErrors,MinError),
max_list(AllSquaredErrors,MaxError),
MSE is SumAllSquaredErrors/Length,
bb_delete(llh_test_queries,LLH_Test_Queries),
logger_set_variable(mse_testset,MSE),
logger_set_variable(mse_min_testset,MinError),
logger_set_variable(mse_max_testset,MaxError),
logger_set_variable(llh_test_queries,LLH_Test_Queries),
format_learning(2,' (~8f)~n',[MSE])
); true
).
@@ -1232,6 +1240,7 @@ gradient_descent :-
bb_put(mse_train_sum, 0.0),
bb_put(mse_train_min, 0.0),
bb_put(mse_train_max, 0.0),
bb_put(llh_training_queries, 0.0),
problog_flag(alpha,Alpha),
logger_set_variable(alpha,Alpha),
@@ -1267,12 +1276,15 @@ gradient_descent :-
bb_get(mse_train_sum,Old_MSE_Train_Sum),
bb_get(mse_train_min,Old_MSE_Train_Min),
bb_get(mse_train_max,Old_MSE_Train_Max),
bb_get(llh_training_queries,Old_LLH_Training_Queries),
New_MSE_Train_Sum is Old_MSE_Train_Sum+Squared_Error,
New_MSE_Train_Min is min(Old_MSE_Train_Min,Squared_Error),
New_MSE_Train_Max is max(Old_MSE_Train_Max,Squared_Error),
New_LLH_Training_Queries is Old_LLH_Training_Queries+log(BDDProb),
bb_put(mse_train_sum,New_MSE_Train_Sum),
bb_put(mse_train_min,New_MSE_Train_Min),
bb_put(mse_train_max,New_MSE_Train_Max),
bb_put(llh_training_queries,New_LLH_Training_Queries),
@@ -1368,11 +1380,13 @@ gradient_descent :-
bb_delete(mse_train_sum,MSE_Train_Sum),
bb_delete(mse_train_min,MSE_Train_Min),
bb_delete(mse_train_max,MSE_Train_Max),
bb_delete(llh_training_queries,LLH_Training_Queries),
MSE is MSE_Train_Sum/Example_Count,
logger_set_variable(mse_trainingset,MSE),
logger_set_variable(mse_min_trainingset,MSE_Train_Min),
logger_set_variable(mse_max_trainingset,MSE_Train_Max),
logger_set_variable(llh_training_queries,LLH_Training_Queries),
format_learning(2,'~n',[]),
@@ -1670,7 +1684,9 @@ init_logger :-
logger_define_variable(ground_truth_mindiff,float),
logger_define_variable(ground_truth_maxdiff,float),
logger_define_variable(learning_rate,float),
logger_define_variable(alpha,float).
logger_define_variable(alpha,float),
logger_define_variable(llh_training_queries,float),
logger_define_variable(llh_test_queries,float).
:- initialization(init_flags).
:- initialization(init_logger).