bfgs
This commit is contained in:
parent
b24df86cb0
commit
9156b90b66
@ -79,10 +79,7 @@ bind_maplist([Node-(Node-Pr)|MapList], Slope, X) :-
|
||||
get_prob(Node, Prob) :-
|
||||
get_fact_probability(Node,Prob).
|
||||
|
||||
gradient(QueryID, l, Slope) :-
|
||||
probability( QueryID, Slope, Prob),
|
||||
assert(query_probability_intern(QueryID,Prob)),
|
||||
fail.
|
||||
|
||||
gradient(_QueryID, l, _).
|
||||
|
||||
/* query_probability(21,6.775948e-01). */
|
||||
|
@ -17,7 +17,7 @@
|
||||
:- use_module('../problog_lbfgs').
|
||||
|
||||
|
||||
:- if(true).
|
||||
:- if(false).
|
||||
|
||||
:- use_module('kbgraph').
|
||||
|
||||
|
@ -238,7 +238,7 @@
|
||||
:- dynamic(values_correct/0).
|
||||
:- dynamic(learning_initialized/0).
|
||||
:- dynamic(current_iteration/1).
|
||||
:- dynamic(solver_iteration/1).
|
||||
:- dynamic(solver_iterations/2).
|
||||
:- dynamic(example_count/1).
|
||||
:- dynamic(query_probability_intern/2).
|
||||
%:- dynamic(query_gradient_intern/4).
|
||||
@ -263,28 +263,15 @@ user:test_example(A,B,C,=) :-
|
||||
user:test_example(A,B,C),
|
||||
\+ user:problog_discard_example(B).
|
||||
|
||||
solver_iteration(0).
|
||||
solver_iterations(0,0).
|
||||
|
||||
%========================================================================
|
||||
%= store the facts with the learned probabilities to a file
|
||||
%========================================================================
|
||||
|
||||
save_model(X):-
|
||||
problog_flag(sigmoid_slope,Slope),
|
||||
current_iteration(Iteration),
|
||||
solver_iteration(LBFGSIteration),
|
||||
Id is Iteration*100+LBFGSIteration,
|
||||
create_factprobs_file_name(Id,Filename),
|
||||
retractall( query_probability_intern(_,_)),
|
||||
forall(
|
||||
user:example(QueryID,_Query,_QueryProb),
|
||||
(recorded(QueryID,BDD,_),
|
||||
BDD = bdd(_,_,MapList),
|
||||
bind_maplist(MapList, Slope, X),
|
||||
query_probabilities( BDD, BDDProb),
|
||||
assert( query_probability_intern(QueryID,BDDProb)))
|
||||
),
|
||||
export_facts(Filename).
|
||||
save_model:-
|
||||
current_iteration(Id),
|
||||
create_factprobs_file_name(Id,Filename), export_facts(Filename).
|
||||
|
||||
|
||||
|
||||
@ -900,19 +887,44 @@ wrap( _X, _Grad, _GradCount).
|
||||
user:progress(FX,_X,_G, _X_Norm,_G_Norm,_Step,_N,_CurrentIteration,_Ls,-1) :-
|
||||
FX < 0, !,
|
||||
format('stopped on bad FX=~4f~n',[FX]).
|
||||
user:progress(FX,X,_G,X_Norm,G_Norm,Step,_N, LBFGSIteration,Ls,0) :-
|
||||
user:progress(FX,X,G,X_Norm,G_Norm,Step,_N, LBFGSIteration,Ls,0) :-
|
||||
problog_flag(sigmoid_slope,Slope),
|
||||
forall(
|
||||
tunable_fact(FactID,_GroundTruth), set_tunable(FactID,Slope,X)),
|
||||
save_state(X, Slope, G),
|
||||
logger_set_variable(mse_trainingset, FX),
|
||||
retractall(solver_iterations(_)),
|
||||
assert(solver_iterations(LBFGSIteration)),
|
||||
save_model(X),
|
||||
(retract(solver_iterations(SI,_)) -> true ; SI = 0),
|
||||
(retract(current_iteration(TI)) -> true ; TI = 0),
|
||||
SI1 is SI+1,
|
||||
TI1 is TI+1,
|
||||
assert(current_iteration(TI1)),
|
||||
assert(solver_iterations(SI1,LBFGSIteration)),
|
||||
save_model,
|
||||
X0 <== X[0], sigmoid(X0,Slope,P0),
|
||||
X1 <== X[1], sigmoid(X1,Slope,P1),
|
||||
format('~d. Iteration : (x0,x1)=(~4f,~4f) f(X)=~4f |X|=~4f |X\'|=~4f Step=~4f Ls=~4f~n',[LBFGSIteration,P0,P1,FX,X_Norm,G_Norm,Step,Ls]).
|
||||
|
||||
|
||||
save_state(X,Slope,_Grad) :-
|
||||
tunable_fact(FactID,_GroundTruth),
|
||||
set_tunable(FactID,Slope,X),
|
||||
fail.
|
||||
save_state(X, Slope, _) :-
|
||||
user:example(QueryID,_Query,_QueryProb),
|
||||
recorded(QueryID,BDD,_),
|
||||
BDD = bdd(_,_,MapList),
|
||||
bind_maplist(MapList, Slope, X),
|
||||
query_probabilities( BDD, BDDProb),
|
||||
assert( query_probability_intern(QueryID,BDDProb)),
|
||||
fail.
|
||||
save_state(X, Slope, _) :-
|
||||
user:test_example(QueryID,_Query,_QueryProb),
|
||||
recorded(QueryID,BDD,_),
|
||||
BDD = bdd(_,_,MapList),
|
||||
bind_maplist(MapList, Slope, X),
|
||||
query_probabilities( BDD, BDDProb),
|
||||
assert( query_probability_intern(QueryID,BDDProb)),
|
||||
fail.
|
||||
save_state(_X, _Slope, _).
|
||||
|
||||
%========================================================================
|
||||
%= initialize the logger module and set the flags for learning
|
||||
%= don't change anything here! use set_problog_flag/2 instead
|
||||
|
Reference in New Issue
Block a user