This commit is contained in:
Vitor Santos Costa 2019-04-01 13:40:17 +01:00
parent b24df86cb0
commit 9156b90b66
3 changed files with 38 additions and 29 deletions

View File

@ -79,10 +79,7 @@ bind_maplist([Node-(Node-Pr)|MapList], Slope, X) :-
get_prob(Node, Prob) :-
get_fact_probability(Node,Prob).
gradient(QueryID, l, Slope) :-
probability( QueryID, Slope, Prob),
assert(query_probability_intern(QueryID,Prob)),
fail.
gradient(_QueryID, l, _).
/* query_probability(21,6.775948e-01). */

View File

@ -17,7 +17,7 @@
:- use_module('../problog_lbfgs').
:- if(true).
:- if(false).
:- use_module('kbgraph').

View File

@ -238,7 +238,7 @@
:- dynamic(values_correct/0).
:- dynamic(learning_initialized/0).
:- dynamic(current_iteration/1).
:- dynamic(solver_iteration/1).
:- dynamic(solver_iterations/2).
:- dynamic(example_count/1).
:- dynamic(query_probability_intern/2).
%:- dynamic(query_gradient_intern/4).
@ -263,28 +263,15 @@ user:test_example(A,B,C,=) :-
user:test_example(A,B,C),
\+ user:problog_discard_example(B).
solver_iteration(0).
solver_iterations(0,0).
%========================================================================
%= store the facts with the learned probabilities to a file
%========================================================================
save_model(X):-
problog_flag(sigmoid_slope,Slope),
current_iteration(Iteration),
solver_iteration(LBFGSIteration),
Id is Iteration*100+LBFGSIteration,
create_factprobs_file_name(Id,Filename),
retractall( query_probability_intern(_,_)),
forall(
user:example(QueryID,_Query,_QueryProb),
(recorded(QueryID,BDD,_),
BDD = bdd(_,_,MapList),
bind_maplist(MapList, Slope, X),
query_probabilities( BDD, BDDProb),
assert( query_probability_intern(QueryID,BDDProb)))
),
export_facts(Filename).
save_model:-
current_iteration(Id),
create_factprobs_file_name(Id,Filename), export_facts(Filename).
@ -900,19 +887,44 @@ wrap( _X, _Grad, _GradCount).
user:progress(FX,_X,_G, _X_Norm,_G_Norm,_Step,_N,_CurrentIteration,_Ls,-1) :-
FX < 0, !,
format('stopped on bad FX=~4f~n',[FX]).
user:progress(FX,X,_G,X_Norm,G_Norm,Step,_N, LBFGSIteration,Ls,0) :-
user:progress(FX,X,G,X_Norm,G_Norm,Step,_N, LBFGSIteration,Ls,0) :-
problog_flag(sigmoid_slope,Slope),
forall(
tunable_fact(FactID,_GroundTruth), set_tunable(FactID,Slope,X)),
save_state(X, Slope, G),
logger_set_variable(mse_trainingset, FX),
retractall(solver_iterations(_)),
assert(solver_iterations(LBFGSIteration)),
save_model(X),
(retract(solver_iterations(SI,_)) -> true ; SI = 0),
(retract(current_iteration(TI)) -> true ; TI = 0),
SI1 is SI+1,
TI1 is TI+1,
assert(current_iteration(TI1)),
assert(solver_iterations(SI1,LBFGSIteration)),
save_model,
X0 <== X[0], sigmoid(X0,Slope,P0),
X1 <== X[1], sigmoid(X1,Slope,P1),
format('~d. Iteration : (x0,x1)=(~4f,~4f) f(X)=~4f |X|=~4f |X\'|=~4f Step=~4f Ls=~4f~n',[LBFGSIteration,P0,P1,FX,X_Norm,G_Norm,Step,Ls]).
save_state(X,Slope,_Grad) :-
tunable_fact(FactID,_GroundTruth),
set_tunable(FactID,Slope,X),
fail.
save_state(X, Slope, _) :-
user:example(QueryID,_Query,_QueryProb),
recorded(QueryID,BDD,_),
BDD = bdd(_,_,MapList),
bind_maplist(MapList, Slope, X),
query_probabilities( BDD, BDDProb),
assert( query_probability_intern(QueryID,BDDProb)),
fail.
save_state(X, Slope, _) :-
user:test_example(QueryID,_Query,_QueryProb),
recorded(QueryID,BDD,_),
BDD = bdd(_,_,MapList),
bind_maplist(MapList, Slope, X),
query_probabilities( BDD, BDDProb),
assert( query_probability_intern(QueryID,BDDProb)),
fail.
save_state(_X, _Slope, _).
%========================================================================
%= initialize the logger module and set the flags for learning
%= don't change anything here! use set_problog_flag/2 instead