em/8 returns the list of updated probabilities of examples

This commit is contained in:
Fabrizio Riguzzi 2013-09-08 17:46:36 +02:00
parent cbf31fcd50
commit b00a5bf7fc
3 changed files with 22 additions and 8 deletions

View File

@ -242,7 +242,12 @@ static double Expectation(DdNode **nodes_ex,int lenNodes)
} }
else else
if (nodes_ex[i]==Cudd_ReadLogicZero(mgr_ex[i])) if (nodes_ex[i]==Cudd_ReadLogicZero(mgr_ex[i]))
{
CLL=CLL+LOGZERO*example_prob[i]; CLL=CLL+LOGZERO*example_prob[i];
nodes_probs_ex[i]=0.0;
}
else
nodes_probs_ex[i]=1.0;
} }
return CLL; return CLL;
} }
@ -266,7 +271,6 @@ static int end(void)
free(probs_ex); free(probs_ex);
free(nVars_ex); free(nVars_ex);
free(boolVars_ex); free(boolVars_ex);
free(nodes_probs_ex);
for (r=0;r<nRules;r++) for (r=0;r<nRules;r++)
{ {
for (i=0;i<rules[r]-1;i++) for (i=0;i<rules[r]-1;i++)
@ -944,8 +948,8 @@ static int randomize(void)
static int EM(void) static int EM(void)
{ {
YAP_Term arg1,arg2,arg3,arg4,arg5,arg6,arg7, YAP_Term arg1,arg2,arg3,arg4,arg5,arg6,arg7,arg8,
out1,out2,nodesTerm,ruleTerm,tail,pair,compoundTerm; out1,out2,out3,nodesTerm,ruleTerm,tail,pair,compoundTerm;
DdNode * node1,**nodes_ex; DdNode * node1,**nodes_ex;
int r,lenNodes,i,iter; int r,lenNodes,i,iter;
long iter1; long iter1;
@ -961,6 +965,7 @@ static int EM(void)
arg5=YAP_ARG5; arg5=YAP_ARG5;
arg6=YAP_ARG6; arg6=YAP_ARG6;
arg7=YAP_ARG7; arg7=YAP_ARG7;
arg8=YAP_ARG8;
nodesTerm=arg1; nodesTerm=arg1;
ea=YAP_FloatOfTerm(arg2); ea=YAP_FloatOfTerm(arg2);
@ -1022,11 +1027,18 @@ static int EM(void)
compoundTerm=YAP_MkPairTerm(ruleTerm,YAP_MkPairTerm(tail,YAP_TermNil())); compoundTerm=YAP_MkPairTerm(ruleTerm,YAP_MkPairTerm(tail,YAP_TermNil()));
out2=YAP_MkPairTerm(compoundTerm,out2); out2=YAP_MkPairTerm(compoundTerm,out2);
} }
out3= YAP_TermNil();
for (i=0;i<lenNodes;i++)
{
out3=YAP_MkPairTerm(YAP_MkFloatTerm(nodes_probs_ex[i]),out3);
}
YAP_Unify(out3,arg8);
out1=YAP_MkFloatTerm(CLL1); out1=YAP_MkFloatTerm(CLL1);
YAP_Unify(out1,arg6); YAP_Unify(out1,arg6);
free(nodes_ex); free(nodes_ex);
free(example_prob); free(example_prob);
free(nodes_probs_ex);
return (YAP_Unify(out2,arg7)); return (YAP_Unify(out2,arg7));
} }
@ -1144,6 +1156,8 @@ static int dag_size(void)
void init_my_predicates() void init_my_predicates()
/* function required by YAP for intitializing the predicates defined by a C function*/ /* function required by YAP for intitializing the predicates defined by a C function*/
{ {
srand(10);
YAP_UserCPredicate("init",init,2); YAP_UserCPredicate("init",init,2);
YAP_UserCPredicate("init_bdd",init_bdd,0); YAP_UserCPredicate("init_bdd",init_bdd,0);
YAP_UserCPredicate("end",end,0); YAP_UserCPredicate("end",end,0);
@ -1159,7 +1173,7 @@ void init_my_predicates()
YAP_UserCPredicate("init_test",init_test,1); YAP_UserCPredicate("init_test",init_test,1);
YAP_UserCPredicate("end_test",end_test,0); YAP_UserCPredicate("end_test",end_test,0);
YAP_UserCPredicate("ret_prob",ret_prob,2); YAP_UserCPredicate("ret_prob",ret_prob,2);
YAP_UserCPredicate("em",EM,7); YAP_UserCPredicate("em",EM,8);
YAP_UserCPredicate("q",Q,4); YAP_UserCPredicate("q",Q,4);
YAP_UserCPredicate("randomize",randomize,0); YAP_UserCPredicate("randomize",randomize,0);
YAP_UserCPredicate("deref",rec_deref,1); YAP_UserCPredicate("deref",rec_deref,1);

View File

@ -419,7 +419,7 @@ random_restarts(N,Nodes,CLL0,CLL,Par0,Par,LE):-
setting(epsilon_em_fraction,ER), setting(epsilon_em_fraction,ER),
length(Nodes,L), length(Nodes,L),
setting(iter,Iter), setting(iter,Iter),
em(Nodes,EA,ER,L,Iter,CLLR,Par1), em(Nodes,EA,ER,L,Iter,CLLR,Par1,_ExP),
setting(verbosity,Ver), setting(verbosity,Ver),
(Ver>2-> (Ver>2->
format("Random_restart: CLL ~f~n",[CLLR]) format("Random_restart: CLL ~f~n",[CLLR])
@ -450,7 +450,7 @@ random_restarts_ref(N,Nodes,CLL0,CLL,Par0,Par,LE):-
setting(epsilon_em_fraction,ER), setting(epsilon_em_fraction,ER),
length(Nodes,L), length(Nodes,L),
setting(iterREF,Iter), setting(iterREF,Iter),
em(Nodes,EA,ER,L,Iter,CLLR,Par1), em(Nodes,EA,ER,L,Iter,CLLR,Par1,_ExP),
setting(verbosity,Ver), setting(verbosity,Ver),
(Ver>2-> (Ver>2->
format("Random_restart: CLL ~f~n",[CLLR]) format("Random_restart: CLL ~f~n",[CLLR])

View File

@ -648,7 +648,7 @@ random_restarts(N,Nodes,CLL0,CLL,Par0,Par,LE):-
setting(epsilon_em_fraction,ER), setting(epsilon_em_fraction,ER),
length(Nodes,L), length(Nodes,L),
setting(iter,Iter), setting(iter,Iter),
em(Nodes,EA,ER,L,Iter,CLLR,Par1), em(Nodes,EA,ER,L,Iter,CLLR,Par1,_ExP),
setting(verbosity,Ver), setting(verbosity,Ver),
(Ver>2-> (Ver>2->
format("Random_restart: CLL ~f~n",[CLLR]) format("Random_restart: CLL ~f~n",[CLLR])
@ -679,7 +679,7 @@ random_restarts_ref(N,Nodes,CLL0,CLL,Par0,Par,LE):-
setting(epsilon_em_fraction,ER), setting(epsilon_em_fraction,ER),
length(Nodes,L), length(Nodes,L),
setting(iterREF,Iter), setting(iterREF,Iter),
em(Nodes,EA,ER,L,Iter,CLLR,Par1), em(Nodes,EA,ER,L,Iter,CLLR,Par1,_ExP),
setting(verbosity,Ver), setting(verbosity,Ver),
(Ver>2-> (Ver>2->
format("Random_restart: CLL ~f~n",[CLLR]) format("Random_restart: CLL ~f~n",[CLLR])