This commit is contained in:
Vítor Santos Costa 2019-04-05 08:55:12 +01:00
parent ead29987d6
commit d72770a18c
2 changed files with 19 additions and 17 deletions

View File

@ -825,11 +825,11 @@ user:evaluate(LLH_Training_Queries, X,Grad,N,_,_) :-
go( X,Grad, LLs), go( X,Grad, LLs),
Error, Error,
(writeln(Error), throw(Error) )), (writeln(Error), throw(Error) )),
length(LLs,NN), length(LLs,NN),
V <== array[NN] of LLs, V <== array[NN] of LLs,
LLH_Training_Queries <== sum(V), SLL <== sum(V),
writeln( LLH_Training_Queries). %sum_list( LLs, SLL),
% sum_list( LLs, LLH_Training_Queries), LLH_Training_Queries[0] <== SLL.
test :- test :-
S =.. [f,0-0.9,1-0.8,2-0.6,3-0.7,4-0.5,5-0.4,6-0.7,7-0.2], S =.. [f,0-0.9,1-0.8,2-0.6,3-0.7,4-0.5,5-0.4,6-0.7,7-0.2],
@ -843,7 +843,8 @@ Grad <== array[N] of floats,
LL, LL,
compute_gradient(Grad, X, Slope,LL), compute_gradient(Grad, X, Slope,LL),
LLs LLs
), sum_list( LLs, _LLH_Training_Queries). ), sum_list( LLs, SLL).
@ -862,7 +863,10 @@ compute_gradient( Grad, X, Slope, LL) :-
BDD = bdd(_,_,MapList), BDD = bdd(_,_,MapList),
MapList = [_|_], MapList = [_|_],
bind_maplist(MapList, Slope, X), bind_maplist(MapList, Slope, X),
%writeln(QueryID:MapList),
query_probabilities( BDD, BDDProb), query_probabilities( BDD, BDDProb),
(isnan(BDDProb) -> writeln((nan::QueryID)), fail;true),
writeln(BDDProb),
LL is (BDDProb-QueryProb)*(BDDProb-QueryProb), LL is (BDDProb-QueryProb)*(BDDProb-QueryProb),
forall( forall(
query_gradients(BDD,I,IProb,GradValue), query_gradients(BDD,I,IProb,GradValue),
@ -872,6 +876,7 @@ compute_gradient( Grad, X, Slope, LL) :-
gradient_pair(BDDProb, QueryProb, Grad, GradValue, I, Prob) :- gradient_pair(BDDProb, QueryProb, Grad, GradValue, I, Prob) :-
G0 <== Grad[I], G0 <== Grad[I],
GN is G0-GradValue*Prob*(1-Prob)*2*(QueryProb-BDDProb), GN is G0-GradValue*Prob*(1-Prob)*2*(QueryProb-BDDProb),
(isnan(GN) -> writeln((nan::I)), fail;true),
Grad[I] <== GN. Grad[I] <== GN.
wrap( X, Grad, GradCount) :- wrap( X, Grad, GradCount) :-

View File

@ -42,13 +42,14 @@ static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x,
const lbfgsfloatval_t step) { const lbfgsfloatval_t step) {
YAP_Term call; YAP_Term call;
YAP_Bool result; YAP_Bool result;
lbfgsfloatval_t rc; lbfgsfloatval_t rc=0.0;
YAP_Term v, t1, t12; YAP_Term v, t1, t12;
YAP_Term t[6], t2[2]; YAP_Term t[6], t2[2];
t[0] = v = YAP_MkVarTerm(); YAP_Term t_0 = YAP_MkIntTerm((YAP_Int)&rc);
t1 = YAP_MkIntTerm((YAP_Int)x); t[0] = YAP_MkApplTerm(ffloats, 1, &t_0);
t[1] = YAP_MkApplTerm(ffloats, 1, &t1); YAP_Term t_1 = YAP_MkIntTerm((YAP_Int)x);
t[1] = YAP_MkApplTerm(ffloats, 1, &t_1);
t12 = YAP_MkIntTerm((YAP_Int)g_tmp); t12 = YAP_MkIntTerm((YAP_Int)g_tmp);
t[2] = YAP_MkApplTerm(ffloats, 1, &t12); t[2] = YAP_MkApplTerm(ffloats, 1, &t12);
t[3] = YAP_MkIntTerm(n); t[3] = YAP_MkIntTerm(n);
@ -70,13 +71,9 @@ static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x,
// Goal did not succeed // Goal did not succeed
return FALSE; return FALSE;
} }
YAP_Term o; YAP_ShutdownGoal(true);
if (YAP_IsIntTerm((o = YAP_GetFromSlot(sl))))
rc = YAP_IntOfTerm(o);
else
rc = YAP_FloatOfTerm(o);
YAP_ShutdownGoal(true);
YAP_RecoverSlots(1, sl); YAP_RecoverSlots(1, sl);
fprintf(stderr,"%gxo\n",rc);
return rc; return rc;
} }
@ -124,7 +121,7 @@ static int progress(void *instance, const lbfgsfloatval_t *local_x,
if (YAP_IsIntTerm(o)) { if (YAP_IsIntTerm(o)) {
int v = YAP_IntOfTerm(o); int v = YAP_IntOfTerm(o);
YAP_ShutdownGoal(true); YAP_ShutdownGoal(true);
return (int)v; return (int)v;
} }
YAP_ShutdownGoal(true); YAP_ShutdownGoal(true);