Merge branch 'master' of /cygdrive/z/vitor/Yap/yap-6.3
This commit is contained in:
commit
4fe1833ece
@ -2081,8 +2081,13 @@ Yap_absmi(int inp)
|
|||||||
goto failloop;
|
goto failloop;
|
||||||
} else
|
} else
|
||||||
#endif /* FROZEN_STACKS */
|
#endif /* FROZEN_STACKS */
|
||||||
if (IN_BETWEEN(H0,pt1,H) && IsAttVar(pt1))
|
if (IN_BETWEEN(H0,pt1,H)) {
|
||||||
goto failloop;
|
if (IsAttVar(pt1)) {
|
||||||
|
goto failloop;
|
||||||
|
} else if (*pt1 == (CELL)FunctorBigInt) {
|
||||||
|
Yap_CleanOpaqueVariable(pt1);
|
||||||
|
}
|
||||||
|
}
|
||||||
#ifdef FROZEN_STACKS /* TRAIL */
|
#ifdef FROZEN_STACKS /* TRAIL */
|
||||||
/* don't reset frozen variables */
|
/* don't reset frozen variables */
|
||||||
if (pt0 < TR_FZ)
|
if (pt0 < TR_FZ)
|
||||||
|
14
C/amasm.c
14
C/amasm.c
@ -252,7 +252,7 @@ STATIC_PROTO(yamop *a_if, (op_numbers, union clause_obj *, int, yamop *, int, st
|
|||||||
STATIC_PROTO(yamop *a_cut, (clause_info *,yamop *, int, struct intermediates *));
|
STATIC_PROTO(yamop *a_cut, (clause_info *,yamop *, int, struct intermediates *));
|
||||||
#ifdef YAPOR
|
#ifdef YAPOR
|
||||||
STATIC_PROTO(yamop *a_try, (op_numbers, CELL, CELL, int, int, yamop *, int, struct intermediates *));
|
STATIC_PROTO(yamop *a_try, (op_numbers, CELL, CELL, int, int, yamop *, int, struct intermediates *));
|
||||||
STATIC_PROTO(yamop *a_either, (op_numbers, CELL, CELL, int, int, yamop *, int, struct intermediates *));
|
STATIC_PROTO(yamop *a_either, (op_numbers, CELL, CELL, int, yamop *, int, struct intermediates *));
|
||||||
#else
|
#else
|
||||||
STATIC_PROTO(yamop *a_try, (op_numbers, CELL, CELL, yamop *, int, struct intermediates *));
|
STATIC_PROTO(yamop *a_try, (op_numbers, CELL, CELL, yamop *, int, struct intermediates *));
|
||||||
STATIC_PROTO(yamop *a_either, (op_numbers, CELL, CELL, yamop *, int, struct intermediates *));
|
STATIC_PROTO(yamop *a_either, (op_numbers, CELL, CELL, yamop *, int, struct intermediates *));
|
||||||
@ -2155,7 +2155,7 @@ a_try(op_numbers opcode, CELL lab, CELL opr, int nofalts, int hascut, yamop *cod
|
|||||||
#endif
|
#endif
|
||||||
#ifdef YAPOR
|
#ifdef YAPOR
|
||||||
INIT_YAMOP_LTT(code_p, nofalts);
|
INIT_YAMOP_LTT(code_p, nofalts);
|
||||||
if (hascut)
|
if (cip->clause_has_cut)
|
||||||
PUT_YAMOP_CUT(code_p);
|
PUT_YAMOP_CUT(code_p);
|
||||||
if (ap->PredFlags & SequentialPredFlag)
|
if (ap->PredFlags & SequentialPredFlag)
|
||||||
PUT_YAMOP_SEQ(code_p);
|
PUT_YAMOP_SEQ(code_p);
|
||||||
@ -2167,7 +2167,7 @@ a_try(op_numbers opcode, CELL lab, CELL opr, int nofalts, int hascut, yamop *cod
|
|||||||
|
|
||||||
static yamop *
|
static yamop *
|
||||||
#ifdef YAPOR
|
#ifdef YAPOR
|
||||||
a_either(op_numbers opcode, CELL opr, CELL lab, int nofalts, int hascut, yamop *code_p, int pass_no, struct intermediates *cip)
|
a_either(op_numbers opcode, CELL opr, CELL lab, int nofalts, yamop *code_p, int pass_no, struct intermediates *cip)
|
||||||
#else
|
#else
|
||||||
a_either(op_numbers opcode, CELL opr, CELL lab, yamop *code_p, int pass_no, struct intermediates *cip)
|
a_either(op_numbers opcode, CELL opr, CELL lab, yamop *code_p, int pass_no, struct intermediates *cip)
|
||||||
#endif /* YAPOR */
|
#endif /* YAPOR */
|
||||||
@ -2179,7 +2179,7 @@ a_either(op_numbers opcode, CELL opr, CELL lab, yamop *code_p, int pass_no, stru
|
|||||||
code_p->u.Osblp.p0 = cip->CurrentPred;
|
code_p->u.Osblp.p0 = cip->CurrentPred;
|
||||||
#ifdef YAPOR
|
#ifdef YAPOR
|
||||||
INIT_YAMOP_LTT(code_p, nofalts);
|
INIT_YAMOP_LTT(code_p, nofalts);
|
||||||
if (hascut)
|
if (cip->clause_has_cut)
|
||||||
PUT_YAMOP_CUT(code_p);
|
PUT_YAMOP_CUT(code_p);
|
||||||
if (cip->CurrentPred->PredFlags & SequentialPredFlag)
|
if (cip->CurrentPred->PredFlags & SequentialPredFlag)
|
||||||
PUT_YAMOP_SEQ(code_p);
|
PUT_YAMOP_SEQ(code_p);
|
||||||
@ -3583,7 +3583,7 @@ do_pass(int pass_no, yamop **entry_codep, int assembling, int *clause_has_blobsp
|
|||||||
}
|
}
|
||||||
code_p = a_either(_either,
|
code_p = a_either(_either,
|
||||||
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
|
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
|
||||||
Unsigned(cip->code_addr) + cip->label_offset[cip->cpc->rnd1], 0, 0, code_p, pass_no, cip);
|
Unsigned(cip->code_addr) + cip->label_offset[cip->cpc->rnd1], 0, code_p, pass_no, cip);
|
||||||
#else
|
#else
|
||||||
code_p = a_either(_either,
|
code_p = a_either(_either,
|
||||||
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
|
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
|
||||||
@ -3596,7 +3596,7 @@ do_pass(int pass_no, yamop **entry_codep, int assembling, int *clause_has_blobsp
|
|||||||
either_inst[either_cont++] = code_p;
|
either_inst[either_cont++] = code_p;
|
||||||
code_p = a_either(_or_else,
|
code_p = a_either(_or_else,
|
||||||
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
|
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
|
||||||
Unsigned(cip->code_addr) + cip->label_offset[cip->cpc->rnd1], 0, 0, code_p, pass_no, cip);
|
Unsigned(cip->code_addr) + cip->label_offset[cip->cpc->rnd1], 0, code_p, pass_no, cip);
|
||||||
#else
|
#else
|
||||||
code_p = a_either(_or_else,
|
code_p = a_either(_or_else,
|
||||||
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
|
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
|
||||||
@ -3608,7 +3608,7 @@ do_pass(int pass_no, yamop **entry_codep, int assembling, int *clause_has_blobsp
|
|||||||
#ifdef YAPOR
|
#ifdef YAPOR
|
||||||
if (pass_no)
|
if (pass_no)
|
||||||
either_inst[either_cont++] = code_p;
|
either_inst[either_cont++] = code_p;
|
||||||
code_p = a_either(_or_last, 0, 0, 0, 0, code_p, pass_no, cip);
|
code_p = a_either(_or_last, 0, 0, 0, code_p, pass_no, cip);
|
||||||
if (pass_no) {
|
if (pass_no) {
|
||||||
int cont = 1;
|
int cont = 1;
|
||||||
do {
|
do {
|
||||||
|
57
C/bignum.c
57
C/bignum.c
@ -25,9 +25,10 @@ static char SccsId[] = "%W% %G%";
|
|||||||
#include <string.h>
|
#include <string.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "YapHeap.h"
|
||||||
|
|
||||||
#ifdef USE_GMP
|
#ifdef USE_GMP
|
||||||
|
|
||||||
#include "YapHeap.h"
|
|
||||||
#include "eval.h"
|
#include "eval.h"
|
||||||
#include "alloc.h"
|
#include "alloc.h"
|
||||||
|
|
||||||
@ -59,6 +60,7 @@ Yap_MkBigIntTerm(MP_INT *big)
|
|||||||
return AbsAppl(ret);
|
return AbsAppl(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
MP_INT *
|
MP_INT *
|
||||||
Yap_BigIntOfTerm(Term t)
|
Yap_BigIntOfTerm(Term t)
|
||||||
{
|
{
|
||||||
@ -127,9 +129,60 @@ Yap_RatTermToApplTerm(Term t)
|
|||||||
return Yap_MkApplTerm(FunctorRDiv,2,ts);
|
return Yap_MkApplTerm(FunctorRDiv,2,ts);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
Term
|
||||||
|
Yap_AllocExternalDataInStack(CELL tag, size_t bytes)
|
||||||
|
{
|
||||||
|
CACHE_REGS
|
||||||
|
Int nlimbs;
|
||||||
|
MP_INT *dst = (MP_INT *)(H+2);
|
||||||
|
CELL *ret = H;
|
||||||
|
|
||||||
|
nlimbs = ALIGN_YAPTYPE(bytes,CELL)/CellSize;
|
||||||
|
if (nlimbs > (ASP-ret)-1024) {
|
||||||
|
return TermNil;
|
||||||
|
}
|
||||||
|
H[0] = (CELL)FunctorBigInt;
|
||||||
|
H[1] = tag;
|
||||||
|
dst->_mp_size = 0;
|
||||||
|
dst->_mp_alloc = nlimbs;
|
||||||
|
H = (CELL *)(dst+1)+nlimbs;
|
||||||
|
H[0] = EndSpecials;
|
||||||
|
H++;
|
||||||
|
if (tag != EXTERNAL_BLOB) {
|
||||||
|
TrailTerm(TR) = AbsPair(ret);
|
||||||
|
TR++;
|
||||||
|
}
|
||||||
|
return AbsAppl(ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
int Yap_CleanOpaqueVariable(CELL *pt)
|
||||||
|
{
|
||||||
|
CELL blob_info, blob_tag;
|
||||||
|
MP_INT *blobp;
|
||||||
|
#ifdef DEBUG
|
||||||
|
/* sanity checking */
|
||||||
|
if (pt[0] != (CELL)FunctorBigInt) {
|
||||||
|
Yap_Error(SYSTEM_ERROR, TermNil, "CleanOpaqueVariable bad call");
|
||||||
|
return FALSE;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
blob_tag = pt[1];
|
||||||
|
if (blob_tag < USER_BLOB_START ||
|
||||||
|
blob_tag >= USER_BLOB_END) {
|
||||||
|
Yap_Error(SYSTEM_ERROR, AbsAppl(pt), "clean opaque: bad blob with tag " UInt_FORMAT ,blob_tag);
|
||||||
|
return FALSE;
|
||||||
|
}
|
||||||
|
blob_info = blob_tag - USER_BLOB_START;
|
||||||
|
if (!GLOBAL_OpaqueHandlers)
|
||||||
|
return FALSE;
|
||||||
|
blobp = (MP_INT *)(pt+2);
|
||||||
|
if (!GLOBAL_OpaqueHandlers[blob_info].fail_handler)
|
||||||
|
return TRUE;
|
||||||
|
return (GLOBAL_OpaqueHandlers[blob_info].fail_handler)((void *)(blobp+1));
|
||||||
|
}
|
||||||
|
|
||||||
Term
|
Term
|
||||||
Yap_MkULLIntTerm(YAP_ULONG_LONG n)
|
Yap_MkULLIntTerm(YAP_ULONG_LONG n)
|
||||||
{
|
{
|
||||||
|
@ -390,6 +390,8 @@ X_API Bool STD_PROTO(YAP_IsDbRefTerm,(Term));
|
|||||||
X_API Bool STD_PROTO(YAP_IsAtomTerm,(Term));
|
X_API Bool STD_PROTO(YAP_IsAtomTerm,(Term));
|
||||||
X_API Bool STD_PROTO(YAP_IsPairTerm,(Term));
|
X_API Bool STD_PROTO(YAP_IsPairTerm,(Term));
|
||||||
X_API Bool STD_PROTO(YAP_IsApplTerm,(Term));
|
X_API Bool STD_PROTO(YAP_IsApplTerm,(Term));
|
||||||
|
X_API Bool STD_PROTO(YAP_IsExternalDataInStackTerm,(Term));
|
||||||
|
X_API Bool STD_PROTO(YAP_IsOpaqueObjectTerm,(Term, int));
|
||||||
X_API Term STD_PROTO(YAP_MkIntTerm,(Int));
|
X_API Term STD_PROTO(YAP_MkIntTerm,(Int));
|
||||||
X_API Term STD_PROTO(YAP_MkBigNumTerm,(void *));
|
X_API Term STD_PROTO(YAP_MkBigNumTerm,(void *));
|
||||||
X_API Term STD_PROTO(YAP_MkRationalTerm,(void *));
|
X_API Term STD_PROTO(YAP_MkRationalTerm,(void *));
|
||||||
@ -463,7 +465,7 @@ X_API IOSTREAM *STD_PROTO(YAP_TermToStream,(Term));
|
|||||||
X_API IOSTREAM *STD_PROTO(YAP_InitConsult,(int, char *));
|
X_API IOSTREAM *STD_PROTO(YAP_InitConsult,(int, char *));
|
||||||
X_API void STD_PROTO(YAP_EndConsult,(IOSTREAM *));
|
X_API void STD_PROTO(YAP_EndConsult,(IOSTREAM *));
|
||||||
X_API Term STD_PROTO(YAP_Read, (IOSTREAM *));
|
X_API Term STD_PROTO(YAP_Read, (IOSTREAM *));
|
||||||
X_API void STD_PROTO(YAP_Write, (Term, int (*)(wchar_t), int));
|
X_API void STD_PROTO(YAP_Write, (Term, IOSTREAM *, int));
|
||||||
X_API Term STD_PROTO(YAP_CopyTerm, (Term));
|
X_API Term STD_PROTO(YAP_CopyTerm, (Term));
|
||||||
X_API Term STD_PROTO(YAP_WriteBuffer, (Term, char *, unsigned int, int));
|
X_API Term STD_PROTO(YAP_WriteBuffer, (Term, char *, unsigned int, int));
|
||||||
X_API char *STD_PROTO(YAP_CompileClause, (Term));
|
X_API char *STD_PROTO(YAP_CompileClause, (Term));
|
||||||
@ -536,13 +538,11 @@ X_API Term STD_PROTO(YAP_ModuleUser,(void));
|
|||||||
X_API Int STD_PROTO(YAP_NumberOfClausesForPredicate,(PredEntry *));
|
X_API Int STD_PROTO(YAP_NumberOfClausesForPredicate,(PredEntry *));
|
||||||
X_API int STD_PROTO(YAP_MaxOpPriority,(Atom, Term));
|
X_API int STD_PROTO(YAP_MaxOpPriority,(Atom, Term));
|
||||||
X_API int STD_PROTO(YAP_OpInfo,(Atom, Term, int, int *, int *));
|
X_API int STD_PROTO(YAP_OpInfo,(Atom, Term, int, int *, int *));
|
||||||
|
X_API Term STD_PROTO(YAP_AllocExternalDataInStack,(size_t));
|
||||||
static int (*do_putcf)(wchar_t);
|
X_API void *STD_PROTO(YAP_ExternalDataInStackFromTerm,(Term));
|
||||||
|
X_API int STD_PROTO(YAP_NewOpaqueType,(void *));
|
||||||
static int do_yap_putc(int streamno,wchar_t ch) {
|
X_API Term STD_PROTO(YAP_NewOpaqueObject,(int, size_t));
|
||||||
do_putcf(ch);
|
X_API void *STD_PROTO(YAP_OpaqueObjectFromTerm,(Term));
|
||||||
return(ch);
|
|
||||||
}
|
|
||||||
|
|
||||||
static int
|
static int
|
||||||
dogc(void)
|
dogc(void)
|
||||||
@ -1677,7 +1677,6 @@ YAP_ExecuteOnCut(PredEntry *pe, CPredicate exec_code, struct cut_c_str *top)
|
|||||||
if (pe->PredFlags & CArgsPredFlag) {
|
if (pe->PredFlags & CArgsPredFlag) {
|
||||||
val = execute_cargs_back(pe, exec_code, ctx PASS_REGS);
|
val = execute_cargs_back(pe, exec_code, ctx PASS_REGS);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr,"ctx=%p\n",ctx);
|
|
||||||
val = ((codev)(args-LCL0,0,ctx));
|
val = ((codev)(args-LCL0,0,ctx));
|
||||||
}
|
}
|
||||||
/* make sure we clean up the frames left by the user */
|
/* make sure we clean up the frames left by the user */
|
||||||
@ -2314,6 +2313,65 @@ YAP_RunGoal(Term t)
|
|||||||
return(out);
|
return(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
X_API Term
|
||||||
|
YAP_AllocExternalDataInStack(size_t bytes)
|
||||||
|
{
|
||||||
|
Term t = Yap_AllocExternalDataInStack(EXTERNAL_BLOB, bytes);
|
||||||
|
if (t == TermNil)
|
||||||
|
return 0L;
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
X_API Bool
|
||||||
|
YAP_IsExternalDataInStackTerm(Term t)
|
||||||
|
{
|
||||||
|
return IsExternalBlobTerm(t, EXTERNAL_BLOB);
|
||||||
|
}
|
||||||
|
|
||||||
|
X_API void *
|
||||||
|
YAP_ExternalDataInStackFromTerm(Term t)
|
||||||
|
{
|
||||||
|
return ExternalBlobFromTerm (t);
|
||||||
|
}
|
||||||
|
|
||||||
|
int YAP_NewOpaqueType(void *f)
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
if (!GLOBAL_OpaqueHandlers) {
|
||||||
|
GLOBAL_OpaqueHandlers = malloc(sizeof(opaque_handler_t)*(USER_BLOB_END-USER_BLOB_START));
|
||||||
|
if (!GLOBAL_OpaqueHandlers) {
|
||||||
|
/* no room */
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
} else if (GLOBAL_OpaqueHandlersCount == USER_BLOB_END-USER_BLOB_START) {
|
||||||
|
/* all types used */
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
i = GLOBAL_OpaqueHandlersCount++;
|
||||||
|
memcpy(GLOBAL_OpaqueHandlers+i,f,sizeof(opaque_handler_t));
|
||||||
|
return i+USER_BLOB_START;
|
||||||
|
}
|
||||||
|
|
||||||
|
Term YAP_NewOpaqueObject(int tag, size_t bytes)
|
||||||
|
{
|
||||||
|
Term t = Yap_AllocExternalDataInStack((CELL)tag, bytes);
|
||||||
|
if (t == TermNil)
|
||||||
|
return 0L;
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
X_API Bool
|
||||||
|
YAP_IsOpaqueObjectTerm(Term t, int tag)
|
||||||
|
{
|
||||||
|
return IsExternalBlobTerm(t, (CELL)tag);
|
||||||
|
}
|
||||||
|
|
||||||
|
X_API void *
|
||||||
|
YAP_OpaqueObjectFromTerm(Term t)
|
||||||
|
{
|
||||||
|
return ExternalBlobFromTerm (t);
|
||||||
|
}
|
||||||
|
|
||||||
X_API Term
|
X_API Term
|
||||||
YAP_RunGoalOnce(Term t)
|
YAP_RunGoalOnce(Term t)
|
||||||
{
|
{
|
||||||
@ -2369,7 +2427,6 @@ YAP_RestartGoal(void)
|
|||||||
BACKUP_MACHINE_REGS();
|
BACKUP_MACHINE_REGS();
|
||||||
if (LOCAL_AllowRestart) {
|
if (LOCAL_AllowRestart) {
|
||||||
P = (yamop *)FAILCODE;
|
P = (yamop *)FAILCODE;
|
||||||
do_putcf = myputc;
|
|
||||||
LOCAL_PrologMode = UserMode;
|
LOCAL_PrologMode = UserMode;
|
||||||
out = Yap_exec_absmi(TRUE);
|
out = Yap_exec_absmi(TRUE);
|
||||||
LOCAL_PrologMode = UserCCallMode;
|
LOCAL_PrologMode = UserCCallMode;
|
||||||
@ -2589,12 +2646,11 @@ YAP_Read(IOSTREAM *inp)
|
|||||||
}
|
}
|
||||||
|
|
||||||
X_API void
|
X_API void
|
||||||
YAP_Write(Term t, int (*myputc)(wchar_t), int flags)
|
YAP_Write(Term t, IOSTREAM *stream, int flags)
|
||||||
{
|
{
|
||||||
BACKUP_MACHINE_REGS();
|
BACKUP_MACHINE_REGS();
|
||||||
|
|
||||||
do_putcf = myputc; /* */
|
Yap_dowrite (t, stream, flags, 1200);
|
||||||
Yap_plwrite (t, do_yap_putc, flags, 1200);
|
|
||||||
|
|
||||||
RECOVER_MACHINE_REGS();
|
RECOVER_MACHINE_REGS();
|
||||||
}
|
}
|
||||||
@ -2774,6 +2830,7 @@ YAP_Init(YAP_init_args *yap_init)
|
|||||||
Yap_init_yapor_global_local_memory();
|
Yap_init_yapor_global_local_memory();
|
||||||
LOCAL = REMOTE(0);
|
LOCAL = REMOTE(0);
|
||||||
#endif /* YAPOR_COPY || YAPOR_COW || YAPOR_SBA */
|
#endif /* YAPOR_COPY || YAPOR_COW || YAPOR_SBA */
|
||||||
|
GLOBAL_PrologShouldHandleInterrupts = yap_init->PrologShouldHandleInterrupts;
|
||||||
Yap_InitSysbits(); /* init signal handling and time, required by later functions */
|
Yap_InitSysbits(); /* init signal handling and time, required by later functions */
|
||||||
GLOBAL_argv = yap_init->Argv;
|
GLOBAL_argv = yap_init->Argv;
|
||||||
GLOBAL_argc = yap_init->Argc;
|
GLOBAL_argc = yap_init->Argc;
|
||||||
@ -2821,7 +2878,6 @@ YAP_Init(YAP_init_args *yap_init)
|
|||||||
} else {
|
} else {
|
||||||
Heap = yap_init->HeapSize;
|
Heap = yap_init->HeapSize;
|
||||||
}
|
}
|
||||||
GLOBAL_PrologShouldHandleInterrupts = yap_init->PrologShouldHandleInterrupts;
|
|
||||||
Yap_InitWorkspace(Heap, Stack, Trail, Atts,
|
Yap_InitWorkspace(Heap, Stack, Trail, Atts,
|
||||||
yap_init->MaxTableSpaceSize,
|
yap_init->MaxTableSpaceSize,
|
||||||
yap_init->NumberWorkers,
|
yap_init->NumberWorkers,
|
||||||
|
@ -735,7 +735,7 @@ c_arg(Int argno, Term t, unsigned int arity, unsigned int level, compiler_struct
|
|||||||
} else if (IsPairTerm(t)) {
|
} else if (IsPairTerm(t)) {
|
||||||
cglobs->space_used += 2;
|
cglobs->space_used += 2;
|
||||||
if (optimizer_on && level < 6) {
|
if (optimizer_on && level < 6) {
|
||||||
#if !defined(THREADS)
|
#if !defined(THREADS) && !defined(YAPOR)
|
||||||
/* discard code sharing because we cannot write on shared stuff */
|
/* discard code sharing because we cannot write on shared stuff */
|
||||||
if (!(cglobs->cint.CurrentPred->PredFlags & (DynamicPredFlag|LogUpdatePredFlag))) {
|
if (!(cglobs->cint.CurrentPred->PredFlags & (DynamicPredFlag|LogUpdatePredFlag))) {
|
||||||
if (try_store_as_dbterm(t, argno, arity, level, cglobs))
|
if (try_store_as_dbterm(t, argno, arity, level, cglobs))
|
||||||
|
4
C/exec.c
4
C/exec.c
@ -961,7 +961,7 @@ exec_absmi(int top USES_REGS)
|
|||||||
restore_H();
|
restore_H();
|
||||||
/* set stack */
|
/* set stack */
|
||||||
ASP = (CELL *)PROTECT_FROZEN_B(B);
|
ASP = (CELL *)PROTECT_FROZEN_B(B);
|
||||||
Yap_StartSlots( PASS_REGS1 );
|
Yap_PopSlots();
|
||||||
LOCK(LOCAL_SignalLock);
|
LOCK(LOCAL_SignalLock);
|
||||||
/* forget any signals active, we're reborne */
|
/* forget any signals active, we're reborne */
|
||||||
LOCAL_ActiveSignals = 0;
|
LOCAL_ActiveSignals = 0;
|
||||||
@ -991,9 +991,9 @@ exec_absmi(int top USES_REGS)
|
|||||||
LOCAL_PrologMode = UserMode;
|
LOCAL_PrologMode = UserMode;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
Yap_CloseSlots( PASS_REGS1 );
|
||||||
LOCAL_PrologMode = UserMode;
|
LOCAL_PrologMode = UserMode;
|
||||||
}
|
}
|
||||||
Yap_CloseSlots( PASS_REGS1 );
|
|
||||||
YENV = ASP;
|
YENV = ASP;
|
||||||
YENV[E_CB] = Unsigned (B);
|
YENV[E_CB] = Unsigned (B);
|
||||||
out = Yap_absmi(0);
|
out = Yap_absmi(0);
|
||||||
|
4
C/grow.c
4
C/grow.c
@ -399,7 +399,7 @@ AdjustTrail(int adjusting_heap, int thread_copying USES_REGS)
|
|||||||
#if defined(YAPOR_THREADS)
|
#if defined(YAPOR_THREADS)
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
/* moving the trail is simple */
|
/* moving the trail is simple, yeaahhh! */
|
||||||
while (ptt != tr_base) {
|
while (ptt != tr_base) {
|
||||||
register CELL reg = TrailTerm(ptt-1);
|
register CELL reg = TrailTerm(ptt-1);
|
||||||
#ifdef FROZEN_STACKS
|
#ifdef FROZEN_STACKS
|
||||||
@ -420,8 +420,6 @@ AdjustTrail(int adjusting_heap, int thread_copying USES_REGS)
|
|||||||
} else if (IsPairTerm(reg)) {
|
} else if (IsPairTerm(reg)) {
|
||||||
TrailTerm(ptt) = AdjustPair(reg PASS_REGS);
|
TrailTerm(ptt) = AdjustPair(reg PASS_REGS);
|
||||||
#ifdef MULTI_ASSIGNMENT_VARIABLES /* does not work with new structures */
|
#ifdef MULTI_ASSIGNMENT_VARIABLES /* does not work with new structures */
|
||||||
/* check it whether we are protecting a
|
|
||||||
multi-assignment */
|
|
||||||
} else if (IsApplTerm(reg)) {
|
} else if (IsApplTerm(reg)) {
|
||||||
TrailTerm(ptt) = AdjustAppl(reg PASS_REGS);
|
TrailTerm(ptt) = AdjustAppl(reg PASS_REGS);
|
||||||
#endif
|
#endif
|
||||||
|
49
C/heapgc.c
49
C/heapgc.c
@ -1325,7 +1325,7 @@ mark_variable(CELL_PTR current USES_REGS)
|
|||||||
sz++;
|
sz++;
|
||||||
#if DEBUG
|
#if DEBUG
|
||||||
if (next[sz] != EndSpecials) {
|
if (next[sz] != EndSpecials) {
|
||||||
fprintf(stderr,"[ Error: could not find EndSpecials at blob %p type %lx ]\n", next, next[1]);
|
fprintf(stderr,"[ Error: could not find EndSpecials at blob %p type " UInt_FORMAT " ]\n", next, next[1]);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
MARK(next+sz);
|
MARK(next+sz);
|
||||||
@ -1658,14 +1658,23 @@ mark_trail(tr_fr_ptr trail_ptr, tr_fr_ptr trail_base, CELL *gc_H, choiceptr gc_B
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
} else if (IsPairTerm(trail_cell)) {
|
} else if (IsPairTerm(trail_cell)) {
|
||||||
/* can safely ignore this */
|
/* cannot safely ignore this */
|
||||||
CELL *cptr = RepPair(trail_cell);
|
CELL *cptr = RepPair(trail_cell);
|
||||||
if (IN_BETWEEN(LOCAL_GlobalBase,cptr,H) &&
|
if (IN_BETWEEN(LOCAL_GlobalBase,cptr,H)) {
|
||||||
GlobalIsAttVar(cptr)) {
|
if (GlobalIsAttVar(cptr)) {
|
||||||
TrailTerm(trail_base) = (CELL)cptr;
|
TrailTerm(trail_base) = (CELL)cptr;
|
||||||
mark_external_reference(&TrailTerm(trail_base) PASS_REGS);
|
mark_external_reference(&TrailTerm(trail_base) PASS_REGS);
|
||||||
TrailTerm(trail_base) = trail_cell;
|
TrailTerm(trail_base) = trail_cell;
|
||||||
}
|
} else if (*cptr == (CELL)FunctorBigInt) {
|
||||||
|
TrailTerm(trail_base) = AbsAppl(cptr);
|
||||||
|
mark_external_reference(&TrailTerm(trail_base) PASS_REGS);
|
||||||
|
TrailTerm(trail_base) = trail_cell;
|
||||||
|
}
|
||||||
|
#ifdef DEBUG
|
||||||
|
else
|
||||||
|
fprintf(GLOBAL_stderr,"OOPS in GC: weird trail entry at %p:" UInt_FORMAT "\n", &TrailTerm(trail_base), (CELL)cptr);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#if MULTI_ASSIGNMENT_VARIABLES
|
#if MULTI_ASSIGNMENT_VARIABLES
|
||||||
else {
|
else {
|
||||||
@ -1904,11 +1913,11 @@ mark_choicepoints(register choiceptr gc_B, tr_fr_ptr saved_TR, int very_verbose
|
|||||||
PredEntry *pe = Yap_PredForChoicePt(gc_B);
|
PredEntry *pe = Yap_PredForChoicePt(gc_B);
|
||||||
#if defined(ANALYST) || defined(DEBUG)
|
#if defined(ANALYST) || defined(DEBUG)
|
||||||
if (pe == NULL) {
|
if (pe == NULL) {
|
||||||
fprintf(GLOBAL_stderr,"%% marked %ld (%s)\n", LOCAL_total_marked, Yap_op_names[opnum]);
|
fprintf(GLOBAL_stderr,"%% marked " UInt_FORMAT " (%s)\n", LOCAL_total_marked, Yap_op_names[opnum]);
|
||||||
} else if (pe->ArityOfPE) {
|
} else if (pe->ArityOfPE) {
|
||||||
fprintf(GLOBAL_stderr,"%% %s/%d marked %ld (%s)\n", RepAtom(NameOfFunctor(pe->FunctorOfPred))->StrOfAE, pe->ArityOfPE, LOCAL_total_marked, Yap_op_names[opnum]);
|
fprintf(GLOBAL_stderr,"%% %s/%d marked " UInt_FORMAT " (%s)\n", RepAtom(NameOfFunctor(pe->FunctorOfPred))->StrOfAE, pe->ArityOfPE, LOCAL_total_marked, Yap_op_names[opnum]);
|
||||||
} else {
|
} else {
|
||||||
fprintf(GLOBAL_stderr,"%% %s marked %ld (%s)\n", RepAtom((Atom)(pe->FunctorOfPred))->StrOfAE, LOCAL_total_marked, Yap_op_names[opnum]);
|
fprintf(GLOBAL_stderr,"%% %s marked " UInt_FORMAT " (%s)\n", RepAtom((Atom)(pe->FunctorOfPred))->StrOfAE, LOCAL_total_marked, Yap_op_names[opnum]);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
if (pe == NULL) {
|
if (pe == NULL) {
|
||||||
@ -2450,11 +2459,19 @@ sweep_trail(choiceptr gc_B, tr_fr_ptr old_TR USES_REGS)
|
|||||||
CELL *pt0 = RepPair(trail_cell);
|
CELL *pt0 = RepPair(trail_cell);
|
||||||
CELL flags;
|
CELL flags;
|
||||||
|
|
||||||
if (IN_BETWEEN(LOCAL_GlobalBase, pt0, H) && GlobalIsAttVar(pt0)) {
|
if (IN_BETWEEN(LOCAL_GlobalBase, pt0, H)) {
|
||||||
TrailTerm(dest) = trail_cell;
|
if (GlobalIsAttVar(pt0)) {
|
||||||
/* be careful with partial gc */
|
TrailTerm(dest) = trail_cell;
|
||||||
if (HEAP_PTR(TrailTerm(dest))) {
|
/* be careful with partial gc */
|
||||||
into_relocation_chain(&TrailTerm(dest), GET_NEXT(trail_cell) PASS_REGS);
|
if (HEAP_PTR(TrailTerm(dest))) {
|
||||||
|
into_relocation_chain(&TrailTerm(dest), GET_NEXT(trail_cell) PASS_REGS);
|
||||||
|
}
|
||||||
|
} else if (*pt0 == (CELL)FunctorBigInt) {
|
||||||
|
TrailTerm(dest) = trail_cell;
|
||||||
|
/* be careful with partial gc */
|
||||||
|
if (HEAP_PTR(TrailTerm(dest))) {
|
||||||
|
into_relocation_chain(&TrailTerm(dest), GET_NEXT(trail_cell) PASS_REGS);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
dest++;
|
dest++;
|
||||||
trail_ptr++;
|
trail_ptr++;
|
||||||
|
23
H/TermExt.h
23
H/TermExt.h
@ -86,7 +86,9 @@ typedef enum
|
|||||||
CLAUSE_LIST = 0x40,
|
CLAUSE_LIST = 0x40,
|
||||||
BLOB_STRING = 0x80, /* SWI style strings */
|
BLOB_STRING = 0x80, /* SWI style strings */
|
||||||
BLOB_WIDE_STRING = 0x81, /* SWI style strings */
|
BLOB_WIDE_STRING = 0x81, /* SWI style strings */
|
||||||
EXTERNAL_BLOB = 0x100 /* for SWI emulation */
|
EXTERNAL_BLOB = 0x100, /* generic data */
|
||||||
|
USER_BLOB_START = 0x1000, /* user defined blob */
|
||||||
|
USER_BLOB_END = 0x1100 /* end of user defined blob */
|
||||||
}
|
}
|
||||||
big_blob_type;
|
big_blob_type;
|
||||||
|
|
||||||
@ -438,6 +440,25 @@ IsLargeNumTerm (Term t)
|
|||||||
&& (FunctorOfTerm (t) >= FunctorLongInt)));
|
&& (FunctorOfTerm (t) >= FunctorLongInt)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline EXTERN int IsExternalBlobTerm (Term, CELL);
|
||||||
|
|
||||||
|
inline EXTERN int
|
||||||
|
IsExternalBlobTerm (Term t, CELL tag)
|
||||||
|
{
|
||||||
|
return (int) (IsApplTerm (t) &&
|
||||||
|
FunctorOfTerm (t) == FunctorBigInt &&
|
||||||
|
RepAppl(t)[1] == tag);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline EXTERN void *ExternalBlobFromTerm (Term);
|
||||||
|
|
||||||
|
inline EXTERN void *
|
||||||
|
ExternalBlobFromTerm (Term t)
|
||||||
|
{
|
||||||
|
MP_INT *base = (MP_INT *)(RepAppl(t)+2);
|
||||||
|
return (void *) (base+1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,6 +33,12 @@ typedef int (*SWI_PLGetStreamPositionFunction)(void *);
|
|||||||
|
|
||||||
#include "../include/dswiatoms.h"
|
#include "../include/dswiatoms.h"
|
||||||
|
|
||||||
|
typedef int (*Opaque_CallOnFail)(void *);
|
||||||
|
|
||||||
|
typedef struct opaque_handler_struct {
|
||||||
|
Opaque_CallOnFail fail_handler;
|
||||||
|
} opaque_handler_t;
|
||||||
|
|
||||||
#ifndef INT_KEYS_DEFAULT_SIZE
|
#ifndef INT_KEYS_DEFAULT_SIZE
|
||||||
#define INT_KEYS_DEFAULT_SIZE 256
|
#define INT_KEYS_DEFAULT_SIZE 256
|
||||||
#endif
|
#endif
|
||||||
|
@ -120,6 +120,8 @@ int STD_PROTO(Yap_IsStringTerm, (Term));
|
|||||||
int STD_PROTO(Yap_IsWideStringTerm, (Term));
|
int STD_PROTO(Yap_IsWideStringTerm, (Term));
|
||||||
Term STD_PROTO(Yap_RatTermToApplTerm, (Term));
|
Term STD_PROTO(Yap_RatTermToApplTerm, (Term));
|
||||||
void STD_PROTO(Yap_InitBigNums, (void));
|
void STD_PROTO(Yap_InitBigNums, (void));
|
||||||
|
Term STD_PROTO(Yap_AllocExternalDataInStack, (CELL, size_t));
|
||||||
|
int STD_PROTO(Yap_CleanOpaqueVariable, (CELL *));
|
||||||
|
|
||||||
/* c_interface.c */
|
/* c_interface.c */
|
||||||
Int STD_PROTO(YAP_Execute,(struct pred_entry *, CPredicate));
|
Int STD_PROTO(YAP_Execute,(struct pred_entry *, CPredicate));
|
||||||
|
@ -102,6 +102,8 @@
|
|||||||
|
|
||||||
#define GLOBAL_Executable Yap_global->Executable_
|
#define GLOBAL_Executable Yap_global->Executable_
|
||||||
#endif
|
#endif
|
||||||
|
#define GLOBAL_OpaqueHandlersCount Yap_global->OpaqueHandlersCount_
|
||||||
|
#define GLOBAL_OpaqueHandlers Yap_global->OpaqueHandlers_
|
||||||
#if __simplescalar__
|
#if __simplescalar__
|
||||||
#define GLOBAL_pwd Yap_global->pwd_
|
#define GLOBAL_pwd Yap_global->pwd_
|
||||||
#endif
|
#endif
|
||||||
|
@ -102,6 +102,8 @@ typedef struct global_data {
|
|||||||
|
|
||||||
char Executable_[YAP_FILENAME_MAX];
|
char Executable_[YAP_FILENAME_MAX];
|
||||||
#endif
|
#endif
|
||||||
|
int OpaqueHandlersCount_;
|
||||||
|
struct opaque_handler_struct* OpaqueHandlers_;
|
||||||
#if __simplescalar__
|
#if __simplescalar__
|
||||||
char pwd_[YAP_FILENAME_MAX];
|
char pwd_[YAP_FILENAME_MAX];
|
||||||
#endif
|
#endif
|
||||||
|
@ -102,6 +102,8 @@ static void InitGlobal(void) {
|
|||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
GLOBAL_OpaqueHandlersCount = 0;
|
||||||
|
GLOBAL_OpaqueHandlers = NULL;
|
||||||
#if __simplescalar__
|
#if __simplescalar__
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -102,6 +102,8 @@ static void RestoreGlobal(void) {
|
|||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#if __simplescalar__
|
#if __simplescalar__
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -67,6 +67,7 @@
|
|||||||
#define YP_FILE FILE
|
#define YP_FILE FILE
|
||||||
|
|
||||||
int STD_PROTO(YP_putc,(int, int));
|
int STD_PROTO(YP_putc,(int, int));
|
||||||
|
void STD_PROTO(Yap_dowrite, (Term, IOSTREAM *, int, int));
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
@ -96,8 +97,6 @@ int STD_PROTO(YP_putc,(int, int));
|
|||||||
#define fclose ERR_fclose
|
#define fclose ERR_fclose
|
||||||
#define fflush ERR_fflush
|
#define fflush ERR_fflush
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/* flags for files in IOSTREAM struct */
|
/* flags for files in IOSTREAM struct */
|
||||||
#define _YP_IO_WRITE 1
|
#define _YP_IO_WRITE 1
|
||||||
#define _YP_IO_READ 2
|
#define _YP_IO_READ 2
|
||||||
|
@ -154,18 +154,6 @@ static int SQLBINDCOL(SQLHSTMT sthandle,
|
|||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int SQLFREESTMT(SQLHSTMT sthandle,
|
|
||||||
SQLUSMALLINT opt,
|
|
||||||
char * print)
|
|
||||||
{
|
|
||||||
SQLRETURN retcode;
|
|
||||||
retcode = SQLFreeStmt(sthandle,opt);
|
|
||||||
|
|
||||||
if (retcode != SQL_SUCCESS && retcode != SQL_SUCCESS_WITH_INFO)
|
|
||||||
return odbc_error(SQL_HANDLE_STMT, sthandle, "SQLFreeStmt", print);
|
|
||||||
return TRUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int SQLNUMRESULTCOLS(SQLHSTMT sthandle,
|
static int SQLNUMRESULTCOLS(SQLHSTMT sthandle,
|
||||||
SQLSMALLINT * ncols,
|
SQLSMALLINT * ncols,
|
||||||
char * print)
|
char * print)
|
||||||
@ -421,8 +409,9 @@ c_db_odbc_query( USES_REGS1 ) {
|
|||||||
/* +1 because of '\0' */
|
/* +1 because of '\0' */
|
||||||
bind_space = malloc(sizeof(char)*(ColumnSizePtr+1));
|
bind_space = malloc(sizeof(char)*(ColumnSizePtr+1));
|
||||||
data_info = malloc(sizeof(SQLINTEGER));
|
data_info = malloc(sizeof(SQLINTEGER));
|
||||||
if (!SQLBINDCOL(hstmt,i,SQL_C_CHAR,bind_space,(ColumnSizePtr+1),data_info,"db_query"))
|
if (!SQLBINDCOL(hstmt,i,SQL_C_CHAR,bind_space,(ColumnSizePtr+1),data_info,"db_query")) {
|
||||||
return FALSE;
|
return FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
properties[0] = MkIntegerTerm((Int)bind_space);
|
properties[0] = MkIntegerTerm((Int)bind_space);
|
||||||
properties[2] = MkIntegerTerm((Int)data_info);
|
properties[2] = MkIntegerTerm((Int)data_info);
|
||||||
@ -444,7 +433,7 @@ c_db_odbc_query( USES_REGS1 ) {
|
|||||||
{
|
{
|
||||||
if (!SQLCLOSECURSOR(hstmt,"db_query"))
|
if (!SQLCLOSECURSOR(hstmt,"db_query"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_query"))
|
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_query"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
return FALSE;
|
return FALSE;
|
||||||
}
|
}
|
||||||
@ -466,7 +455,7 @@ c_db_odbc_number_of_fields( USES_REGS1 ) {
|
|||||||
char sql[256];
|
char sql[256];
|
||||||
SQLSMALLINT number_fields;
|
SQLSMALLINT number_fields;
|
||||||
|
|
||||||
sprintf(sql,"DESCRIBE %s",relation);
|
sprintf(sql,"SELECT column_name from INFORMATION_SCHEMA.COLUMNS where table_name = \'%s\'",relation);
|
||||||
|
|
||||||
if (!SQLALLOCHANDLE(SQL_HANDLE_STMT, hdbc, &hstmt, "db_number_of_fields"))
|
if (!SQLALLOCHANDLE(SQL_HANDLE_STMT, hdbc, &hstmt, "db_number_of_fields"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
@ -482,7 +471,7 @@ c_db_odbc_number_of_fields( USES_REGS1 ) {
|
|||||||
|
|
||||||
if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields"))
|
if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_number_of_fields"))
|
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_number_of_fields"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
|
|
||||||
if (!Yap_unify(arg_fields, MkIntegerTerm(number_fields)))
|
if (!Yap_unify(arg_fields, MkIntegerTerm(number_fields)))
|
||||||
@ -506,7 +495,7 @@ c_db_odbc_get_attributes_types( USES_REGS1 ) {
|
|||||||
Term head, list;
|
Term head, list;
|
||||||
list = arg_types_list;
|
list = arg_types_list;
|
||||||
|
|
||||||
sprintf(sql,"DESCRIBE %s",relation);
|
sprintf(sql,"SELECT column_name,data_type FROM INFORMATION_SCHEMA.COLUMNS where table_name = \'%s\'",relation);
|
||||||
|
|
||||||
if (!SQLALLOCHANDLE(SQL_HANDLE_STMT, hdbc, &hstmt, "db_get_attributes_types"))
|
if (!SQLALLOCHANDLE(SQL_HANDLE_STMT, hdbc, &hstmt, "db_get_attributes_types"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
@ -547,7 +536,7 @@ c_db_odbc_get_attributes_types( USES_REGS1 ) {
|
|||||||
|
|
||||||
if (!SQLCLOSECURSOR(hstmt,"db_get_attributes_types"))
|
if (!SQLCLOSECURSOR(hstmt,"db_get_attributes_types"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE, "db_get_attributes_types"))
|
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_get_attributes_types"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
@ -585,12 +574,31 @@ c_db_odbc_row_cut( USES_REGS1 ) {
|
|||||||
|
|
||||||
if (!SQLCLOSECURSOR(hstmt,"db_row_cut"))
|
if (!SQLCLOSECURSOR(hstmt,"db_row_cut"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_row_cut"))
|
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_row_cut"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
|
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int
|
||||||
|
release_list_args(Term arg_list_args, Term arg_bind_list, const char *error_msg)
|
||||||
|
{
|
||||||
|
Term list = arg_list_args;
|
||||||
|
Term list_bind = arg_bind_list;
|
||||||
|
|
||||||
|
while (IsPairTerm(list_bind))
|
||||||
|
{
|
||||||
|
Term head_bind = HeadOfTerm(list_bind);
|
||||||
|
|
||||||
|
list = TailOfTerm(list);
|
||||||
|
list_bind = TailOfTerm(list_bind);
|
||||||
|
|
||||||
|
free((char *)IntegerOfTerm(ArgOfTerm(1,head_bind)));
|
||||||
|
free((SQLINTEGER *)IntegerOfTerm(ArgOfTerm(3,head_bind)));
|
||||||
|
}
|
||||||
|
return TRUE;
|
||||||
|
}
|
||||||
|
|
||||||
/* db_row: ResultSet x BindList x ListOfArgs -> */
|
/* db_row: ResultSet x BindList x ListOfArgs -> */
|
||||||
static Int
|
static Int
|
||||||
c_db_odbc_row( USES_REGS1 ) {
|
c_db_odbc_row( USES_REGS1 ) {
|
||||||
@ -611,9 +619,12 @@ c_db_odbc_row( USES_REGS1 ) {
|
|||||||
{
|
{
|
||||||
if (!SQLCLOSECURSOR(hstmt,"db_row"))
|
if (!SQLCLOSECURSOR(hstmt,"db_row"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_row"))
|
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_row"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
|
if (!release_list_args(arg_list_args, arg_bind_list, "db_row")) {
|
||||||
|
return FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
cut_fail();
|
cut_fail();
|
||||||
return FALSE;
|
return FALSE;
|
||||||
}
|
}
|
||||||
@ -699,7 +710,7 @@ c_db_odbc_number_of_fields_in_query( USES_REGS1 ) {
|
|||||||
if (!Yap_unify(arg_fields, MkIntegerTerm(number_cols))){
|
if (!Yap_unify(arg_fields, MkIntegerTerm(number_cols))){
|
||||||
if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields_in_query"))
|
if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields_in_query"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE, "db_number_of_fields_in_query"))
|
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_number_of_fields_in_query"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
|
|
||||||
return FALSE;
|
return FALSE;
|
||||||
@ -707,7 +718,7 @@ c_db_odbc_number_of_fields_in_query( USES_REGS1 ) {
|
|||||||
|
|
||||||
if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields_in_query"))
|
if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields_in_query"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE, "db_number_of_fields_in_query"))
|
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_number_of_fields_in_query"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
|
|
||||||
return TRUE;
|
return TRUE;
|
||||||
@ -775,7 +786,7 @@ c_db_odbc_get_fields_properties( USES_REGS1 ) {
|
|||||||
|
|
||||||
if (!SQLCLOSECURSOR(hstmt2,"db_get_fields_properties"))
|
if (!SQLCLOSECURSOR(hstmt2,"db_get_fields_properties"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
if (!SQLFREESTMT(hstmt2,SQL_CLOSE,"db_get_fields_properties"))
|
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt2, "db_get_fields_properties"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
|
|
||||||
for (i=1;i<=num_fields;i++)
|
for (i=1;i<=num_fields;i++)
|
||||||
@ -816,9 +827,8 @@ c_db_odbc_get_fields_properties( USES_REGS1 ) {
|
|||||||
|
|
||||||
if (!SQLCLOSECURSOR(hstmt,"db_get_fields_properties"))
|
if (!SQLCLOSECURSOR(hstmt,"db_get_fields_properties"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_get_fields_properties"))
|
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt2, "db_get_fields_properties"))
|
||||||
return FALSE;
|
return FALSE;
|
||||||
|
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
7
configure
vendored
7
configure
vendored
@ -9927,6 +9927,9 @@ mkdir -p packages/clib/maildrop
|
|||||||
mkdir -p packages/clib/maildrop/rfc822
|
mkdir -p packages/clib/maildrop/rfc822
|
||||||
mkdir -p packages/clib/maildrop/rfc2045
|
mkdir -p packages/clib/maildrop/rfc2045
|
||||||
mkdir -p packages/CLPBN
|
mkdir -p packages/CLPBN
|
||||||
|
mkdir -p packages/CLPBN/clpbn
|
||||||
|
mkdir -p packages/CLPBN/clpbn/bp
|
||||||
|
mkdir -p packages/CLPBN/clpbn/bp/xmlParser
|
||||||
mkdir -p packages/clpqr
|
mkdir -p packages/clpqr
|
||||||
mkdir -p packages/cplint
|
mkdir -p packages/cplint
|
||||||
mkdir -p packages/cplint/approx
|
mkdir -p packages/cplint/approx
|
||||||
@ -10619,8 +10622,8 @@ esac
|
|||||||
|
|
||||||
cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1
|
cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1
|
||||||
# Files that config.status was made for.
|
# Files that config.status was made for.
|
||||||
config_files="`echo $ac_config_files`"
|
config_files="$ac_config_files"
|
||||||
config_headers="`echo $ac_config_headers`"
|
config_headers="$ac_config_headers"
|
||||||
|
|
||||||
_ACEOF
|
_ACEOF
|
||||||
|
|
||||||
|
@ -2103,6 +2103,9 @@ mkdir -p packages/clib/maildrop
|
|||||||
mkdir -p packages/clib/maildrop/rfc822
|
mkdir -p packages/clib/maildrop/rfc822
|
||||||
mkdir -p packages/clib/maildrop/rfc2045
|
mkdir -p packages/clib/maildrop/rfc2045
|
||||||
mkdir -p packages/CLPBN
|
mkdir -p packages/CLPBN
|
||||||
|
mkdir -p packages/CLPBN/clpbn
|
||||||
|
mkdir -p packages/CLPBN/clpbn/bp
|
||||||
|
mkdir -p packages/CLPBN/clpbn/bp/xmlParser
|
||||||
mkdir -p packages/clpqr
|
mkdir -p packages/clpqr
|
||||||
mkdir -p packages/cplint
|
mkdir -p packages/cplint
|
||||||
mkdir -p packages/cplint/approx
|
mkdir -p packages/cplint/approx
|
||||||
|
80
docs/yap.tex
80
docs/yap.tex
@ -16811,13 +16811,29 @@ only two boolean flags are accepted: @code{YAPC_ENABLE_GC} and
|
|||||||
@code{YAPC_ENABLE_AGC}. The first enables/disables the standard garbage
|
@code{YAPC_ENABLE_AGC}. The first enables/disables the standard garbage
|
||||||
collector, the second does the same for the atom garbage collector.`
|
collector, the second does the same for the atom garbage collector.`
|
||||||
|
|
||||||
|
@item @code{YAP_TERM} YAP_AllocExternalDataInStack(@code{size_t bytes})
|
||||||
|
@item @code{void *} YAP_ExternalDataInStackFromTerm(@code{YAP_Term t})
|
||||||
|
@item @code{YAP_Bool} YAP_IsExternalDataInStackTerm(@code{YAP_Term t})
|
||||||
|
@findex YAP_AllocExternalDataInStack (C-Interface function)
|
||||||
|
|
||||||
|
The next routines allow one to store external data in the Prolog
|
||||||
|
execution stack. The first routine reserves space for @var{sz} bytes
|
||||||
|
and returns an opaque handle. The second routines receives the handle
|
||||||
|
and returns a pointer to the data. The last routine checks if a term
|
||||||
|
is an opaque handle.
|
||||||
|
|
||||||
|
Data will be automatically reclaimed during
|
||||||
|
backtracking. Also, this storage is opaque to the Prolog garbage compiler,
|
||||||
|
so it should not be used to store Prolog terms. On the other hand, it
|
||||||
|
may be useful to store arrays in a compact way, or pointers to external objects.
|
||||||
|
|
||||||
@item @code{int} YAP_HaltRegisterHook(@code{YAP_halt_hook f, void *closure})
|
@item @code{int} YAP_HaltRegisterHook(@code{YAP_halt_hook f, void *closure})
|
||||||
@findex YAP_HaltRegisterHook (C-Interface function)
|
@findex YAP_HaltRegisterHook (C-Interface function)
|
||||||
|
|
||||||
Register the function @var{f} to be called if YAP is halted. The
|
Register the function @var{f} to be called if YAP is halted. The
|
||||||
function is called with two arguments: the exit code of the process (@code{0}
|
function is called with two arguments: the exit code of the process
|
||||||
if this cannot be determined on your operating system) and the closure
|
(@code{0} if this cannot be determined on your operating system) and
|
||||||
argument @var{closure}.
|
the closure argument @var{closure}.
|
||||||
@c See also @code{at_halt/1}.
|
@c See also @code{at_halt/1}.
|
||||||
|
|
||||||
@end table
|
@end table
|
||||||
@ -16850,6 +16866,7 @@ implementing the predicate and @var{arity} is its arity.
|
|||||||
@findex YAP_UserBackCutCPredicate (C-Interface function)
|
@findex YAP_UserBackCutCPredicate (C-Interface function)
|
||||||
@findex YAP_PRESERVE_DATA (C-Interface function)
|
@findex YAP_PRESERVE_DATA (C-Interface function)
|
||||||
@findex YAP_PRESERVED_DATA (C-Interface function)
|
@findex YAP_PRESERVED_DATA (C-Interface function)
|
||||||
|
@findex YAP_PRESERVED_DATA_CUT (C-Interface function)
|
||||||
@findex YAP_cutsucceed (C-Interface function)
|
@findex YAP_cutsucceed (C-Interface function)
|
||||||
@findex YAP_cutfail (C-Interface function)
|
@findex YAP_cutfail (C-Interface function)
|
||||||
For the second kind of predicates we need three C functions. The first one
|
For the second kind of predicates we need three C functions. The first one
|
||||||
@ -16915,11 +16932,15 @@ static int start_n100(void)
|
|||||||
@end example
|
@end example
|
||||||
|
|
||||||
The routine starts by getting the dereference value of the argument.
|
The routine starts by getting the dereference value of the argument.
|
||||||
The call to @code{YAP_PRESERVE_DATA} is used to initialize the memory which will
|
The call to @code{YAP_PRESERVE_DATA} is used to initialize the memory
|
||||||
hold the information to be preserved across backtracking. The first
|
which will hold the information to be preserved across
|
||||||
argument is the variable we shall use, and the second its type. Note
|
backtracking. The first argument is the variable we shall use, and the
|
||||||
that we can only use @code{YAP_PRESERVE_DATA} once, so often we will
|
second its type. Note that we can only use @code{YAP_PRESERVE_DATA}
|
||||||
want the variable to be a structure.
|
once, so often we will want the variable to be a structure. This data
|
||||||
|
is visible to the garbage collector, so it should consist of Prolog
|
||||||
|
terms, as in the example. It is also correct to store pointers to
|
||||||
|
objects external to YAP stacks, as the garbage collector will ignore
|
||||||
|
such references.
|
||||||
|
|
||||||
If the argument of the predicate is a variable, the routine initializes the
|
If the argument of the predicate is a variable, the routine initializes the
|
||||||
structure to be preserved across backtracking with the information
|
structure to be preserved across backtracking with the information
|
||||||
@ -16988,6 +17009,34 @@ when pruning the execution of the predicate, @var{arity} is the
|
|||||||
predicate arity, and @var{sizeof} is the size of the data to be
|
predicate arity, and @var{sizeof} is the size of the data to be
|
||||||
preserved in the stack. In this example, we would have something like
|
preserved in the stack. In this example, we would have something like
|
||||||
|
|
||||||
|
@example
|
||||||
|
void
|
||||||
|
init_n100(void)
|
||||||
|
@{
|
||||||
|
YAP_UserBackCutCPredicate("n100", start_n100, continue_n100, cut_n100, 1, 1);
|
||||||
|
@}
|
||||||
|
@end example
|
||||||
|
The argument before last is the predicate's arity. Notice again the
|
||||||
|
last argument to the call. function argument gives the extra space we
|
||||||
|
want to use for @code{PRESERVED_DATA}. Space is given in cells, where
|
||||||
|
a cell is the same size as a pointer. The garbage collector has access
|
||||||
|
to this space, hence users should use it either to store terms or to
|
||||||
|
store pointers to objects outside the stacks.
|
||||||
|
|
||||||
|
The code for @code{cut_n100} could be:
|
||||||
|
@example
|
||||||
|
static int cut_n100(void)
|
||||||
|
@{
|
||||||
|
YAP_PRESERVED_DATA_CUT(n100_data,n100_data_type*);
|
||||||
|
|
||||||
|
fprintf("n100 cut with counter %ld\n", YAP_IntOfTerm(n100_data->next_solution));
|
||||||
|
return TRUE;
|
||||||
|
@}
|
||||||
|
@end example
|
||||||
|
Notice that we have to use @code{YAP_PRESERVED_DATA_CUT}: this is
|
||||||
|
because the Prolog engine is at a different state during cut.
|
||||||
|
|
||||||
|
If no work is required at cut, we can use:
|
||||||
@example
|
@example
|
||||||
void
|
void
|
||||||
init_n100(void)
|
init_n100(void)
|
||||||
@ -16995,8 +17044,7 @@ init_n100(void)
|
|||||||
YAP_UserBackCutCPredicate("n100", start_n100, continue_n100, NULL, 1, 1);
|
YAP_UserBackCutCPredicate("n100", start_n100, continue_n100, NULL, 1, 1);
|
||||||
@}
|
@}
|
||||||
@end example
|
@end example
|
||||||
Notice that we do not actually need to do anything on receiving a cut in
|
in this case no code is executed at cut time.
|
||||||
this case.
|
|
||||||
|
|
||||||
@node Loading Objects, Save&Rest, Writing C, C-Interface
|
@node Loading Objects, Save&Rest, Writing C, C-Interface
|
||||||
@section Loading Object Files
|
@section Loading Object Files
|
||||||
@ -17272,9 +17320,9 @@ Associate the term @var{value} with the atom @var{at}. The term
|
|||||||
@var{value} must be a constant. This functionality is used by YAP as a
|
@var{value} must be a constant. This functionality is used by YAP as a
|
||||||
simple way for controlling and communicating with the Prolog run-time.
|
simple way for controlling and communicating with the Prolog run-time.
|
||||||
|
|
||||||
@item @code{YAP_Term} YAP_Read(@code{int (*)(void)} @var{GetC})
|
@item @code{YAP_Term} YAP_Read(@code{IOSTREAM *Stream})
|
||||||
@findex YAP_Read/1
|
@findex YAP_Read/1
|
||||||
Parse a Term using the function @var{GetC} to input characters.
|
Parse a @var{Term} from the stream @var{Stream}.
|
||||||
|
|
||||||
@item @code{YAP_Term} YAP_Write(@code{YAP_Term} @var{t})
|
@item @code{YAP_Term} YAP_Write(@code{YAP_Term} @var{t})
|
||||||
@findex YAP_CopyTerm/1
|
@findex YAP_CopyTerm/1
|
||||||
@ -17282,13 +17330,13 @@ Copy a Term @var{t} and all associated constraints. May call the garbage
|
|||||||
collector and returns @code{0L} on error (such as no space being
|
collector and returns @code{0L} on error (such as no space being
|
||||||
available).
|
available).
|
||||||
|
|
||||||
@item @code{void} YAP_Write(@code{YAP_Term} @var{t}, @code{void (*)(int)}
|
@item @code{void} YAP_Write(@code{YAP_Term} @var{t}, @code{IOSTREAM}
|
||||||
@var{PutC}, @code{int} @var{flags})
|
@var{stream}, @code{int} @var{flags})
|
||||||
@findex YAP_Write/3
|
@findex YAP_Write/3
|
||||||
Write a Term @var{t} using the function @var{PutC} to output
|
Write a Term @var{t} using the stream @var{stream} to output
|
||||||
characters. The term is written according to a mask of the following
|
characters. The term is written according to a mask of the following
|
||||||
flags in the @code{flag} argument: @code{YAP_WRITE_QUOTED},
|
flags in the @code{flag} argument: @code{YAP_WRITE_QUOTED},
|
||||||
@code{YAP_WRITE_HANDLE_VARS}, and @code{YAP_WRITE_IGNORE_OPS}.
|
@code{YAP_WRITE_HANDLE_VARS}, @code{YAP_WRITE_USE_PORTRAY}, and @code{YAP_WRITE_IGNORE_OPS}.
|
||||||
|
|
||||||
@item @code{void} YAP_WriteBuffer(@code{YAP_Term} @var{t}, @code{char *}
|
@item @code{void} YAP_WriteBuffer(@code{YAP_Term} @var{t}, @code{char *}
|
||||||
@var{buff}, @code{unsigned int}
|
@var{buff}, @code{unsigned int}
|
||||||
|
@ -237,7 +237,7 @@ extern X_API void PROTO(YAP_UserBackCPredicate,(CONST char *, YAP_Bool (*)(void)
|
|||||||
|
|
||||||
/* void UserBackCPredicate(char *name, int *init(), int *cont(), int *cut(), int
|
/* void UserBackCPredicate(char *name, int *init(), int *cont(), int *cut(), int
|
||||||
arity, int extra) */
|
arity, int extra) */
|
||||||
extern X_API void PROTO(YAP_UserBackCutCPredicate,(char *, YAP_Bool (*)(void), YAP_Bool (*)(void), YAP_Bool (*)(void), YAP_Arity, unsigned int));
|
extern X_API void PROTO(YAP_UserBackCutCPredicate,(CONST char *, YAP_Bool (*)(void), YAP_Bool (*)(void), YAP_Bool (*)(void), YAP_Arity, unsigned int));
|
||||||
|
|
||||||
/* void CallProlog(YAP_Term t) */
|
/* void CallProlog(YAP_Term t) */
|
||||||
extern X_API YAP_Bool PROTO(YAP_CallProlog,(YAP_Term t));
|
extern X_API YAP_Bool PROTO(YAP_CallProlog,(YAP_Term t));
|
||||||
@ -245,9 +245,9 @@ extern X_API YAP_Bool PROTO(YAP_CallProlog,(YAP_Term t));
|
|||||||
/* void cut_fail(void) */
|
/* void cut_fail(void) */
|
||||||
extern X_API void PROTO(YAP_cut_up,(void));
|
extern X_API void PROTO(YAP_cut_up,(void));
|
||||||
|
|
||||||
#define YAP_cut_succeed() { YAP_cut_up(); return TRUE; }
|
#define YAP_cut_succeed() do { YAP_cut_up(); return TRUE; } while(0)
|
||||||
|
|
||||||
#define YAP_cut_fail() { YAP_cut_up(); return FALSE; }
|
#define YAP_cut_fail() do { YAP_cut_up(); return FALSE; } while(0)
|
||||||
|
|
||||||
/* void *AllocSpaceFromYAP_(int) */
|
/* void *AllocSpaceFromYAP_(int) */
|
||||||
extern X_API void *PROTO(YAP_AllocSpaceFromYap,(unsigned int));
|
extern X_API void *PROTO(YAP_AllocSpaceFromYap,(unsigned int));
|
||||||
@ -555,6 +555,17 @@ extern X_API int PROTO(YAP_MaxOpPriority,(YAP_Atom, YAP_Term));
|
|||||||
/* int YAP_OpInfo(Atom, Term, int, int *, int *) */
|
/* int YAP_OpInfo(Atom, Term, int, int *, int *) */
|
||||||
extern X_API int PROTO(YAP_OpInfo,(YAP_Atom, YAP_Term, int, int *, int *));
|
extern X_API int PROTO(YAP_OpInfo,(YAP_Atom, YAP_Term, int, int *, int *));
|
||||||
|
|
||||||
|
/* YAP_Bool YAP_IsExternalDataInStackTerm(YAP_Term) */
|
||||||
|
extern X_API YAP_Bool PROTO(YAP_IsExternalDataInStackTerm,(YAP_Term));
|
||||||
|
|
||||||
|
extern X_API YAP_opaque_tag_t PROTO(YAP_NewOpaqueType,(struct YAP_opaque_handler_struct *));
|
||||||
|
|
||||||
|
extern X_API YAP_Bool PROTO(YAP_IsOpaqueObjectTerm,(YAP_Term, YAP_opaque_tag_t));
|
||||||
|
|
||||||
|
extern X_API YAP_Term PROTO(YAP_NewOpaqueObject,(YAP_opaque_tag_t, size_t));
|
||||||
|
|
||||||
|
extern X_API void *PROTO(YAP_OpaqueObjectFromTerm,(YAP_Term));
|
||||||
|
|
||||||
#define YAP_InitCPred(N,A,F) YAP_UserCPredicate(N,F,A)
|
#define YAP_InitCPred(N,A,F) YAP_UserCPredicate(N,F,A)
|
||||||
|
|
||||||
__END_DECLS
|
__END_DECLS
|
||||||
|
@ -101,9 +101,10 @@ typedef double YAP_Float;
|
|||||||
#define YAP_FULL_BOOT_FROM_PROLOG 4
|
#define YAP_FULL_BOOT_FROM_PROLOG 4
|
||||||
#define YAP_BOOT_ERROR -1
|
#define YAP_BOOT_ERROR -1
|
||||||
|
|
||||||
#define YAP_WRITE_QUOTED 0
|
#define YAP_WRITE_QUOTED 1
|
||||||
#define YAP_WRITE_HANDLE_VARS 1
|
|
||||||
#define YAP_WRITE_IGNORE_OPS 2
|
#define YAP_WRITE_IGNORE_OPS 2
|
||||||
|
#define YAP_WRITE_HANDLE_VARS 2
|
||||||
|
#define YAP_WRITE_USE_PORTRAY 8
|
||||||
|
|
||||||
#define YAP_CONSULT_MODE 0
|
#define YAP_CONSULT_MODE 0
|
||||||
#define YAP_RECONSULT_MODE 1
|
#define YAP_RECONSULT_MODE 1
|
||||||
@ -204,6 +205,14 @@ typedef int (*YAP_agc_hook)(void *_Atom);
|
|||||||
|
|
||||||
typedef void (*YAP_halt_hook)(int exit_code, void *closure);
|
typedef void (*YAP_halt_hook)(int exit_code, void *closure);
|
||||||
|
|
||||||
|
typedef int YAP_opaque_tag_t;
|
||||||
|
|
||||||
|
typedef int (*YAP_Opaque_CallOnFail)(void *);
|
||||||
|
|
||||||
|
typedef struct YAP_opaque_handler_struct {
|
||||||
|
YAP_Opaque_CallOnFail fail_handler;
|
||||||
|
} YAP_opaque_handler_t;
|
||||||
|
|
||||||
/********* execution mode ***********************/
|
/********* execution mode ***********************/
|
||||||
|
|
||||||
typedef enum
|
typedef enum
|
||||||
|
@ -241,6 +241,8 @@ tokenize_arguments([FirstArg|RestArgs],[TokFirstArg|TokRestArgs]):-
|
|||||||
%
|
%
|
||||||
% --------------------------------------------------------------------------------------
|
% --------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
:- dynamic attribute/4.
|
||||||
|
|
||||||
query_generation([],_,[]).
|
query_generation([],_,[]).
|
||||||
|
|
||||||
query_generation([Conjunction|Conjunctions],ProjectionTerm,[Query|Queries]):-
|
query_generation([Conjunction|Conjunctions],ProjectionTerm,[Query|Queries]):-
|
||||||
@ -1157,9 +1159,9 @@ column_atom(att(RangeVar,Attribute),QueryList,Diff):-
|
|||||||
column_atom(Attribute,X2,Diff).
|
column_atom(Attribute,X2,Diff).
|
||||||
|
|
||||||
column_atom(rel(Relation,RangeVar),QueryList,Diff):-
|
column_atom(rel(Relation,RangeVar),QueryList,Diff):-
|
||||||
column_atom('`',QueryList,X0),
|
column_atom('',QueryList,X0),
|
||||||
column_atom(Relation,X0,X1),
|
column_atom(Relation,X0,X1),
|
||||||
column_atom('` ',X1,X2),
|
column_atom(' ',X1,X2),
|
||||||
column_atom(RangeVar,X2,Diff).
|
column_atom(RangeVar,X2,Diff).
|
||||||
|
|
||||||
column_atom('$const$'(String),QueryList,Diff):-
|
column_atom('$const$'(String),QueryList,Diff):-
|
||||||
|
@ -142,12 +142,6 @@ p2c_putc(const int c) {
|
|||||||
/*
|
/*
|
||||||
* Function used by YAP to read a char from a string
|
* Function used by YAP to read a char from a string
|
||||||
*/
|
*/
|
||||||
static int
|
|
||||||
p2c_getc(void) {
|
|
||||||
if( BUFFER_POS < BUFFER_LEN )
|
|
||||||
return BUFFER_PTR[BUFFER_POS++];
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
/*
|
/*
|
||||||
* Writes a term to a stream.
|
* Writes a term to a stream.
|
||||||
*/
|
*/
|
||||||
@ -177,7 +171,7 @@ read_term_from_stream(const int fd) {
|
|||||||
if ( size> BUFFER_SIZE)
|
if ( size> BUFFER_SIZE)
|
||||||
expand_buffer(size-BUFFER_SIZE);
|
expand_buffer(size-BUFFER_SIZE);
|
||||||
read(fd,BUFFER_PTR,size); // read term from stream
|
read(fd,BUFFER_PTR,size); // read term from stream
|
||||||
return YAP_Read( p2c_getc );
|
return YAP_ReadBuffer( BUFFER_PTR , NULL);
|
||||||
}
|
}
|
||||||
/*********************************************************************************************
|
/*********************************************************************************************
|
||||||
* Conversion: Prolog Term->char[] and char[]->Prolog Term
|
* Conversion: Prolog Term->char[] and char[]->Prolog Term
|
||||||
@ -229,7 +223,7 @@ string2term(char *const ptr,const size_t *size) {
|
|||||||
}
|
}
|
||||||
BUFFER_POS=0;
|
BUFFER_POS=0;
|
||||||
LOCAL_ErrorMessage=NULL;
|
LOCAL_ErrorMessage=NULL;
|
||||||
t = YAP_Read(p2c_getc);
|
t = YAP_ReadBuffer( BUFFER_PTR , NULL );
|
||||||
if ( t==FALSE ) {
|
if ( t==FALSE ) {
|
||||||
write_msg(__FUNCTION__,__FILE__,__LINE__,"FAILED string2term>>>>size:%d %d %s\n",BUFFER_SIZE,strlen(BUFFER_PTR),LOCAL_ErrorMessage);
|
write_msg(__FUNCTION__,__FILE__,__LINE__,"FAILED string2term>>>>size:%d %d %s\n",BUFFER_SIZE,strlen(BUFFER_PTR),LOCAL_ErrorMessage);
|
||||||
exit(1);
|
exit(1);
|
||||||
|
@ -30,7 +30,6 @@ static char *rcsid = "$Header: /Users/vitor/Yap/yap-cvsbackup/library/mpi/mpi.c,
|
|||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <mpi.h>
|
#include <mpi.h>
|
||||||
|
|
||||||
Term STD_PROTO(YAP_Read, (int (*)(void)));
|
|
||||||
void STD_PROTO(YAP_Write, (Term, void (*)(int), int));
|
void STD_PROTO(YAP_Write, (Term, void (*)(int), int));
|
||||||
|
|
||||||
STATIC_PROTO (Int p_mpi_open, (void));
|
STATIC_PROTO (Int p_mpi_open, (void));
|
||||||
@ -105,15 +104,6 @@ mpi_putc(Int ch)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static Int
|
|
||||||
mpi_getc(void)
|
|
||||||
{
|
|
||||||
if( bufptr < bufsize ) return buf[bufptr++];
|
|
||||||
else return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* C Predicates
|
* C Predicates
|
||||||
@ -301,7 +291,7 @@ p_mpi_receive() /* mpi_receive(-data, ?orig, ?tag) */
|
|||||||
/* parse received string into a Prolog term */
|
/* parse received string into a Prolog term */
|
||||||
|
|
||||||
bufptr = 0;
|
bufptr = 0;
|
||||||
t = YAP_Read( mpi_getc );
|
t = YAP_ReadBuffer( buf, NULL );
|
||||||
|
|
||||||
if( t == TermNil ) {
|
if( t == TermNil ) {
|
||||||
retv = FALSE;
|
retv = FALSE;
|
||||||
@ -384,7 +374,7 @@ p_mpi_bcast3() /* mpi_bcast( ?data, +root, +max_size ) */
|
|||||||
bufptr = 0;
|
bufptr = 0;
|
||||||
|
|
||||||
/* parse received string into a Prolog term */
|
/* parse received string into a Prolog term */
|
||||||
return Yap_unify( YAP_Read(mpi_getc), ARG1 );
|
return Yap_unify( YAP_ReadBuffer( buf, NULL ), ARG1 );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -464,7 +454,7 @@ p_mpi_bcast2() /* mpi_bcast( ?data, +root ) */
|
|||||||
bufstrlen = strlen(buf);
|
bufstrlen = strlen(buf);
|
||||||
bufptr = 0;
|
bufptr = 0;
|
||||||
|
|
||||||
return Yap_unify(YAP_Read( mpi_getc ), ARG1);
|
return Yap_unify(YAP_ReadBuffer( buf, NULL ), ARG1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,6 +120,8 @@ char* DIRNAME =NULL
|
|||||||
char Executable[YAP_FILENAME_MAX] void
|
char Executable[YAP_FILENAME_MAX] void
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
int OpaqueHandlersCount =0
|
||||||
|
struct opaque_handler_struct* OpaqueHandlers =NULL
|
||||||
|
|
||||||
#if __simplescalar__
|
#if __simplescalar__
|
||||||
char pwd[YAP_FILENAME_MAX] void
|
char pwd[YAP_FILENAME_MAX] void
|
||||||
|
149
packages/CLPBN/clpbn/bp/BPNodeInfo.cpp
Executable file
149
packages/CLPBN/clpbn/bp/BPNodeInfo.cpp
Executable file
@ -0,0 +1,149 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "BPNodeInfo.h"
|
||||||
|
#include "BPSolver.h"
|
||||||
|
|
||||||
|
BPNodeInfo::BPNodeInfo (BayesNode* node)
|
||||||
|
{
|
||||||
|
node_ = node;
|
||||||
|
ds_ = node->getDomainSize();
|
||||||
|
piValsCalc_ = false;
|
||||||
|
ldValsCalc_ = false;
|
||||||
|
nPiMsgsRcv_ = 0;
|
||||||
|
nLdMsgsRcv_ = 0;
|
||||||
|
piVals_.resize (ds_, 1);
|
||||||
|
ldVals_.resize (ds_, 1);
|
||||||
|
const BnNodeSet& childs = node->getChilds();
|
||||||
|
for (unsigned i = 0; i < childs.size(); i++) {
|
||||||
|
cmsgs_.insert (make_pair (childs[i], false));
|
||||||
|
}
|
||||||
|
const BnNodeSet& parents = node->getParents();
|
||||||
|
for (unsigned i = 0; i < parents.size(); i++) {
|
||||||
|
pmsgs_.insert (make_pair (parents[i], false));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParamSet
|
||||||
|
BPNodeInfo::getBeliefs (void) const
|
||||||
|
{
|
||||||
|
double sum = 0.0;
|
||||||
|
ParamSet beliefs (ds_);
|
||||||
|
for (unsigned xi = 0; xi < ds_; xi++) {
|
||||||
|
double prod = piVals_[xi] * ldVals_[xi];
|
||||||
|
beliefs[xi] = prod;
|
||||||
|
sum += prod;
|
||||||
|
}
|
||||||
|
assert (sum);
|
||||||
|
//normalize the beliefs
|
||||||
|
for (unsigned xi = 0; xi < ds_; xi++) {
|
||||||
|
beliefs[xi] /= sum;
|
||||||
|
}
|
||||||
|
return beliefs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
BPNodeInfo::readyToSendPiMsgTo (const BayesNode* child) const
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < inChildLinks_.size(); i++) {
|
||||||
|
if (inChildLinks_[i]->getSource() != child
|
||||||
|
&& !inChildLinks_[i]->messageWasSended()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
BPNodeInfo::readyToSendLambdaMsgTo (const BayesNode* parent) const
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < inParentLinks_.size(); i++) {
|
||||||
|
if (inParentLinks_[i]->getSource() != parent
|
||||||
|
&& !inParentLinks_[i]->messageWasSended()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
BPNodeInfo::getPiValue (unsigned idx) const
|
||||||
|
{
|
||||||
|
assert (idx >=0 && idx < ds_);
|
||||||
|
return piVals_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BPNodeInfo::setPiValue (unsigned idx, Param value)
|
||||||
|
{
|
||||||
|
assert (idx >=0 && idx < ds_);
|
||||||
|
piVals_[idx] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
BPNodeInfo::getLambdaValue (unsigned idx) const
|
||||||
|
{
|
||||||
|
assert (idx >=0 && idx < ds_);
|
||||||
|
return ldVals_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BPNodeInfo::setLambdaValue (unsigned idx, Param value)
|
||||||
|
{
|
||||||
|
assert (idx >=0 && idx < ds_);
|
||||||
|
ldVals_[idx] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
BPNodeInfo::getBeliefChange (void)
|
||||||
|
{
|
||||||
|
double change = 0.0;
|
||||||
|
if (oldBeliefs_.size() == 0) {
|
||||||
|
oldBeliefs_ = getBeliefs();
|
||||||
|
change = 9999999999.0;
|
||||||
|
} else {
|
||||||
|
ParamSet currentBeliefs = getBeliefs();
|
||||||
|
for (unsigned xi = 0; xi < ds_; xi++) {
|
||||||
|
change += abs (currentBeliefs[xi] - oldBeliefs_[xi]);
|
||||||
|
}
|
||||||
|
oldBeliefs_ = currentBeliefs;
|
||||||
|
}
|
||||||
|
return change;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
BPNodeInfo::receivedBottomInfluence (void) const
|
||||||
|
{
|
||||||
|
// if all lambda values are equal, then neither
|
||||||
|
// this node neither its descendents have evidence,
|
||||||
|
// we can use this to don't send lambda messages his parents
|
||||||
|
bool childInfluenced = false;
|
||||||
|
for (unsigned xi = 1; xi < ds_; xi++) {
|
||||||
|
if (ldVals_[xi] != ldVals_[0]) {
|
||||||
|
childInfluenced = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return childInfluenced;
|
||||||
|
}
|
||||||
|
|
82
packages/CLPBN/clpbn/bp/BPNodeInfo.h
Executable file
82
packages/CLPBN/clpbn/bp/BPNodeInfo.h
Executable file
@ -0,0 +1,82 @@
|
|||||||
|
#ifndef BP_BP_NODE_H
|
||||||
|
#define BP_BP_NODE_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "BPSolver.h"
|
||||||
|
#include "BayesNode.h"
|
||||||
|
#include "Shared.h"
|
||||||
|
|
||||||
|
//class Edge;
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
class BPNodeInfo
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
BPNodeInfo (int);
|
||||||
|
BPNodeInfo (BayesNode*);
|
||||||
|
|
||||||
|
ParamSet getBeliefs (void) const;
|
||||||
|
double getPiValue (unsigned) const;
|
||||||
|
void setPiValue (unsigned, Param);
|
||||||
|
double getLambdaValue (unsigned) const;
|
||||||
|
void setLambdaValue (unsigned, Param);
|
||||||
|
double getBeliefChange (void);
|
||||||
|
bool receivedBottomInfluence (void) const;
|
||||||
|
|
||||||
|
ParamSet& getPiValues (void) { return piVals_; }
|
||||||
|
ParamSet& getLambdaValues (void) { return ldVals_; }
|
||||||
|
bool arePiValuesCalculated (void) { return piValsCalc_; }
|
||||||
|
bool areLambdaValuesCalculated (void) { return ldValsCalc_; }
|
||||||
|
void markPiValuesAsCalculated (void) { piValsCalc_ = true; }
|
||||||
|
void markLambdaValuesAsCalculated (void) { ldValsCalc_ = true; }
|
||||||
|
void incNumPiMsgsRcv (void) { nPiMsgsRcv_ ++; }
|
||||||
|
void incNumLambdaMsgsRcv (void) { nLdMsgsRcv_ ++; }
|
||||||
|
|
||||||
|
bool receivedAllPiMessages (void)
|
||||||
|
{
|
||||||
|
return node_->getParents().size() == nPiMsgsRcv_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool receivedAllLambdaMessages (void)
|
||||||
|
{
|
||||||
|
return node_->getChilds().size() == nLdMsgsRcv_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool readyToSendPiMsgTo (const BayesNode*) const ;
|
||||||
|
bool readyToSendLambdaMsgTo (const BayesNode*) const;
|
||||||
|
|
||||||
|
CEdgeSet getIncomingParentLinks (void) { return inParentLinks_; }
|
||||||
|
CEdgeSet getIncomingChildLinks (void) { return inChildLinks_; }
|
||||||
|
CEdgeSet getOutcomingParentLinks (void) { return outParentLinks_; }
|
||||||
|
CEdgeSet getOutcomingChildLinks (void) { return outChildLinks_; }
|
||||||
|
|
||||||
|
void addIncomingParentLink (Edge* l) { inParentLinks_.push_back (l); }
|
||||||
|
void addIncomingChildLink (Edge* l) { inChildLinks_.push_back (l); }
|
||||||
|
void addOutcomingParentLink (Edge* l) { outParentLinks_.push_back (l); }
|
||||||
|
void addOutcomingChildLink (Edge* l) { outChildLinks_.push_back (l); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
DISALLOW_COPY_AND_ASSIGN (BPNodeInfo);
|
||||||
|
|
||||||
|
ParamSet piVals_; // pi values
|
||||||
|
ParamSet ldVals_; // lambda values
|
||||||
|
ParamSet oldBeliefs_;
|
||||||
|
unsigned nPiMsgsRcv_;
|
||||||
|
unsigned nLdMsgsRcv_;
|
||||||
|
bool piValsCalc_;
|
||||||
|
bool ldValsCalc_;
|
||||||
|
EdgeSet inParentLinks_;
|
||||||
|
EdgeSet inChildLinks_;
|
||||||
|
EdgeSet outParentLinks_;
|
||||||
|
EdgeSet outChildLinks_;
|
||||||
|
unsigned ds_;
|
||||||
|
const BayesNode* node_;
|
||||||
|
map<const BayesNode*, bool> pmsgs_;
|
||||||
|
map<const BayesNode*, bool> cmsgs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif //BP_BP_NODE_H
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
@ -1,259 +1,106 @@
|
|||||||
#ifndef BP_BPSOLVER_H
|
#ifndef BP_BP_SOLVER_H
|
||||||
#define BP_BPSOLVER_H
|
#define BP_BP_SOLVER_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
#include "Solver.h"
|
#include "Solver.h"
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
#include "BpNode.h"
|
#include "BPNodeInfo.h"
|
||||||
#include "Shared.h"
|
#include "Shared.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class BPSolver;
|
class BPNodeInfo;
|
||||||
|
|
||||||
static const string PI = "pi" ;
|
static const string PI = "pi" ;
|
||||||
static const string LD = "ld" ;
|
static const string LD = "ld" ;
|
||||||
|
|
||||||
|
|
||||||
enum MessageType {PI_MSG, LAMBDA_MSG};
|
enum MessageType {PI_MSG, LAMBDA_MSG};
|
||||||
|
enum JointCalcType {CHAIN_RULE, JUNCTION_NODE};
|
||||||
|
|
||||||
class BPSolver;
|
class Edge
|
||||||
struct Edge
|
|
||||||
{
|
{
|
||||||
Edge (BayesNode* s, BayesNode* d, MessageType t)
|
public:
|
||||||
{
|
Edge (BayesNode* s, BayesNode* d, MessageType t)
|
||||||
source = s;
|
{
|
||||||
destination = d;
|
source_ = s;
|
||||||
type = t;
|
destin_ = d;
|
||||||
}
|
type_ = t;
|
||||||
string getId (void) const
|
if (type_ == PI_MSG) {
|
||||||
{
|
currMsg_.resize (s->getDomainSize(), 1);
|
||||||
stringstream ss;
|
nextMsg_.resize (s->getDomainSize(), 1);
|
||||||
type == PI_MSG ? ss << PI : ss << LD;
|
} else {
|
||||||
ss << source->getVarId() << "." << destination->getVarId();
|
currMsg_.resize (d->getDomainSize(), 1);
|
||||||
return ss.str();
|
nextMsg_.resize (d->getDomainSize(), 1);
|
||||||
}
|
|
||||||
string toString (void) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
type == PI_MSG ? ss << PI << "(" : ss << LD << "(" ;
|
|
||||||
ss << source->getLabel() << " --> " ;
|
|
||||||
ss << destination->getLabel();
|
|
||||||
ss << ")" ;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
BayesNode* source;
|
|
||||||
BayesNode* destination;
|
|
||||||
MessageType type;
|
|
||||||
static BPSolver* klass;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
class BPMessage
|
|
||||||
{
|
|
||||||
BPMessage (BayesNode* parent, BayesNode* child)
|
|
||||||
{
|
|
||||||
parent_ = parent;
|
|
||||||
child_ = child;
|
|
||||||
currPiMsg_.resize (child->getDomainSize(), 1);
|
|
||||||
currLdMsg_.resize (parent->getDomainSize(), 1);
|
|
||||||
nextLdMsg_.resize (parent->getDomainSize(), 1);
|
|
||||||
nextPiMsg_.resize (child->getDomainSize(), 1);
|
|
||||||
piResidual_ = 1.0;
|
|
||||||
ldResidual_ = 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
Param getPiMessageValue (int idx) const
|
|
||||||
{
|
|
||||||
assert (idx >=0 && idx < child->getDomainSize());
|
|
||||||
return currPiMsg_[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
Param getLambdaMessageValue (int idx) const
|
|
||||||
{
|
|
||||||
assert (idx >=0 && idx < parent->getDomainSize());
|
|
||||||
return currLdMsg_[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
const ParamSet& getPiMessage (void) const
|
|
||||||
{
|
|
||||||
return currPiMsg_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const ParamSet& getLambdaMessage (void) const
|
|
||||||
{
|
|
||||||
return currLdMsg_;
|
|
||||||
}
|
|
||||||
|
|
||||||
ParamSet& piNextMessageReference (void)
|
|
||||||
{
|
|
||||||
return nextPiMsg_;
|
|
||||||
}
|
|
||||||
|
|
||||||
ParamSet& lambdaNextMessageReference (const BayesNode* source)
|
|
||||||
{
|
|
||||||
return nextLdMsg_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void updatePiMessage (void)
|
|
||||||
{
|
|
||||||
currPiMsg_ = nextPiMsg_;
|
|
||||||
Util::normalize (currPiMsg_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateLambdaMessage (void)
|
|
||||||
{
|
|
||||||
currLdMsg_ = nextLdMsg_;
|
|
||||||
Util::normalize (currLdMsg_);
|
|
||||||
}
|
|
||||||
|
|
||||||
double getPiResidual (void)
|
|
||||||
{
|
|
||||||
return piResidual_;
|
|
||||||
}
|
|
||||||
|
|
||||||
double getLambdaResidual (void)
|
|
||||||
{
|
|
||||||
return ldResidual_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void updatePiResidual (void)
|
|
||||||
{
|
|
||||||
piResidual_ = Util::getL1dist (currPiMsg_, nextPiMsg_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateLambdaResidual (void)
|
|
||||||
{
|
|
||||||
ldResidual_ = Util::getL1dist (currLdMsg_, nextLdMsg_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void clearPiResidual (void)
|
|
||||||
{
|
|
||||||
piResidual_ = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void clearLambdaResidual (void)
|
|
||||||
{
|
|
||||||
ldResidual_ = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNode* parent_;
|
|
||||||
BayesNode* child_;
|
|
||||||
ParamSet currPiMsg_; // current pi messages
|
|
||||||
ParamSet currLdMsg_; // current lambda messages
|
|
||||||
ParamSet nextPiMsg_;
|
|
||||||
ParamSet nextLdMsg_;
|
|
||||||
Param piResidual_;
|
|
||||||
Param ldResidual_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class NodeInfo
|
|
||||||
{
|
|
||||||
NodeInfo (BayesNode* node)
|
|
||||||
{
|
|
||||||
node_ = node;
|
|
||||||
piVals_.resize (node->getDomainSize(), 1);
|
|
||||||
ldVals_.resize (node->getDomainSize(), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
ParamSet getBeliefs (void) const
|
|
||||||
{
|
|
||||||
double sum = 0.0;
|
|
||||||
ParamSet beliefs (node_->getDomainSize());
|
|
||||||
for (int xi = 0; xi < node_->getDomainSize(); xi++) {
|
|
||||||
double prod = piVals_[xi] * ldVals_[xi];
|
|
||||||
beliefs[xi] = prod;
|
|
||||||
sum += prod;
|
|
||||||
}
|
|
||||||
assert (sum);
|
|
||||||
//normalize the beliefs
|
|
||||||
for (int xi = 0; xi < node_->getDomainSize(); xi++) {
|
|
||||||
beliefs[xi] /= sum;
|
|
||||||
}
|
|
||||||
return beliefs;
|
|
||||||
}
|
|
||||||
|
|
||||||
double getPiValue (int idx) const
|
|
||||||
{
|
|
||||||
assert (idx >=0 && idx < node_->getDomainSize());
|
|
||||||
return piVals_[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
void setPiValue (int idx, double value)
|
|
||||||
{
|
|
||||||
assert (idx >=0 && idx < node_->getDomainSize());
|
|
||||||
piVals_[idx] = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
double getLambdaValue (int idx) const
|
|
||||||
{
|
|
||||||
assert (idx >=0 && idx < node_->getDomainSize());
|
|
||||||
return ldVals_[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
void setLambdaValue (int idx, double value)
|
|
||||||
{
|
|
||||||
assert (idx >=0 && idx < node_->getDomainSize());
|
|
||||||
ldVals_[idx] = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
ParamSet& getPiValues (void)
|
|
||||||
{
|
|
||||||
return piVals_;
|
|
||||||
}
|
|
||||||
|
|
||||||
ParamSet& getLambdaValues (void)
|
|
||||||
{
|
|
||||||
return ldVals_;
|
|
||||||
}
|
|
||||||
|
|
||||||
double getBeliefChange (void)
|
|
||||||
{
|
|
||||||
double change = 0.0;
|
|
||||||
if (oldBeliefs_.size() == 0) {
|
|
||||||
oldBeliefs_ = getBeliefs();
|
|
||||||
change = MAX_CHANGE_;
|
|
||||||
} else {
|
|
||||||
ParamSet currentBeliefs = getBeliefs();
|
|
||||||
for (int xi = 0; xi < node_->getDomainSize(); xi++) {
|
|
||||||
change += abs (currentBeliefs[xi] - oldBeliefs_[xi]);
|
|
||||||
}
|
}
|
||||||
oldBeliefs_ = currentBeliefs;
|
msgSended_ = false;
|
||||||
|
residual_ = 0.0;
|
||||||
}
|
}
|
||||||
return change;
|
|
||||||
}
|
//void setMessage (ParamSet msg)
|
||||||
|
//{
|
||||||
|
// Util::normalize (msg);
|
||||||
|
// residual_ = Util::getMaxNorm (currMsg_, msg);
|
||||||
|
// currMsg_ = msg;
|
||||||
|
//}
|
||||||
|
|
||||||
bool hasReceivedChildInfluence (void) const
|
void setNextMessage (CParamSet msg)
|
||||||
{
|
{
|
||||||
// if all lambda values are equal, then neither
|
nextMsg_ = msg;
|
||||||
// this node neither its descendents have evidence,
|
Util::normalize (nextMsg_);
|
||||||
// we can use this to don't send lambda messages his parents
|
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
|
||||||
bool childInfluenced = false;
|
}
|
||||||
for (int xi = 1; xi < node_->getDomainSize(); xi++) {
|
|
||||||
if (ldVals_[xi] != ldVals_[0]) {
|
void updateMessage (void)
|
||||||
childInfluenced = true;
|
{
|
||||||
break;
|
currMsg_ = nextMsg_;
|
||||||
|
if (DL >= 3) {
|
||||||
|
cout << "updating " << toString() << endl;
|
||||||
}
|
}
|
||||||
|
msgSended_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void updateResidual (void)
|
||||||
|
{
|
||||||
|
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
|
||||||
}
|
}
|
||||||
return childInfluenced;
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNode* node_;
|
string toString (void) const
|
||||||
ParamSet piVals_; // pi values
|
{
|
||||||
ParamSet ldVals_; // lambda values
|
stringstream ss;
|
||||||
ParamSet oldBeliefs_;
|
if (type_ == PI_MSG) {
|
||||||
|
ss << PI;
|
||||||
|
} else if (type_ == LAMBDA_MSG) {
|
||||||
|
ss << LD;
|
||||||
|
} else {
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
ss << "(" << source_->getLabel();
|
||||||
|
ss << " --> " << destin_->getLabel() << ")" ;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
BayesNode* getSource (void) const { return source_; }
|
||||||
|
BayesNode* getDestination (void) const { return destin_; }
|
||||||
|
MessageType getMessageType (void) const { return type_; }
|
||||||
|
CParamSet getMessage (void) const { return currMsg_; }
|
||||||
|
bool messageWasSended (void) const { return msgSended_; }
|
||||||
|
double getResidual (void) const { return residual_; }
|
||||||
|
void clearResidual (void) { residual_ = 0.0; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
BayesNode* source_;
|
||||||
|
BayesNode* destin_;
|
||||||
|
MessageType type_;
|
||||||
|
ParamSet currMsg_;
|
||||||
|
ParamSet nextMsg_;
|
||||||
|
bool msgSended_;
|
||||||
|
double residual_;
|
||||||
};
|
};
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
bool compareResidual (const Edge&, const Edge&);
|
|
||||||
|
|
||||||
class BPSolver : public Solver
|
class BPSolver : public Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -261,190 +108,85 @@ class BPSolver : public Solver
|
|||||||
~BPSolver (void);
|
~BPSolver (void);
|
||||||
|
|
||||||
void runSolver (void);
|
void runSolver (void);
|
||||||
ParamSet getPosterioriOf (const Variable* var) const;
|
ParamSet getPosterioriOf (Vid) const;
|
||||||
ParamSet getJointDistribution (const NodeSet&) const;
|
ParamSet getJointDistributionOf (const VidSet&);
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (BPSolver);
|
DISALLOW_COPY_AND_ASSIGN (BPSolver);
|
||||||
|
|
||||||
void initializeSolver (void);
|
void initializeSolver (void);
|
||||||
void incorporateEvidence (BayesNode*);
|
|
||||||
void runPolyTreeSolver (void);
|
void runPolyTreeSolver (void);
|
||||||
void polyTreePiMessage (BayesNode*, BayesNode*);
|
void runLoopySolver (void);
|
||||||
void polyTreeLambdaMessage (BayesNode*, BayesNode*);
|
|
||||||
void runGenericSolver (void);
|
|
||||||
void maxResidualSchedule (void);
|
void maxResidualSchedule (void);
|
||||||
bool converged (void) const;
|
bool converged (void) const;
|
||||||
void updatePiValues (BayesNode*);
|
void updatePiValues (BayesNode*);
|
||||||
void updateLambdaValues (BayesNode*);
|
void updateLambdaValues (BayesNode*);
|
||||||
void calculateNextPiMessage (BayesNode*, BayesNode*);
|
ParamSet calculateNextLambdaMessage (Edge* edge);
|
||||||
void calculateNextLambdaMessage (BayesNode*, BayesNode*);
|
ParamSet calculateNextPiMessage (Edge* edge);
|
||||||
|
ParamSet getJointByJunctionNode (const VidSet&) const;
|
||||||
|
ParamSet getJointByChainRule (const VidSet&) const;
|
||||||
void printMessageStatusOf (const BayesNode*) const;
|
void printMessageStatusOf (const BayesNode*) const;
|
||||||
void printAllMessageStatus (void) const;
|
void printAllMessageStatus (void) const;
|
||||||
// inlines
|
|
||||||
void updatePiMessage (BayesNode*, BayesNode*);
|
ParamSet getMessage (Edge* edge)
|
||||||
void updateLambdaMessage (BayesNode*, BayesNode*);
|
{
|
||||||
void calculateNextMessage (const Edge&);
|
if (DL >= 3) {
|
||||||
void updateMessage (const Edge&);
|
cout << " calculating " << edge->toString() << endl;
|
||||||
void updateValues (const Edge&);
|
}
|
||||||
double getResidual (const Edge&) const;
|
if (edge->getMessageType() == PI_MSG) {
|
||||||
void updateResidual (const Edge&);
|
return calculateNextPiMessage (edge);
|
||||||
void clearResidual (const Edge&);
|
} else if (edge->getMessageType() == LAMBDA_MSG) {
|
||||||
BpNode* M (const BayesNode*) const;
|
return calculateNextLambdaMessage (edge);
|
||||||
friend bool compareResidual (const Edge&, const Edge&);
|
} else {
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
return ParamSet();
|
||||||
|
}
|
||||||
|
|
||||||
|
void updateValues (Edge* edge)
|
||||||
|
{
|
||||||
|
if (!edge->getDestination()->hasEvidence()) {
|
||||||
|
if (edge->getMessageType() == PI_MSG) {
|
||||||
|
updatePiValues (edge->getDestination());
|
||||||
|
} else if (edge->getMessageType() == LAMBDA_MSG) {
|
||||||
|
updateLambdaValues (edge->getDestination());
|
||||||
|
} else {
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BPNodeInfo* M (const BayesNode* node) const
|
||||||
|
{
|
||||||
|
assert (node);
|
||||||
|
assert (node == bn_->getBayesNode (node->getVarId()));
|
||||||
|
assert (node->getIndex() < nodesI_.size());
|
||||||
|
return nodesI_[node->getIndex()];
|
||||||
|
}
|
||||||
|
|
||||||
const BayesNet* bn_;
|
const BayesNet* bn_;
|
||||||
vector<BpNode*> msgs_;
|
vector<BPNodeInfo*> nodesI_;
|
||||||
Schedule schedule_;
|
unsigned nIter_;
|
||||||
int nIter_;
|
vector<Edge*> links_;
|
||||||
int maxIter_;
|
bool useAlwaysLoopySolver_;
|
||||||
double accuracy_;
|
JointCalcType jointCalcType_;
|
||||||
vector<Edge> updateOrder_;
|
|
||||||
bool forceGenericSolver_;
|
|
||||||
|
|
||||||
struct compare
|
struct compare
|
||||||
{
|
{
|
||||||
inline bool operator() (const Edge& e1, const Edge& e2)
|
inline bool operator() (const Edge* e1, const Edge* e2)
|
||||||
{
|
{
|
||||||
return compareResidual (e1, e2);
|
return e1->getResidual() > e2->getResidual();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef multiset<Edge, compare> SortedOrder;
|
typedef multiset<Edge*, compare> SortedOrder;
|
||||||
SortedOrder sortedOrder_;
|
SortedOrder sortedOrder_;
|
||||||
|
|
||||||
typedef unordered_map<string, SortedOrder::iterator> EdgeMap;
|
typedef map<Edge*, SortedOrder::iterator> EdgeMap;
|
||||||
EdgeMap edgeMap_;
|
EdgeMap edgeMap_;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#endif //BP_BP_SOLVER_H
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::updatePiMessage (BayesNode* source, BayesNode* destination)
|
|
||||||
{
|
|
||||||
M(source)->updatePiMessage(destination);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::updateLambdaMessage (BayesNode* source, BayesNode* destination)
|
|
||||||
{
|
|
||||||
M(destination)->updateLambdaMessage(source);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::calculateNextMessage (const Edge& e)
|
|
||||||
{
|
|
||||||
if (DL >= 1) {
|
|
||||||
cout << "calculating " << e.toString() << endl;
|
|
||||||
}
|
|
||||||
if (e.type == PI_MSG) {
|
|
||||||
calculateNextPiMessage (e.source, e.destination);
|
|
||||||
} else {
|
|
||||||
calculateNextLambdaMessage (e.source, e.destination);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::updateMessage (const Edge& e)
|
|
||||||
{
|
|
||||||
if (DL >= 1) {
|
|
||||||
cout << "updating " << e.toString() << endl;
|
|
||||||
}
|
|
||||||
if (e.type == PI_MSG) {
|
|
||||||
M(e.source)->updatePiMessage(e.destination);
|
|
||||||
} else {
|
|
||||||
M(e.destination)->updateLambdaMessage(e.source);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::updateValues (const Edge& e)
|
|
||||||
{
|
|
||||||
if (!e.destination->hasEvidence()) {
|
|
||||||
if (e.type == PI_MSG) {
|
|
||||||
updatePiValues (e.destination);
|
|
||||||
} else {
|
|
||||||
updateLambdaValues (e.destination);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
|
||||||
BPSolver::getResidual (const Edge& e) const
|
|
||||||
{
|
|
||||||
if (e.type == PI_MSG) {
|
|
||||||
return M(e.source)->getPiResidual(e.destination);
|
|
||||||
} else {
|
|
||||||
return M(e.destination)->getLambdaResidual(e.source);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::updateResidual (const Edge& e)
|
|
||||||
{
|
|
||||||
if (e.type == PI_MSG) {
|
|
||||||
M(e.source)->updatePiResidual(e.destination);
|
|
||||||
} else {
|
|
||||||
M(e.destination)->updateLambdaResidual(e.source);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::clearResidual (const Edge& e)
|
|
||||||
{
|
|
||||||
if (e.type == PI_MSG) {
|
|
||||||
M(e.source)->clearPiResidual(e.destination);
|
|
||||||
} else {
|
|
||||||
M(e.destination)->clearLambdaResidual(e.source);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline bool
|
|
||||||
compareResidual (const Edge& e1, const Edge& e2)
|
|
||||||
{
|
|
||||||
double residual1;
|
|
||||||
double residual2;
|
|
||||||
if (e1.type == PI_MSG) {
|
|
||||||
residual1 = Edge::klass->M(e1.source)->getPiResidual(e1.destination);
|
|
||||||
} else {
|
|
||||||
residual1 = Edge::klass->M(e1.destination)->getLambdaResidual(e1.source);
|
|
||||||
}
|
|
||||||
if (e2.type == PI_MSG) {
|
|
||||||
residual2 = Edge::klass->M(e2.source)->getPiResidual(e2.destination);
|
|
||||||
} else {
|
|
||||||
residual2 = Edge::klass->M(e2.destination)->getLambdaResidual(e2.source);
|
|
||||||
}
|
|
||||||
return residual1 > residual2;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline BpNode*
|
|
||||||
BPSolver::M (const BayesNode* node) const
|
|
||||||
{
|
|
||||||
assert (node);
|
|
||||||
assert (node == bn_->getNode (node->getVarId()));
|
|
||||||
assert (node->getIndex() < msgs_.size());
|
|
||||||
return msgs_[node->getIndex()];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
@ -1,30 +1,24 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <cassert>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "xmlParser/xmlParser.h"
|
#include "xmlParser/xmlParser.h"
|
||||||
|
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
|
|
||||||
|
|
||||||
BayesNet::BayesNet (void)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNet::BayesNet (const char* fileName)
|
BayesNet::BayesNet (const char* fileName)
|
||||||
{
|
{
|
||||||
map<string, Domain> domains;
|
map<string, Domain> domains;
|
||||||
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
|
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
|
||||||
// only the first network is parsed, others are ignored
|
// only the first network is parsed, others are ignored
|
||||||
XMLNode xNode = xMainNode.getChildNode ("NETWORK");
|
XMLNode xNode = xMainNode.getChildNode ("NETWORK");
|
||||||
int nVars = xNode.nChildNode ("VARIABLE");
|
unsigned nVars = xNode.nChildNode ("VARIABLE");
|
||||||
for (int i = 0; i < nVars; i++) {
|
for (unsigned i = 0; i < nVars; i++) {
|
||||||
XMLNode var = xNode.getChildNode ("VARIABLE", i);
|
XMLNode var = xNode.getChildNode ("VARIABLE", i);
|
||||||
string type = var.getAttribute ("TYPE");
|
string type = var.getAttribute ("TYPE");
|
||||||
if (type != "nature") {
|
if (type != "nature") {
|
||||||
@ -32,9 +26,9 @@ BayesNet::BayesNet (const char* fileName)
|
|||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
Domain domain;
|
Domain domain;
|
||||||
string label = var.getChildNode("NAME").getText();
|
string varLabel = var.getChildNode("NAME").getText();
|
||||||
int domainSize = var.nChildNode ("OUTCOME");
|
unsigned dsize = var.nChildNode ("OUTCOME");
|
||||||
for (int j = 0; j < domainSize; j++) {
|
for (unsigned j = 0; j < dsize; j++) {
|
||||||
if (var.getChildNode("OUTCOME", j).getText() == 0) {
|
if (var.getChildNode("OUTCOME", j).getText() == 0) {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << j + 1;
|
ss << j + 1;
|
||||||
@ -43,37 +37,37 @@ BayesNet::BayesNet (const char* fileName)
|
|||||||
domain.push_back (var.getChildNode("OUTCOME", j).getText());
|
domain.push_back (var.getChildNode("OUTCOME", j).getText());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
domains.insert (make_pair (label, domain));
|
domains.insert (make_pair (varLabel, domain));
|
||||||
}
|
}
|
||||||
|
|
||||||
int nDefs = xNode.nChildNode ("DEFINITION");
|
unsigned nDefs = xNode.nChildNode ("DEFINITION");
|
||||||
if (nVars != nDefs) {
|
if (nVars != nDefs) {
|
||||||
cerr << "error: different number of variables and definitions";
|
cerr << "error: different number of variables and definitions" << endl;
|
||||||
cerr << endl;
|
abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
queue<int> indexes;
|
queue<unsigned> indexes;
|
||||||
for (int i = 0; i < nDefs; i++) {
|
for (unsigned i = 0; i < nDefs; i++) {
|
||||||
indexes.push (i);
|
indexes.push (i);
|
||||||
}
|
}
|
||||||
|
|
||||||
while (!indexes.empty()) {
|
while (!indexes.empty()) {
|
||||||
int index = indexes.front();
|
unsigned index = indexes.front();
|
||||||
indexes.pop();
|
indexes.pop();
|
||||||
XMLNode def = xNode.getChildNode ("DEFINITION", index);
|
XMLNode def = xNode.getChildNode ("DEFINITION", index);
|
||||||
string label = def.getChildNode("FOR").getText();
|
string varLabel = def.getChildNode("FOR").getText();
|
||||||
map<string, Domain>::const_iterator iter;
|
map<string, Domain>::const_iterator iter;
|
||||||
iter = domains.find (label);
|
iter = domains.find (varLabel);
|
||||||
if (iter == domains.end()) {
|
if (iter == domains.end()) {
|
||||||
cerr << "error: unknow variable `" << label << "'" << endl;
|
cerr << "error: unknow variable `" << varLabel << "'" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
bool processItLatter = false;
|
bool processItLatter = false;
|
||||||
NodeSet parents;
|
BnNodeSet parents;
|
||||||
int nParams = iter->second.size();
|
unsigned nParams = iter->second.size();
|
||||||
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
|
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
|
||||||
string parentLabel = def.getChildNode("GIVEN", j).getText();
|
string parentLabel = def.getChildNode("GIVEN", j).getText();
|
||||||
BayesNode* parentNode = getNode (parentLabel);
|
BayesNode* parentNode = getBayesNode (parentLabel);
|
||||||
if (parentNode) {
|
if (parentNode) {
|
||||||
nParams *= parentNode->getDomainSize();
|
nParams *= parentNode->getDomainSize();
|
||||||
parents.push_back (parentNode);
|
parents.push_back (parentNode);
|
||||||
@ -95,7 +89,7 @@ BayesNet::BayesNet (const char* fileName)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!processItLatter) {
|
if (!processItLatter) {
|
||||||
int count = 0;
|
unsigned count = 0;
|
||||||
ParamSet params (nParams);
|
ParamSet params (nParams);
|
||||||
stringstream s (def.getChildNode("TABLE").getText());
|
stringstream s (def.getChildNode("TABLE").getText());
|
||||||
while (!s.eof() && count < nParams) {
|
while (!s.eof() && count < nParams) {
|
||||||
@ -104,11 +98,11 @@ BayesNet::BayesNet (const char* fileName)
|
|||||||
}
|
}
|
||||||
if (count != nParams) {
|
if (count != nParams) {
|
||||||
cerr << "error: invalid number of parameters " ;
|
cerr << "error: invalid number of parameters " ;
|
||||||
cerr << "for variable `" << label << "'" << endl;
|
cerr << "for variable `" << varLabel << "'" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
params = reorderParameters (params, iter->second.size());
|
params = reorderParameters (params, iter->second.size());
|
||||||
addNode (label, iter->second, parents, params);
|
addNode (varLabel, iter->second, parents, params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
setIndexes();
|
setIndexes();
|
||||||
@ -118,7 +112,6 @@ BayesNet::BayesNet (const char* fileName)
|
|||||||
|
|
||||||
BayesNet::~BayesNet (void)
|
BayesNet::~BayesNet (void)
|
||||||
{
|
{
|
||||||
Statistics::writeStats();
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
delete nodes_[i];
|
delete nodes_[i];
|
||||||
}
|
}
|
||||||
@ -127,25 +120,25 @@ BayesNet::~BayesNet (void)
|
|||||||
|
|
||||||
|
|
||||||
BayesNode*
|
BayesNode*
|
||||||
BayesNet::addNode (unsigned varId)
|
BayesNet::addNode (Vid vid)
|
||||||
{
|
{
|
||||||
indexMap_.insert (make_pair (varId, nodes_.size()));
|
indexMap_.insert (make_pair (vid, nodes_.size()));
|
||||||
nodes_.push_back (new BayesNode (varId));
|
nodes_.push_back (new BayesNode (vid));
|
||||||
return nodes_.back();
|
return nodes_.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode*
|
BayesNode*
|
||||||
BayesNet::addNode (unsigned varId,
|
BayesNet::addNode (Vid vid,
|
||||||
unsigned dsize,
|
unsigned dsize,
|
||||||
int evidence,
|
int evidence,
|
||||||
NodeSet& parents,
|
BnNodeSet& parents,
|
||||||
Distribution* dist)
|
Distribution* dist)
|
||||||
{
|
{
|
||||||
indexMap_.insert (make_pair (varId, nodes_.size()));
|
indexMap_.insert (make_pair (vid, nodes_.size()));
|
||||||
nodes_.push_back (new BayesNode (
|
nodes_.push_back (new BayesNode (
|
||||||
varId, dsize, evidence, parents, dist));
|
vid, dsize, evidence, parents, dist));
|
||||||
return nodes_.back();
|
return nodes_.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -154,7 +147,7 @@ BayesNet::addNode (unsigned varId,
|
|||||||
BayesNode*
|
BayesNode*
|
||||||
BayesNet::addNode (string label,
|
BayesNet::addNode (string label,
|
||||||
Domain domain,
|
Domain domain,
|
||||||
NodeSet& parents,
|
BnNodeSet& parents,
|
||||||
ParamSet& params)
|
ParamSet& params)
|
||||||
{
|
{
|
||||||
indexMap_.insert (make_pair (nodes_.size(), nodes_.size()));
|
indexMap_.insert (make_pair (nodes_.size(), nodes_.size()));
|
||||||
@ -169,9 +162,9 @@ BayesNet::addNode (string label,
|
|||||||
|
|
||||||
|
|
||||||
BayesNode*
|
BayesNode*
|
||||||
BayesNet::getNode (unsigned varId) const
|
BayesNet::getBayesNode (Vid vid) const
|
||||||
{
|
{
|
||||||
IndexMap::const_iterator it = indexMap_.find(varId);
|
IndexMap::const_iterator it = indexMap_.find (vid);
|
||||||
if (it == indexMap_.end()) {
|
if (it == indexMap_.end()) {
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
@ -182,7 +175,7 @@ BayesNet::getNode (unsigned varId) const
|
|||||||
|
|
||||||
|
|
||||||
BayesNode*
|
BayesNode*
|
||||||
BayesNet::getNode (string label) const
|
BayesNet::getBayesNode (string label) const
|
||||||
{
|
{
|
||||||
BayesNode* node = 0;
|
BayesNode* node = 0;
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
@ -196,6 +189,15 @@ BayesNet::getNode (string label) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Variable*
|
||||||
|
BayesNet::getVariable (Vid vid) const
|
||||||
|
{
|
||||||
|
return getBayesNode (vid);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::addDistribution (Distribution* dist)
|
BayesNet::addDistribution (Distribution* dist)
|
||||||
{
|
{
|
||||||
@ -219,15 +221,15 @@ BayesNet::getDistribution (unsigned distId) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
const NodeSet&
|
const BnNodeSet&
|
||||||
BayesNet::getNodes (void) const
|
BayesNet::getBayesNodes (void) const
|
||||||
{
|
{
|
||||||
return nodes_;
|
return nodes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
unsigned
|
||||||
BayesNet::getNumberOfNodes (void) const
|
BayesNet::getNumberOfNodes (void) const
|
||||||
{
|
{
|
||||||
return nodes_.size();
|
return nodes_.size();
|
||||||
@ -235,10 +237,10 @@ BayesNet::getNumberOfNodes (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
NodeSet
|
BnNodeSet
|
||||||
BayesNet::getRootNodes (void) const
|
BayesNet::getRootNodes (void) const
|
||||||
{
|
{
|
||||||
NodeSet roots;
|
BnNodeSet roots;
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
if (nodes_[i]->isRoot()) {
|
if (nodes_[i]->isRoot()) {
|
||||||
roots.push_back (nodes_[i]);
|
roots.push_back (nodes_[i]);
|
||||||
@ -249,10 +251,10 @@ BayesNet::getRootNodes (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
NodeSet
|
BnNodeSet
|
||||||
BayesNet::getLeafNodes (void) const
|
BayesNet::getLeafNodes (void) const
|
||||||
{
|
{
|
||||||
NodeSet leafs;
|
BnNodeSet leafs;
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
if (nodes_[i]->isLeaf()) {
|
if (nodes_[i]->isLeaf()) {
|
||||||
leafs.push_back (nodes_[i]);
|
leafs.push_back (nodes_[i]);
|
||||||
@ -276,30 +278,32 @@ BayesNet::getVariables (void) const
|
|||||||
|
|
||||||
|
|
||||||
BayesNet*
|
BayesNet*
|
||||||
BayesNet::pruneNetwork (BayesNode* queryNode) const
|
BayesNet::getMinimalRequesiteNetwork (Vid vid) const
|
||||||
{
|
{
|
||||||
NodeSet queryNodes;
|
return getMinimalRequesiteNetwork (VidSet() = {vid});
|
||||||
queryNodes.push_back (queryNode);
|
|
||||||
return pruneNetwork (queryNodes);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNet*
|
BayesNet*
|
||||||
BayesNet::pruneNetwork (const NodeSet& interestedVars) const
|
BayesNet::getMinimalRequesiteNetwork (const VidSet& queryVids) const
|
||||||
{
|
{
|
||||||
/*
|
BnNodeSet queryVars;
|
||||||
cout << "interested vars: " ;
|
for (unsigned i = 0; i < queryVids.size(); i++) {
|
||||||
for (unsigned i = 0; i < interestedVars.size(); i++) {
|
assert (getBayesNode (queryVids[i]));
|
||||||
cout << interestedVars[i]->getLabel() << " " ;
|
queryVars.push_back (getBayesNode (queryVids[i]));
|
||||||
}
|
}
|
||||||
cout << endl;
|
// cout << "query vars: " ;
|
||||||
*/
|
// for (unsigned i = 0; i < queryVars.size(); i++) {
|
||||||
|
// cout << queryVars[i]->getLabel() << " " ;
|
||||||
|
// }
|
||||||
|
// cout << endl;
|
||||||
|
|
||||||
vector<StateInfo*> states (nodes_.size(), 0);
|
vector<StateInfo*> states (nodes_.size(), 0);
|
||||||
|
|
||||||
Scheduling scheduling;
|
Scheduling scheduling;
|
||||||
for (NodeSet::const_iterator it = interestedVars.begin();
|
for (BnNodeSet::const_iterator it = queryVars.begin();
|
||||||
it != interestedVars.end(); it++) {
|
it != queryVars.end(); it++) {
|
||||||
scheduling.push (ScheduleInfo (*it, false, true));
|
scheduling.push (ScheduleInfo (*it, false, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -378,18 +382,18 @@ BayesNet::constructGraph (BayesNet* bn,
|
|||||||
states[i]->markedOnTop;
|
states[i]->markedOnTop;
|
||||||
}
|
}
|
||||||
if (isRequired) {
|
if (isRequired) {
|
||||||
NodeSet parents;
|
BnNodeSet parents;
|
||||||
if (states[i]->markedOnTop) {
|
if (states[i]->markedOnTop) {
|
||||||
const NodeSet& ps = nodes_[i]->getParents();
|
const BnNodeSet& ps = nodes_[i]->getParents();
|
||||||
for (unsigned j = 0; j < ps.size(); j++) {
|
for (unsigned j = 0; j < ps.size(); j++) {
|
||||||
BayesNode* parent = bn->getNode (ps[j]->getVarId());
|
BayesNode* parent = bn->getBayesNode (ps[j]->getVarId());
|
||||||
if (!parent) {
|
if (!parent) {
|
||||||
parent = bn->addNode (ps[j]->getVarId());
|
parent = bn->addNode (ps[j]->getVarId());
|
||||||
}
|
}
|
||||||
parents.push_back (parent);
|
parents.push_back (parent);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
BayesNode* node = bn->getNode (nodes_[i]->getVarId());
|
BayesNode* node = bn->getBayesNode (nodes_[i]->getVarId());
|
||||||
if (node) {
|
if (node) {
|
||||||
node->setData (nodes_[i]->getDomainSize(),
|
node->setData (nodes_[i]->getDomainSize(),
|
||||||
nodes_[i]->getEvidence(), parents,
|
nodes_[i]->getEvidence(), parents,
|
||||||
@ -411,65 +415,6 @@ BayesNet::constructGraph (BayesNet* bn,
|
|||||||
bn->setIndexes();
|
bn->setIndexes();
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
void
|
|
||||||
BayesNet::constructGraph (BayesNet* bn,
|
|
||||||
const vector<StateInfo*>& states) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
if (states[i]) {
|
|
||||||
if (nodes_[i]->hasEvidence() && states[i]->visited) {
|
|
||||||
NodeSet parents;
|
|
||||||
if (states[i]->markedOnTop) {
|
|
||||||
const NodeSet& ps = nodes_[i]->getParents();
|
|
||||||
for (unsigned j = 0; j < ps.size(); j++) {
|
|
||||||
BayesNode* parent = bn->getNode (ps[j]->getVarId());
|
|
||||||
if (parent == 0) {
|
|
||||||
parent = bn->addNode (ps[j]->getVarId());
|
|
||||||
}
|
|
||||||
parents.push_back (parent);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNode* n = bn->getNode (nodes_[i]->getVarId());
|
|
||||||
if (n) {
|
|
||||||
n->setData (nodes_[i]->getDomainSize(),
|
|
||||||
nodes_[i]->getEvidence(), parents,
|
|
||||||
nodes_[i]->getDistribution());
|
|
||||||
} else {
|
|
||||||
bn->addNode (nodes_[i]->getVarId(),
|
|
||||||
nodes_[i]->getDomainSize(),
|
|
||||||
nodes_[i]->getEvidence(), parents,
|
|
||||||
nodes_[i]->getDistribution());
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if (states[i]->markedOnTop) {
|
|
||||||
NodeSet parents;
|
|
||||||
const NodeSet& ps = nodes_[i]->getParents();
|
|
||||||
for (unsigned j = 0; j < ps.size(); j++) {
|
|
||||||
BayesNode* parent = bn->getNode (ps[j]->getVarId());
|
|
||||||
if (parent == 0) {
|
|
||||||
parent = bn->addNode (ps[j]->getVarId());
|
|
||||||
}
|
|
||||||
parents.push_back (parent);
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNode* n = bn->getNode (nodes_[i]->getVarId());
|
|
||||||
if (n) {
|
|
||||||
n->setData (nodes_[i]->getDomainSize(),
|
|
||||||
nodes_[i]->getEvidence(), parents,
|
|
||||||
nodes_[i]->getDistribution());
|
|
||||||
} else {
|
|
||||||
bn->addNode (nodes_[i]->getVarId(),
|
|
||||||
nodes_[i]->getDomainSize(),
|
|
||||||
nodes_[i]->getEvidence(), parents,
|
|
||||||
nodes_[i]->getDistribution());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}*/
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
@ -480,70 +425,6 @@ BayesNet::isSingleConnected (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<DomainConf>
|
|
||||||
BayesNet::getDomainConfigurationsOf (const NodeSet& nodes)
|
|
||||||
{
|
|
||||||
int nConfs = 1;
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
nConfs *= nodes[i]->getDomainSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<DomainConf> confs (nConfs);
|
|
||||||
for (int i = 0; i < nConfs; i++) {
|
|
||||||
confs[i].resize (nodes.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
int nReps = 1;
|
|
||||||
for (int i = nodes.size() - 1; i >= 0; i--) {
|
|
||||||
int index = 0;
|
|
||||||
while (index < nConfs) {
|
|
||||||
for (int j = 0; j < nodes[i]->getDomainSize(); j++) {
|
|
||||||
for (int r = 0; r < nReps; r++) {
|
|
||||||
confs[index][i] = j;
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nReps *= nodes[i]->getDomainSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
return confs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<string>
|
|
||||||
BayesNet::getInstantiations (const NodeSet& parents_)
|
|
||||||
{
|
|
||||||
int nParents = parents_.size();
|
|
||||||
int rowSize = 1;
|
|
||||||
for (unsigned i = 0; i < parents_.size(); i++) {
|
|
||||||
rowSize *= parents_[i]->getDomainSize();
|
|
||||||
}
|
|
||||||
int nReps = 1;
|
|
||||||
vector<string> headers (rowSize);
|
|
||||||
for (int i = nParents - 1; i >= 0; i--) {
|
|
||||||
Domain domain = parents_[i]->getDomain();
|
|
||||||
int index = 0;
|
|
||||||
while (index < rowSize) {
|
|
||||||
for (int j = 0; j < parents_[i]->getDomainSize(); j++) {
|
|
||||||
for (int r = 0; r < nReps; r++) {
|
|
||||||
if (headers[index] != "") {
|
|
||||||
headers[index] = domain[j] + "," + headers[index];
|
|
||||||
} else {
|
|
||||||
headers[index] = domain[j];
|
|
||||||
}
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nReps *= parents_[i]->getDomainSize();
|
|
||||||
}
|
|
||||||
return headers;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::setIndexes (void)
|
BayesNet::setIndexes (void)
|
||||||
{
|
{
|
||||||
@ -565,7 +446,7 @@ BayesNet::freeDistributions (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::printNetwork (void) const
|
BayesNet::printGraphicalModel (void) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
cout << *nodes_[i];
|
cout << *nodes_[i];
|
||||||
@ -575,32 +456,11 @@ BayesNet::printNetwork (void) const
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::printNetworkToFile (const char* fileName) const
|
BayesNet::exportToDotFormat (const char* fileName,
|
||||||
|
bool showNeighborless,
|
||||||
|
CVidSet& highlightVids) const
|
||||||
{
|
{
|
||||||
string s = "../../" ;
|
ofstream out (fileName);
|
||||||
s += fileName;
|
|
||||||
ofstream out (s.c_str());
|
|
||||||
if (!out.is_open()) {
|
|
||||||
cerr << "error: cannot open file to write at " ;
|
|
||||||
cerr << "BayesNet::printToFile()" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
out << *nodes_[i];
|
|
||||||
}
|
|
||||||
out.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNet::exportToDotFile (const char* fileName,
|
|
||||||
bool showNeighborless,
|
|
||||||
const NodeSet& highlightNodes) const
|
|
||||||
{
|
|
||||||
string s = "../../" ;
|
|
||||||
s+= fileName;
|
|
||||||
ofstream out (s.c_str());
|
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file to write at " ;
|
||||||
cerr << "BayesNet::exportToDotFile()" << endl;
|
cerr << "BayesNet::exportToDotFile()" << endl;
|
||||||
@ -608,13 +468,6 @@ BayesNet::exportToDotFile (const char* fileName,
|
|||||||
}
|
}
|
||||||
|
|
||||||
out << "digraph \"" << fileName << "\" {" << endl;
|
out << "digraph \"" << fileName << "\" {" << endl;
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
const NodeSet& childs = nodes_[i]->getChilds();
|
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
|
||||||
out << '"' << nodes_[i]->getLabel() << '"' << " -> " ;
|
|
||||||
out << '"' << childs[j]->getLabel() << '"' << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
if (showNeighborless || nodes_[i]->hasNeighbors()) {
|
if (showNeighborless || nodes_[i]->hasNeighbors()) {
|
||||||
@ -627,9 +480,24 @@ BayesNet::exportToDotFile (const char* fileName,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < highlightNodes.size(); i++) {
|
for (unsigned i = 0; i < highlightVids.size(); i++) {
|
||||||
out << '"' << highlightNodes[i]->getLabel() << '"' ;
|
BayesNode* node = getBayesNode (highlightVids[i]);
|
||||||
out << " [shape=box]" << endl;
|
if (node) {
|
||||||
|
out << '"' << node->getLabel() << '"' ;
|
||||||
|
// out << " [shape=polygon, sides=6]" << endl;
|
||||||
|
out << " [shape=box3d]" << endl;
|
||||||
|
} else {
|
||||||
|
cout << "error: invalid variable id: " << highlightVids[i] << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
|
const BnNodeSet& childs = nodes_[i]->getChilds();
|
||||||
|
for (unsigned j = 0; j < childs.size(); j++) {
|
||||||
|
out << '"' << nodes_[i]->getLabel() << '"' << " -> " ;
|
||||||
|
out << '"' << childs[j]->getLabel() << '"' << endl;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out << "}" << endl;
|
out << "}" << endl;
|
||||||
@ -639,11 +507,9 @@ BayesNet::exportToDotFile (const char* fileName,
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::exportToBifFile (const char* fileName) const
|
BayesNet::exportToBifFormat (const char* fileName) const
|
||||||
{
|
{
|
||||||
string s = "../../" ;
|
ofstream out (fileName);
|
||||||
s += fileName;
|
|
||||||
ofstream out (s.c_str());
|
|
||||||
if(!out.is_open()) {
|
if(!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file to write at " ;
|
||||||
cerr << "BayesNet::exportToBifFile()" << endl;
|
cerr << "BayesNet::exportToBifFile()" << endl;
|
||||||
@ -666,7 +532,7 @@ BayesNet::exportToBifFile (const char* fileName) const
|
|||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
out << "<DEFINITION>" << endl;
|
out << "<DEFINITION>" << endl;
|
||||||
out << "\t<FOR>" << nodes_[i]->getLabel() << "</FOR>" << endl;
|
out << "\t<FOR>" << nodes_[i]->getLabel() << "</FOR>" << endl;
|
||||||
const NodeSet& parents = nodes_[i]->getParents();
|
const BnNodeSet& parents = nodes_[i]->getParents();
|
||||||
for (unsigned j = 0; j < parents.size(); j++) {
|
for (unsigned j = 0; j < parents.size(); j++) {
|
||||||
out << "\t<GIVEN>" << parents[j]->getLabel();
|
out << "\t<GIVEN>" << parents[j]->getLabel();
|
||||||
out << "</GIVEN>" << endl;
|
out << "</GIVEN>" << endl;
|
||||||
@ -682,7 +548,7 @@ BayesNet::exportToBifFile (const char* fileName) const
|
|||||||
}
|
}
|
||||||
out << "</NETWORK>" << endl;
|
out << "</NETWORK>" << endl;
|
||||||
out << "</BIF>" << endl << endl;
|
out << "</BIF>" << endl << endl;
|
||||||
out.close();
|
out.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -731,8 +597,8 @@ vector<int>
|
|||||||
BayesNet::getAdjacentNodes (int v) const
|
BayesNet::getAdjacentNodes (int v) const
|
||||||
{
|
{
|
||||||
vector<int> adjacencies;
|
vector<int> adjacencies;
|
||||||
const NodeSet& parents = nodes_[v]->getParents();
|
const BnNodeSet& parents = nodes_[v]->getParents();
|
||||||
const NodeSet& childs = nodes_[v]->getChilds();
|
const BnNodeSet& childs = nodes_[v]->getChilds();
|
||||||
for (unsigned i = 0; i < parents.size(); i++) {
|
for (unsigned i = 0; i < parents.size(); i++) {
|
||||||
adjacencies.push_back (parents[i]->getIndex());
|
adjacencies.push_back (parents[i]->getIndex());
|
||||||
}
|
}
|
||||||
@ -745,8 +611,8 @@ BayesNet::getAdjacentNodes (int v) const
|
|||||||
|
|
||||||
|
|
||||||
ParamSet
|
ParamSet
|
||||||
BayesNet::reorderParameters (const ParamSet& params,
|
BayesNet::reorderParameters (CParamSet params,
|
||||||
int domainSize) const
|
unsigned domainSize) const
|
||||||
{
|
{
|
||||||
// the interchange format for bayesian networks keeps the probabilities
|
// the interchange format for bayesian networks keeps the probabilities
|
||||||
// in the following order:
|
// in the following order:
|
||||||
@ -773,15 +639,15 @@ BayesNet::reorderParameters (const ParamSet& params,
|
|||||||
|
|
||||||
|
|
||||||
ParamSet
|
ParamSet
|
||||||
BayesNet::revertParameterReorder (const ParamSet& params,
|
BayesNet::revertParameterReorder (CParamSet params,
|
||||||
int domainSize) const
|
unsigned domainSize) const
|
||||||
{
|
{
|
||||||
unsigned count = 0;
|
unsigned count = 0;
|
||||||
unsigned rowSize = params.size() / domainSize;
|
unsigned rowSize = params.size() / domainSize;
|
||||||
ParamSet reordered;
|
ParamSet reordered;
|
||||||
while (reordered.size() < params.size()) {
|
while (reordered.size() < params.size()) {
|
||||||
unsigned idx = count;
|
unsigned idx = count;
|
||||||
for (int i = 0; i < domainSize; i++) {
|
for (unsigned i = 0; i < domainSize; i++) {
|
||||||
reordered.push_back (params[idx]);
|
reordered.push_back (params[idx]);
|
||||||
idx += rowSize;
|
idx += rowSize;
|
||||||
}
|
}
|
||||||
|
@ -4,8 +4,6 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <string>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
#include "GraphicalModel.h"
|
||||||
@ -46,42 +44,42 @@ struct StateInfo
|
|||||||
|
|
||||||
typedef vector<Distribution*> DistSet;
|
typedef vector<Distribution*> DistSet;
|
||||||
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
|
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
|
||||||
typedef unordered_map<unsigned, unsigned> Histogram;
|
typedef map<unsigned, unsigned> Histogram;
|
||||||
typedef unordered_map<unsigned, double> Times;
|
typedef map<unsigned, double> Times;
|
||||||
|
|
||||||
|
|
||||||
class BayesNet : public GraphicalModel
|
class BayesNet : public GraphicalModel
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
BayesNet (void);
|
BayesNet (void) {};
|
||||||
BayesNet (const char*);
|
BayesNet (const char*);
|
||||||
~BayesNet (void);
|
~BayesNet (void);
|
||||||
|
|
||||||
BayesNode* addNode (unsigned);
|
BayesNode* addNode (unsigned);
|
||||||
BayesNode* addNode (unsigned, unsigned, int, NodeSet&, Distribution*);
|
BayesNode* addNode (unsigned, unsigned, int, BnNodeSet&,
|
||||||
BayesNode* addNode (string, Domain, NodeSet&, ParamSet&);
|
Distribution*);
|
||||||
BayesNode* getNode (unsigned) const;
|
BayesNode* addNode (string, Domain, BnNodeSet&, ParamSet&);
|
||||||
BayesNode* getNode (string) const;
|
BayesNode* getBayesNode (Vid) const;
|
||||||
|
BayesNode* getBayesNode (string) const;
|
||||||
|
Variable* getVariable (Vid) const;
|
||||||
void addDistribution (Distribution*);
|
void addDistribution (Distribution*);
|
||||||
Distribution* getDistribution (unsigned) const;
|
Distribution* getDistribution (unsigned) const;
|
||||||
const NodeSet& getNodes (void) const;
|
const BnNodeSet& getBayesNodes (void) const;
|
||||||
int getNumberOfNodes (void) const;
|
unsigned getNumberOfNodes (void) const;
|
||||||
NodeSet getRootNodes (void) const;
|
BnNodeSet getRootNodes (void) const;
|
||||||
NodeSet getLeafNodes (void) const;
|
BnNodeSet getLeafNodes (void) const;
|
||||||
VarSet getVariables (void) const;
|
VarSet getVariables (void) const;
|
||||||
BayesNet* pruneNetwork (BayesNode*) const;
|
BayesNet* getMinimalRequesiteNetwork (Vid) const;
|
||||||
BayesNet* pruneNetwork (const NodeSet& queryNodes) const;
|
BayesNet* getMinimalRequesiteNetwork (const VidSet&) const;
|
||||||
void constructGraph (BayesNet*, const vector<StateInfo*>&) const;
|
void constructGraph (BayesNet*,
|
||||||
|
const vector<StateInfo*>&) const;
|
||||||
bool isSingleConnected (void) const;
|
bool isSingleConnected (void) const;
|
||||||
static vector<DomainConf> getDomainConfigurationsOf (const NodeSet&);
|
|
||||||
static vector<string> getInstantiations (const NodeSet& nodes);
|
|
||||||
void setIndexes (void);
|
void setIndexes (void);
|
||||||
void freeDistributions (void);
|
void freeDistributions (void);
|
||||||
void printNetwork (void) const;
|
void printGraphicalModel (void) const;
|
||||||
void printNetworkToFile (const char*) const;
|
void exportToDotFormat (const char*, bool = true,
|
||||||
void exportToDotFile (const char*, bool = true,
|
CVidSet = VidSet()) const;
|
||||||
const NodeSet& = NodeSet()) const;
|
void exportToBifFormat (const char*) const;
|
||||||
void exportToBifFile (const char*) const;
|
|
||||||
|
|
||||||
static Histogram histogram_;
|
static Histogram histogram_;
|
||||||
static Times times_;
|
static Times times_;
|
||||||
@ -93,12 +91,12 @@ class BayesNet : public GraphicalModel
|
|||||||
bool containsUndirectedCycle (int, int,
|
bool containsUndirectedCycle (int, int,
|
||||||
vector<bool>&)const;
|
vector<bool>&)const;
|
||||||
vector<int> getAdjacentNodes (int) const ;
|
vector<int> getAdjacentNodes (int) const ;
|
||||||
ParamSet reorderParameters (const ParamSet&, int) const;
|
ParamSet reorderParameters (CParamSet, unsigned) const;
|
||||||
ParamSet revertParameterReorder (const ParamSet&, int) const;
|
ParamSet revertParameterReorder (CParamSet, unsigned) const;
|
||||||
void scheduleParents (const BayesNode*, Scheduling&) const;
|
void scheduleParents (const BayesNode*, Scheduling&) const;
|
||||||
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
||||||
|
|
||||||
NodeSet nodes_;
|
BnNodeSet nodes_;
|
||||||
DistSet dists_;
|
DistSet dists_;
|
||||||
IndexMap indexMap_;
|
IndexMap indexMap_;
|
||||||
};
|
};
|
||||||
@ -108,8 +106,8 @@ class BayesNet : public GraphicalModel
|
|||||||
inline void
|
inline void
|
||||||
BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
|
BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
|
||||||
{
|
{
|
||||||
const NodeSet& ps = n->getParents();
|
const BnNodeSet& ps = n->getParents();
|
||||||
for (NodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) {
|
for (BnNodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) {
|
||||||
sch.push (ScheduleInfo (*it, false, true));
|
sch.push (ScheduleInfo (*it, false, true));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -119,11 +117,11 @@ BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
|
|||||||
inline void
|
inline void
|
||||||
BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
|
BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
|
||||||
{
|
{
|
||||||
const NodeSet& cs = n->getChilds();
|
const BnNodeSet& cs = n->getChilds();
|
||||||
for (NodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) {
|
for (BnNodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) {
|
||||||
sch.push (ScheduleInfo (*it, true, false));
|
sch.push (ScheduleInfo (*it, true, false));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif //BP_BAYES_NET_H
|
||||||
|
|
||||||
|
@ -1,26 +1,21 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <cassert>
|
|
||||||
#include <cstdlib>
|
|
||||||
|
|
||||||
#include "BayesNode.h"
|
#include "BayesNode.h"
|
||||||
|
|
||||||
|
|
||||||
BayesNode::BayesNode (unsigned varId) : Variable (varId)
|
BayesNode::BayesNode (Vid vid,
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode::BayesNode (unsigned varId,
|
|
||||||
unsigned dsize,
|
unsigned dsize,
|
||||||
int evidence,
|
int evidence,
|
||||||
const NodeSet& parents,
|
const BnNodeSet& parents,
|
||||||
Distribution* dist) : Variable(varId, dsize, evidence)
|
Distribution* dist) : Variable (vid, dsize, evidence)
|
||||||
{
|
{
|
||||||
parents_ = parents;
|
parents_ = parents;
|
||||||
dist_ = dist;
|
dist_ = dist;
|
||||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||||
parents[i]->addChild (this);
|
parents[i]->addChild (this);
|
||||||
}
|
}
|
||||||
@ -28,15 +23,15 @@ BayesNode::BayesNode (unsigned varId,
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode::BayesNode (unsigned varId,
|
BayesNode::BayesNode (Vid vid,
|
||||||
string label,
|
string label,
|
||||||
const Domain& domain,
|
const Domain& domain,
|
||||||
const NodeSet& parents,
|
const BnNodeSet& parents,
|
||||||
Distribution* dist) : Variable(varId, domain)
|
Distribution* dist) : Variable (vid, domain,
|
||||||
|
NO_EVIDENCE, label)
|
||||||
{
|
{
|
||||||
label_ = new string (label);
|
parents_ = parents;
|
||||||
parents_ = parents;
|
dist_ = dist;
|
||||||
dist_ = dist;
|
|
||||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||||
parents[i]->addChild (this);
|
parents[i]->addChild (this);
|
||||||
}
|
}
|
||||||
@ -47,11 +42,11 @@ BayesNode::BayesNode (unsigned varId,
|
|||||||
void
|
void
|
||||||
BayesNode::setData (unsigned dsize,
|
BayesNode::setData (unsigned dsize,
|
||||||
int evidence,
|
int evidence,
|
||||||
const NodeSet& parents,
|
const BnNodeSet& parents,
|
||||||
Distribution* dist)
|
Distribution* dist)
|
||||||
{
|
{
|
||||||
setDomainSize (dsize);
|
setDomainSize (dsize);
|
||||||
evidence_ = evidence;
|
setEvidence (evidence);
|
||||||
parents_ = parents;
|
parents_ = parents;
|
||||||
dist_ = dist;
|
dist_ = dist;
|
||||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||||
@ -135,19 +130,18 @@ BayesNode::getCptEntries (void)
|
|||||||
{
|
{
|
||||||
if (dist_->entries.size() == 0) {
|
if (dist_->entries.size() == 0) {
|
||||||
unsigned rowSize = getRowSize();
|
unsigned rowSize = getRowSize();
|
||||||
unsigned nParents = parents_.size();
|
vector<DConf> confs (rowSize);
|
||||||
vector<DomainConf> confs (rowSize);
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < rowSize; i++) {
|
for (unsigned i = 0; i < rowSize; i++) {
|
||||||
confs[i].resize (nParents);
|
confs[i].resize (parents_.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
int nReps = 1;
|
unsigned nReps = 1;
|
||||||
for (int i = nParents - 1; i >= 0; i--) {
|
for (int i = parents_.size() - 1; i >= 0; i--) {
|
||||||
unsigned index = 0;
|
unsigned index = 0;
|
||||||
while (index < rowSize) {
|
while (index < rowSize) {
|
||||||
for (int j = 0; j < parents_[i]->getDomainSize(); j++) {
|
for (unsigned j = 0; j < parents_[i]->getDomainSize(); j++) {
|
||||||
for (int r = 0; r < nReps; r++) {
|
for (unsigned r = 0; r < nReps; r++) {
|
||||||
confs[index][i] = j;
|
confs[index][i] = j;
|
||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
@ -184,7 +178,7 @@ BayesNode::cptEntryToString (const CptEntry& entry) const
|
|||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "p(" ;
|
ss << "p(" ;
|
||||||
const DomainConf& conf = entry.getParentConfigurations();
|
const DConf& conf = entry.getDomainConfiguration();
|
||||||
int row = entry.getParameterIndex() / getRowSize();
|
int row = entry.getParameterIndex() / getRowSize();
|
||||||
ss << getDomain()[row];
|
ss << getDomain()[row];
|
||||||
if (parents_.size() > 0) {
|
if (parents_.size() > 0) {
|
||||||
@ -207,7 +201,7 @@ BayesNode::cptEntryToString (int row, const CptEntry& entry) const
|
|||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "p(" ;
|
ss << "p(" ;
|
||||||
const DomainConf& conf = entry.getParentConfigurations();
|
const DConf& conf = entry.getDomainConfiguration();
|
||||||
ss << getDomain()[row];
|
ss << getDomain()[row];
|
||||||
if (parents_.size() > 0) {
|
if (parents_.size() > 0) {
|
||||||
ss << "|" ;
|
ss << "|" ;
|
||||||
@ -227,16 +221,16 @@ BayesNode::cptEntryToString (int row, const CptEntry& entry) const
|
|||||||
vector<string>
|
vector<string>
|
||||||
BayesNode::getDomainHeaders (void) const
|
BayesNode::getDomainHeaders (void) const
|
||||||
{
|
{
|
||||||
int nParents = parents_.size();
|
unsigned nParents = parents_.size();
|
||||||
int rowSize = getRowSize();
|
unsigned rowSize = getRowSize();
|
||||||
int nReps = 1;
|
unsigned nReps = 1;
|
||||||
vector<string> headers (rowSize);
|
vector<string> headers (rowSize);
|
||||||
for (int i = nParents - 1; i >= 0; i--) {
|
for (int i = nParents - 1; i >= 0; i--) {
|
||||||
Domain domain = parents_[i]->getDomain();
|
Domain domain = parents_[i]->getDomain();
|
||||||
int index = 0;
|
unsigned index = 0;
|
||||||
while (index < rowSize) {
|
while (index < rowSize) {
|
||||||
for (int j = 0; j < parents_[i]->getDomainSize(); j++) {
|
for (unsigned j = 0; j < parents_[i]->getDomainSize(); j++) {
|
||||||
for (int r = 0; r < nReps; r++) {
|
for (unsigned r = 0; r < nReps; r++) {
|
||||||
if (headers[index] != "") {
|
if (headers[index] != "") {
|
||||||
headers[index] = domain[j] + "," + headers[index];
|
headers[index] = domain[j] + "," + headers[index];
|
||||||
} else {
|
} else {
|
||||||
@ -270,7 +264,7 @@ operator << (ostream& o, const BayesNode& node)
|
|||||||
o << endl;
|
o << endl;
|
||||||
|
|
||||||
o << "Parents: " ;
|
o << "Parents: " ;
|
||||||
const NodeSet& parents = node.getParents();
|
const BnNodeSet& parents = node.getParents();
|
||||||
if (parents.size() != 0) {
|
if (parents.size() != 0) {
|
||||||
for (unsigned int i = 0; i < parents.size() - 1; i++) {
|
for (unsigned int i = 0; i < parents.size() - 1; i++) {
|
||||||
o << parents[i]->getLabel() << ", " ;
|
o << parents[i]->getLabel() << ", " ;
|
||||||
@ -280,7 +274,7 @@ operator << (ostream& o, const BayesNode& node)
|
|||||||
o << endl;
|
o << endl;
|
||||||
|
|
||||||
o << "Childs: " ;
|
o << "Childs: " ;
|
||||||
const NodeSet& childs = node.getChilds();
|
const BnNodeSet& childs = node.getChilds();
|
||||||
if (childs.size() != 0) {
|
if (childs.size() != 0) {
|
||||||
for (unsigned int i = 0; i < childs.size() - 1; i++) {
|
for (unsigned int i = 0; i < childs.size() - 1; i++) {
|
||||||
o << childs[i]->getLabel() << ", " ;
|
o << childs[i]->getLabel() << ", " ;
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
#ifndef BP_BAYESNODE_H
|
#ifndef BP_BAYES_NODE_H
|
||||||
#define BP_BAYESNODE_H
|
#define BP_BAYES_NODE_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "Variable.h"
|
#include "Variable.h"
|
||||||
#include "CptEntry.h"
|
#include "CptEntry.h"
|
||||||
@ -16,11 +14,12 @@ using namespace std;
|
|||||||
class BayesNode : public Variable
|
class BayesNode : public Variable
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
BayesNode (unsigned);
|
BayesNode (Vid vid) : Variable (vid) {}
|
||||||
BayesNode (unsigned, unsigned, int, const NodeSet&, Distribution*);
|
BayesNode (Vid, unsigned, int, const BnNodeSet&, Distribution*);
|
||||||
BayesNode (unsigned, string, const Domain&, const NodeSet&, Distribution*);
|
BayesNode (Vid, string, const Domain&, const BnNodeSet&, Distribution*);
|
||||||
|
|
||||||
void setData (unsigned, int, const NodeSet&, Distribution*);
|
void setData (unsigned, int, const BnNodeSet&,
|
||||||
|
Distribution*);
|
||||||
void addChild (BayesNode*);
|
void addChild (BayesNode*);
|
||||||
Distribution* getDistribution (void);
|
Distribution* getDistribution (void);
|
||||||
const ParamSet& getParameters (void);
|
const ParamSet& getParameters (void);
|
||||||
@ -34,11 +33,21 @@ class BayesNode : public Variable
|
|||||||
int getIndexOfParent (const BayesNode*) const;
|
int getIndexOfParent (const BayesNode*) const;
|
||||||
string cptEntryToString (const CptEntry&) const;
|
string cptEntryToString (const CptEntry&) const;
|
||||||
string cptEntryToString (int, const CptEntry&) const;
|
string cptEntryToString (int, const CptEntry&) const;
|
||||||
// inlines
|
|
||||||
const NodeSet& getParents (void) const;
|
const BnNodeSet& getParents (void) const { return parents_; }
|
||||||
const NodeSet& getChilds (void) const;
|
const BnNodeSet& getChilds (void) const { return childs_; }
|
||||||
double getProbability (int, const CptEntry& entry);
|
|
||||||
unsigned getRowSize (void) const;
|
unsigned getRowSize (void) const
|
||||||
|
{
|
||||||
|
return dist_->params.size() / getDomainSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
double getProbability (int row, const CptEntry& entry)
|
||||||
|
{
|
||||||
|
int col = entry.getParameterIndex();
|
||||||
|
int idx = (row * getRowSize()) + col;
|
||||||
|
return dist_->params[idx];
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (BayesNode);
|
DISALLOW_COPY_AND_ASSIGN (BayesNode);
|
||||||
@ -46,46 +55,12 @@ class BayesNode : public Variable
|
|||||||
Domain getDomainHeaders (void) const;
|
Domain getDomainHeaders (void) const;
|
||||||
friend ostream& operator << (ostream&, const BayesNode&);
|
friend ostream& operator << (ostream&, const BayesNode&);
|
||||||
|
|
||||||
NodeSet parents_;
|
BnNodeSet parents_;
|
||||||
NodeSet childs_;
|
BnNodeSet childs_;
|
||||||
Distribution* dist_;
|
Distribution* dist_;
|
||||||
};
|
};
|
||||||
|
|
||||||
ostream& operator << (ostream&, const BayesNode&);
|
ostream& operator << (ostream&, const BayesNode&);
|
||||||
|
|
||||||
|
#endif //BP_BAYES_NODE_H
|
||||||
|
|
||||||
inline const NodeSet&
|
|
||||||
BayesNode::getParents (void) const
|
|
||||||
{
|
|
||||||
return parents_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline const NodeSet&
|
|
||||||
BayesNode::getChilds (void) const
|
|
||||||
{
|
|
||||||
return childs_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
|
||||||
BayesNode::getProbability (int row, const CptEntry& entry)
|
|
||||||
{
|
|
||||||
int col = entry.getParameterIndex();
|
|
||||||
int idx = (row * getRowSize()) + col;
|
|
||||||
return dist_->params[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline unsigned
|
|
||||||
BayesNode::getRowSize (void) const
|
|
||||||
{
|
|
||||||
return dist_->params.size() / getDomainSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
198
packages/CLPBN/clpbn/bp/CountingBP.cpp
Normal file
198
packages/CLPBN/clpbn/bp/CountingBP.cpp
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
#include "CountingBP.h"
|
||||||
|
|
||||||
|
|
||||||
|
CountingBP::~CountingBP (void)
|
||||||
|
{
|
||||||
|
delete lfg_;
|
||||||
|
delete fg_;
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
delete links_[i];
|
||||||
|
}
|
||||||
|
links_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParamSet
|
||||||
|
CountingBP::getPosterioriOf (Vid vid) const
|
||||||
|
{
|
||||||
|
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
||||||
|
ParamSet probs;
|
||||||
|
|
||||||
|
if (var->hasEvidence()) {
|
||||||
|
probs.resize (var->getDomainSize(), 0.0);
|
||||||
|
probs[var->getEvidence()] = 1.0;
|
||||||
|
} else {
|
||||||
|
probs.resize (var->getDomainSize(), 1.0);
|
||||||
|
CLinkSet links = varsI_[var->getIndex()]->getLinks();
|
||||||
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
|
ParamSet msg = links[i]->getMessage();
|
||||||
|
CountingBPLink* l = static_cast<CountingBPLink*> (links[i]);
|
||||||
|
Util::pow (msg, l->getNumberOfEdges());
|
||||||
|
for (unsigned j = 0; j < msg.size(); j++) {
|
||||||
|
probs[j] *= msg[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Util::normalize (probs);
|
||||||
|
}
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBP::initializeSolver (void)
|
||||||
|
{
|
||||||
|
lfg_ = new LiftedFG (*fg_);
|
||||||
|
unsigned nUncVars = fg_->getFgVarNodes().size();
|
||||||
|
unsigned nUncFactors = fg_->getFactors().size();
|
||||||
|
CFgVarSet vars = fg_->getFgVarNodes();
|
||||||
|
unsigned nNeighborLessVars = 0;
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
CFactorSet factors = vars[i]->getFactors();
|
||||||
|
if (factors.size() == 1 && factors[0]->getFgVarNodes().size() == 1) {
|
||||||
|
nNeighborLessVars ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// cout << "UNCOMPRESSED FACTOR GRAPH" << endl;
|
||||||
|
// fg_->printGraphicalModel();
|
||||||
|
fg_->exportToDotFormat ("uncompress.dot");
|
||||||
|
|
||||||
|
FactorGraph *temp;
|
||||||
|
temp = fg_;
|
||||||
|
fg_ = lfg_->getCompressedFactorGraph();
|
||||||
|
unsigned nCompVars = fg_->getFgVarNodes().size();
|
||||||
|
unsigned nCompFactors = fg_->getFactors().size();
|
||||||
|
|
||||||
|
Statistics::updateCompressingStats (nUncVars,
|
||||||
|
nUncFactors,
|
||||||
|
nCompVars,
|
||||||
|
nCompFactors,
|
||||||
|
nNeighborLessVars);
|
||||||
|
|
||||||
|
cout << "COMPRESSED FACTOR GRAPH" << endl;
|
||||||
|
fg_->printGraphicalModel();
|
||||||
|
//fg_->exportToDotFormat ("compress.dot");
|
||||||
|
|
||||||
|
SPSolver::initializeSolver();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBP::createLinks (void)
|
||||||
|
{
|
||||||
|
const FactorClusterSet fcs = lfg_->getFactorClusters();
|
||||||
|
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||||
|
const VarClusterSet vcs = fcs[i]->getVarClusters();
|
||||||
|
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||||
|
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
|
||||||
|
links_.push_back (
|
||||||
|
new CountingBPLink (fcs[i]->getRepresentativeFactor(),
|
||||||
|
vcs[j]->getRepresentativeVariable(), c));
|
||||||
|
//cout << (links_.back())->toString() << " edge count =" << c << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBP::deleteJunction (Factor* f, FgVarNode*)
|
||||||
|
{
|
||||||
|
f->freeDistribution();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBP::maxResidualSchedule (void)
|
||||||
|
{
|
||||||
|
if (nIter_ == 1) {
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||||
|
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||||
|
linkMap_.insert (make_pair (links_[i], it));
|
||||||
|
if (DL >= 2 && DL < 5) {
|
||||||
|
cout << "calculating " << links_[i]->toString() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned c = 0; c < links_.size(); c++) {
|
||||||
|
if (DL >= 2) {
|
||||||
|
cout << endl << "current residuals:" << endl;
|
||||||
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
it != sortedOrder_.end(); it ++) {
|
||||||
|
cout << " " << setw (30) << left << (*it)->toString();
|
||||||
|
cout << "residual = " << (*it)->getResidual() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
Link* link = *it;
|
||||||
|
if (DL >= 2) {
|
||||||
|
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||||
|
}
|
||||||
|
if (link->getResidual() < SolverOptions::accuracy) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
link->updateMessage();
|
||||||
|
link->clearResidual();
|
||||||
|
sortedOrder_.erase (it);
|
||||||
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
|
// update the messages that depend on message source --> destin
|
||||||
|
CFactorSet factorNeighbors = link->getVariable()->getFactors();
|
||||||
|
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||||
|
CLinkSet links = factorsI_[factorNeighbors[i]->getIndex()]->getLinks();
|
||||||
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
|
if (links[j]->getVariable() != link->getVariable()) { //FIXMEFIXME
|
||||||
|
if (DL >= 2 && DL < 5) {
|
||||||
|
cout << " calculating " << links[j]->toString() << endl;
|
||||||
|
}
|
||||||
|
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||||
|
LinkMap::iterator iter = linkMap_.find (links[j]);
|
||||||
|
sortedOrder_.erase (iter->second);
|
||||||
|
iter->second = sortedOrder_.insert (links[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParamSet
|
||||||
|
CountingBP::getVar2FactorMsg (const Link* link) const
|
||||||
|
{
|
||||||
|
const FgVarNode* src = link->getVariable();
|
||||||
|
const Factor* dest = link->getFactor();
|
||||||
|
ParamSet msg;
|
||||||
|
if (src->hasEvidence()) {
|
||||||
|
cout << "has evidence" << endl;
|
||||||
|
msg.resize (src->getDomainSize(), 0.0);
|
||||||
|
msg[src->getEvidence()] = link->getMessage()[src->getEvidence()];
|
||||||
|
cout << "-> " << link->getVariable()->getLabel() << " " << link->getFactor()->getLabel() << endl;
|
||||||
|
cout << "-> p2s " << Util::parametersToString (msg) << endl;
|
||||||
|
} else {
|
||||||
|
msg = link->getMessage();
|
||||||
|
}
|
||||||
|
const CountingBPLink* l = static_cast<const CountingBPLink*> (link);
|
||||||
|
Util::pow (msg, l->getNumberOfEdges() - 1);
|
||||||
|
CLinkSet links = varsI_[src->getIndex()]->getLinks();
|
||||||
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
|
if (links[i]->getFactor() != dest) {
|
||||||
|
ParamSet msgFromFactor = links[i]->getMessage();
|
||||||
|
CountingBPLink* l = static_cast<CountingBPLink*> (links[i]);
|
||||||
|
Util::pow (msgFromFactor, l->getNumberOfEdges());
|
||||||
|
for (unsigned j = 0; j < msgFromFactor.size(); j++) {
|
||||||
|
msg[j] *= msgFromFactor[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return msg;
|
||||||
|
}
|
||||||
|
|
45
packages/CLPBN/clpbn/bp/CountingBP.h
Normal file
45
packages/CLPBN/clpbn/bp/CountingBP.h
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
#ifndef BP_COUNTING_BP_H
|
||||||
|
#define BP_COUNTING_BP_H
|
||||||
|
|
||||||
|
#include "SPSolver.h"
|
||||||
|
#include "LiftedFG.h"
|
||||||
|
|
||||||
|
class Factor;
|
||||||
|
class FgVarNode;
|
||||||
|
|
||||||
|
class CountingBPLink : public Link
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CountingBPLink (Factor* f, FgVarNode* v, unsigned c) : Link (f, v)
|
||||||
|
{
|
||||||
|
edgeCount_ = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
unsigned edgeCount_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class CountingBP : public SPSolver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CountingBP (FactorGraph& fg) : SPSolver (fg) { }
|
||||||
|
~CountingBP (void);
|
||||||
|
|
||||||
|
ParamSet getPosterioriOf (Vid) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void initializeSolver (void);
|
||||||
|
void createLinks (void);
|
||||||
|
void deleteJunction (Factor*, FgVarNode*);
|
||||||
|
|
||||||
|
void maxResidualSchedule (void);
|
||||||
|
ParamSet getVar2FactorMsg (const Link*) const;
|
||||||
|
|
||||||
|
LiftedFG* lfg_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // BP_COUNTING_BP_H
|
||||||
|
|
@ -1,5 +1,5 @@
|
|||||||
#ifndef BP_CPTENTRY_H
|
#ifndef BP_CPT_ENTRY_H
|
||||||
#define BP_CPTENTRY_H
|
#define BP_CPT_ENTRY_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -10,62 +10,34 @@ using namespace std;
|
|||||||
class CptEntry
|
class CptEntry
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CptEntry (unsigned, const vector<unsigned>&);
|
CptEntry (unsigned index, const DConf& conf)
|
||||||
|
{
|
||||||
|
index_ = index;
|
||||||
|
conf_ = conf;
|
||||||
|
}
|
||||||
|
|
||||||
unsigned getParameterIndex (void) const;
|
unsigned getParameterIndex (void) const { return index_; }
|
||||||
const vector<unsigned>& getParentConfigurations (void) const;
|
const DConf& getDomainConfiguration (void) const { return conf_; }
|
||||||
bool matchConstraints (const DomainConstr&) const;
|
|
||||||
bool matchConstraints (const vector<DomainConstr>&) const;
|
bool matchConstraints (const DConstraint& constr) const
|
||||||
|
{
|
||||||
|
return conf_[constr.first] == constr.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool matchConstraints (const vector<DConstraint>& constrs) const
|
||||||
|
{
|
||||||
|
for (unsigned j = 0; j < constrs.size(); j++) {
|
||||||
|
if (conf_[constrs[j].first] != constrs[j].second) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unsigned index_;
|
unsigned index_;
|
||||||
vector<unsigned> confs_;
|
DConf conf_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#endif //BP_CPT_ENTRY_H
|
||||||
|
|
||||||
|
|
||||||
inline
|
|
||||||
CptEntry::CptEntry (unsigned index, const vector<unsigned>& confs)
|
|
||||||
{
|
|
||||||
index_ = index;
|
|
||||||
confs_ = confs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline unsigned
|
|
||||||
CptEntry::getParameterIndex (void) const
|
|
||||||
{
|
|
||||||
return index_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline const vector<unsigned>&
|
|
||||||
CptEntry::getParentConfigurations (void) const
|
|
||||||
{
|
|
||||||
return confs_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline bool
|
|
||||||
CptEntry::matchConstraints (const DomainConstr& constr) const
|
|
||||||
{
|
|
||||||
return confs_[constr.first] == constr.second;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline bool
|
|
||||||
CptEntry::matchConstraints (const vector<DomainConstr>& constrs) const
|
|
||||||
{
|
|
||||||
for (unsigned j = 0; j < constrs.size(); j++) {
|
|
||||||
if (confs_[constrs[j].first] != constrs[j].second) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
#define BP_DISTRIBUTION_H
|
#define BP_DISTRIBUTION_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
|
||||||
|
|
||||||
|
#include "CptEntry.h"
|
||||||
#include "Shared.h"
|
#include "Shared.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -11,16 +11,18 @@ using namespace std;
|
|||||||
struct Distribution
|
struct Distribution
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Distribution (unsigned id)
|
Distribution (unsigned id, bool shared = false)
|
||||||
{
|
{
|
||||||
this->id = id;
|
this->id = id;
|
||||||
this->params = params;
|
this->params = params;
|
||||||
|
this->shared = shared;
|
||||||
}
|
}
|
||||||
|
|
||||||
Distribution (const ParamSet& params)
|
Distribution (const ParamSet& params, bool shared = false)
|
||||||
{
|
{
|
||||||
this->id = -1;
|
this->id = -1;
|
||||||
this->params = params;
|
this->params = params;
|
||||||
|
this->shared = shared;
|
||||||
}
|
}
|
||||||
|
|
||||||
void updateParameters (const ParamSet& params)
|
void updateParameters (const ParamSet& params)
|
||||||
@ -31,10 +33,11 @@ struct Distribution
|
|||||||
unsigned id;
|
unsigned id;
|
||||||
ParamSet params;
|
ParamSet params;
|
||||||
vector<CptEntry> entries;
|
vector<CptEntry> entries;
|
||||||
|
bool shared;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (Distribution);
|
DISALLOW_COPY_AND_ASSIGN (Distribution);
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif //BP_DISTRIBUTION_H
|
||||||
|
|
||||||
|
@ -1,37 +1,37 @@
|
|||||||
#include <iostream>
|
|
||||||
#include <sstream>
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "FgVarNode.h"
|
#include "FgVarNode.h"
|
||||||
|
|
||||||
|
|
||||||
int Factor::indexCount_ = 0;
|
Factor::Factor (const Factor& g)
|
||||||
|
{
|
||||||
Factor::Factor (FgVarNode* var) {
|
copyFactor (g);
|
||||||
vs_.push_back (var);
|
|
||||||
int nParams = var->getDomainSize();
|
|
||||||
// create a uniform distribution
|
|
||||||
double val = 1.0 / nParams;
|
|
||||||
ps_ = ParamSet (nParams, val);
|
|
||||||
id_ = indexCount_;
|
|
||||||
indexCount_ ++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (const FgVarSet& vars) {
|
Factor::Factor (FgVarNode* var)
|
||||||
vs_ = vars;
|
{
|
||||||
|
Factor (FgVarSet() = {var});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Factor::Factor (const FgVarSet& vars)
|
||||||
|
{
|
||||||
|
vars_ = vars;
|
||||||
int nParams = 1;
|
int nParams = 1;
|
||||||
for (unsigned i = 0; i < vs_.size(); i++) {
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||||
nParams *= vs_[i]->getDomainSize();
|
nParams *= vars_[i]->getDomainSize();
|
||||||
}
|
}
|
||||||
// create a uniform distribution
|
// create a uniform distribution
|
||||||
double val = 1.0 / nParams;
|
double val = 1.0 / nParams;
|
||||||
ps_ = ParamSet (nParams, val);
|
dist_ = new Distribution (ParamSet (nParams, val));
|
||||||
id_ = indexCount_;
|
|
||||||
indexCount_ ++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -39,10 +39,17 @@ Factor::Factor (const FgVarSet& vars) {
|
|||||||
Factor::Factor (FgVarNode* var,
|
Factor::Factor (FgVarNode* var,
|
||||||
const ParamSet& params)
|
const ParamSet& params)
|
||||||
{
|
{
|
||||||
vs_.push_back (var);
|
vars_.push_back (var);
|
||||||
ps_ = params;
|
dist_ = new Distribution (params);
|
||||||
id_ = indexCount_;
|
}
|
||||||
indexCount_ ++;
|
|
||||||
|
|
||||||
|
|
||||||
|
Factor::Factor (FgVarSet& vars,
|
||||||
|
Distribution* dist)
|
||||||
|
{
|
||||||
|
vars_ = vars;
|
||||||
|
dist_ = dist;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -50,42 +57,8 @@ Factor::Factor (FgVarNode* var,
|
|||||||
Factor::Factor (const FgVarSet& vars,
|
Factor::Factor (const FgVarSet& vars,
|
||||||
const ParamSet& params)
|
const ParamSet& params)
|
||||||
{
|
{
|
||||||
vs_ = vars;
|
vars_ = vars;
|
||||||
ps_ = params;
|
dist_ = new Distribution (params);
|
||||||
id_ = indexCount_;
|
|
||||||
indexCount_ ++;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const FgVarSet&
|
|
||||||
Factor::getFgVarNodes (void) const
|
|
||||||
{
|
|
||||||
return vs_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FgVarSet&
|
|
||||||
Factor::getFgVarNodes (void)
|
|
||||||
{
|
|
||||||
return vs_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const ParamSet&
|
|
||||||
Factor::getParameters (void) const
|
|
||||||
{
|
|
||||||
return ps_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ParamSet&
|
|
||||||
Factor::getParameters (void)
|
|
||||||
{
|
|
||||||
return ps_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -93,75 +66,95 @@ Factor::getParameters (void)
|
|||||||
void
|
void
|
||||||
Factor::setParameters (const ParamSet& params)
|
Factor::setParameters (const ParamSet& params)
|
||||||
{
|
{
|
||||||
//cout << "ps size: " << ps_.size() << endl;
|
assert (dist_->params.size() == params.size());
|
||||||
//cout << "params size: " << params.size() << endl;
|
dist_->updateParameters (params);
|
||||||
assert (ps_.size() == params.size());
|
|
||||||
ps_ = params;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor&
|
void
|
||||||
Factor::operator= (const Factor& g)
|
Factor::copyFactor (const Factor& g)
|
||||||
{
|
{
|
||||||
FgVarSet vars = g.getFgVarNodes();
|
vars_ = g.getFgVarNodes();
|
||||||
ParamSet params = g.getParameters();
|
dist_ = new Distribution (g.getDistribution()->params);
|
||||||
return *this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor&
|
void
|
||||||
Factor::operator*= (const Factor& g)
|
Factor::multiplyByFactor (const Factor& g, const vector<CptEntry>* entries)
|
||||||
{
|
{
|
||||||
FgVarSet gVs = g.getFgVarNodes();
|
if (vars_.size() == 0) {
|
||||||
|
copyFactor (g);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const FgVarSet& gVs = g.getFgVarNodes();
|
||||||
const ParamSet& gPs = g.getParameters();
|
const ParamSet& gPs = g.getParameters();
|
||||||
|
|
||||||
bool hasCommonVars = false;
|
bool factorsAreEqual = true;
|
||||||
vector<int> varIndexes;
|
if (gVs.size() == vars_.size()) {
|
||||||
for (unsigned i = 0; i < gVs.size(); i++) {
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||||
int idx = getIndexOf (gVs[i]);
|
if (gVs[i] != vars_[i]) {
|
||||||
if (idx == -1) {
|
factorsAreEqual = false;
|
||||||
insertVariable (gVs[i]);
|
break;
|
||||||
varIndexes.push_back (vs_.size() - 1);
|
|
||||||
} else {
|
|
||||||
hasCommonVars = true;
|
|
||||||
varIndexes.push_back (idx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (hasCommonVars) {
|
|
||||||
vector<int> offsets (gVs.size());
|
|
||||||
offsets[gVs.size() - 1] = 1;
|
|
||||||
for (int i = gVs.size() - 2; i >= 0; i--) {
|
|
||||||
offsets[i] = offsets[i + 1] * gVs[i + 1]->getDomainSize();
|
|
||||||
}
|
|
||||||
vector<CptEntry> entries = getCptEntries();
|
|
||||||
for (unsigned i = 0; i < entries.size(); i++) {
|
|
||||||
int idx = 0;
|
|
||||||
const DomainConf conf = entries[i].getParentConfigurations();
|
|
||||||
for (unsigned j = 0; j < varIndexes.size(); j++) {
|
|
||||||
idx += offsets[j] * conf[varIndexes[j]];
|
|
||||||
}
|
}
|
||||||
//cout << "ps_[" << i << "] = " << ps_[i] << " * " ;
|
|
||||||
//cout << gPs[idx] << " , idx = " << idx << endl;
|
|
||||||
ps_[i] = ps_[i] * gPs[idx];
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// if the originally factors doesn't have common factors.
|
factorsAreEqual = false;
|
||||||
// we don't have to make domain comparations
|
}
|
||||||
unsigned idx = 0;
|
|
||||||
for (unsigned i = 0; i < ps_.size(); i++) {
|
if (factorsAreEqual) {
|
||||||
//cout << "ps_[" << i << "] = " << ps_[i] << " * " ;
|
// optimization: if the factors contain the same set of variables,
|
||||||
//cout << gPs[idx] << " , idx = " << idx << endl;
|
// we can do 1 to 1 operations on the parameteres
|
||||||
ps_[i] = ps_[i] * gPs[idx];
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||||
idx ++;
|
dist_->params[i] *= gPs[i];
|
||||||
if (idx >= gPs.size()) {
|
}
|
||||||
idx = 0;
|
} else {
|
||||||
|
bool hasCommonVars = false;
|
||||||
|
vector<unsigned> gVsIndexes;
|
||||||
|
for (unsigned i = 0; i < gVs.size(); i++) {
|
||||||
|
int idx = getIndexOf (gVs[i]);
|
||||||
|
if (idx == -1) {
|
||||||
|
insertVariable (gVs[i]);
|
||||||
|
gVsIndexes.push_back (vars_.size() - 1);
|
||||||
|
} else {
|
||||||
|
hasCommonVars = true;
|
||||||
|
gVsIndexes.push_back (idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (hasCommonVars) {
|
||||||
|
vector<unsigned> gVsOffsets (gVs.size());
|
||||||
|
gVsOffsets[gVs.size() - 1] = 1;
|
||||||
|
for (int i = gVs.size() - 2; i >= 0; i--) {
|
||||||
|
gVsOffsets[i] = gVsOffsets[i + 1] * gVs[i + 1]->getDomainSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (entries == 0) {
|
||||||
|
entries = &getCptEntries();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < entries->size(); i++) {
|
||||||
|
unsigned idx = 0;
|
||||||
|
const DConf& conf = (*entries)[i].getDomainConfiguration();
|
||||||
|
for (unsigned j = 0; j < gVsIndexes.size(); j++) {
|
||||||
|
idx += gVsOffsets[j] * conf[ gVsIndexes[j] ];
|
||||||
|
}
|
||||||
|
dist_->params[i] = dist_->params[i] * gPs[idx];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// optimization: if the original factors doesn't have common variables,
|
||||||
|
// we don't need to marry the states of the common variables
|
||||||
|
unsigned count = 0;
|
||||||
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||||
|
dist_->params[i] *= gPs[count];
|
||||||
|
count ++;
|
||||||
|
if (count >= gPs.size()) {
|
||||||
|
count = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -169,81 +162,109 @@ Factor::operator*= (const Factor& g)
|
|||||||
void
|
void
|
||||||
Factor::insertVariable (FgVarNode* var)
|
Factor::insertVariable (FgVarNode* var)
|
||||||
{
|
{
|
||||||
int c = 0;
|
assert (getIndexOf (var) == -1);
|
||||||
ParamSet newPs (ps_.size() * var->getDomainSize());
|
ParamSet newPs;
|
||||||
for (unsigned i = 0; i < ps_.size(); i++) {
|
newPs.reserve (dist_->params.size() * var->getDomainSize());
|
||||||
for (int j = 0; j < var->getDomainSize(); j++) {
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||||
newPs[c] = ps_[i];
|
for (unsigned j = 0; j < var->getDomainSize(); j++) {
|
||||||
c ++;
|
newPs.push_back (dist_->params[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
vs_.push_back (var);
|
vars_.push_back (var);
|
||||||
ps_ = newPs;
|
dist_->updateParameters (newPs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Factor::marginalizeVariable (const FgVarNode* var) {
|
Factor::removeVariable (const FgVarNode* var)
|
||||||
int varIndex = getIndexOf (var);
|
|
||||||
marginalizeVariable (varIndex);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::marginalizeVariable (unsigned varIndex)
|
|
||||||
{
|
{
|
||||||
assert (varIndex >= 0 && varIndex < vs_.size());
|
int varIndex = getIndexOf (var);
|
||||||
int distOffset = 1;
|
assert (varIndex >= 0 && varIndex < (int)vars_.size());
|
||||||
int leftVarOffset = 1;
|
|
||||||
for (unsigned i = vs_.size() - 1; i > varIndex; i--) {
|
// number of parameters separating a different state of `var',
|
||||||
distOffset *= vs_[i]->getDomainSize();
|
// with the states of the remaining variables fixed
|
||||||
leftVarOffset *= vs_[i]->getDomainSize();
|
unsigned varOffset = 1;
|
||||||
}
|
|
||||||
leftVarOffset *= vs_[varIndex]->getDomainSize();
|
// number of parameters separating a different state of the variable
|
||||||
|
// on the left of `var', with the states of the remaining vars fixed
|
||||||
|
unsigned leftVarOffset = 1;
|
||||||
|
|
||||||
|
for (int i = vars_.size() - 1; i > varIndex; i--) {
|
||||||
|
varOffset *= vars_[i]->getDomainSize();
|
||||||
|
leftVarOffset *= vars_[i]->getDomainSize();
|
||||||
|
}
|
||||||
|
leftVarOffset *= vars_[varIndex]->getDomainSize();
|
||||||
|
|
||||||
|
unsigned offset = 0;
|
||||||
|
unsigned count1 = 0;
|
||||||
|
unsigned count2 = 0;
|
||||||
|
unsigned newPsSize = dist_->params.size() / vars_[varIndex]->getDomainSize();
|
||||||
|
|
||||||
int ds = vs_[varIndex]->getDomainSize();
|
|
||||||
int count = 0;
|
|
||||||
int offset = 0;
|
|
||||||
int startIndex = 0;
|
|
||||||
int currDomainIdx = 0;
|
|
||||||
unsigned newPsSize = ps_.size() / ds;
|
|
||||||
ParamSet newPs;
|
ParamSet newPs;
|
||||||
newPs.reserve (newPsSize);
|
newPs.reserve (newPsSize);
|
||||||
|
|
||||||
stringstream ss;
|
// stringstream ss;
|
||||||
ss << "marginalizing " << vs_[varIndex]->getLabel();
|
// ss << "marginalizing " << vars_[varIndex]->getLabel();
|
||||||
ss << " from factor " << getLabel() << endl;
|
// ss << " from factor " << getLabel() << endl;
|
||||||
while (newPs.size() < newPsSize) {
|
while (newPs.size() < newPsSize) {
|
||||||
ss << " sum = ";
|
// ss << " sum = ";
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
for (int j = 0; j < ds; j++) {
|
for (unsigned i = 0; i < vars_[varIndex]->getDomainSize(); i++) {
|
||||||
if (j != 0) ss << " + ";
|
// if (i != 0) ss << " + ";
|
||||||
ss << ps_[offset];
|
// ss << dist_->params[offset];
|
||||||
sum = sum + ps_[offset];
|
sum += dist_->params[offset];
|
||||||
offset = offset + distOffset;
|
offset += varOffset;
|
||||||
}
|
}
|
||||||
newPs.push_back (sum);
|
newPs.push_back (sum);
|
||||||
count ++;
|
count1 ++;
|
||||||
if (varIndex == vs_.size() - 1) {
|
if (varIndex == (int)vars_.size() - 1) {
|
||||||
offset = count * ds;
|
offset = count1 * vars_[varIndex]->getDomainSize();
|
||||||
} else {
|
} else {
|
||||||
offset = offset - distOffset + 1;
|
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
||||||
if ((offset % leftVarOffset) == 0) {
|
count1 = 0;
|
||||||
currDomainIdx ++;
|
count2 ++;
|
||||||
startIndex = leftVarOffset * currDomainIdx;
|
|
||||||
offset = startIndex;
|
|
||||||
count = 0;
|
|
||||||
} else {
|
|
||||||
offset = startIndex + count;
|
|
||||||
}
|
}
|
||||||
|
offset = (leftVarOffset * count2) + count1;
|
||||||
}
|
}
|
||||||
ss << " = " << sum << endl;
|
// ss << " = " << sum << endl;
|
||||||
}
|
}
|
||||||
//cout << ss.str() << endl;
|
// cout << ss.str() << endl;
|
||||||
ps_ = newPs;
|
vars_.erase (vars_.begin() + varIndex);
|
||||||
vs_.erase (vs_.begin() + varIndex);
|
dist_->updateParameters (newPs);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
const vector<CptEntry>&
|
||||||
|
Factor::getCptEntries (void) const
|
||||||
|
{
|
||||||
|
if (dist_->entries.size() == 0) {
|
||||||
|
vector<DConf> confs (dist_->params.size());
|
||||||
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||||
|
confs[i].resize (vars_.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned nReps = 1;
|
||||||
|
for (int i = vars_.size() - 1; i >= 0; i--) {
|
||||||
|
unsigned index = 0;
|
||||||
|
while (index < dist_->params.size()) {
|
||||||
|
for (unsigned j = 0; j < vars_[i]->getDomainSize(); j++) {
|
||||||
|
for (unsigned r = 0; r < nReps; r++) {
|
||||||
|
confs[index][i] = j;
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nReps *= vars_[i]->getDomainSize();
|
||||||
|
}
|
||||||
|
dist_->entries.clear();
|
||||||
|
dist_->entries.reserve (dist_->params.size());
|
||||||
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||||
|
dist_->entries.push_back (CptEntry (i, confs[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dist_->entries;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -252,11 +273,10 @@ string
|
|||||||
Factor::getLabel (void) const
|
Factor::getLabel (void) const
|
||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "f(" ;
|
ss << "Φ(" ;
|
||||||
// ss << "Φ(" ;
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||||
for (unsigned i = 0; i < vs_.size(); i++) {
|
if (i != 0) ss << "," ;
|
||||||
if (i != 0) ss << ", " ;
|
ss << vars_[i]->getLabel();
|
||||||
ss << "v" << vs_[i]->getVarId();
|
|
||||||
}
|
}
|
||||||
ss << ")" ;
|
ss << ")" ;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
@ -264,62 +284,24 @@ Factor::getLabel (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
void
|
||||||
Factor::toString (void) const
|
Factor::printFactor (void)
|
||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "vars: " ;
|
ss << getLabel() << endl;
|
||||||
for (unsigned i = 0; i < vs_.size(); i++) {
|
ss << "--------------------" << endl;
|
||||||
if (i != 0) ss << ", " ;
|
VarSet vs;
|
||||||
ss << "v" << vs_[i]->getVarId();
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||||
|
vs.push_back (vars_[i]);
|
||||||
}
|
}
|
||||||
ss << endl;
|
vector<string> domainConfs = Util::getInstantiations (vs);
|
||||||
vector<CptEntry> entries = getCptEntries();
|
const vector<CptEntry>& entries = getCptEntries();
|
||||||
for (unsigned i = 0; i < entries.size(); i++) {
|
for (unsigned i = 0; i < entries.size(); i++) {
|
||||||
ss << "Φ(" ;
|
ss << "Φ(" << domainConfs[i] << ")" ;
|
||||||
char s = 'a' ;
|
unsigned idx = entries[i].getParameterIndex();
|
||||||
const DomainConf& conf = entries[i].getParentConfigurations();
|
ss << " = " << dist_->params[idx] << endl;
|
||||||
for (unsigned j = 0; j < conf.size(); j++) {
|
|
||||||
if (j != 0) ss << "," ;
|
|
||||||
ss << s << conf[j] + 1;
|
|
||||||
s++;
|
|
||||||
}
|
|
||||||
ss << ") = " << ps_[entries[i].getParameterIndex()] << endl;
|
|
||||||
}
|
}
|
||||||
return ss.str();
|
cout << ss.str();
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<CptEntry>
|
|
||||||
Factor::getCptEntries (void) const
|
|
||||||
{
|
|
||||||
vector<DomainConf> confs (ps_.size());
|
|
||||||
for (unsigned i = 0; i < ps_.size(); i++) {
|
|
||||||
confs[i].resize (vs_.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
int nReps = 1;
|
|
||||||
for (int i = vs_.size() - 1; i >= 0; i--) {
|
|
||||||
unsigned index = 0;
|
|
||||||
while (index < ps_.size()) {
|
|
||||||
for (int j = 0; j < vs_[i]->getDomainSize(); j++) {
|
|
||||||
for (int r = 0; r < nReps; r++) {
|
|
||||||
confs[index][i] = j;
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nReps *= vs_[i]->getDomainSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<CptEntry> entries;
|
|
||||||
for (unsigned i = 0; i < ps_.size(); i++) {
|
|
||||||
for (unsigned j = 0; j < vs_.size(); j++) {
|
|
||||||
}
|
|
||||||
entries.push_back (CptEntry (i, confs[i]));
|
|
||||||
}
|
|
||||||
return entries;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -327,20 +309,11 @@ Factor::getCptEntries (void) const
|
|||||||
int
|
int
|
||||||
Factor::getIndexOf (const FgVarNode* var) const
|
Factor::getIndexOf (const FgVarNode* var) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < vs_.size(); i++) {
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||||
if (vs_[i] == var) {
|
if (vars_[i] == var) {
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor operator* (const Factor& f, const Factor& g)
|
|
||||||
{
|
|
||||||
Factor r = f;
|
|
||||||
r *= g;
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -3,43 +3,46 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "Distribution.h"
|
||||||
#include "CptEntry.h"
|
#include "CptEntry.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class FgVarNode;
|
class FgVarNode;
|
||||||
|
class Distribution;
|
||||||
|
|
||||||
class Factor
|
class Factor
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
Factor (void) { }
|
||||||
|
Factor (const Factor&);
|
||||||
Factor (FgVarNode*);
|
Factor (FgVarNode*);
|
||||||
Factor (const FgVarSet&);
|
Factor (CFgVarSet);
|
||||||
Factor (FgVarNode*, const ParamSet&);
|
Factor (FgVarNode*, const ParamSet&);
|
||||||
Factor (const FgVarSet&, const ParamSet&);
|
Factor (FgVarSet&, Distribution*);
|
||||||
|
Factor (CFgVarSet, CParamSet);
|
||||||
|
|
||||||
const FgVarSet& getFgVarNodes (void) const;
|
void setParameters (CParamSet);
|
||||||
FgVarSet& getFgVarNodes (void);
|
void copyFactor (const Factor& f);
|
||||||
const ParamSet& getParameters (void) const;
|
void multiplyByFactor (const Factor& f, const vector<CptEntry>* = 0);
|
||||||
ParamSet& getParameters (void);
|
void insertVariable (FgVarNode* index);
|
||||||
void setParameters (const ParamSet&);
|
void removeVariable (const FgVarNode* var);
|
||||||
Factor& operator= (const Factor& f);
|
const vector<CptEntry>& getCptEntries (void) const;
|
||||||
Factor& operator*= (const Factor& f);
|
string getLabel (void) const;
|
||||||
void insertVariable (FgVarNode* index);
|
void printFactor (void);
|
||||||
void marginalizeVariable (const FgVarNode* var);
|
|
||||||
void marginalizeVariable (unsigned);
|
CFgVarSet getFgVarNodes (void) const { return vars_; }
|
||||||
string getLabel (void) const;
|
CParamSet getParameters (void) const { return dist_->params; }
|
||||||
string toString (void) const;
|
Distribution* getDistribution (void) const { return dist_; }
|
||||||
|
unsigned getIndex (void) const { return index_; }
|
||||||
|
void setIndex (unsigned index) { index_ = index; }
|
||||||
|
void freeDistribution (void) { delete dist_; dist_ = 0;}
|
||||||
|
int getIndexOf (const FgVarNode*) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<CptEntry> getCptEntries() const;
|
FgVarSet vars_;
|
||||||
int getIndexOf (const FgVarNode*) const;
|
Distribution* dist_;
|
||||||
|
unsigned index_;
|
||||||
FgVarSet vs_;
|
|
||||||
ParamSet ps_;
|
|
||||||
int id_;
|
|
||||||
static int indexCount_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Factor operator* (const Factor&, const Factor&);
|
#endif //BP_FACTOR_H
|
||||||
|
|
||||||
#endif
|
|
||||||
|
@ -1,23 +1,26 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
|
||||||
#include <cstdlib>
|
|
||||||
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "FgVarNode.h"
|
#include "FgVarNode.h"
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
|
#include "BayesNet.h"
|
||||||
|
|
||||||
|
|
||||||
FactorGraph::FactorGraph (const char* fileName)
|
FactorGraph::FactorGraph (const char* fileName)
|
||||||
{
|
{
|
||||||
string line;
|
|
||||||
ifstream is (fileName);
|
ifstream is (fileName);
|
||||||
if (!is.is_open()) {
|
if (!is.is_open()) {
|
||||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string line;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||||
getline (is, line);
|
getline (is, line);
|
||||||
if (line != "MARKOV") {
|
if (line != "MARKOV") {
|
||||||
@ -39,7 +42,7 @@ FactorGraph::FactorGraph (const char* fileName)
|
|||||||
|
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||||
for (int i = 0; i < nVars; i++) {
|
for (int i = 0; i < nVars; i++) {
|
||||||
varNodes_.push_back (new FgVarNode (i, domainSizes[i]));
|
addVariable (new FgVarNode (i, domainSizes[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
int nFactors;
|
int nFactors;
|
||||||
@ -50,11 +53,11 @@ FactorGraph::FactorGraph (const char* fileName)
|
|||||||
is >> nFactorVars;
|
is >> nFactorVars;
|
||||||
FgVarSet factorVars;
|
FgVarSet factorVars;
|
||||||
for (int j = 0; j < nFactorVars; j++) {
|
for (int j = 0; j < nFactorVars; j++) {
|
||||||
int varId;
|
int vid;
|
||||||
is >> varId;
|
is >> vid;
|
||||||
FgVarNode* var = getVariableById (varId);
|
FgVarNode* var = getFgVarNode (vid);
|
||||||
if (var == 0) {
|
if (!var) {
|
||||||
cerr << "error: invalid variable identifier (" << varId << ")" << endl;
|
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
factorVars.push_back (var);
|
factorVars.push_back (var);
|
||||||
@ -87,6 +90,33 @@ FactorGraph::FactorGraph (const char* fileName)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph::FactorGraph (const BayesNet& bn)
|
||||||
|
{
|
||||||
|
const BnNodeSet& nodes = bn.getBayesNodes();
|
||||||
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
|
FgVarNode* varNode = new FgVarNode (nodes[i]);
|
||||||
|
varNode->setIndex (i);
|
||||||
|
addVariable (varNode);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
|
const BnNodeSet& parents = nodes[i]->getParents();
|
||||||
|
if (!(nodes[i]->hasEvidence() && parents.size() == 0)) {
|
||||||
|
FgVarSet factorVars = { varNodes_[nodes[i]->getIndex()] };
|
||||||
|
for (unsigned j = 0; j < parents.size(); j++) {
|
||||||
|
factorVars.push_back (varNodes_[parents[j]->getIndex()]);
|
||||||
|
}
|
||||||
|
Factor* f = new Factor (factorVars, nodes[i]->getDistribution());
|
||||||
|
factors_.push_back (f);
|
||||||
|
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||||
|
factorVars[j]->addFactor (f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph::~FactorGraph (void)
|
FactorGraph::~FactorGraph (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
@ -99,18 +129,67 @@ FactorGraph::~FactorGraph (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
FgVarSet
|
void
|
||||||
FactorGraph::getFgVarNodes (void) const
|
FactorGraph::addVariable (FgVarNode* varNode)
|
||||||
{
|
{
|
||||||
return varNodes_;
|
varNodes_.push_back (varNode);
|
||||||
|
varNode->setIndex (varNodes_.size() - 1);
|
||||||
|
indexMap_.insert (make_pair (varNode->getVarId(), varNodes_.size() - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<Factor*>
|
void
|
||||||
FactorGraph::getFactors (void) const
|
FactorGraph::removeVariable (const FgVarNode* var)
|
||||||
{
|
{
|
||||||
return factors_;
|
if (varNodes_[varNodes_.size() - 1] == var) {
|
||||||
|
varNodes_.pop_back();
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
|
if (varNodes_[i] == var) {
|
||||||
|
varNodes_.erase (varNodes_.begin() + i);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (false);
|
||||||
|
}
|
||||||
|
indexMap_.erase (indexMap_.find (var->getVarId()));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::addFactor (Factor* f)
|
||||||
|
{
|
||||||
|
factors_.push_back (f);
|
||||||
|
const FgVarSet& factorVars = f->getFgVarNodes();
|
||||||
|
for (unsigned i = 0; i < factorVars.size(); i++) {
|
||||||
|
factorVars[i]->addFactor (f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::removeFactor (const Factor* f)
|
||||||
|
{
|
||||||
|
const FgVarSet& factorVars = f->getFgVarNodes();
|
||||||
|
for (unsigned i = 0; i < factorVars.size(); i++) {
|
||||||
|
if (factorVars[i]) {
|
||||||
|
factorVars[i]->removeFactor (f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (factors_[factors_.size() - 1] == f) {
|
||||||
|
factors_.pop_back();
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
if (factors_[i] == f) {
|
||||||
|
factors_.erase (factors_.begin() + i);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -127,47 +206,142 @@ FactorGraph::getVariables (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
FgVarNode*
|
Variable*
|
||||||
FactorGraph::getVariableById (unsigned id) const
|
FactorGraph::getVariable (Vid vid) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
return getFgVarNode (vid);
|
||||||
if (varNodes_[i]->getVarId() == id) {
|
|
||||||
return varNodes_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FgVarNode*
|
|
||||||
FactorGraph::getVariableByLabel (string label) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << "v" << varNodes_[i]->getVarId();
|
|
||||||
if (ss.str() == label) {
|
|
||||||
return varNodes_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::printFactorGraph (void) const
|
FactorGraph::setIndexes (void)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
|
varNodes_[i]->setIndex (i);
|
||||||
|
}
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
factors_[i]->setIndex (i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::freeDistributions (void)
|
||||||
|
{
|
||||||
|
set<Distribution*> dists;
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
dists.insert (factors_[i]->getDistribution());
|
||||||
|
}
|
||||||
|
for (set<Distribution*>::iterator it = dists.begin();
|
||||||
|
it != dists.end(); it++) {
|
||||||
|
delete *it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::printGraphicalModel (void) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
cout << "variable number " << varNodes_[i]->getIndex() << endl;
|
cout << "variable number " << varNodes_[i]->getIndex() << endl;
|
||||||
cout << "Id = " << varNodes_[i]->getVarId() << endl;
|
cout << "Id = " << varNodes_[i]->getVarId() << endl;
|
||||||
|
cout << "Label = " << varNodes_[i]->getLabel() << endl;
|
||||||
cout << "Domain size = " << varNodes_[i]->getDomainSize() << endl;
|
cout << "Domain size = " << varNodes_[i]->getDomainSize() << endl;
|
||||||
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
|
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||||
cout << endl;
|
cout << "Factors = " ;
|
||||||
|
for (unsigned j = 0; j < varNodes_[i]->getFactors().size(); j++) {
|
||||||
|
cout << varNodes_[i]->getFactors()[j]->getLabel() << " " ;
|
||||||
|
}
|
||||||
|
cout << endl << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
|
||||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
cout << factors_[i]->toString() << endl;
|
factors_[i]->printFactor();
|
||||||
|
cout << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::exportToDotFormat (const char* fileName) const
|
||||||
|
{
|
||||||
|
ofstream out (fileName);
|
||||||
|
if (!out.is_open()) {
|
||||||
|
cerr << "error: cannot open file to write at " ;
|
||||||
|
cerr << "FactorGraph::exportToDotFile()" << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "graph \"" << fileName << "\" {" << endl;
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
|
if (varNodes_[i]->hasEvidence()) {
|
||||||
|
out << '"' << varNodes_[i]->getLabel() << '"' ;
|
||||||
|
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
out << '"' << factors_[i]->getLabel() << '"' ;
|
||||||
|
out << " [label=\"" << factors_[i]->getLabel() << "\\n(";
|
||||||
|
out << factors_[i]->getDistribution()->id << ")" << "\"" ;
|
||||||
|
out << ", shape=box]" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
CFgVarSet myVars = factors_[i]->getFgVarNodes();
|
||||||
|
for (unsigned j = 0; j < myVars.size(); j++) {
|
||||||
|
out << '"' << factors_[i]->getLabel() << '"' ;
|
||||||
|
out << " -- " ;
|
||||||
|
out << '"' << myVars[j]->getLabel() << '"' << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "}" << endl;
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||||
|
{
|
||||||
|
ofstream out (fileName);
|
||||||
|
if (!out.is_open()) {
|
||||||
|
cerr << "error: cannot open file to write at " ;
|
||||||
|
cerr << "FactorGraph::exportToUaiFormat()" << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "MARKOV" << endl;
|
||||||
|
out << varNodes_.size() << endl;
|
||||||
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
|
out << varNodes_[i]->getDomainSize() << " " ;
|
||||||
|
}
|
||||||
|
out << endl;
|
||||||
|
|
||||||
|
out << factors_.size() << endl;
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
CFgVarSet factorVars = factors_[i]->getFgVarNodes();
|
||||||
|
out << factorVars.size();
|
||||||
|
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||||
|
out << " " << factorVars[j]->getIndex();
|
||||||
|
}
|
||||||
|
out << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
CParamSet params = factors_[i]->getParameters();
|
||||||
|
out << endl << params.size() << endl << " " ;
|
||||||
|
for (unsigned j = 0; j < params.size(); j++) {
|
||||||
|
out << params[j] << " " ;
|
||||||
|
}
|
||||||
|
out << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
#ifndef BP_FACTORGRAPH_H
|
#ifndef BP_FACTOR_GRAPH_H
|
||||||
#define BP_FACTORGRAPH_H
|
#define BP_FACTOR_GRAPH_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
#include "GraphicalModel.h"
|
||||||
#include "Shared.h"
|
#include "Shared.h"
|
||||||
@ -11,25 +10,48 @@ using namespace std;
|
|||||||
|
|
||||||
class FgVarNode;
|
class FgVarNode;
|
||||||
class Factor;
|
class Factor;
|
||||||
|
class BayesNet;
|
||||||
|
|
||||||
class FactorGraph : public GraphicalModel
|
class FactorGraph : public GraphicalModel
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FactorGraph (const char* fileName);
|
FactorGraph (void) {};
|
||||||
|
FactorGraph (const char*);
|
||||||
|
FactorGraph (const BayesNet&);
|
||||||
~FactorGraph (void);
|
~FactorGraph (void);
|
||||||
|
|
||||||
FgVarSet getFgVarNodes (void) const;
|
void addVariable (FgVarNode*);
|
||||||
vector<Factor*> getFactors (void) const;
|
void removeVariable (const FgVarNode*);
|
||||||
|
void addFactor (Factor*);
|
||||||
|
void removeFactor (const Factor*);
|
||||||
VarSet getVariables (void) const;
|
VarSet getVariables (void) const;
|
||||||
FgVarNode* getVariableById (unsigned) const;
|
Variable* getVariable (unsigned) const;
|
||||||
FgVarNode* getVariableByLabel (string) const;
|
void setIndexes (void);
|
||||||
void printFactorGraph (void) const;
|
void freeDistributions (void);
|
||||||
|
void printGraphicalModel (void) const;
|
||||||
|
void exportToDotFormat (const char*) const;
|
||||||
|
void exportToUaiFormat (const char*) const;
|
||||||
|
|
||||||
|
const FgVarSet& getFgVarNodes (void) const { return varNodes_; }
|
||||||
|
const FactorSet& getFactors (void) const { return factors_; }
|
||||||
|
|
||||||
|
FgVarNode* getFgVarNode (Vid vid) const
|
||||||
|
{
|
||||||
|
IndexMap::const_iterator it = indexMap_.find (vid);
|
||||||
|
if (it == indexMap_.end()) {
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
return varNodes_[it->second];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||||
|
|
||||||
FgVarSet varNodes_;
|
FgVarSet varNodes_;
|
||||||
vector<Factor*> factors_;
|
FactorSet factors_;
|
||||||
|
IndexMap indexMap_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif // BP_FACTOR_GRAPH_H
|
||||||
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
#ifndef BP_VARIABLE_H
|
#ifndef BP_FG_VAR_NODE_H
|
||||||
#define BP_VARIABLE_H
|
#define BP_FG_VAR_NODE_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "Variable.h"
|
#include "Variable.h"
|
||||||
#include "Shared.h"
|
#include "Shared.h"
|
||||||
@ -14,15 +13,31 @@ class Factor;
|
|||||||
class FgVarNode : public Variable
|
class FgVarNode : public Variable
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FgVarNode (int varId, int dsize) : Variable (varId, dsize) { }
|
FgVarNode (unsigned vid, unsigned dsize) : Variable (vid, dsize) { }
|
||||||
|
FgVarNode (const Variable* v) : Variable (v) { }
|
||||||
|
|
||||||
void addFactor (Factor* f) { factors_.push_back (f); }
|
void addFactor (Factor* f) { factors_.push_back (f); }
|
||||||
vector<Factor*> getFactors (void) const { return factors_; }
|
CFactorSet getFactors (void) const { return factors_; }
|
||||||
|
|
||||||
|
void removeFactor (const Factor* f)
|
||||||
|
{
|
||||||
|
if (factors_[factors_.size() -1] == f) {
|
||||||
|
factors_.pop_back();
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
if (factors_[i] == f) {
|
||||||
|
factors_.erase (factors_.begin() + i);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
||||||
// members
|
// members
|
||||||
vector<Factor*> factors_;
|
FactorSet factors_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // BP_VARIABLE_H
|
#endif // BP_FG_VAR_NODE_H
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#ifndef BP_GRAPHICALMODEL_H
|
#ifndef BP_GRAPHICAL_MODEL_H
|
||||||
#define BP_GRAPHICALMODEL_H
|
#define BP_GRAPHICAL_MODEL_H
|
||||||
|
|
||||||
#include "Variable.h"
|
#include "Variable.h"
|
||||||
#include "Shared.h"
|
#include "Shared.h"
|
||||||
@ -9,9 +9,10 @@ using namespace std;
|
|||||||
class GraphicalModel
|
class GraphicalModel
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
virtual VarSet getVariables (void) const = 0;
|
virtual ~GraphicalModel (void) {};
|
||||||
|
virtual Variable* getVariable (Vid) const = 0;
|
||||||
private:
|
virtual VarSet getVariables (void) const = 0;
|
||||||
|
virtual void printGraphicalModel (void) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif // BP_GRAPHICAL_MODEL_H
|
||||||
|
@ -1,17 +1,19 @@
|
|||||||
#include <iostream>
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
#include "BPSolver.h"
|
|
||||||
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "SPSolver.h"
|
#include "SPSolver.h"
|
||||||
|
#include "BPSolver.h"
|
||||||
|
#include "CountingBP.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
void BayesianNetwork (int, const char* []);
|
void BayesianNetwork (int, const char* []);
|
||||||
void markovNetwork (int, const char* []);
|
void markovNetwork (int, const char* []);
|
||||||
|
void runSolver (Solver*, const VarSet&);
|
||||||
|
|
||||||
const string USAGE = "usage: \
|
const string USAGE = "usage: \
|
||||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||||
@ -19,14 +21,40 @@ const string USAGE = "usage: \
|
|||||||
|
|
||||||
int
|
int
|
||||||
main (int argc, const char* argv[])
|
main (int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
|
/*
|
||||||
|
FactorGraph fg;
|
||||||
|
FgVarNode* varNode1 = new FgVarNode (0, 2);
|
||||||
|
FgVarNode* varNode2 = new FgVarNode (1, 2);
|
||||||
|
FgVarNode* varNode3 = new FgVarNode (2, 2);
|
||||||
|
fg.addVariable (varNode1);
|
||||||
|
fg.addVariable (varNode2);
|
||||||
|
fg.addVariable (varNode3);
|
||||||
|
Distribution* dist = new Distribution (ParamSet() = {1.2, 1.4, 2.0, 0.4});
|
||||||
|
fg.addFactor (new Factor (FgVarSet() = {varNode1, varNode2}, dist));
|
||||||
|
fg.addFactor (new Factor (FgVarSet() = {varNode3, varNode2}, dist));
|
||||||
|
//fg.printGraphicalModel();
|
||||||
|
//SPSolver sp (fg);
|
||||||
|
//sp.runSolver();
|
||||||
|
//sp.printAllPosterioris();
|
||||||
|
//ParamSet p = sp.getJointDistributionOf (VidSet() = {0, 1, 2});
|
||||||
|
//cout << Util::parametersToString (p) << endl;
|
||||||
|
CountingBP cbp (fg);
|
||||||
|
//cbp.runSolver();
|
||||||
|
//cbp.printAllPosterioris();
|
||||||
|
ParamSet p2 = cbp.getJointDistributionOf (VidSet() = {0, 1, 2});
|
||||||
|
cout << Util::parametersToString (p2) << endl;
|
||||||
|
fg.freeDistributions();
|
||||||
|
Statistics::printCompressingStats ("compressing.stats");
|
||||||
|
return 0;
|
||||||
|
*/
|
||||||
if (!argv[1]) {
|
if (!argv[1]) {
|
||||||
cerr << "error: no graphical model specified" << endl;
|
cerr << "error: no graphical model specified" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
string fileName = argv[1];
|
const string& fileName = argv[1];
|
||||||
string extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
const string& extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||||
if (extension == "xml") {
|
if (extension == "xml") {
|
||||||
BayesianNetwork (argc, argv);
|
BayesianNetwork (argc, argv);
|
||||||
} else if (extension == "uai") {
|
} else if (extension == "uai") {
|
||||||
@ -45,13 +73,13 @@ void
|
|||||||
BayesianNetwork (int argc, const char* argv[])
|
BayesianNetwork (int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
BayesNet bn (argv[1]);
|
BayesNet bn (argv[1]);
|
||||||
//bn.printNetwork();
|
//bn.printGraphicalModel();
|
||||||
|
|
||||||
NodeSet queryVars;
|
VarSet queryVars;
|
||||||
for (int i = 2; i < argc; i++) {
|
for (int i = 2; i < argc; i++) {
|
||||||
string arg = argv[i];
|
const string& arg = argv[i];
|
||||||
if (arg.find ('=') == std::string::npos) {
|
if (arg.find ('=') == std::string::npos) {
|
||||||
BayesNode* queryVar = bn.getNode (arg);
|
BayesNode* queryVar = bn.getBayesNode (arg);
|
||||||
if (queryVar) {
|
if (queryVar) {
|
||||||
queryVars.push_back (queryVar);
|
queryVars.push_back (queryVar);
|
||||||
} else {
|
} else {
|
||||||
@ -61,9 +89,9 @@ BayesianNetwork (int argc, const char* argv[])
|
|||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
size_t pos = arg.find ('=');
|
size_t pos = arg.find ('=');
|
||||||
string label = arg.substr (0, pos);
|
const string& label = arg.substr (0, pos);
|
||||||
string state = arg.substr (pos + 1);
|
const string& state = arg.substr (pos + 1);
|
||||||
if (label.empty()) {
|
if (label.empty()) {
|
||||||
cerr << "error: missing left argument" << endl;
|
cerr << "error: missing left argument" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
@ -74,7 +102,7 @@ BayesianNetwork (int argc, const char* argv[])
|
|||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
BayesNode* node = bn.getNode (label);
|
BayesNode* node = bn.getBayesNode (label);
|
||||||
if (node) {
|
if (node) {
|
||||||
if (node->isValidState (state)) {
|
if (node->isValidState (state)) {
|
||||||
node->setEvidence (state);
|
node->setEvidence (state);
|
||||||
@ -94,19 +122,16 @@ BayesianNetwork (int argc, const char* argv[])
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BPSolver solver (bn);
|
Solver* solver;
|
||||||
if (queryVars.size() == 0) {
|
if (SolverOptions::convertBn2Fg) {
|
||||||
solver.runSolver();
|
FactorGraph* fg = new FactorGraph (bn);
|
||||||
solver.printAllPosterioris();
|
fg->printGraphicalModel();
|
||||||
} else if (queryVars.size() == 1) {
|
solver = new SPSolver (*fg);
|
||||||
solver.runSolver();
|
runSolver (solver, queryVars);
|
||||||
solver.printPosterioriOf (queryVars[0]);
|
delete fg;
|
||||||
} else {
|
} else {
|
||||||
Domain domain = BayesNet::getInstantiations(queryVars);
|
solver = new BPSolver (bn);
|
||||||
ParamSet params = solver.getJointDistribution (queryVars);
|
runSolver (solver, queryVars);
|
||||||
for (unsigned i = 0; i < params.size(); i++) {
|
|
||||||
cout << domain[i] << "\t" << params[i] << endl;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
bn.freeDistributions();
|
bn.freeDistributions();
|
||||||
}
|
}
|
||||||
@ -117,11 +142,11 @@ void
|
|||||||
markovNetwork (int argc, const char* argv[])
|
markovNetwork (int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
FactorGraph fg (argv[1]);
|
FactorGraph fg (argv[1]);
|
||||||
//fg.printFactorGraph();
|
//fg.printGraphicalModel();
|
||||||
|
|
||||||
VarSet queryVars;
|
VarSet queryVars;
|
||||||
for (int i = 2; i < argc; i++) {
|
for (int i = 2; i < argc; i++) {
|
||||||
string arg = argv[i];
|
const string& arg = argv[i];
|
||||||
if (arg.find ('=') == std::string::npos) {
|
if (arg.find ('=') == std::string::npos) {
|
||||||
if (!Util::isInteger (arg)) {
|
if (!Util::isInteger (arg)) {
|
||||||
cerr << "error: `" << arg << "' " ;
|
cerr << "error: `" << arg << "' " ;
|
||||||
@ -129,16 +154,16 @@ markovNetwork (int argc, const char* argv[])
|
|||||||
cerr << endl;
|
cerr << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
unsigned varId;
|
Vid vid;
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << arg;
|
ss << arg;
|
||||||
ss >> varId;
|
ss >> vid;
|
||||||
Variable* queryVar = fg.getVariableById (varId);
|
Variable* queryVar = fg.getFgVarNode (vid);
|
||||||
if (queryVar) {
|
if (queryVar) {
|
||||||
queryVars.push_back (queryVar);
|
queryVars.push_back (queryVar);
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: there isn't a variable with " ;
|
cerr << "error: there isn't a variable with " ;
|
||||||
cerr << "`" << varId << "' as id" ;
|
cerr << "`" << vid << "' as id" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
@ -160,11 +185,11 @@ markovNetwork (int argc, const char* argv[])
|
|||||||
cerr << endl;
|
cerr << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
unsigned varId;
|
Vid vid;
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << arg.substr (0, pos);
|
ss << arg.substr (0, pos);
|
||||||
ss >> varId;
|
ss >> vid;
|
||||||
Variable* var = fg.getVariableById (varId);
|
Variable* var = fg.getFgVarNode (vid);
|
||||||
if (var) {
|
if (var) {
|
||||||
if (!Util::isInteger (arg.substr (pos + 1))) {
|
if (!Util::isInteger (arg.substr (pos + 1))) {
|
||||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
||||||
@ -176,7 +201,6 @@ markovNetwork (int argc, const char* argv[])
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << arg.substr (pos + 1);
|
ss << arg.substr (pos + 1);
|
||||||
ss >> stateIndex;
|
ss >> stateIndex;
|
||||||
cout << "si: " << stateIndex << endl;
|
|
||||||
if (var->isValidStateIndex (stateIndex)) {
|
if (var->isValidStateIndex (stateIndex)) {
|
||||||
var->setEvidence (stateIndex);
|
var->setEvidence (stateIndex);
|
||||||
} else {
|
} else {
|
||||||
@ -188,27 +212,35 @@ markovNetwork (int argc, const char* argv[])
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: there isn't a variable with " ;
|
cerr << "error: there isn't a variable with " ;
|
||||||
cerr << "`" << varId << "' as id" ;
|
cerr << "`" << vid << "' as id" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Solver* solver = new SPSolver (fg);
|
||||||
SPSolver solver (fg);
|
runSolver (solver, queryVars);
|
||||||
if (queryVars.size() == 0) {
|
fg.freeDistributions();
|
||||||
solver.runSolver();
|
}
|
||||||
solver.printAllPosterioris();
|
|
||||||
} else if (queryVars.size() == 1) {
|
|
||||||
solver.runSolver();
|
|
||||||
solver.printPosterioriOf (queryVars[0]);
|
void
|
||||||
} else {
|
runSolver (Solver* solver, const VarSet& queryVars)
|
||||||
assert (false); //FIXME
|
{
|
||||||
//Domain domain = BayesNet::getInstantiations(queryVars);
|
VidSet vids;
|
||||||
//ParamSet params = solver.getJointDistribution (queryVars);
|
for (unsigned i = 0; i < queryVars.size(); i++) {
|
||||||
//for (unsigned i = 0; i < params.size(); i++) {
|
vids.push_back (queryVars[i]->getVarId());
|
||||||
// cout << domain[i] << "\t" << params[i] << endl;
|
}
|
||||||
//}
|
if (queryVars.size() == 0) {
|
||||||
}
|
solver->runSolver();
|
||||||
|
solver->printAllPosterioris();
|
||||||
|
} else if (queryVars.size() == 1) {
|
||||||
|
solver->runSolver();
|
||||||
|
solver->printPosterioriOf (vids[0]);
|
||||||
|
} else {
|
||||||
|
solver->printJointDistributionOf (vids);
|
||||||
|
}
|
||||||
|
delete solver;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,41 +1,39 @@
|
|||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include <YapInterface.h>
|
#include <YapInterface.h>
|
||||||
|
|
||||||
#include "callgrind.h"
|
|
||||||
|
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
#include "BayesNode.h"
|
#include "FactorGraph.h"
|
||||||
#include "BPSolver.h"
|
#include "BPSolver.h"
|
||||||
|
#include "SPSolver.h"
|
||||||
|
#include "CountingBP.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
createNetwork (void)
|
createNetwork (void)
|
||||||
{
|
{
|
||||||
Statistics::numCreatedNets ++;
|
//Statistics::numCreatedNets ++;
|
||||||
cout << "creating network number " << Statistics::numCreatedNets << endl;
|
//cout << "creating network number " << Statistics::numCreatedNets << endl;
|
||||||
if (Statistics::numCreatedNets == 1) {
|
|
||||||
//CALLGRIND_START_INSTRUMENTATION;
|
|
||||||
}
|
|
||||||
BayesNet* bn = new BayesNet();
|
|
||||||
|
|
||||||
|
BayesNet* bn = new BayesNet();
|
||||||
YAP_Term varList = YAP_ARG1;
|
YAP_Term varList = YAP_ARG1;
|
||||||
while (varList != YAP_TermNil()) {
|
while (varList != YAP_TermNil()) {
|
||||||
YAP_Term var = YAP_HeadOfTerm (varList);
|
YAP_Term var = YAP_HeadOfTerm (varList);
|
||||||
unsigned varId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
|
Vid vid = (Vid) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
|
||||||
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var));
|
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var));
|
||||||
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
|
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
|
||||||
YAP_Term parentL = YAP_ArgOfTerm (4, var);
|
YAP_Term parentL = YAP_ArgOfTerm (4, var);
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
|
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
|
||||||
NodeSet parents;
|
BnNodeSet parents;
|
||||||
while (parentL != YAP_TermNil()) {
|
while (parentL != YAP_TermNil()) {
|
||||||
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
|
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
|
||||||
BayesNode* parent = bn->getNode (parentId);
|
BayesNode* parent = bn->getBayesNode (parentId);
|
||||||
if (!parent) {
|
if (!parent) {
|
||||||
parent = bn->addNode (parentId);
|
parent = bn->addNode (parentId);
|
||||||
}
|
}
|
||||||
@ -47,23 +45,20 @@ createNetwork (void)
|
|||||||
dist = new Distribution (distId);
|
dist = new Distribution (distId);
|
||||||
bn->addDistribution (dist);
|
bn->addDistribution (dist);
|
||||||
}
|
}
|
||||||
BayesNode* node = bn->getNode (varId);
|
BayesNode* node = bn->getBayesNode (vid);
|
||||||
if (node) {
|
if (node) {
|
||||||
node->setData (dsize, evidence, parents, dist);
|
node->setData (dsize, evidence, parents, dist);
|
||||||
} else {
|
} else {
|
||||||
bn->addNode (varId, dsize, evidence, parents, dist);
|
bn->addNode (vid, dsize, evidence, parents, dist);
|
||||||
}
|
}
|
||||||
varList = YAP_TailOfTerm (varList);
|
varList = YAP_TailOfTerm (varList);
|
||||||
}
|
}
|
||||||
bn->setIndexes();
|
bn->setIndexes();
|
||||||
|
|
||||||
if (Statistics::numCreatedNets == 1688) {
|
// if (Statistics::numCreatedNets == 1688) {
|
||||||
Statistics::writeStats();
|
// Statistics::writeStats();
|
||||||
//Statistics::writeStats();
|
// exit (0);
|
||||||
//CALLGRIND_STOP_INSTRUMENTATION;
|
// }
|
||||||
//CALLGRIND_DUMP_STATS;
|
|
||||||
//exit (0);
|
|
||||||
}
|
|
||||||
YAP_Int p = (YAP_Int) (bn);
|
YAP_Int p = (YAP_Int) (bn);
|
||||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
|
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
|
||||||
}
|
}
|
||||||
@ -73,20 +68,20 @@ createNetwork (void)
|
|||||||
int
|
int
|
||||||
setExtraVarsInfo (void)
|
setExtraVarsInfo (void)
|
||||||
{
|
{
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
YAP_Term varsInfoL = YAP_ARG2;
|
YAP_Term varsInfoL = YAP_ARG2;
|
||||||
while (varsInfoL != YAP_TermNil()) {
|
while (varsInfoL != YAP_TermNil()) {
|
||||||
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
||||||
unsigned varId = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
|
Vid vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
|
||||||
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
|
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
|
||||||
YAP_Term domainL = YAP_ArgOfTerm (3, head);
|
YAP_Term domainL = YAP_ArgOfTerm (3, head);
|
||||||
Domain domain;
|
Domain domain;
|
||||||
while (domainL != YAP_TermNil()) {
|
while (domainL != YAP_TermNil()) {
|
||||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (domainL));
|
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (domainL));
|
||||||
domain.push_back ((char*) YAP_AtomName (atom));
|
domain.push_back ((char*) YAP_AtomName (atom));
|
||||||
domainL = YAP_TailOfTerm (domainL);
|
domainL = YAP_TailOfTerm (domainL);
|
||||||
}
|
}
|
||||||
BayesNode* node = bn->getNode (varId);
|
BayesNode* node = bn->getBayesNode (vid);
|
||||||
assert (node);
|
assert (node);
|
||||||
node->setLabel ((char*) YAP_AtomName (label));
|
node->setLabel ((char*) YAP_AtomName (label));
|
||||||
node->setDomain (domain);
|
node->setDomain (domain);
|
||||||
@ -100,8 +95,8 @@ setExtraVarsInfo (void)
|
|||||||
int
|
int
|
||||||
setParameters (void)
|
setParameters (void)
|
||||||
{
|
{
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
YAP_Term distList = YAP_ARG2;
|
YAP_Term distList = YAP_ARG2;
|
||||||
while (distList != YAP_TermNil()) {
|
while (distList != YAP_TermNil()) {
|
||||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||||
@ -112,6 +107,11 @@ setParameters (void)
|
|||||||
paramL = YAP_TailOfTerm (paramL);
|
paramL = YAP_TailOfTerm (paramL);
|
||||||
}
|
}
|
||||||
bn->getDistribution(distId)->updateParameters(params);
|
bn->getDistribution(distId)->updateParameters(params);
|
||||||
|
if (Statistics::numCreatedNets == 4) {
|
||||||
|
cout << "dist " << distId << " parameters:" ;
|
||||||
|
cout << Util::parametersToString (params);
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
distList = YAP_TailOfTerm (distList);
|
distList = YAP_TailOfTerm (distList);
|
||||||
}
|
}
|
||||||
return TRUE;
|
return TRUE;
|
||||||
@ -122,84 +122,126 @@ setParameters (void)
|
|||||||
int
|
int
|
||||||
runSolver (void)
|
runSolver (void)
|
||||||
{
|
{
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
YAP_Term taskList = YAP_ARG2;
|
YAP_Term taskList = YAP_ARG2;
|
||||||
|
vector<VidSet> tasks;
|
||||||
vector<NodeSet> tasks;
|
VidSet marginalVids;
|
||||||
NodeSet marginalVars;
|
|
||||||
|
|
||||||
while (taskList != YAP_TermNil()) {
|
while (taskList != YAP_TermNil()) {
|
||||||
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) {
|
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) {
|
||||||
NodeSet jointVars;
|
VidSet jointVids;
|
||||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||||
while (jointList != YAP_TermNil()) {
|
while (jointList != YAP_TermNil()) {
|
||||||
unsigned varId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
Vid vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
||||||
assert (bn->getNode (varId));
|
assert (bn->getBayesNode (vid));
|
||||||
jointVars.push_back (bn->getNode (varId));
|
jointVids.push_back (vid);
|
||||||
jointList = YAP_TailOfTerm (jointList);
|
jointList = YAP_TailOfTerm (jointList);
|
||||||
}
|
}
|
||||||
tasks.push_back (jointVars);
|
tasks.push_back (jointVids);
|
||||||
} else {
|
} else {
|
||||||
unsigned varId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
|
Vid vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
|
||||||
BayesNode* node = bn->getNode (varId);
|
assert (bn->getBayesNode (vid));
|
||||||
assert (node);
|
tasks.push_back (VidSet() = {vid});
|
||||||
tasks.push_back (NodeSet() = {node});
|
marginalVids.push_back (vid);
|
||||||
marginalVars.push_back (node);
|
|
||||||
}
|
}
|
||||||
taskList = YAP_TailOfTerm (taskList);
|
taskList = YAP_TailOfTerm (taskList);
|
||||||
}
|
}
|
||||||
/*
|
|
||||||
cout << "tasks to resolve:" << endl;
|
|
||||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
|
||||||
cout << "i" << ": " ;
|
|
||||||
if (tasks[i].size() == 1) {
|
|
||||||
cout << tasks[i][0]->getVarId() << endl;
|
|
||||||
} else {
|
|
||||||
for (unsigned j = 0; j < tasks[i].size(); j++) {
|
|
||||||
cout << tasks[i][j]->getVarId() << " " ;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
cerr << "prunning now..." << endl;
|
// cout << "inference tasks:" << endl;
|
||||||
BayesNet* prunedNet = bn->pruneNetwork (marginalVars);
|
// for (unsigned i = 0; i < tasks.size(); i++) {
|
||||||
bn->printNetworkToFile ("net.txt");
|
// cout << "i" << ": " ;
|
||||||
BPSolver solver (*prunedNet);
|
// if (tasks[i].size() == 1) {
|
||||||
cerr << "solving marginals now..." << endl;
|
// cout << tasks[i][0] << endl;
|
||||||
solver.runSolver();
|
// } else {
|
||||||
cerr << "calculating joints now ..." << endl;
|
// for (unsigned j = 0; j < tasks[i].size(); j++) {
|
||||||
|
// cout << tasks[i][j] << " " ;
|
||||||
|
// }
|
||||||
|
// cout << endl;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
Solver* solver = 0;
|
||||||
|
GraphicalModel* gm = 0;
|
||||||
|
VidSet vids;
|
||||||
|
const BnNodeSet& nodes = bn->getBayesNodes();
|
||||||
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
|
vids.push_back (nodes[i]->getVarId());
|
||||||
|
}
|
||||||
|
if (marginalVids.size() != 0) {
|
||||||
|
bn->exportToDotFormat ("bn unbayes.dot");
|
||||||
|
BayesNet* mrn = bn->getMinimalRequesiteNetwork (marginalVids);
|
||||||
|
mrn->exportToDotFormat ("bn bayes.dot");
|
||||||
|
//BayesNet* mrn = bn->getMinimalRequesiteNetwork (vids);
|
||||||
|
if (SolverOptions::convertBn2Fg) {
|
||||||
|
gm = new FactorGraph (*mrn);
|
||||||
|
if (SolverOptions::compressFactorGraph) {
|
||||||
|
solver = new CountingBP (*static_cast<FactorGraph*> (gm));
|
||||||
|
} else {
|
||||||
|
solver = new SPSolver (*static_cast<FactorGraph*> (gm));
|
||||||
|
}
|
||||||
|
if (SolverOptions::runBayesBall) {
|
||||||
|
delete mrn;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gm = mrn;
|
||||||
|
solver = new BPSolver (*static_cast<BayesNet*> (gm));
|
||||||
|
}
|
||||||
|
solver->runSolver();
|
||||||
|
}
|
||||||
|
|
||||||
vector<ParamSet> results;
|
vector<ParamSet> results;
|
||||||
results.reserve (tasks.size());
|
results.reserve (tasks.size());
|
||||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||||
if (tasks[i].size() == 1) {
|
if (tasks[i].size() == 1) {
|
||||||
BayesNode* node = prunedNet->getNode (tasks[i][0]->getVarId());
|
results.push_back (solver->getPosterioriOf (tasks[i][0]));
|
||||||
results.push_back (solver.getPosterioriOf (node));
|
|
||||||
} else {
|
} else {
|
||||||
BPSolver solver2 (*bn);
|
static int count = 0;
|
||||||
cout << "calculating an join dist on: " ;
|
cout << "calculating joint... " << count ++ << endl;
|
||||||
for (unsigned j = 0; j < tasks[i].size(); j++) {
|
//if (count == 5225) {
|
||||||
cout << tasks[i][j]->getVarId() << " " ;
|
// Statistics::printCompressingStats ("compressing.stats");
|
||||||
|
//}
|
||||||
|
Solver* solver2 = 0;
|
||||||
|
GraphicalModel* gm2 = 0;
|
||||||
|
bn->exportToDotFormat ("joint.dot");
|
||||||
|
BayesNet* mrn2;
|
||||||
|
if (SolverOptions::runBayesBall) {
|
||||||
|
mrn2 = bn->getMinimalRequesiteNetwork (tasks[i]);
|
||||||
|
} else {
|
||||||
|
mrn2 = bn;
|
||||||
}
|
}
|
||||||
cout << "..." << endl;
|
if (SolverOptions::convertBn2Fg) {
|
||||||
results.push_back (solver2.getJointDistribution (tasks[i]));
|
gm2 = new FactorGraph (*mrn2);
|
||||||
|
if (SolverOptions::compressFactorGraph) {
|
||||||
|
solver2 = new CountingBP (*static_cast<FactorGraph*> (gm2));
|
||||||
|
} else {
|
||||||
|
solver2 = new SPSolver (*static_cast<FactorGraph*> (gm2));
|
||||||
|
}
|
||||||
|
if (SolverOptions::runBayesBall) {
|
||||||
|
delete mrn2;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gm2 = mrn2;
|
||||||
|
solver2 = new BPSolver (*static_cast<BayesNet*> (gm2));
|
||||||
|
}
|
||||||
|
results.push_back (solver2->getJointDistributionOf (tasks[i]));
|
||||||
|
delete solver2;
|
||||||
|
delete gm2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
delete prunedNet;
|
delete solver;
|
||||||
|
delete gm;
|
||||||
|
|
||||||
YAP_Term list = YAP_TermNil();
|
YAP_Term list = YAP_TermNil();
|
||||||
for (int i = results.size() - 1; i >= 0; i--) {
|
for (int i = results.size() - 1; i >= 0; i--) {
|
||||||
const ParamSet& beliefs = results[i];
|
const ParamSet& beliefs = results[i];
|
||||||
YAP_Term queryBeliefsL = YAP_TermNil();
|
YAP_Term queryBeliefsL = YAP_TermNil();
|
||||||
for (int j = beliefs.size() - 1; j >= 0; j--) {
|
for (int j = beliefs.size() - 1; j >= 0; j--) {
|
||||||
YAP_Int sl1 = YAP_InitSlot(list);
|
YAP_Int sl1 = YAP_InitSlot (list);
|
||||||
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
||||||
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
||||||
list = YAP_GetFromSlot(sl1);
|
list = YAP_GetFromSlot (sl1);
|
||||||
YAP_RecoverSlots(1);
|
YAP_RecoverSlots (1);
|
||||||
}
|
}
|
||||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||||
}
|
}
|
||||||
@ -210,8 +252,9 @@ runSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
deleteBayesNet (void)
|
freeBayesNetwork (void)
|
||||||
{
|
{
|
||||||
|
//Statistics::printCompressingStats ("../../compressing.stats");
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
bn->freeDistributions();
|
bn->freeDistributions();
|
||||||
delete bn;
|
delete bn;
|
||||||
@ -223,10 +266,10 @@ deleteBayesNet (void)
|
|||||||
extern "C" void
|
extern "C" void
|
||||||
init_predicates (void)
|
init_predicates (void)
|
||||||
{
|
{
|
||||||
YAP_UserCPredicate ("create_network", createNetwork, 2);
|
YAP_UserCPredicate ("create_network", createNetwork, 2);
|
||||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
||||||
YAP_UserCPredicate ("set_parameters", setParameters, 2);
|
YAP_UserCPredicate ("set_parameters", setParameters, 2);
|
||||||
YAP_UserCPredicate ("run_solver", runSolver, 3);
|
YAP_UserCPredicate ("run_solver", runSolver, 3);
|
||||||
YAP_UserCPredicate ("delete_bayes_net", deleteBayesNet, 1);
|
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
278
packages/CLPBN/clpbn/bp/LiftedFG.cpp
Normal file
278
packages/CLPBN/clpbn/bp/LiftedFG.cpp
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
|
||||||
|
#include "LiftedFG.h"
|
||||||
|
#include "FgVarNode.h"
|
||||||
|
#include "Factor.h"
|
||||||
|
#include "Distribution.h"
|
||||||
|
|
||||||
|
LiftedFG::LiftedFG (const FactorGraph& fg)
|
||||||
|
{
|
||||||
|
groundFg_ = &fg;
|
||||||
|
freeColor_ = 0;
|
||||||
|
|
||||||
|
const FgVarSet& varNodes = fg.getFgVarNodes();
|
||||||
|
const FactorSet& factors = fg.getFactors();
|
||||||
|
varColors_.resize (varNodes.size());
|
||||||
|
factorColors_.resize (factors.size());
|
||||||
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
|
factors[i]->setIndex (i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the initial variable colors
|
||||||
|
VarColorMap colorMap;
|
||||||
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
|
unsigned dsize = varNodes[i]->getDomainSize();
|
||||||
|
VarColorMap::iterator it = colorMap.find (dsize);
|
||||||
|
if (it == colorMap.end()) {
|
||||||
|
it = colorMap.insert (make_pair (
|
||||||
|
dsize, vector<Color> (dsize + 1,-1))).first;
|
||||||
|
}
|
||||||
|
unsigned idx;
|
||||||
|
if (varNodes[i]->hasEvidence()) {
|
||||||
|
idx = varNodes[i]->getEvidence();
|
||||||
|
} else {
|
||||||
|
idx = dsize;
|
||||||
|
}
|
||||||
|
vector<Color>& stateColors = it->second;
|
||||||
|
if (stateColors[idx] == -1) {
|
||||||
|
stateColors[idx] = getFreeColor();
|
||||||
|
}
|
||||||
|
setColor (varNodes[i], stateColors[idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the initial factor colors
|
||||||
|
DistColorMap distColors;
|
||||||
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
|
Distribution* dist = factors[i]->getDistribution();
|
||||||
|
DistColorMap::iterator it = distColors.find (dist);
|
||||||
|
if (it == distColors.end()) {
|
||||||
|
it = distColors.insert (make_pair (dist, getFreeColor())).first;
|
||||||
|
}
|
||||||
|
setColor (factors[i], it->second);
|
||||||
|
}
|
||||||
|
|
||||||
|
VarSignMap varGroups;
|
||||||
|
FactorSignMap factorGroups;
|
||||||
|
bool groupsHaveChanged = true;
|
||||||
|
unsigned nIter = 0;
|
||||||
|
while (groupsHaveChanged || nIter == 1) {
|
||||||
|
nIter ++;
|
||||||
|
if (Statistics::numCreatedNets == 4) {
|
||||||
|
cout << "--------------------------------------------" << endl;
|
||||||
|
cout << "Iteration " << nIter << endl;
|
||||||
|
cout << "--------------------------------------------" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned prevFactorGroupsSize = factorGroups.size();
|
||||||
|
factorGroups.clear();
|
||||||
|
// set a new color to the factors with the same signature
|
||||||
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
|
const string& signatureId = getSignatureId (factors[i]);
|
||||||
|
// cout << factors[i]->getLabel() << " signature: " ;
|
||||||
|
// cout<< signatureId << endl;
|
||||||
|
FactorSignMap::iterator it = factorGroups.find (signatureId);
|
||||||
|
if (it == factorGroups.end()) {
|
||||||
|
it = factorGroups.insert (make_pair (signatureId, FactorSet())).first;
|
||||||
|
}
|
||||||
|
it->second.push_back (factors[i]);
|
||||||
|
}
|
||||||
|
if (nIter > 0)
|
||||||
|
for (FactorSignMap::iterator it = factorGroups.begin();
|
||||||
|
it != factorGroups.end(); it++) {
|
||||||
|
Color newColor = getFreeColor();
|
||||||
|
FactorSet& groupMembers = it->second;
|
||||||
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
|
setColor (groupMembers[i], newColor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set a new color to the variables with the same signature
|
||||||
|
unsigned prevVarGroupsSize = varGroups.size();
|
||||||
|
varGroups.clear();
|
||||||
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
|
const string& signatureId = getSignatureId (varNodes[i]);
|
||||||
|
VarSignMap::iterator it = varGroups.find (signatureId);
|
||||||
|
// cout << varNodes[i]->getLabel() << " signature: " ;
|
||||||
|
// cout << signatureId << endl;
|
||||||
|
if (it == varGroups.end()) {
|
||||||
|
it = varGroups.insert (make_pair (signatureId, FgVarSet())).first;
|
||||||
|
}
|
||||||
|
it->second.push_back (varNodes[i]);
|
||||||
|
}
|
||||||
|
if (nIter > 0)
|
||||||
|
for (VarSignMap::iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); it++) {
|
||||||
|
Color newColor = getFreeColor();
|
||||||
|
FgVarSet& groupMembers = it->second;
|
||||||
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
|
setColor (groupMembers[i], newColor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//if (nIter >= 3) cout << "bigger than three: " << nIter << endl;
|
||||||
|
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||||
|
|| prevFactorGroupsSize != factorGroups.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
printGroups (varGroups, factorGroups);
|
||||||
|
for (VarSignMap::iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); it++) {
|
||||||
|
CFgVarSet vars = it->second;
|
||||||
|
VarCluster* vc = new VarCluster (vars);
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
vid2VarCluster_.insert (make_pair (vars[i]->getVarId(), vc));
|
||||||
|
}
|
||||||
|
varClusters_.push_back (vc);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (FactorSignMap::iterator it = factorGroups.begin();
|
||||||
|
it != factorGroups.end(); it++) {
|
||||||
|
VarClusterSet varClusters;
|
||||||
|
Factor* groundFactor = it->second[0];
|
||||||
|
FgVarSet groundVars = groundFactor->getFgVarNodes();
|
||||||
|
for (unsigned i = 0; i < groundVars.size(); i++) {
|
||||||
|
Vid vid = groundVars[i]->getVarId();
|
||||||
|
varClusters.push_back (vid2VarCluster_.find (vid)->second);
|
||||||
|
}
|
||||||
|
factorClusters_.push_back (new FactorCluster (it->second, varClusters));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LiftedFG::~LiftedFG (void)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||||
|
delete varClusters_[i];
|
||||||
|
}
|
||||||
|
for (unsigned i = 0; i < factorClusters_.size(); i++) {
|
||||||
|
delete factorClusters_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
LiftedFG::getSignatureId (FgVarNode* var) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
CFactorSet myFactors = var->getFactors();
|
||||||
|
ss << myFactors.size();
|
||||||
|
for (unsigned i = 0; i < myFactors.size(); i++) {
|
||||||
|
ss << "." << getColor (myFactors[i]);
|
||||||
|
ss << "." << myFactors[i]->getIndexOf(var);
|
||||||
|
}
|
||||||
|
ss << "." << getColor (var);
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
LiftedFG::getSignatureId (Factor* factor) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
CFgVarSet myVars = factor->getFgVarNodes();
|
||||||
|
ss << myVars.size();
|
||||||
|
for (unsigned i = 0; i < myVars.size(); i++) {
|
||||||
|
ss << "." << getColor (myVars[i]);
|
||||||
|
}
|
||||||
|
ss << "." << getColor (factor);
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph*
|
||||||
|
LiftedFG::getCompressedFactorGraph (void)
|
||||||
|
{
|
||||||
|
FactorGraph* fg = new FactorGraph();
|
||||||
|
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||||
|
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0];
|
||||||
|
FgVarNode* newVar = new FgVarNode (var);
|
||||||
|
newVar->setIndex (i);
|
||||||
|
varClusters_[i]->setRepresentativeVariable (newVar);
|
||||||
|
fg->addVariable (newVar);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < factorClusters_.size(); i++) {
|
||||||
|
FgVarSet myGroundVars;
|
||||||
|
const VarClusterSet& myVarClusters = factorClusters_[i]->getVarClusters();
|
||||||
|
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||||
|
myGroundVars.push_back (myVarClusters[j]->getRepresentativeVariable());
|
||||||
|
}
|
||||||
|
Factor* newFactor = new Factor (myGroundVars,
|
||||||
|
factorClusters_[i]->getGroundFactors()[0]->getDistribution());
|
||||||
|
factorClusters_[i]->setRepresentativeFactor (newFactor);
|
||||||
|
fg->addFactor (newFactor);
|
||||||
|
}
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
LiftedFG::getGroundEdgeCount (FactorCluster* fc, VarCluster* vc) const
|
||||||
|
{
|
||||||
|
CFactorSet clusterGroundFactors = fc->getGroundFactors();
|
||||||
|
FgVarNode* var = vc->getGroundFgVarNodes()[0];
|
||||||
|
unsigned count = 0;
|
||||||
|
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||||
|
if (clusterGroundFactors[i]->getIndexOf (var) != -1) {
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
CFgVarSet vars = vc->getGroundFgVarNodes();
|
||||||
|
for (unsigned i = 1; i < vars.size(); i++) {
|
||||||
|
FgVarNode* var = vc->getGroundFgVarNodes()[i];
|
||||||
|
unsigned count2 = 0;
|
||||||
|
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||||
|
if (clusterGroundFactors[i]->getIndexOf (var) != -1) {
|
||||||
|
count2 ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (count != count2) { cout << "oops!" << endl; abort(); }
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedFG::printGroups (const VarSignMap& varGroups,
|
||||||
|
const FactorSignMap& factorGroups) const
|
||||||
|
{
|
||||||
|
cout << "variable groups:" << endl;
|
||||||
|
unsigned count = 0;
|
||||||
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); it++) {
|
||||||
|
const FgVarSet& groupMembers = it->second;
|
||||||
|
if (groupMembers.size() > 0) {
|
||||||
|
cout << ++count << ": " ;
|
||||||
|
//if (groupMembers.size() > 1) {
|
||||||
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
|
cout << groupMembers[i]->getLabel() << " " ;
|
||||||
|
}
|
||||||
|
//}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
cout << "factor groups:" << endl;
|
||||||
|
count = 0;
|
||||||
|
for (FactorSignMap::const_iterator it = factorGroups.begin();
|
||||||
|
it != factorGroups.end(); it++) {
|
||||||
|
const FactorSet& groupMembers = it->second;
|
||||||
|
if (groupMembers.size() > 0) {
|
||||||
|
cout << ++count << ": " ;
|
||||||
|
//if (groupMembers.size() > 1) {
|
||||||
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
|
cout << groupMembers[i]->getLabel() << " " ;
|
||||||
|
}
|
||||||
|
//}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
152
packages/CLPBN/clpbn/bp/LiftedFG.h
Normal file
152
packages/CLPBN/clpbn/bp/LiftedFG.h
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
#ifndef BP_LIFTED_FG_H
|
||||||
|
#define BP_LIFTED_FG_H
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "FgVarNode.h"
|
||||||
|
#include "Factor.h"
|
||||||
|
#include "Shared.h"
|
||||||
|
|
||||||
|
class VarCluster;
|
||||||
|
class FactorCluster;
|
||||||
|
class Distribution;
|
||||||
|
|
||||||
|
typedef long Color;
|
||||||
|
typedef vector<Color> Signature;
|
||||||
|
typedef vector<VarCluster*> VarClusterSet;
|
||||||
|
typedef vector<FactorCluster*> FactorClusterSet;
|
||||||
|
|
||||||
|
typedef map<string, FgVarSet> VarSignMap;
|
||||||
|
typedef map<string, FactorSet> FactorSignMap;
|
||||||
|
|
||||||
|
typedef map<unsigned, vector<Color> > VarColorMap;
|
||||||
|
typedef map<Distribution*, Color> DistColorMap;
|
||||||
|
|
||||||
|
typedef map<Vid, VarCluster*> Vid2VarCluster;
|
||||||
|
|
||||||
|
|
||||||
|
class VarCluster
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
VarCluster (CFgVarSet vs)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < vs.size(); i++) {
|
||||||
|
groundVars_.push_back (vs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void addFactorCluster (FactorCluster* fc)
|
||||||
|
{
|
||||||
|
factorClusters_.push_back (fc);
|
||||||
|
}
|
||||||
|
|
||||||
|
const FactorClusterSet& getFactorClusters (void) const
|
||||||
|
{
|
||||||
|
return factorClusters_;
|
||||||
|
}
|
||||||
|
|
||||||
|
FgVarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||||
|
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
|
||||||
|
CFgVarSet getGroundFgVarNodes (void) const { return groundVars_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
FgVarSet groundVars_;
|
||||||
|
FactorClusterSet factorClusters_;
|
||||||
|
FgVarNode* representVar_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class FactorCluster
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
FactorCluster (CFactorSet groundFactors, const VarClusterSet& vcs)
|
||||||
|
{
|
||||||
|
groundFactors_ = groundFactors;
|
||||||
|
varClusters_ = vcs;
|
||||||
|
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||||
|
varClusters_[i]->addFactorCluster (this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const VarClusterSet& getVarClusters (void) const
|
||||||
|
{
|
||||||
|
return varClusters_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool containsGround (const Factor* f)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
||||||
|
if (groundFactors_[i] == f) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
Factor* getRepresentativeFactor (void) const { return representFactor_; }
|
||||||
|
void setRepresentativeFactor (Factor* f) { representFactor_ = f; }
|
||||||
|
CFactorSet getGroundFactors (void) const { return groundFactors_; }
|
||||||
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
FactorSet groundFactors_;
|
||||||
|
VarClusterSet varClusters_;
|
||||||
|
Factor* representFactor_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class LiftedFG
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LiftedFG (const FactorGraph&);
|
||||||
|
~LiftedFG (void);
|
||||||
|
|
||||||
|
FactorGraph* getCompressedFactorGraph (void);
|
||||||
|
unsigned getGroundEdgeCount (FactorCluster*, VarCluster*) const;
|
||||||
|
void printGroups (const VarSignMap& varGroups,
|
||||||
|
const FactorSignMap& factorGroups) const;
|
||||||
|
|
||||||
|
FgVarNode* getEquivalentVariable (Vid vid)
|
||||||
|
{
|
||||||
|
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||||
|
return vc->getRepresentativeVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
const VarClusterSet& getVariableClusters (void) { return varClusters_; }
|
||||||
|
const FactorClusterSet& getFactorClusters (void) { return factorClusters_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
string getSignatureId (FgVarNode*) const;
|
||||||
|
string getSignatureId (Factor*) const;
|
||||||
|
|
||||||
|
Color getFreeColor (void) { return ++freeColor_ -1; }
|
||||||
|
Color getColor (FgVarNode* v) const { return varColors_[v->getIndex()]; }
|
||||||
|
Color getColor (Factor* f) const { return factorColors_[f->getIndex()]; }
|
||||||
|
|
||||||
|
void setColor (FgVarNode* v, Color c)
|
||||||
|
{
|
||||||
|
varColors_[v->getIndex()] = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
void setColor (Factor* f, Color c)
|
||||||
|
{
|
||||||
|
factorColors_[f->getIndex()] = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
VarCluster* getVariableCluster (Vid vid) const
|
||||||
|
{
|
||||||
|
return vid2VarCluster_.find (vid)->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
Color freeColor_;
|
||||||
|
vector<Color> varColors_;
|
||||||
|
vector<Color> factorColors_;
|
||||||
|
VarClusterSet varClusters_;
|
||||||
|
FactorClusterSet factorClusters_;
|
||||||
|
Vid2VarCluster vid2VarCluster_;
|
||||||
|
const FactorGraph* groundFg_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // BP_LIFTED_FG_H
|
||||||
|
|
@ -50,28 +50,33 @@ CWD=$(PWD)
|
|||||||
HEADERS = \
|
HEADERS = \
|
||||||
$(srcdir)/GraphicalModel.h \
|
$(srcdir)/GraphicalModel.h \
|
||||||
$(srcdir)/Variable.h \
|
$(srcdir)/Variable.h \
|
||||||
|
$(srcdir)/Distribution.h \
|
||||||
$(srcdir)/BayesNet.h \
|
$(srcdir)/BayesNet.h \
|
||||||
$(srcdir)/BayesNode.h \
|
$(srcdir)/BayesNode.h \
|
||||||
$(srcdir)/Distribution.h \
|
$(srcdir)/LiftedFG.h \
|
||||||
$(srcdir)/CptEntry.h \
|
$(srcdir)/CptEntry.h \
|
||||||
$(srcdir)/FactorGraph.h \
|
$(srcdir)/FactorGraph.h \
|
||||||
$(srcdir)/FgVarNode.h \
|
$(srcdir)/FgVarNode.h \
|
||||||
$(srcdir)/Factor.h \
|
$(srcdir)/Factor.h \
|
||||||
$(srcdir)/Solver.h \
|
$(srcdir)/Solver.h \
|
||||||
$(srcdir)/BPSolver.h \
|
$(srcdir)/BPSolver.h \
|
||||||
$(srcdir)/BpNode.h \
|
$(srcdir)/BPNodeInfo.h \
|
||||||
$(srcdir)/SPSolver.h \
|
$(srcdir)/SPSolver.h \
|
||||||
|
$(srcdir)/CountingBP.h \
|
||||||
$(srcdir)/Shared.h \
|
$(srcdir)/Shared.h \
|
||||||
$(srcdir)/xmlParser/xmlParser.h
|
$(srcdir)/xmlParser/xmlParser.h
|
||||||
|
|
||||||
CPP_SOURCES = \
|
CPP_SOURCES = \
|
||||||
$(srcdir)/BayesNet.cpp \
|
$(srcdir)/BayesNet.cpp \
|
||||||
$(srcdir)/BayesNode.cpp \
|
$(srcdir)/BayesNode.cpp \
|
||||||
$(srcdir)/FactorGraph.cpp \
|
$(srcdir)/FactorGraph.cpp \
|
||||||
$(srcdir)/Factor.cpp \
|
$(srcdir)/Factor.cpp \
|
||||||
|
$(srcdir)/LiftedFG.cpp \
|
||||||
$(srcdir)/BPSolver.cpp \
|
$(srcdir)/BPSolver.cpp \
|
||||||
$(srcdir)/BpNode.cpp \
|
$(srcdir)/BPNodeInfo.cpp \
|
||||||
$(srcdir)/SPSolver.cpp \
|
$(srcdir)/SPSolver.cpp \
|
||||||
|
$(srcdir)/CountingBP.cpp \
|
||||||
|
$(srcdir)/Util.cpp \
|
||||||
$(srcdir)/HorusYap.cpp \
|
$(srcdir)/HorusYap.cpp \
|
||||||
$(srcdir)/HorusCli.cpp \
|
$(srcdir)/HorusCli.cpp \
|
||||||
$(srcdir)/xmlParser/xmlParser.cpp
|
$(srcdir)/xmlParser/xmlParser.cpp
|
||||||
@ -82,22 +87,38 @@ OBJS = \
|
|||||||
FactorGraph.o \
|
FactorGraph.o \
|
||||||
Factor.o \
|
Factor.o \
|
||||||
BPSolver.o \
|
BPSolver.o \
|
||||||
BpNode.o \
|
BPNodeInfo.o \
|
||||||
SPSolver.o \
|
SPSolver.o \
|
||||||
HorusYap.o \
|
Util.o \
|
||||||
xmlParser.o
|
LiftedFG.o \
|
||||||
|
CountingBP.o \
|
||||||
|
HorusYap.o
|
||||||
|
|
||||||
|
HCLI_OBJS = \
|
||||||
|
BayesNet.o \
|
||||||
|
BayesNode.o \
|
||||||
|
FactorGraph.o \
|
||||||
|
Factor.o \
|
||||||
|
BPSolver.o \
|
||||||
|
BPNodeInfo.o \
|
||||||
|
SPSolver.o \
|
||||||
|
Util.o \
|
||||||
|
LiftedFG.o \
|
||||||
|
CountingBP.o \
|
||||||
|
HorusCli.o \
|
||||||
|
xmlParser/xmlParser.o
|
||||||
|
|
||||||
SOBJS=horus.@SO@
|
SOBJS=horus.@SO@
|
||||||
|
|
||||||
|
|
||||||
all: $(SOBJS)
|
all: $(SOBJS) hcli
|
||||||
|
|
||||||
# default rule
|
# default rule
|
||||||
%.o : $(srcdir)/%.cpp
|
%.o : $(srcdir)/%.cpp
|
||||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
$(CXX) -c $(CXXFLAGS) $< -o $@
|
||||||
|
|
||||||
|
|
||||||
xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
|
xmlParser/xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
|
||||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
$(CXX) -c $(CXXFLAGS) $< -o $@
|
||||||
|
|
||||||
|
|
||||||
@ -105,7 +126,7 @@ xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
|
|||||||
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
||||||
|
|
||||||
|
|
||||||
hcli: $(OBJS)
|
hcli: $(HCLI_OBJS)
|
||||||
$(CXX) -o hcli $(HCLI_OBJS)
|
$(CXX) -o hcli $(HCLI_OBJS)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,38 +1,77 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <algorithm>
|
#include <limits>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "SPSolver.h"
|
#include "SPSolver.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "FgVarNode.h"
|
#include "FgVarNode.h"
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
|
#include "Shared.h"
|
||||||
SPSolver* Link::klass = 0;
|
|
||||||
|
|
||||||
|
|
||||||
SPSolver::SPSolver (const FactorGraph& fg) : Solver (&fg)
|
SPSolver::SPSolver (FactorGraph& fg) : Solver (&fg)
|
||||||
{
|
{
|
||||||
fg_ = &fg;
|
fg_ = &fg;
|
||||||
accuracy_ = 0.0001;
|
|
||||||
maxIter_ = 10000;
|
|
||||||
//schedule_ = S_SEQ_FIXED;
|
|
||||||
//schedule_ = S_SEQ_RANDOM;
|
|
||||||
//schedule_ = S_SEQ_PARALLEL;
|
|
||||||
schedule_ = S_MAX_RESIDUAL;
|
|
||||||
Link::klass = this;
|
|
||||||
FgVarSet vars = fg_->getFgVarNodes();
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
msgs_.push_back (new MessageBanket (vars[i]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
SPSolver::~SPSolver (void)
|
SPSolver::~SPSolver (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < msgs_.size(); i++) {
|
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||||
delete msgs_[i];
|
delete varsI_[i];
|
||||||
}
|
}
|
||||||
|
for (unsigned i = 0; i < factorsI_.size(); i++) {
|
||||||
|
delete factorsI_[i];
|
||||||
|
}
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
delete links_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
SPSolver::runTreeSolver (void)
|
||||||
|
{
|
||||||
|
CFactorSet factors = fg_->getFactors();
|
||||||
|
bool finish = false;
|
||||||
|
while (!finish) {
|
||||||
|
finish = true;
|
||||||
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
|
CLinkSet links = factorsI_[factors[i]->getIndex()]->getLinks();
|
||||||
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
|
if (!links[j]->messageWasSended()) {
|
||||||
|
if (readyToSendMessage(links[j])) {
|
||||||
|
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||||
|
links[j]->updateMessage();
|
||||||
|
}
|
||||||
|
finish = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
SPSolver::readyToSendMessage (const Link* link) const
|
||||||
|
{
|
||||||
|
CFgVarSet factorVars = link->getFactor()->getFgVarNodes();
|
||||||
|
for (unsigned i = 0; i < factorVars.size(); i++) {
|
||||||
|
if (factorVars[i] != link->getVariable()) {
|
||||||
|
CLinkSet links = varsI_[factorVars[i]->getIndex()]->getLinks();
|
||||||
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
|
if (links[j]->getFactor() != link->getFactor() &&
|
||||||
|
!links[j]->messageWasSended()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -40,62 +79,54 @@ SPSolver::~SPSolver (void)
|
|||||||
void
|
void
|
||||||
SPSolver::runSolver (void)
|
SPSolver::runSolver (void)
|
||||||
{
|
{
|
||||||
|
initializeSolver();
|
||||||
|
runTreeSolver();
|
||||||
|
return;
|
||||||
nIter_ = 0;
|
nIter_ = 0;
|
||||||
vector<Factor*> factors = fg_->getFactors();
|
while (!converged() && nIter_ < SolverOptions::maxIter) {
|
||||||
for (unsigned i = 0; i < factors.size(); i++) {
|
|
||||||
FgVarSet neighbors = factors[i]->getFgVarNodes();
|
nIter_ ++;
|
||||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
if (DL >= 2) {
|
||||||
updateOrder_.push_back (Link (factors[i], neighbors[j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
while (!converged() && nIter_ < maxIter_) {
|
|
||||||
if (DL >= 1) {
|
|
||||||
cout << endl;
|
cout << endl;
|
||||||
cout << "****************************************" ;
|
cout << "****************************************" ;
|
||||||
cout << "****************************************" ;
|
cout << "****************************************" ;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
cout << " Iteration " << nIter_ + 1 << endl;
|
cout << " Iteration " << nIter_ << endl;
|
||||||
cout << "****************************************" ;
|
cout << "****************************************" ;
|
||||||
cout << "****************************************" ;
|
cout << "****************************************" ;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (schedule_) {
|
switch (SolverOptions::schedule) {
|
||||||
|
case SolverOptions::S_SEQ_RANDOM:
|
||||||
case S_SEQ_RANDOM:
|
random_shuffle (links_.begin(), links_.end());
|
||||||
random_shuffle (updateOrder_.begin(), updateOrder_.end());
|
|
||||||
// no break
|
// no break
|
||||||
|
|
||||||
case S_SEQ_FIXED:
|
case SolverOptions::S_SEQ_FIXED:
|
||||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
Link& link = updateOrder_[c];
|
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||||
calculateNextMessage (link.source, link.destination);
|
links_[i]->updateMessage();
|
||||||
updateMessage (updateOrder_[c]);
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case S_PARALLEL:
|
case SolverOptions::S_PARALLEL:
|
||||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
Link link = updateOrder_[c];
|
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||||
calculateNextMessage (link.source, link.destination);
|
|
||||||
}
|
}
|
||||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
Link link = updateOrder_[c];
|
links_[i]->updateMessage();
|
||||||
updateMessage (updateOrder_[c]);
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case S_MAX_RESIDUAL:
|
case SolverOptions::S_MAX_RESIDUAL:
|
||||||
maxResidualSchedule();
|
maxResidualSchedule();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
nIter_++;
|
|
||||||
}
|
}
|
||||||
cout << endl;
|
|
||||||
if (DL >= 1) {
|
if (DL >= 2) {
|
||||||
if (nIter_ < maxIter_) {
|
cout << endl;
|
||||||
|
if (nIter_ < SolverOptions::maxIter) {
|
||||||
cout << "Loopy Sum-Product converged in " ;
|
cout << "Loopy Sum-Product converged in " ;
|
||||||
cout << nIter_ << " iterations" << endl;
|
cout << nIter_ << " iterations" << endl;
|
||||||
} else {
|
} else {
|
||||||
@ -108,58 +139,168 @@ SPSolver::runSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
ParamSet
|
ParamSet
|
||||||
SPSolver::getPosterioriOf (const Variable* var) const
|
SPSolver::getPosterioriOf (Vid vid) const
|
||||||
{
|
{
|
||||||
assert (var);
|
assert (fg_->getFgVarNode (vid));
|
||||||
assert (var == fg_->getVariableById (var->getVarId()));
|
FgVarNode* var = fg_->getFgVarNode (vid);
|
||||||
assert (var->getIndex() < msgs_.size());
|
ParamSet probs;
|
||||||
|
|
||||||
ParamSet probs (var->getDomainSize(), 1);
|
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
for (unsigned i = 0; i < probs.size(); i++) {
|
probs.resize (var->getDomainSize(), 0.0);
|
||||||
if ((int)i != var->getEvidence()) {
|
probs[var->getEvidence()] = 1.0;
|
||||||
probs[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
probs.resize (var->getDomainSize(), 1.0);
|
||||||
MessageBanket* mb = msgs_[var->getIndex()];
|
CLinkSet links = varsI_[var->getIndex()]->getLinks();
|
||||||
const FgVarNode* varNode = fg_->getFgVarNodes()[var->getIndex()];
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
vector<Factor*> neighbors = varNode->getFactors();
|
CParamSet msg = links[i]->getMessage();
|
||||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
|
||||||
const Message& msg = mb->getMessage (neighbors[i]);
|
|
||||||
for (unsigned j = 0; j < msg.size(); j++) {
|
for (unsigned j = 0; j < msg.size(); j++) {
|
||||||
probs[j] *= msg[j];
|
probs[j] *= msg[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Util::normalize (probs);
|
Util::normalize (probs);
|
||||||
}
|
}
|
||||||
|
|
||||||
return probs;
|
return probs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParamSet
|
||||||
|
SPSolver::getJointDistributionOf (const VidSet& jointVids)
|
||||||
|
{
|
||||||
|
FgVarSet jointVars;
|
||||||
|
unsigned dsize = 1;
|
||||||
|
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||||
|
FgVarNode* varNode = fg_->getFgVarNode (jointVids[i]);
|
||||||
|
dsize *= varNode->getDomainSize();
|
||||||
|
jointVars.push_back (varNode);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned maxVid = std::numeric_limits<unsigned>::max();
|
||||||
|
FgVarNode* junctionVar = new FgVarNode (maxVid, dsize);
|
||||||
|
FgVarSet factorVars = { junctionVar };
|
||||||
|
for (unsigned i = 0; i < jointVars.size(); i++) {
|
||||||
|
factorVars.push_back (jointVars[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned nParams = dsize * dsize;
|
||||||
|
ParamSet params (nParams);
|
||||||
|
for (unsigned i = 0; i < nParams; i++) {
|
||||||
|
unsigned row = i / dsize;
|
||||||
|
unsigned col = i % dsize;
|
||||||
|
if (row == col) {
|
||||||
|
params[i] = 1;
|
||||||
|
} else {
|
||||||
|
params[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Distribution* dist = new Distribution (params, maxVid);
|
||||||
|
Factor* newFactor = new Factor (factorVars, dist);
|
||||||
|
fg_->addVariable (junctionVar);
|
||||||
|
fg_->addFactor (newFactor);
|
||||||
|
|
||||||
|
runSolver();
|
||||||
|
ParamSet results = getPosterioriOf (maxVid);
|
||||||
|
deleteJunction (newFactor, junctionVar);
|
||||||
|
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
SPSolver::initializeSolver (void)
|
||||||
|
{
|
||||||
|
fg_->setIndexes();
|
||||||
|
|
||||||
|
CFgVarSet vars = fg_->getFgVarNodes();
|
||||||
|
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||||
|
delete varsI_[i];
|
||||||
|
}
|
||||||
|
varsI_.reserve (vars.size());
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
varsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
CFactorSet factors = fg_->getFactors();
|
||||||
|
for (unsigned i = 0; i < factorsI_.size(); i++) {
|
||||||
|
delete factorsI_[i];
|
||||||
|
}
|
||||||
|
factorsI_.reserve (factors.size());
|
||||||
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
|
factorsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
delete links_[i];
|
||||||
|
}
|
||||||
|
createLinks();
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
Factor* source = links_[i]->getFactor();
|
||||||
|
FgVarNode* dest = links_[i]->getVariable();
|
||||||
|
varsI_[dest->getIndex()]->addLink (links_[i]);
|
||||||
|
factorsI_[source->getIndex()]->addLink (links_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
SPSolver::createLinks (void)
|
||||||
|
{
|
||||||
|
CFactorSet factors = fg_->getFactors();
|
||||||
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
|
CFgVarSet neighbors = factors[i]->getFgVarNodes();
|
||||||
|
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||||
|
links_.push_back (new Link (factors[i], neighbors[j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
SPSolver::deleteJunction (Factor* f, FgVarNode* v)
|
||||||
|
{
|
||||||
|
fg_->removeFactor (f);
|
||||||
|
f->freeDistribution();
|
||||||
|
delete f;
|
||||||
|
fg_->removeVariable (v);
|
||||||
|
delete v;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
SPSolver::converged (void)
|
SPSolver::converged (void)
|
||||||
{
|
{
|
||||||
|
// this can happen if the graph is fully disconnected
|
||||||
|
if (links_.size() == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (nIter_ == 0 || nIter_ == 1) {
|
if (nIter_ == 0 || nIter_ == 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool converged = true;
|
bool converged = true;
|
||||||
for (unsigned i = 0; i < updateOrder_.size(); i++) {
|
if (SolverOptions::schedule == SolverOptions::S_MAX_RESIDUAL) {
|
||||||
double residual = getResidual (updateOrder_[i]);
|
Param maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||||
if (DL >= 1) {
|
if (maxResidual < SolverOptions::accuracy) {
|
||||||
cout << updateOrder_[i].toString();
|
converged = true;
|
||||||
cout << " residual = " << residual << endl;
|
} else {
|
||||||
}
|
|
||||||
if (residual > accuracy_) {
|
|
||||||
converged = false;
|
converged = false;
|
||||||
if (DL == 0) {
|
}
|
||||||
break;
|
} else {
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
double residual = links_[i]->getResidual();
|
||||||
|
if (DL >= 2) {
|
||||||
|
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||||
}
|
}
|
||||||
}
|
if (residual > SolverOptions::accuracy) {
|
||||||
|
converged = false;
|
||||||
|
if (DL == 0) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return converged;
|
return converged;
|
||||||
}
|
}
|
||||||
@ -169,127 +310,161 @@ SPSolver::converged (void)
|
|||||||
void
|
void
|
||||||
SPSolver::maxResidualSchedule (void)
|
SPSolver::maxResidualSchedule (void)
|
||||||
{
|
{
|
||||||
if (nIter_ == 0) {
|
if (nIter_ == 1) {
|
||||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
Link& l = updateOrder_[c];
|
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||||
calculateNextMessage (l.source, l.destination);
|
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||||
if (DL >= 1) {
|
linkMap_.insert (make_pair (links_[i], it));
|
||||||
cout << updateOrder_[c].toString() << " residual = " ;
|
if (DL >= 2 && DL < 5) {
|
||||||
cout << getResidual (updateOrder_[c]) << endl;
|
cout << "calculating " << links_[i]->toString() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sort (updateOrder_.begin(), updateOrder_.end(), compareResidual);
|
return;
|
||||||
} else {
|
}
|
||||||
|
|
||||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
for (unsigned c = 0; c < links_.size(); c++) {
|
||||||
Link& link = updateOrder_.front();
|
if (DL >= 2) {
|
||||||
updateMessage (link);
|
cout << endl << "current residuals:" << endl;
|
||||||
resetResidual (link);
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
it != sortedOrder_.end(); it ++) {
|
||||||
|
cout << " " << setw (30) << left << (*it)->toString();
|
||||||
|
cout << "residual = " << (*it)->getResidual() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// update the messages that depend on message source --> destination
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
vector<Factor*> fstLevelNeighbors = link.destination->getFactors();
|
Link* link = *it;
|
||||||
for (unsigned i = 0; i < fstLevelNeighbors.size(); i++) {
|
if (DL >= 2) {
|
||||||
if (fstLevelNeighbors[i] != link.source) {
|
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||||
FgVarSet sndLevelNeighbors;
|
}
|
||||||
sndLevelNeighbors = fstLevelNeighbors[i]->getFgVarNodes();
|
if (link->getResidual() < SolverOptions::accuracy) {
|
||||||
for (unsigned j = 0; j < sndLevelNeighbors.size(); j++) {
|
return;
|
||||||
if (sndLevelNeighbors[j] != link.destination) {
|
}
|
||||||
calculateNextMessage (fstLevelNeighbors[i], sndLevelNeighbors[j]);
|
link->updateMessage();
|
||||||
|
link->clearResidual();
|
||||||
|
sortedOrder_.erase (it);
|
||||||
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
|
// update the messages that depend on message source --> destin
|
||||||
|
CFactorSet factorNeighbors = link->getVariable()->getFactors();
|
||||||
|
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||||
|
if (factorNeighbors[i] != link->getFactor()) {
|
||||||
|
CLinkSet links = factorsI_[factorNeighbors[i]->getIndex()]->getLinks();
|
||||||
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
|
if (links[j]->getVariable() != link->getVariable()) {
|
||||||
|
if (DL >= 2 && DL < 5) {
|
||||||
|
cout << " calculating " << links[j]->toString() << endl;
|
||||||
}
|
}
|
||||||
|
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||||
|
LinkMap::iterator iter = linkMap_.find (links[j]);
|
||||||
|
sortedOrder_.erase (iter->second);
|
||||||
|
iter->second = sortedOrder_.insert (links[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sort (updateOrder_.begin(), updateOrder_.end(), compareResidual);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
ParamSet
|
||||||
SPSolver::updateMessage (const Link& link)
|
SPSolver::getFactor2VarMsg (const Link* link) const
|
||||||
{
|
{
|
||||||
updateMessage (link.source, link.destination);
|
const Factor* src = link->getFactor();
|
||||||
}
|
const FgVarNode* dest = link->getVariable();
|
||||||
|
CFgVarSet neighbors = src->getFgVarNodes();
|
||||||
|
CLinkSet links = factorsI_[src->getIndex()]->getLinks();
|
||||||
|
// calculate the product of messages that were sent
|
||||||
void
|
|
||||||
SPSolver::updateMessage (const Factor* src, const FgVarNode* dest)
|
|
||||||
{
|
|
||||||
msgs_[dest->getIndex()]->updateMessage (src);
|
|
||||||
/* cout << src->getLabel() << " --> " << dest->getLabel() << endl;
|
|
||||||
cout << " m: " ;
|
|
||||||
Message msg = msgs_[dest->getIndex()]->getMessage (src);
|
|
||||||
for (unsigned i = 0; i < msg.size(); i++) {
|
|
||||||
if (i != 0) cout << ", " ;
|
|
||||||
cout << msg[i];
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
SPSolver::calculateNextMessage (const Link& link)
|
|
||||||
{
|
|
||||||
calculateNextMessage (link.source, link.destination);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
SPSolver::calculateNextMessage (const Factor* src, const FgVarNode* dest)
|
|
||||||
{
|
|
||||||
FgVarSet neighbors = src->getFgVarNodes();
|
|
||||||
// calculate the product of MessageBankets sended
|
|
||||||
// to factor `src', except from var `dest'
|
// to factor `src', except from var `dest'
|
||||||
Factor result = *src;
|
Factor result (*src);
|
||||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
Factor temp;
|
||||||
if (neighbors[i] != dest) {
|
if (DL >= 5) {
|
||||||
Message msg (neighbors[i]->getDomainSize(), 1);
|
cout << "calculating " ;
|
||||||
calculateVarFactorMessage (neighbors[i], src, msg);
|
cout << src->getLabel() << " --> " << dest->getLabel();
|
||||||
result *= Factor (neighbors[i], msg);
|
cout << endl;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// marginalize all vars except `dest'
|
|
||||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
for (unsigned i = 0; i < neighbors.size(); i++) {
|
||||||
if (neighbors[i] != dest) {
|
if (links[i]->getVariable() != dest) {
|
||||||
result.marginalizeVariable (neighbors[i]);
|
if (DL >= 5) {
|
||||||
}
|
cout << " message from " << links[i]->getVariable()->getLabel();
|
||||||
}
|
cout << ": " ;
|
||||||
msgs_[dest->getIndex()]->setNextMessage (src, result.getParameters());
|
ParamSet p = getVar2FactorMsg (links[i]);
|
||||||
}
|
cout << endl;
|
||||||
|
Factor temp2 (links[i]->getVariable(), p);
|
||||||
|
temp.multiplyByFactor (temp2);
|
||||||
|
temp2.freeDistribution();
|
||||||
void
|
|
||||||
SPSolver::calculateVarFactorMessage (const FgVarNode* src,
|
|
||||||
const Factor* dest,
|
|
||||||
Message& placeholder) const
|
|
||||||
{
|
|
||||||
assert (src->getDomainSize() == (int)placeholder.size());
|
|
||||||
if (src->hasEvidence()) {
|
|
||||||
for (unsigned i = 0; i < placeholder.size(); i++) {
|
|
||||||
if ((int)i != src->getEvidence()) {
|
|
||||||
placeholder[i] = 0.0;
|
|
||||||
} else {
|
} else {
|
||||||
placeholder[i] = 1.0;
|
Factor temp2 (links[i]->getVariable(), getVar2FactorMsg (links[i]));
|
||||||
}
|
temp.multiplyByFactor (temp2);
|
||||||
}
|
temp2.freeDistribution();
|
||||||
|
|
||||||
} else {
|
|
||||||
|
|
||||||
MessageBanket* mb = msgs_[src->getIndex()];
|
|
||||||
vector<Factor*> neighbors = src->getFactors();
|
|
||||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
|
||||||
if (neighbors[i] != dest) {
|
|
||||||
const Message& fromFactor = mb->getMessage (neighbors[i]);
|
|
||||||
for (unsigned j = 0; j < fromFactor.size(); j++) {
|
|
||||||
placeholder[j] *= fromFactor[j];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (links.size() >= 2) {
|
||||||
|
result.multiplyByFactor (temp, &(src->getCptEntries()));
|
||||||
|
if (DL >= 5) {
|
||||||
|
cout << " message product: " ;
|
||||||
|
cout << Util::parametersToString (temp.getParameters()) << endl;
|
||||||
|
cout << " factor product: " ;
|
||||||
|
cout << Util::parametersToString (src->getParameters());
|
||||||
|
cout << " x " ;
|
||||||
|
cout << Util::parametersToString (temp.getParameters());
|
||||||
|
cout << " = " ;
|
||||||
|
cout << Util::parametersToString (result.getParameters()) << endl;
|
||||||
|
}
|
||||||
|
temp.freeDistribution();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
|
if (links[i]->getVariable() != dest) {
|
||||||
|
result.removeVariable (links[i]->getVariable());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (DL >= 5) {
|
||||||
|
cout << " final message: " ;
|
||||||
|
cout << Util::parametersToString (result.getParameters()) << endl << endl;
|
||||||
|
}
|
||||||
|
ParamSet msg = result.getParameters();
|
||||||
|
result.freeDistribution();
|
||||||
|
return msg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParamSet
|
||||||
|
SPSolver::getVar2FactorMsg (const Link* link) const
|
||||||
|
{
|
||||||
|
const FgVarNode* src = link->getVariable();
|
||||||
|
const Factor* dest = link->getFactor();
|
||||||
|
ParamSet msg;
|
||||||
|
if (src->hasEvidence()) {
|
||||||
|
msg.resize (src->getDomainSize(), 0.0);
|
||||||
|
msg[src->getEvidence()] = 1.0;
|
||||||
|
if (DL >= 5) {
|
||||||
|
cout << Util::parametersToString (msg);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
msg.resize (src->getDomainSize(), 1.0);
|
||||||
|
}
|
||||||
|
if (DL >= 5) {
|
||||||
|
cout << Util::parametersToString (msg);
|
||||||
|
}
|
||||||
|
CLinkSet links = varsI_[src->getIndex()]->getLinks();
|
||||||
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
|
if (links[i]->getFactor() != dest) {
|
||||||
|
CParamSet msgFromFactor = links[i]->getMessage();
|
||||||
|
for (unsigned j = 0; j < msgFromFactor.size(); j++) {
|
||||||
|
msg[j] *= msgFromFactor[j];
|
||||||
|
}
|
||||||
|
if (DL >= 5) {
|
||||||
|
cout << " x " << Util::parametersToString (msgFromFactor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (DL >= 5) {
|
||||||
|
cout << " = " << Util::parametersToString (msg);
|
||||||
|
}
|
||||||
|
return msg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
#ifndef BP_SPSOLVER_H
|
#ifndef BP_SP_SOLVER_H
|
||||||
#define BP_SPSOLVER_H
|
#define BP_SP_SOLVER_H
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <set>
|
||||||
|
|
||||||
#include "Solver.h"
|
#include "Solver.h"
|
||||||
#include "FgVarNode.h"
|
#include "FgVarNode.h"
|
||||||
@ -15,157 +13,118 @@ using namespace std;
|
|||||||
class FactorGraph;
|
class FactorGraph;
|
||||||
class SPSolver;
|
class SPSolver;
|
||||||
|
|
||||||
struct Link
|
|
||||||
{
|
|
||||||
Link (Factor* s, FgVarNode* d)
|
|
||||||
{
|
|
||||||
source = s;
|
|
||||||
destination = d;
|
|
||||||
}
|
|
||||||
string toString (void) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss << source->getLabel() << " --> " ;
|
|
||||||
ss << destination->getLabel();
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
Factor* source;
|
|
||||||
FgVarNode* destination;
|
|
||||||
static SPSolver* klass;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
class Link
|
||||||
|
|
||||||
class MessageBanket
|
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
MessageBanket (const FgVarNode* var)
|
Link (Factor* f, FgVarNode* v)
|
||||||
|
{
|
||||||
|
factor_ = f;
|
||||||
|
var_ = v;
|
||||||
|
currMsg_.resize (v->getDomainSize(), 1);
|
||||||
|
nextMsg_.resize (v->getDomainSize(), 1);
|
||||||
|
msgSended_ = false;
|
||||||
|
residual_ = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void setMessage (ParamSet msg)
|
||||||
{
|
{
|
||||||
vector<Factor*> sources = var->getFactors();
|
Util::normalize (msg);
|
||||||
for (unsigned i = 0; i < sources.size(); i++) {
|
residual_ = Util::getMaxNorm (currMsg_, msg);
|
||||||
indexMap_.insert (make_pair (sources[i], i));
|
currMsg_ = msg;
|
||||||
currMsgs_.push_back (Message(var->getDomainSize(), 1));
|
|
||||||
nextMsgs_.push_back (Message(var->getDomainSize(), -10));
|
|
||||||
residuals_.push_back (0.0);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void updateMessage (const Factor* source)
|
void setNextMessage (CParamSet msg)
|
||||||
{
|
{
|
||||||
unsigned idx = getIndex(source);
|
nextMsg_ = msg;
|
||||||
currMsgs_[idx] = nextMsgs_[idx];
|
Util::normalize (nextMsg_);
|
||||||
|
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void setNextMessage (const Factor* source, const Message& msg)
|
void updateMessage (void)
|
||||||
{
|
{
|
||||||
unsigned idx = getIndex(source);
|
currMsg_ = nextMsg_;
|
||||||
nextMsgs_[idx] = msg;
|
msgSended_ = true;
|
||||||
residuals_[idx] = computeResidual (source);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const Message& getMessage (const Factor* source) const
|
string toString (void) const
|
||||||
{
|
{
|
||||||
return currMsgs_[getIndex(source)];
|
stringstream ss;
|
||||||
}
|
ss << factor_->getLabel();
|
||||||
|
ss << " -- " ;
|
||||||
double getResidual (const Factor* source) const
|
ss << var_->getLabel();
|
||||||
{
|
return ss.str();
|
||||||
return residuals_[getIndex(source)];
|
|
||||||
}
|
|
||||||
|
|
||||||
void resetResidual (const Factor* source)
|
|
||||||
{
|
|
||||||
residuals_[getIndex(source)] = 0.0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Factor* getFactor (void) const { return factor_; }
|
||||||
|
FgVarNode* getVariable (void) const { return var_; }
|
||||||
|
CParamSet getMessage (void) const { return currMsg_; }
|
||||||
|
bool messageWasSended (void) const { return msgSended_; }
|
||||||
|
double getResidual (void) const { return residual_; }
|
||||||
|
void clearResidual (void) { residual_ = 0.0; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
double computeResidual (const Factor* source)
|
Factor* factor_;
|
||||||
{
|
FgVarNode* var_;
|
||||||
double change = 0.0;
|
ParamSet currMsg_;
|
||||||
unsigned idx = getIndex (source);
|
ParamSet nextMsg_;
|
||||||
const Message& currMessage = currMsgs_[idx];
|
bool msgSended_;
|
||||||
const Message& nextMessage = nextMsgs_[idx];
|
double residual_;
|
||||||
for (unsigned i = 0; i < currMessage.size(); i++) {
|
|
||||||
change += abs (currMessage[i] - nextMessage[i]);
|
|
||||||
}
|
|
||||||
return change;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned getIndex (const Factor* factor) const
|
|
||||||
{
|
|
||||||
assert (factor);
|
|
||||||
assert (indexMap_.find(factor) != indexMap_.end());
|
|
||||||
return indexMap_.find(factor)->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef map<const Factor*, unsigned> IndexMap;
|
|
||||||
|
|
||||||
IndexMap indexMap_;
|
|
||||||
vector<Message> currMsgs_;
|
|
||||||
vector<Message> nextMsgs_;
|
|
||||||
vector<double> residuals_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class SPNodeInfo
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
void addLink (Link* link) { links_.push_back (link); }
|
||||||
|
CLinkSet getLinks (void) { return links_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
LinkSet links_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class SPSolver : public Solver
|
class SPSolver : public Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
SPSolver (const FactorGraph&);
|
SPSolver (FactorGraph&);
|
||||||
~SPSolver (void);
|
virtual ~SPSolver (void);
|
||||||
|
|
||||||
void runSolver (void);
|
void runSolver (void);
|
||||||
ParamSet getPosterioriOf (const Variable* var) const;
|
virtual ParamSet getPosterioriOf (Vid) const;
|
||||||
|
ParamSet getJointDistributionOf (CVidSet);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual void initializeSolver (void);
|
||||||
|
void runTreeSolver (void);
|
||||||
|
bool readyToSendMessage (const Link*) const;
|
||||||
|
virtual void createLinks (void);
|
||||||
|
virtual void deleteJunction (Factor*, FgVarNode*);
|
||||||
|
bool converged (void);
|
||||||
|
virtual void maxResidualSchedule (void);
|
||||||
|
virtual ParamSet getFactor2VarMsg (const Link*) const;
|
||||||
|
virtual ParamSet getVar2FactorMsg (const Link*) const;
|
||||||
|
|
||||||
private:
|
struct CompareResidual {
|
||||||
bool converged (void);
|
inline bool operator() (const Link* link1, const Link* link2)
|
||||||
void maxResidualSchedule (void);
|
{
|
||||||
void updateMessage (const Link&);
|
return link1->getResidual() > link2->getResidual();
|
||||||
void updateMessage (const Factor*, const FgVarNode*);
|
}
|
||||||
void calculateNextMessage (const Link&);
|
};
|
||||||
void calculateNextMessage (const Factor*, const FgVarNode*);
|
|
||||||
void calculateVarFactorMessage (
|
FactorGraph* fg_;
|
||||||
const FgVarNode*, const Factor*, Message&) const;
|
LinkSet links_;
|
||||||
double getResidual (const Link&) const;
|
vector<SPNodeInfo*> varsI_;
|
||||||
void resetResidual (const Link&) const;
|
vector<SPNodeInfo*> factorsI_;
|
||||||
friend bool compareResidual (const Link&, const Link&);
|
unsigned nIter_;
|
||||||
|
|
||||||
|
typedef multiset<Link*, CompareResidual> SortedOrder;
|
||||||
|
SortedOrder sortedOrder_;
|
||||||
|
|
||||||
|
typedef map<Link*, SortedOrder::iterator> LinkMap;
|
||||||
|
LinkMap linkMap_;
|
||||||
|
|
||||||
const FactorGraph* fg_;
|
|
||||||
vector<MessageBanket*> msgs_;
|
|
||||||
Schedule schedule_;
|
|
||||||
int nIter_;
|
|
||||||
double accuracy_;
|
|
||||||
int maxIter_;
|
|
||||||
vector<Link> updateOrder_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#endif // BP_SP_SOLVER_H
|
||||||
|
|
||||||
inline double
|
|
||||||
SPSolver::getResidual (const Link& link) const
|
|
||||||
{
|
|
||||||
MessageBanket* mb = Link::klass->msgs_[link.destination->getIndex()];
|
|
||||||
return mb->getResidual (link.source);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
SPSolver::resetResidual (const Link& link) const
|
|
||||||
{
|
|
||||||
MessageBanket* mb = Link::klass->msgs_[link.destination->getIndex()];
|
|
||||||
mb->resetResidual (link.source);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline bool
|
|
||||||
compareResidual (const Link& link1, const Link& link2)
|
|
||||||
{
|
|
||||||
MessageBanket* mb1 = Link::klass->msgs_[link1.destination->getIndex()];
|
|
||||||
MessageBanket* mb2 = Link::klass->msgs_[link2.destination->getIndex()];
|
|
||||||
return mb1->getResidual(link1.source) > mb2->getResidual(link2.source);
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
@ -2,14 +2,15 @@
|
|||||||
#define BP_SHARED_H
|
#define BP_SHARED_H
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <iostream>
|
|
||||||
#include <fstream>
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
// Macro to disallow the copy constructor and operator= functions
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iomanip>
|
||||||
|
|
||||||
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||||
TypeName(const TypeName&); \
|
TypeName(const TypeName&); \
|
||||||
void operator=(const TypeName&)
|
void operator=(const TypeName&)
|
||||||
@ -19,61 +20,162 @@ using namespace std;
|
|||||||
class Variable;
|
class Variable;
|
||||||
class BayesNode;
|
class BayesNode;
|
||||||
class FgVarNode;
|
class FgVarNode;
|
||||||
|
class Factor;
|
||||||
|
class Link;
|
||||||
|
class Edge;
|
||||||
|
|
||||||
typedef double Param;
|
typedef double Param;
|
||||||
typedef vector<Param> ParamSet;
|
typedef vector<Param> ParamSet;
|
||||||
typedef vector<Param> Message;
|
typedef const ParamSet& CParamSet;
|
||||||
|
typedef unsigned Vid;
|
||||||
|
typedef vector<Vid> VidSet;
|
||||||
|
typedef const VidSet& CVidSet;
|
||||||
typedef vector<Variable*> VarSet;
|
typedef vector<Variable*> VarSet;
|
||||||
typedef vector<BayesNode*> NodeSet;
|
typedef vector<BayesNode*> BnNodeSet;
|
||||||
|
typedef const BnNodeSet& CBnNodeSet;
|
||||||
typedef vector<FgVarNode*> FgVarSet;
|
typedef vector<FgVarNode*> FgVarSet;
|
||||||
|
typedef const FgVarSet& CFgVarSet;
|
||||||
|
typedef vector<Factor*> FactorSet;
|
||||||
|
typedef const FactorSet& CFactorSet;
|
||||||
|
typedef vector<Link*> LinkSet;
|
||||||
|
typedef const LinkSet& CLinkSet;
|
||||||
|
typedef vector<Edge*> EdgeSet;
|
||||||
|
typedef const EdgeSet& CEdgeSet;
|
||||||
typedef vector<string> Domain;
|
typedef vector<string> Domain;
|
||||||
typedef vector<unsigned> DomainConf;
|
typedef vector<unsigned> DConf;
|
||||||
typedef pair<unsigned, unsigned> DomainConstr;
|
typedef pair<unsigned, unsigned> DConstraint;
|
||||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
typedef map<unsigned, unsigned> IndexMap;
|
||||||
|
|
||||||
|
// level of debug information
|
||||||
//extern unsigned DL;
|
|
||||||
static const unsigned DL = 0;
|
static const unsigned DL = 0;
|
||||||
|
|
||||||
// number of digits to show when printing a parameter
|
static const int NO_EVIDENCE = -1;
|
||||||
static const unsigned PRECISION = 10;
|
|
||||||
|
|
||||||
// shared by bp and sp solver
|
// number of digits to show when printing a parameter
|
||||||
enum Schedule
|
static const unsigned PRECISION = 5;
|
||||||
|
|
||||||
|
static const bool EXPORT_TO_DOT = false;
|
||||||
|
static const unsigned EXPORT_MIN_SIZE = 30;
|
||||||
|
|
||||||
|
|
||||||
|
namespace SolverOptions
|
||||||
{
|
{
|
||||||
S_SEQ_FIXED,
|
enum Schedule
|
||||||
S_SEQ_RANDOM,
|
{
|
||||||
S_PARALLEL,
|
S_SEQ_FIXED,
|
||||||
S_MAX_RESIDUAL
|
S_SEQ_RANDOM,
|
||||||
|
S_PARALLEL,
|
||||||
|
S_MAX_RESIDUAL
|
||||||
|
};
|
||||||
|
extern bool runBayesBall;
|
||||||
|
extern bool convertBn2Fg;
|
||||||
|
extern bool compressFactorGraph;
|
||||||
|
extern Schedule schedule;
|
||||||
|
extern double accuracy;
|
||||||
|
extern unsigned maxIter;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
namespace Util
|
||||||
|
{
|
||||||
|
void normalize (ParamSet&);
|
||||||
|
void pow (ParamSet&, unsigned);
|
||||||
|
double getL1dist (CParamSet, CParamSet);
|
||||||
|
double getMaxNorm (CParamSet, CParamSet);
|
||||||
|
bool isInteger (const string&);
|
||||||
|
string parametersToString (CParamSet);
|
||||||
|
vector<DConf> getDomainConfigurations (const VarSet&);
|
||||||
|
vector<string> getInstantiations (const VarSet&);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
struct NetInfo
|
struct NetInfo
|
||||||
{
|
{
|
||||||
NetInfo (unsigned c, double t)
|
NetInfo (void)
|
||||||
{
|
{
|
||||||
counting = c;
|
counting = 0;
|
||||||
solvingTime = t;
|
nIters = 0;
|
||||||
|
solvingTime = 0.0;
|
||||||
}
|
}
|
||||||
unsigned counting;
|
unsigned counting;
|
||||||
double solvingTime;
|
double solvingTime;
|
||||||
|
unsigned nIters;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct CompressInfo
|
||||||
|
{
|
||||||
|
CompressInfo (unsigned a, unsigned b, unsigned c,
|
||||||
|
unsigned d, unsigned e) {
|
||||||
|
nUncVars = a;
|
||||||
|
nUncFactors = b;
|
||||||
|
nCompVars = c;
|
||||||
|
nCompFactors = d;
|
||||||
|
nNeighborlessVars = e;
|
||||||
|
}
|
||||||
|
unsigned nUncVars;
|
||||||
|
unsigned nUncFactors;
|
||||||
|
unsigned nCompVars;
|
||||||
|
unsigned nCompFactors;
|
||||||
|
unsigned nNeighborlessVars;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
typedef map<unsigned, NetInfo> StatisticMap;
|
typedef map<unsigned, NetInfo> StatisticMap;
|
||||||
|
|
||||||
|
|
||||||
class Statistics
|
class Statistics
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
static void updateStats (unsigned size, double time)
|
static void updateStats (unsigned size, unsigned nIters, double time)
|
||||||
{
|
{
|
||||||
StatisticMap::iterator it = stats_.find(size);
|
StatisticMap::iterator it = stats_.find (size);
|
||||||
if (it == stats_.end()) {
|
if (it == stats_.end()) {
|
||||||
stats_.insert (make_pair (size, NetInfo (1, 0.0)));
|
it = (stats_.insert (make_pair (size, NetInfo()))).first;
|
||||||
} else {
|
} else {
|
||||||
it->second.counting ++;
|
it->second.counting ++;
|
||||||
|
it->second.nIters += nIters;
|
||||||
it->second.solvingTime += time;
|
it->second.solvingTime += time;
|
||||||
|
totalOfIterations += nIters;
|
||||||
|
if (nIters > maxIterations) {
|
||||||
|
maxIterations = nIters;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void updateCompressingStats (unsigned nUncVars,
|
||||||
|
unsigned nUncFactors,
|
||||||
|
unsigned nCompVars,
|
||||||
|
unsigned nCompFactors,
|
||||||
|
unsigned nNeighborlessVars) {
|
||||||
|
compressInfo_.push_back (CompressInfo (
|
||||||
|
nUncVars, nUncFactors, nCompVars, nCompFactors, nNeighborlessVars));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printCompressingStats (const char* fileName)
|
||||||
|
{
|
||||||
|
ofstream out (fileName);
|
||||||
|
if (!out.is_open()) {
|
||||||
|
cerr << "error: cannot open file to write at " ;
|
||||||
|
cerr << "BayesNet::printCompressingStats()" << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
out << "--------------------------------------" ;
|
||||||
|
out << "--------------------------------------" << endl;
|
||||||
|
out << " Compression Stats" << endl;
|
||||||
|
out << "--------------------------------------" ;
|
||||||
|
out << "--------------------------------------" << endl;
|
||||||
|
out << left;
|
||||||
|
out << "Uncompress Compressed Uncompress Compressed Neighborless";
|
||||||
|
out << endl;
|
||||||
|
out << "Vars Vars Factors Factors Vars" ;
|
||||||
|
out << endl;
|
||||||
|
for (unsigned i = 0; i < compressInfo_.size(); i++) {
|
||||||
|
out << setw (13) << compressInfo_[i].nUncVars;
|
||||||
|
out << setw (13) << compressInfo_[i].nCompVars;
|
||||||
|
out << setw (13) << compressInfo_[i].nUncFactors;
|
||||||
|
out << setw (13) << compressInfo_[i].nCompFactors;
|
||||||
|
out << setw (13) << compressInfo_[i].nNeighborlessVars;
|
||||||
|
out << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,20 +186,12 @@ class Statistics
|
|||||||
return it->second.counting;
|
return it->second.counting;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void updateIterations (unsigned nIters)
|
|
||||||
{
|
|
||||||
totalOfIterations += nIters;
|
|
||||||
if (nIters > maxIterations) {
|
|
||||||
maxIterations = nIters;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void writeStats (void)
|
static void writeStats (void)
|
||||||
{
|
{
|
||||||
ofstream out ("../../stats.txt");
|
ofstream out ("../../stats.txt");
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file to write at " ;
|
||||||
cerr << "Statistics:::updateStats()" << endl;
|
cerr << "Statistics::updateStats()" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
unsigned avgIterations = 0;
|
unsigned avgIterations = 0;
|
||||||
@ -117,17 +211,24 @@ class Statistics
|
|||||||
out << " average iterations: " << avgIterations << endl;
|
out << " average iterations: " << avgIterations << endl;
|
||||||
out << "total solving time " << totalSolvingTime << endl;
|
out << "total solving time " << totalSolvingTime << endl;
|
||||||
out << endl;
|
out << endl;
|
||||||
out << "Network Size\tCounting\tSolving Time\tAverage Time" << endl;
|
out << left << endl;
|
||||||
|
out << setw (15) << "Network Size" ;
|
||||||
|
out << setw (15) << "Counting" ;
|
||||||
|
out << setw (15) << "Solving Time" ;
|
||||||
|
out << setw (15) << "Average Time" ;
|
||||||
|
out << setw (15) << "#Iterations" ;
|
||||||
|
out << endl;
|
||||||
for (StatisticMap::iterator it = stats_.begin();
|
for (StatisticMap::iterator it = stats_.begin();
|
||||||
it != stats_.end(); it++) {
|
it != stats_.end(); it++) {
|
||||||
out << it->first;
|
out << setw (15) << it->first;
|
||||||
out << "\t\t" << it->second.counting;
|
out << setw (15) << it->second.counting;
|
||||||
out << "\t\t" << it->second.solvingTime;
|
out << setw (15) << it->second.solvingTime;
|
||||||
if (it->second.counting > 0) {
|
if (it->second.counting > 0) {
|
||||||
out << "\t\t" << it->second.solvingTime / it->second.counting;
|
out << setw (15) << it->second.solvingTime / it->second.counting;
|
||||||
} else {
|
} else {
|
||||||
out << "\t\t0.0" ;
|
out << setw (15) << "0.0" ;
|
||||||
}
|
}
|
||||||
|
out << setw (15) << it->second.nIters;
|
||||||
out << endl;
|
out << endl;
|
||||||
}
|
}
|
||||||
out.close();
|
out.close();
|
||||||
@ -142,62 +243,8 @@ class Statistics
|
|||||||
static StatisticMap stats_;
|
static StatisticMap stats_;
|
||||||
static unsigned maxIterations;
|
static unsigned maxIterations;
|
||||||
static unsigned totalOfIterations;
|
static unsigned totalOfIterations;
|
||||||
|
static vector<CompressInfo> compressInfo_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#endif //BP_SHARED_H
|
||||||
|
|
||||||
class Util
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
static void normalize (ParamSet& v)
|
|
||||||
{
|
|
||||||
double sum = 0.0;
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
sum += v[i];
|
|
||||||
}
|
|
||||||
assert (sum != 0.0);
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] /= sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static double getL1dist (const ParamSet& v1, const ParamSet& v2)
|
|
||||||
{
|
|
||||||
assert (v1.size() == v2.size());
|
|
||||||
double dist = 0.0;
|
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
dist += abs (v1[i] - v2[i]);
|
|
||||||
}
|
|
||||||
return dist;
|
|
||||||
}
|
|
||||||
|
|
||||||
static double getMaxNorm (const ParamSet& v1, const ParamSet& v2)
|
|
||||||
{
|
|
||||||
assert (v1.size() == v2.size());
|
|
||||||
double max = 0.0;
|
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
double diff = abs (v1[i] - v2[i]);
|
|
||||||
if (diff > max) {
|
|
||||||
max = diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return max;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool isInteger (const string& s)
|
|
||||||
{
|
|
||||||
stringstream ss1 (s);
|
|
||||||
stringstream ss2;
|
|
||||||
int integer;
|
|
||||||
ss1 >> integer;
|
|
||||||
ss2 << integer;
|
|
||||||
return (ss1.str() == ss2.str());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
//unsigned Statistics::totalOfIterations = 0;
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
@ -15,19 +15,30 @@ class Solver
|
|||||||
{
|
{
|
||||||
gm_ = gm;
|
gm_ = gm;
|
||||||
}
|
}
|
||||||
|
virtual ~Solver() {} // to call subclass destructor
|
||||||
virtual void runSolver (void) = 0;
|
virtual void runSolver (void) = 0;
|
||||||
virtual ParamSet getPosterioriOf (const Variable*) const = 0;
|
virtual ParamSet getPosterioriOf (Vid) const = 0;
|
||||||
|
virtual ParamSet getJointDistributionOf (const VidSet&) = 0;
|
||||||
|
|
||||||
void printPosterioriOf (const Variable* var) const
|
void printAllPosterioris (void) const
|
||||||
{
|
{
|
||||||
|
VarSet vars = gm_->getVariables();
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
printPosterioriOf (vars[i]->getVarId());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void printPosterioriOf (Vid vid) const
|
||||||
|
{
|
||||||
|
Variable* var = gm_->getVariable (vid);
|
||||||
cout << endl;
|
cout << endl;
|
||||||
cout << setw (20) << left << var->getLabel() << "posteriori" ;
|
cout << setw (20) << left << var->getLabel() << "posteriori" ;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
cout << "------------------------------" ;
|
cout << "------------------------------" ;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
const Domain& domain = var->getDomain();
|
const Domain& domain = var->getDomain();
|
||||||
ParamSet results = getPosterioriOf (var);
|
ParamSet results = getPosterioriOf (vid);
|
||||||
for (int xi = 0; xi < var->getDomainSize(); xi++) {
|
for (unsigned xi = 0; xi < var->getDomainSize(); xi++) {
|
||||||
cout << setw (20) << domain[xi];
|
cout << setw (20) << domain[xi];
|
||||||
cout << setprecision (PRECISION) << results[xi];
|
cout << setprecision (PRECISION) << results[xi];
|
||||||
cout << endl;
|
cout << endl;
|
||||||
@ -35,16 +46,35 @@ class Solver
|
|||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void printAllPosterioris (void) const
|
void printJointDistributionOf (const VidSet& vids)
|
||||||
{
|
{
|
||||||
VarSet vars = gm_->getVariables();
|
const ParamSet& jointDist = getJointDistributionOf (vids);
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
cout << endl;
|
||||||
printPosterioriOf (vars[i]);
|
cout << "joint distribution of " ;
|
||||||
|
VarSet vars;
|
||||||
|
for (unsigned i = 0; i < vids.size() - 1; i++) {
|
||||||
|
Variable* var = gm_->getVariable (vids[i]);
|
||||||
|
cout << var->getLabel() << ", " ;
|
||||||
|
vars.push_back (var);
|
||||||
}
|
}
|
||||||
|
Variable* var = gm_->getVariable (vids[vids.size() - 1]);
|
||||||
|
cout << var->getLabel() ;
|
||||||
|
vars.push_back (var);
|
||||||
|
cout << endl;
|
||||||
|
cout << "------------------------------" ;
|
||||||
|
cout << endl;
|
||||||
|
const vector<string>& domainConfs = Util::getInstantiations (vars);
|
||||||
|
for (unsigned i = 0; i < jointDist.size(); i++) {
|
||||||
|
cout << left << setw (20) << domainConfs[i];
|
||||||
|
cout << setprecision (PRECISION) << jointDist[i];
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const GraphicalModel* gm_;
|
const GraphicalModel* gm_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif //BP_SOLVER_H
|
||||||
|
|
||||||
|
191
packages/CLPBN/clpbn/bp/Util.cpp
Normal file
191
packages/CLPBN/clpbn/bp/Util.cpp
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "Variable.h"
|
||||||
|
#include "Shared.h"
|
||||||
|
|
||||||
|
namespace SolverOptions {
|
||||||
|
|
||||||
|
bool runBayesBall = false;
|
||||||
|
bool convertBn2Fg = true;
|
||||||
|
bool compressFactorGraph = true;
|
||||||
|
Schedule schedule = S_SEQ_FIXED;
|
||||||
|
//Schedule schedule = S_SEQ_RANDOM;
|
||||||
|
//Schedule schedule = S_PARALLEL;
|
||||||
|
//Schedule schedule = S_MAX_RESIDUAL;
|
||||||
|
double accuracy = 0.0001;
|
||||||
|
unsigned maxIter = 1000; //FIXME
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned Statistics::numCreatedNets = 0;
|
||||||
|
unsigned Statistics::numSolvedPolyTrees = 0;
|
||||||
|
unsigned Statistics::numSolvedLoopyNets = 0;
|
||||||
|
unsigned Statistics::numUnconvergedRuns = 0;
|
||||||
|
unsigned Statistics::maxIterations = 0;
|
||||||
|
unsigned Statistics::totalOfIterations = 0;
|
||||||
|
vector<CompressInfo> Statistics::compressInfo_;
|
||||||
|
StatisticMap Statistics::stats_;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
namespace Util {
|
||||||
|
|
||||||
|
void
|
||||||
|
normalize (ParamSet& v)
|
||||||
|
{
|
||||||
|
double sum = 0.0;
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
sum += v[i];
|
||||||
|
}
|
||||||
|
assert (sum != 0.0);
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
v[i] /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
pow (ParamSet& v, unsigned expoent)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
double value = 1;
|
||||||
|
for (unsigned j = 0; j < expoent; j++) {
|
||||||
|
value *= v[i];
|
||||||
|
}
|
||||||
|
v[i] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
getL1dist (const ParamSet& v1, const ParamSet& v2)
|
||||||
|
{
|
||||||
|
assert (v1.size() == v2.size());
|
||||||
|
double dist = 0.0;
|
||||||
|
for (unsigned i = 0; i < v1.size(); i++) {
|
||||||
|
dist += abs (v1[i] - v2[i]);
|
||||||
|
}
|
||||||
|
return dist;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
getMaxNorm (const ParamSet& v1, const ParamSet& v2)
|
||||||
|
{
|
||||||
|
assert (v1.size() == v2.size());
|
||||||
|
double max = 0.0;
|
||||||
|
for (unsigned i = 0; i < v1.size(); i++) {
|
||||||
|
double diff = abs (v1[i] - v2[i]);
|
||||||
|
if (diff > max) {
|
||||||
|
max = diff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
isInteger (const string& s)
|
||||||
|
{
|
||||||
|
stringstream ss1 (s);
|
||||||
|
stringstream ss2;
|
||||||
|
int integer;
|
||||||
|
ss1 >> integer;
|
||||||
|
ss2 << integer;
|
||||||
|
return (ss1.str() == ss2.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
parametersToString (CParamSet v)
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "[" ;
|
||||||
|
for (unsigned i = 0; i < v.size() - 1; i++) {
|
||||||
|
ss << v[i] << ", " ;
|
||||||
|
}
|
||||||
|
if (v.size() != 0) {
|
||||||
|
ss << v[v.size() - 1];
|
||||||
|
}
|
||||||
|
ss << "]" ;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<DConf>
|
||||||
|
getDomainConfigurations (const VarSet& vars)
|
||||||
|
{
|
||||||
|
unsigned nConfs = 1;
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
nConfs *= vars[i]->getDomainSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<DConf> confs (nConfs);
|
||||||
|
for (unsigned i = 0; i < nConfs; i++) {
|
||||||
|
confs[i].resize (vars.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned nReps = 1;
|
||||||
|
for (int i = vars.size() - 1; i >= 0; i--) {
|
||||||
|
unsigned index = 0;
|
||||||
|
while (index < nConfs) {
|
||||||
|
for (unsigned j = 0; j < vars[i]->getDomainSize(); j++) {
|
||||||
|
for (unsigned r = 0; r < nReps; r++) {
|
||||||
|
confs[index][i] = j;
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nReps *= vars[i]->getDomainSize();
|
||||||
|
}
|
||||||
|
return confs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
vector<string>
|
||||||
|
getInstantiations (const VarSet& vars)
|
||||||
|
{
|
||||||
|
//FIXME handle variables without domain
|
||||||
|
/*
|
||||||
|
char c = 'a' ;
|
||||||
|
const DConf& conf = entries[i].getDomainConfiguration();
|
||||||
|
for (unsigned j = 0; j < conf.size(); j++) {
|
||||||
|
if (j != 0) ss << "," ;
|
||||||
|
ss << c << conf[j] + 1;
|
||||||
|
c ++;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
unsigned rowSize = 1;
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
rowSize *= vars[i]->getDomainSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<string> headers (rowSize);
|
||||||
|
|
||||||
|
unsigned nReps = 1;
|
||||||
|
for (int i = vars.size() - 1; i >= 0; i--) {
|
||||||
|
Domain domain = vars[i]->getDomain();
|
||||||
|
unsigned index = 0;
|
||||||
|
while (index < rowSize) {
|
||||||
|
for (unsigned j = 0; j < vars[i]->getDomainSize(); j++) {
|
||||||
|
for (unsigned r = 0; r < nReps; r++) {
|
||||||
|
if (headers[index] != "") {
|
||||||
|
headers[index] = domain[j] + ", " + headers[index];
|
||||||
|
} else {
|
||||||
|
headers[index] = domain[j];
|
||||||
|
}
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nReps *= vars[i]->getDomainSize();
|
||||||
|
}
|
||||||
|
return headers;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -1,9 +1,10 @@
|
|||||||
#ifndef BP_GENERIC_VARIABLE_H
|
#ifndef BP_VARIABLE_H
|
||||||
#define BP_GENERIC_VARIABLE_H
|
#define BP_VARIABLE_H
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include "Shared.h"
|
#include "Shared.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -12,33 +13,61 @@ class Variable
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
Variable (unsigned varId)
|
Variable (const Variable* v)
|
||||||
{
|
{
|
||||||
this->varId_ = varId;
|
vid_ = v->getVarId();
|
||||||
this->dsize_ = 0;
|
dsize_ = v->getDomainSize();
|
||||||
this->evidence_ = -1;
|
if (v->hasDomain()) {
|
||||||
this->label_ = 0;
|
domain_ = v->getDomain();
|
||||||
|
dsize_ = domain_.size();
|
||||||
|
} else {
|
||||||
|
dsize_ = v->getDomainSize();
|
||||||
|
}
|
||||||
|
evidence_ = v->getEvidence();
|
||||||
|
if (v->hasLabel()) {
|
||||||
|
label_ = new string (v->getLabel());
|
||||||
|
} else {
|
||||||
|
label_ = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Variable (unsigned varId, unsigned dsize, int evidence = -1)
|
Variable (Vid vid)
|
||||||
|
{
|
||||||
|
this->vid_ = vid;
|
||||||
|
this->dsize_ = 0;
|
||||||
|
this->evidence_ = NO_EVIDENCE;
|
||||||
|
this->label_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Variable (Vid vid, unsigned dsize, int evidence = NO_EVIDENCE,
|
||||||
|
const string& lbl = string())
|
||||||
{
|
{
|
||||||
assert (dsize != 0);
|
assert (dsize != 0);
|
||||||
assert (evidence < (int)dsize);
|
assert (evidence < (int)dsize);
|
||||||
this->varId_ = varId;
|
this->vid_ = vid;
|
||||||
this->dsize_ = dsize;
|
this->dsize_ = dsize;
|
||||||
this->evidence_ = evidence;
|
this->evidence_ = evidence;
|
||||||
this->label_ = 0;
|
if (!lbl.empty()) {
|
||||||
|
this->label_ = new string (lbl);
|
||||||
|
} else {
|
||||||
|
this->label_ = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Variable (unsigned varId, const Domain& domain, int evidence = -1)
|
Variable (Vid vid, const Domain& domain, int evidence = NO_EVIDENCE,
|
||||||
|
const string& lbl = string())
|
||||||
{
|
{
|
||||||
assert (!domain.empty());
|
assert (!domain.empty());
|
||||||
assert (evidence < (int)domain.size());
|
assert (evidence < (int)domain.size());
|
||||||
this->varId_ = varId;
|
this->vid_ = vid;
|
||||||
this->dsize_ = domain.size();
|
this->dsize_ = domain.size();
|
||||||
this->domain_ = domain;
|
this->domain_ = domain;
|
||||||
this->evidence_ = evidence;
|
this->evidence_ = evidence;
|
||||||
this->label_ = 0;
|
if (!lbl.empty()) {
|
||||||
|
this->label_ = new string (lbl);
|
||||||
|
} else {
|
||||||
|
this->label_ = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~Variable (void)
|
~Variable (void)
|
||||||
@ -46,19 +75,19 @@ class Variable
|
|||||||
delete label_;
|
delete label_;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned getVarId (void) const { return varId_; }
|
unsigned getVarId (void) const { return vid_; }
|
||||||
unsigned getIndex (void) const { return index_; }
|
unsigned getIndex (void) const { return index_; }
|
||||||
void setIndex (unsigned idx) { index_ = idx; }
|
void setIndex (unsigned idx) { index_ = idx; }
|
||||||
int getDomainSize (void) const { return dsize_; }
|
unsigned getDomainSize (void) const { return dsize_; }
|
||||||
bool hasEvidence (void) const { return evidence_ != -1; }
|
bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; }
|
||||||
int getEvidence (void) const { return evidence_; }
|
int getEvidence (void) const { return evidence_; }
|
||||||
bool hasDomain (void) { return !domain_.empty(); }
|
bool hasDomain (void) const { return !domain_.empty(); }
|
||||||
bool hasLabel (void) { return label_ != 0; }
|
bool hasLabel (void) const { return label_ != 0; }
|
||||||
|
|
||||||
bool isValidStateIndex (int index)
|
bool isValidStateIndex (int index)
|
||||||
{
|
{
|
||||||
return index >= 0 && index < dsize_;
|
return index >= 0 && index < (int)dsize_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isValidState (const string& state)
|
bool isValidState (const string& state)
|
||||||
{
|
{
|
||||||
@ -70,7 +99,7 @@ class Variable
|
|||||||
assert (dsize_ != 0);
|
assert (dsize_ != 0);
|
||||||
if (domain_.size() == 0) {
|
if (domain_.size() == 0) {
|
||||||
Domain d;
|
Domain d;
|
||||||
for (int i = 0; i < dsize_; i++) {
|
for (unsigned i = 0; i < dsize_; i++) {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "x" << i ;
|
ss << "x" << i ;
|
||||||
d.push_back (ss.str());
|
d.push_back (ss.str());
|
||||||
@ -110,7 +139,7 @@ class Variable
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void setLabel (string label)
|
void setLabel (const string& label)
|
||||||
{
|
{
|
||||||
label_ = new string (label);
|
label_ = new string (label);
|
||||||
}
|
}
|
||||||
@ -119,25 +148,25 @@ class Variable
|
|||||||
{
|
{
|
||||||
if (label_ == 0) {
|
if (label_ == 0) {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "v" << varId_;
|
ss << "v" << vid_;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
} else {
|
} else {
|
||||||
return *label_;
|
return *label_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
|
||||||
unsigned varId_;
|
|
||||||
string* label_;
|
|
||||||
unsigned index_;
|
|
||||||
int evidence_;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (Variable);
|
DISALLOW_COPY_AND_ASSIGN (Variable);
|
||||||
Domain domain_;
|
|
||||||
int dsize_;
|
Vid vid_;
|
||||||
|
unsigned dsize_;
|
||||||
|
int evidence_;
|
||||||
|
Domain domain_;
|
||||||
|
string* label_;
|
||||||
|
unsigned index_;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // BP_GENERIC_VARIABLE_H
|
#endif // BP_VARIABLE_H
|
||||||
|
|
||||||
|
34
packages/CLPBN/clpbn/bp/examples/1parentNchilds.yap
Executable file
34
packages/CLPBN/clpbn/bp/examples/1parentNchilds.yap
Executable file
@ -0,0 +1,34 @@
|
|||||||
|
|
||||||
|
:- use_module(library(clpbn)).
|
||||||
|
|
||||||
|
:- set_clpbn_flag(solver, bp).
|
||||||
|
|
||||||
|
%
|
||||||
|
% R
|
||||||
|
% / | \
|
||||||
|
% / | \
|
||||||
|
% A B C
|
||||||
|
%
|
||||||
|
|
||||||
|
|
||||||
|
r(R) :-
|
||||||
|
{ R = r with p([t, f], [0.35, 0.65]) }.
|
||||||
|
|
||||||
|
a(A) :-
|
||||||
|
r(R),
|
||||||
|
child_dist(R,Dist),
|
||||||
|
{ A = a with Dist }.
|
||||||
|
|
||||||
|
b(B) :-
|
||||||
|
r(R),
|
||||||
|
child_dist(R,Dist),
|
||||||
|
{ B = b with Dist }.
|
||||||
|
|
||||||
|
c(C) :-
|
||||||
|
r(R),
|
||||||
|
child_dist(R,Dist),
|
||||||
|
{ C = c with Dist }.
|
||||||
|
|
||||||
|
|
||||||
|
child_dist(R, p([t, f], [0.3, 0.4, 0.25, 0.05], [R])).
|
||||||
|
|
53
packages/CLPBN/clpbn/bp/examples/bp-example.xml
Executable file
53
packages/CLPBN/clpbn/bp/examples/bp-example.xml
Executable file
@ -0,0 +1,53 @@
|
|||||||
|
<?xml version="1.0" encoding="US-ASCII"?>
|
||||||
|
|
||||||
|
<!--
|
||||||
|
|
||||||
|
A B
|
||||||
|
\ /
|
||||||
|
\ /
|
||||||
|
C
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
<BIF VERSION="0.3">
|
||||||
|
<NETWORK>
|
||||||
|
<NAME>Neapolitan</NAME>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>A</NAME>
|
||||||
|
<OUTCOME>a1</OUTCOME>
|
||||||
|
<OUTCOME>a2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>B</NAME>
|
||||||
|
<OUTCOME>b1</OUTCOME>
|
||||||
|
<OUTCOME>b2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>C</NAME>
|
||||||
|
<OUTCOME>c1</OUTCOME>
|
||||||
|
<OUTCOME>c2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>A</FOR>
|
||||||
|
<TABLE> .695 .305 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>B</FOR>
|
||||||
|
<TABLE> .25 .75 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>C</FOR>
|
||||||
|
<GIVEN>A</GIVEN>
|
||||||
|
<GIVEN>B</GIVEN>
|
||||||
|
<TABLE> .2 .8 .45 .55 .32 .68 .7 .3 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
</NETWORK>
|
||||||
|
</BIF>
|
||||||
|
|
@ -9,10 +9,10 @@ MARKOV
|
|||||||
2 4 2
|
2 4 2
|
||||||
|
|
||||||
2
|
2
|
||||||
.001 .009
|
.001 .999
|
||||||
|
|
||||||
2
|
2
|
||||||
.002 .008
|
.002 .998
|
||||||
|
|
||||||
8
|
8
|
||||||
.95 .94 .29 .001
|
.95 .94 .29 .001
|
||||||
|
@ -49,12 +49,12 @@
|
|||||||
|
|
||||||
<DEFINITION>
|
<DEFINITION>
|
||||||
<FOR>B</FOR>
|
<FOR>B</FOR>
|
||||||
<TABLE> .001 .009 </TABLE>
|
<TABLE> .001 .999 </TABLE>
|
||||||
</DEFINITION>
|
</DEFINITION>
|
||||||
|
|
||||||
<DEFINITION>
|
<DEFINITION>
|
||||||
<FOR>E</FOR>
|
<FOR>E</FOR>
|
||||||
<TABLE> .002 .008 </TABLE>
|
<TABLE> .002 .998 </TABLE>
|
||||||
</DEFINITION>
|
</DEFINITION>
|
||||||
|
|
||||||
<DEFINITION>
|
<DEFINITION>
|
||||||
|
@ -1,54 +1,29 @@
|
|||||||
|
|
||||||
:- use_module(library(clpbn)).
|
:- use_module(library(clpbn)).
|
||||||
|
|
||||||
:- set_clpbn_flag(solver, vel).
|
:- set_clpbn_flag(solver, bp).
|
||||||
|
|
||||||
%
|
r(R) :- r_cpt(RCpt),
|
||||||
% B E
|
{ R = r with p([r1, r2], RCpt) }.
|
||||||
% \ /
|
|
||||||
% \ /
|
|
||||||
% A
|
|
||||||
% / \
|
|
||||||
% / \
|
|
||||||
% J M
|
|
||||||
%
|
|
||||||
|
|
||||||
|
t(T) :- t_cpt(TCpt),
|
||||||
|
{ T = t with p([t1, t2], TCpt) }.
|
||||||
|
|
||||||
b(B) :-
|
a(A) :- r(R), t(T), a_cpt(ACpt),
|
||||||
b_table(BDist),
|
{ A = a with p([a1, a2], ACpt, [R, T]) }.
|
||||||
{ B = b with p([b1, b2], BDist) }.
|
|
||||||
|
|
||||||
e(E) :-
|
j(J) :- a(A), j_cpt(JCpt),
|
||||||
e_table(EDist),
|
{ J = j with p([j1, j2], JCpt, [A]) }.
|
||||||
{ E = e with p([e1, e2], EDist) }.
|
|
||||||
|
|
||||||
a(A) :-
|
m(M) :- a(A), m_cpt(MCpt),
|
||||||
b(B),
|
{ M = m with p([m1, m2], MCpt, [A]) }.
|
||||||
e(E),
|
|
||||||
a_table(ADist),
|
|
||||||
{ A = a with p([a1, a2], ADist, [B, E]) }.
|
|
||||||
|
|
||||||
j(J):-
|
|
||||||
a(A),
|
|
||||||
j_table(JDist),
|
|
||||||
{ J = j with p([j1, j2], JDist, [A]) }.
|
|
||||||
|
|
||||||
m(M):-
|
|
||||||
a(A),
|
|
||||||
m_table(MDist),
|
|
||||||
{ M = m with p([m1, m2], MDist, [A]) }.
|
|
||||||
|
|
||||||
|
r_cpt([0.001, 0.999]).
|
||||||
b_table([0.001, 0.009]).
|
t_cpt([0.002, 0.998]).
|
||||||
|
a_cpt([0.95, 0.94, 0.29, 0.001,
|
||||||
e_table([0.002, 0.008]).
|
0.05, 0.06, 0.71, 0.999]).
|
||||||
|
j_cpt([0.9, 0.05,
|
||||||
a_table([0.95, 0.94, 0.29, 0.001,
|
0.1, 0.95]).
|
||||||
0.05, 0.06, 0.71, 0.999]).
|
m_cpt([0.7, 0.01,
|
||||||
|
0.3, 0.99]).
|
||||||
j_table([0.9, 0.05,
|
|
||||||
0.1, 0.95]).
|
|
||||||
|
|
||||||
m_table([0.7, 0.01,
|
|
||||||
0.3, 0.99]).
|
|
||||||
|
|
||||||
|
@ -16,34 +16,37 @@
|
|||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
<VARIABLE TYPE="nature">
|
||||||
<NAME>A</NAME>
|
<NAME>A</NAME>
|
||||||
<OUTCOME></OUTCOME>
|
<OUTCOME>a1</OUTCOME>
|
||||||
|
<OUTCOME>a2</OUTCOME>
|
||||||
</VARIABLE>
|
</VARIABLE>
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
<VARIABLE TYPE="nature">
|
||||||
<NAME>B</NAME>
|
<NAME>B</NAME>
|
||||||
<OUTCOME></OUTCOME>
|
<OUTCOME>b1</OUTCOME>
|
||||||
|
<OUTCOME>b2</OUTCOME>
|
||||||
</VARIABLE>
|
</VARIABLE>
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
<VARIABLE TYPE="nature">
|
||||||
<NAME>C</NAME>
|
<NAME>C</NAME>
|
||||||
<OUTCOME></OUTCOME>
|
<OUTCOME>c1</OUTCOME>
|
||||||
|
<OUTCOME>c2</OUTCOME>
|
||||||
</VARIABLE>
|
</VARIABLE>
|
||||||
|
|
||||||
<DEFINITION>
|
<DEFINITION>
|
||||||
<FOR>A</FOR>
|
<FOR>A</FOR>
|
||||||
<TABLE>1</TABLE>
|
<TABLE>.695 .305</TABLE>
|
||||||
</DEFINITION>
|
</DEFINITION>
|
||||||
|
|
||||||
<DEFINITION>
|
<DEFINITION>
|
||||||
<FOR>B</FOR>
|
<FOR>B</FOR>
|
||||||
<TABLE>1</TABLE>
|
<TABLE>0.25 0.75</TABLE>
|
||||||
</DEFINITION>
|
</DEFINITION>
|
||||||
|
|
||||||
<DEFINITION>
|
<DEFINITION>
|
||||||
<FOR>C</FOR>
|
<FOR>C</FOR>
|
||||||
<GIVEN>A</GIVEN>
|
<GIVEN>A</GIVEN>
|
||||||
<GIVEN>B</GIVEN>
|
<GIVEN>B</GIVEN>
|
||||||
<TABLE>1</TABLE>
|
<TABLE>0.2 0.8 0.45 0.55 0.32 0.68 0.7 0.3</TABLE>
|
||||||
</DEFINITION>
|
</DEFINITION>
|
||||||
|
|
||||||
</NETWORK>
|
</NETWORK>
|
||||||
|
67
packages/CLPBN/clpbn/bp/examples/lambda fail.xml
Executable file
67
packages/CLPBN/clpbn/bp/examples/lambda fail.xml
Executable file
@ -0,0 +1,67 @@
|
|||||||
|
<?xml version="1.0" encoding="US-ASCII"?>
|
||||||
|
|
||||||
|
<!--
|
||||||
|
|
||||||
|
P1 P2 P3
|
||||||
|
\ | /
|
||||||
|
\ | /
|
||||||
|
-
|
||||||
|
C
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
<BIF VERSION="0.3">
|
||||||
|
<NETWORK>
|
||||||
|
|
||||||
|
<NAME>Simple Convergence</NAME>
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>P1</NAME>
|
||||||
|
<OUTCOME>p1</OUTCOME>
|
||||||
|
<OUTCOME>p2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>P2</NAME>
|
||||||
|
<OUTCOME>p1</OUTCOME>
|
||||||
|
<OUTCOME>p2</OUTCOME>
|
||||||
|
<OUTCOME>p3</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>P3</NAME>
|
||||||
|
<OUTCOME>p1</OUTCOME>
|
||||||
|
<OUTCOME>p2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>C</NAME>
|
||||||
|
<OUTCOME>c1</OUTCOME>
|
||||||
|
<OUTCOME>c2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>P1</FOR>
|
||||||
|
<TABLE>.695 .305</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>P2</FOR>
|
||||||
|
<TABLE>0.2 0.3 0.5</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>P3</FOR>
|
||||||
|
<TABLE>0.25 0.75</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>C</FOR>
|
||||||
|
<GIVEN>P1</GIVEN>
|
||||||
|
<GIVEN>P2</GIVEN>
|
||||||
|
<GIVEN>P3</GIVEN>
|
||||||
|
<TABLE>0.2 0.8 0.45 0.55 0.32 0.68 0.7 0.3 0.3 0.7 0.55 0.45 0.22 0.78 0.25 0.75 0.11 0.89 0.34 0.66 0.1 0.9 0.6 0.4</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
</NETWORK>
|
||||||
|
</BIF>
|
||||||
|
|
@ -2,6 +2,7 @@
|
|||||||
:- use_module(library(clpbn)).
|
:- use_module(library(clpbn)).
|
||||||
|
|
||||||
:- set_clpbn_flag(solver, bp).
|
:- set_clpbn_flag(solver, bp).
|
||||||
|
%:- set_clpbn_flag(solver, jt).
|
||||||
|
|
||||||
%
|
%
|
||||||
% B F
|
% B F
|
||||||
|
17
packages/CLPBN/clpbn/bp/examples/sp-example.uai
Executable file
17
packages/CLPBN/clpbn/bp/examples/sp-example.uai
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
MARKOV
|
||||||
|
3
|
||||||
|
2 2 2
|
||||||
|
3
|
||||||
|
1 0
|
||||||
|
1 1
|
||||||
|
3 2 0 1
|
||||||
|
|
||||||
|
2
|
||||||
|
.695 .305
|
||||||
|
|
||||||
|
2
|
||||||
|
.25 .75
|
||||||
|
|
||||||
|
8
|
||||||
|
0.2 0.45 0.32 0.7
|
||||||
|
0.8 0.55 0.68 0.3
|
128
packages/CLPBN/clpbn/bp/examples/test_bn.xml
Executable file
128
packages/CLPBN/clpbn/bp/examples/test_bn.xml
Executable file
@ -0,0 +1,128 @@
|
|||||||
|
<?xml version="1.0" encoding="US-ASCII"?>
|
||||||
|
|
||||||
|
<!--
|
||||||
|
|
||||||
|
A B C
|
||||||
|
\ | /
|
||||||
|
\ | /
|
||||||
|
D
|
||||||
|
/ | \
|
||||||
|
/ | \
|
||||||
|
E F G
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
<BIF VERSION="0.3">
|
||||||
|
<NETWORK>
|
||||||
|
<NAME>Node with several parents and childs</NAME>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>A</NAME>
|
||||||
|
<OUTCOME>a1</OUTCOME>
|
||||||
|
<OUTCOME>a2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>B</NAME>
|
||||||
|
<OUTCOME>b1</OUTCOME>
|
||||||
|
<OUTCOME>b2</OUTCOME>
|
||||||
|
<OUTCOME>b3</OUTCOME>
|
||||||
|
<OUTCOME>b4</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>C</NAME>
|
||||||
|
<OUTCOME>c1</OUTCOME>
|
||||||
|
<OUTCOME>c2</OUTCOME>
|
||||||
|
<OUTCOME>c3</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>D</NAME>
|
||||||
|
<OUTCOME>d1</OUTCOME>
|
||||||
|
<OUTCOME>d2</OUTCOME>
|
||||||
|
<OUTCOME>d3</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>E</NAME>
|
||||||
|
<OUTCOME>e1</OUTCOME>
|
||||||
|
<OUTCOME>e2</OUTCOME>
|
||||||
|
<OUTCOME>e3</OUTCOME>
|
||||||
|
<OUTCOME>e4</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>F</NAME>
|
||||||
|
<OUTCOME>f1</OUTCOME>
|
||||||
|
<OUTCOME>f2</OUTCOME>
|
||||||
|
<OUTCOME>f3</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>G</NAME>
|
||||||
|
<OUTCOME>g1</OUTCOME>
|
||||||
|
<OUTCOME>g2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>A</FOR>
|
||||||
|
<TABLE> .1 .2 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>B</FOR>
|
||||||
|
<TABLE> .01 .02 .03 .04 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>C</FOR>
|
||||||
|
<TABLE> .11 .22 .33 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>D</FOR>
|
||||||
|
<GIVEN>A</GIVEN>
|
||||||
|
<GIVEN>B</GIVEN>
|
||||||
|
<GIVEN>C</GIVEN>
|
||||||
|
<TABLE>
|
||||||
|
.522 .008 .99 .01 .2 .8 .003 .457 .423 .007 .92 .04 .5 .232 .033 .227 .112 .048 .91 .21 .24 .18 .005 .227
|
||||||
|
.212 .04 .59 .21 .6 .1 .023 .215 .913 .017 .96 .01 .55 .422 .013 .417 .272 .068 .61 .11 .26 .28 .205 .322
|
||||||
|
.142 .028 .19 .11 .5 .67 .013 .437 .163 .067 .12 .06 .1 .262 .063 .167 .512 .028 .11 .41 .14 .68 .015 .92
|
||||||
|
</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>E</FOR>
|
||||||
|
<GIVEN>D</GIVEN>
|
||||||
|
<TABLE>
|
||||||
|
.111 .11 .1
|
||||||
|
.222 .22 .2
|
||||||
|
.333 .33 .3
|
||||||
|
.444 .44 .4
|
||||||
|
</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>F</FOR>
|
||||||
|
<GIVEN>D</GIVEN>
|
||||||
|
<TABLE>
|
||||||
|
.112 .111 .110
|
||||||
|
.223 .222 .221
|
||||||
|
.334 .333 .332
|
||||||
|
</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>G</FOR>
|
||||||
|
<GIVEN>D</GIVEN>
|
||||||
|
<TABLE>
|
||||||
|
.101 .102 .103
|
||||||
|
.201 .202 .203
|
||||||
|
</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
</NETWORK>
|
||||||
|
</BIF>
|
||||||
|
|
36
packages/CLPBN/clpbn/bp/examples/test_mk.uai
Executable file
36
packages/CLPBN/clpbn/bp/examples/test_mk.uai
Executable file
@ -0,0 +1,36 @@
|
|||||||
|
MARKOV
|
||||||
|
5
|
||||||
|
4 2 3 2 3
|
||||||
|
7
|
||||||
|
1 0
|
||||||
|
1 1
|
||||||
|
1 2
|
||||||
|
1 3
|
||||||
|
1 4
|
||||||
|
2 0 1
|
||||||
|
4 1 2 3 4
|
||||||
|
|
||||||
|
4
|
||||||
|
0.1 0.7 0.43 0.22
|
||||||
|
|
||||||
|
2
|
||||||
|
0.2 0.6
|
||||||
|
|
||||||
|
3
|
||||||
|
0.3 0.5 0.2
|
||||||
|
|
||||||
|
2
|
||||||
|
0.15 0.75
|
||||||
|
|
||||||
|
3
|
||||||
|
0.25 0.45 0.15
|
||||||
|
|
||||||
|
8
|
||||||
|
0.210 0.333 0.457 0.4
|
||||||
|
0.811 0.000 0.189 0.89
|
||||||
|
|
||||||
|
36
|
||||||
|
0.1 0.15 0.2 0.25 0.3 0.45 0.5 0.55 0.65 0.7 0.75 0.9
|
||||||
|
0.11 0.22 0.33 0.44 0.55 0.66 0.77 0.88 0.91 0.93 0.95 0.97
|
||||||
|
0.42 0.22 0.33 0.44 0.15 0.36 0.27 0.28 0.21 0.13 0.25 0.17
|
||||||
|
|
69
packages/CLPBN/clpbn/bp/examples/ve_example.xml
Executable file
69
packages/CLPBN/clpbn/bp/examples/ve_example.xml
Executable file
@ -0,0 +1,69 @@
|
|||||||
|
<?xml version="1.0" encoding="US-ASCII"?>
|
||||||
|
|
||||||
|
<!--
|
||||||
|
|
||||||
|
A B
|
||||||
|
\ /
|
||||||
|
\ /
|
||||||
|
C
|
||||||
|
|
|
||||||
|
|
|
||||||
|
D
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
|
||||||
|
<BIF VERSION="0.3">
|
||||||
|
<NETWORK>
|
||||||
|
<NAME>Simple Loop</NAME>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>A</NAME>
|
||||||
|
<OUTCOME>a1</OUTCOME>
|
||||||
|
<OUTCOME>a2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>B</NAME>
|
||||||
|
<OUTCOME>b1</OUTCOME>
|
||||||
|
<OUTCOME>b2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>C</NAME>
|
||||||
|
<OUTCOME>c1</OUTCOME>
|
||||||
|
<OUTCOME>c2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>D</NAME>
|
||||||
|
<OUTCOME>d1</OUTCOME>
|
||||||
|
<OUTCOME>d2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>A</FOR>
|
||||||
|
<TABLE> .001 .009 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>B</FOR>
|
||||||
|
<TABLE> .002 .008 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>C</FOR>
|
||||||
|
<GIVEN>A</GIVEN>
|
||||||
|
<GIVEN>B</GIVEN>
|
||||||
|
<TABLE> .95 .05 .94 .06 .29 .71 .001 .999 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>D</FOR>
|
||||||
|
<GIVEN>C</GIVEN>
|
||||||
|
<TABLE> .9 .1 .05 .95 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
</NETWORK>
|
||||||
|
</BIF>
|
||||||
|
|
@ -884,6 +884,15 @@ writePrimitive(term_t t, write_options *options)
|
|||||||
return writeString(t, options);
|
return writeString(t, options);
|
||||||
#endif /* O_STRING */
|
#endif /* O_STRING */
|
||||||
|
|
||||||
|
#if __YAP_PROLOG__
|
||||||
|
{
|
||||||
|
number n;
|
||||||
|
n.type = V_INTEGER;
|
||||||
|
n.value.i = 0;
|
||||||
|
return WriteNumber(&n, options);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
assert(0);
|
assert(0);
|
||||||
fail;
|
fail;
|
||||||
}
|
}
|
||||||
|
@ -1121,6 +1121,9 @@ Yap_StreamPosition(IOSTREAM *st)
|
|||||||
return StreamPosition(st);
|
return StreamPosition(st);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
IOSTREAM *STD_PROTO(Yap_Scurin, (void));
|
||||||
|
int STD_PROTO(Yap_dowrite, (Term, IOSTREAM *, int, int));
|
||||||
|
|
||||||
IOSTREAM *
|
IOSTREAM *
|
||||||
Yap_Scurin(void)
|
Yap_Scurin(void)
|
||||||
{
|
{
|
||||||
@ -1128,6 +1131,32 @@ Yap_Scurin(void)
|
|||||||
return Scurin;
|
return Scurin;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
Yap_dowrite(Term t, IOSTREAM *stream, int flags, int priority)
|
||||||
|
/* term to be written */
|
||||||
|
/* consumer */
|
||||||
|
/* write options */
|
||||||
|
{
|
||||||
|
CACHE_REGS
|
||||||
|
int swi_flags;
|
||||||
|
int res;
|
||||||
|
Int slot = Yap_InitSlot(t PASS_REGS);
|
||||||
|
|
||||||
|
swi_flags = 0;
|
||||||
|
if (flags & Quote_illegal_f)
|
||||||
|
swi_flags |= PL_WRT_QUOTED;
|
||||||
|
if (flags & Handle_vars_f)
|
||||||
|
swi_flags |= PL_WRT_NUMBERVARS;
|
||||||
|
if (flags & Use_portray_f)
|
||||||
|
swi_flags |= PL_WRT_PORTRAY;
|
||||||
|
if (flags & Ignore_ops_f)
|
||||||
|
swi_flags |= PL_WRT_IGNOREOPS;
|
||||||
|
|
||||||
|
res = PL_write_term(stream, slot, priority, swi_flags);
|
||||||
|
Yap_RecoverSlots(1 PASS_REGS);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#if THREADS
|
#if THREADS
|
||||||
|
|
||||||
@ -1178,6 +1207,7 @@ error:
|
|||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
recursiveMutexInit(recursiveMutex *m)
|
recursiveMutexInit(recursiveMutex *m)
|
||||||
{
|
{
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit cee6c346ba77e046ef1873b9d8c88c52c3baae2d
|
Subproject commit a5d2a3755f86ebffd8eb6e2a6921790d299eab30
|
@ -1 +1 @@
|
|||||||
Subproject commit f6fce313722d2f69e5ce6074f304ce05065bbf40
|
Subproject commit a6f0f4ec7d5fd51ca8b268b8392da9b20bfd1b44
|
@ -1 +1 @@
|
|||||||
Subproject commit 3f7eee803d46071f10804010f7266b3864bab7e6
|
Subproject commit 2a0843683e8790d8129fa4bad99c82a5fcaf441b
|
@ -1 +1 @@
|
|||||||
Subproject commit aa90b8f8a67e605e82566b719a8f20d125598bd2
|
Subproject commit 2daa5b9942a9fc22005ec48b6560f8596f33d995
|
@ -1 +1 @@
|
|||||||
Subproject commit 2229eb3807b21497ffd680686ceeeddeac9f57fb
|
Subproject commit e3c7eb9a9a54d6a9069396ecd1c3b537a8f6165a
|
@ -1 +1 @@
|
|||||||
Subproject commit babcbfe9cc5ab269cd5fd4f024e9d57bd3d0a8db
|
Subproject commit f1c3ef54f4d9431ba5b4188cb72ca3056d20b202
|
@ -1,59 +0,0 @@
|
|||||||
#
|
|
||||||
# default base directory for YAP installation
|
|
||||||
# (EROOT for architecture-dependent files)
|
|
||||||
#
|
|
||||||
prefix = @prefix@
|
|
||||||
exec_prefix = @exec_prefix@
|
|
||||||
ROOTDIR = $(prefix)
|
|
||||||
EROOTDIR = @exec_prefix@
|
|
||||||
abs_top_builddir = @abs_top_builddir@
|
|
||||||
#
|
|
||||||
# where the binary should be
|
|
||||||
#
|
|
||||||
BINDIR = $(EROOTDIR)/bin
|
|
||||||
#
|
|
||||||
# where YAP should look for libraries
|
|
||||||
#
|
|
||||||
LIBDIR=@libdir@
|
|
||||||
YAPLIBDIR=@libdir@/Yap
|
|
||||||
#
|
|
||||||
#
|
|
||||||
DEFS=@DEFS@ -D_YAP_NOT_INSTALLED_=1
|
|
||||||
CC=@CC@
|
|
||||||
CFLAGS= @SHLIB_CFLAGS@ $(YAP_EXTRAS) $(DEFS) -I$(srcdir) -I../.. -I$(srcdir)/../../include -I$(srcdir)/../PLStream -I$(srcdir)/../PLStream/windows -I$(srcdir)/../../H
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# You shouldn't need to change what follows.
|
|
||||||
#
|
|
||||||
INSTALL=@INSTALL@
|
|
||||||
INSTALL_DATA=@INSTALL_DATA@
|
|
||||||
INSTALL_PROGRAM=@INSTALL_PROGRAM@
|
|
||||||
SHELL=/bin/sh
|
|
||||||
RANLIB=@RANLIB@
|
|
||||||
srcdir=@srcdir@
|
|
||||||
SO=@SO@
|
|
||||||
#4.1VPATH=@srcdir@:@srcdir@/OPTYap
|
|
||||||
CWD=$(PWD)
|
|
||||||
#
|
|
||||||
|
|
||||||
OBJS=pl-tai.o
|
|
||||||
SOBJS=pl-tai.@SO@
|
|
||||||
|
|
||||||
#in some systems we just create a single object, in others we need to
|
|
||||||
# create a libray
|
|
||||||
|
|
||||||
all: $(SOBJS)
|
|
||||||
|
|
||||||
pl-tai.o: $(srcdir)/pl-tai.c
|
|
||||||
(cd libtai ; $(MAKE))
|
|
||||||
$(CC) -c $(CFLAGS) $(srcdir)/pl-tai.c -o pl-tai.o
|
|
||||||
|
|
||||||
@DO_SECOND_LD@pl-tai.@SO@: pl-tai.o
|
|
||||||
@DO_SECOND_LD@ @SHLIB_LD@ $(LDFLAGS) -o pl-tai.@SO@ pl-tai.o libtai/libtai.a @EXTRA_LIBS_FOR_SWIDLLS@
|
|
||||||
|
|
||||||
install: all
|
|
||||||
$(INSTALL_PROGRAM) $(SOBJS) $(DESTDIR)$(YAPLIBDIR)
|
|
||||||
|
|
||||||
clean:
|
|
||||||
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK
|
|
||||||
-(cd libtai && $(MAKE) clean)
|
|
@ -1 +1 @@
|
|||||||
Subproject commit deea4bfdf7041387e91eca37978a5c8db9287eda
|
Subproject commit 109bf1d224009dc049ab12a0fbbb50511ef8e1fb
|
@ -1050,6 +1050,8 @@ make :-
|
|||||||
fail.
|
fail.
|
||||||
make.
|
make.
|
||||||
|
|
||||||
|
make_library_index(_Directory).
|
||||||
|
|
||||||
'$file_name'(Stream,F) :-
|
'$file_name'(Stream,F) :-
|
||||||
stream_property(Stream, file_name(F)), !.
|
stream_property(Stream, file_name(F)), !.
|
||||||
'$file_name'(user_input,user_output).
|
'$file_name'(user_input,user_output).
|
||||||
|
@ -495,7 +495,7 @@ debugging :-
|
|||||||
'$continue_debugging'(no, '$execute_nonstop'(G,M)).
|
'$continue_debugging'(no, '$execute_nonstop'(G,M)).
|
||||||
'$spycall'(G, M, CalledFromDebugger, InRedo) :-
|
'$spycall'(G, M, CalledFromDebugger, InRedo) :-
|
||||||
'$flags'(G,M,F,F),
|
'$flags'(G,M,F,F),
|
||||||
F /\ 0x18402000 =\= 0, !, % dynamic procedure, logical semantics, user-C, or source
|
F /\ 0x08402000 =\= 0, !, % dynamic procedure, logical semantics, or source
|
||||||
% use the interpreter
|
% use the interpreter
|
||||||
CP is '$last_choice_pt',
|
CP is '$last_choice_pt',
|
||||||
'$clause'(G, M, Cl, _),
|
'$clause'(G, M, Cl, _),
|
||||||
|
@ -245,22 +245,15 @@ print_message(Severity, Msg) :-
|
|||||||
'$notrace'(user:portray_message(Severity, Msg)), !.
|
'$notrace'(user:portray_message(Severity, Msg)), !.
|
||||||
% This predicate has more hooks than a pirate ship!
|
% This predicate has more hooks than a pirate ship!
|
||||||
print_message(Severity, Term) :-
|
print_message(Severity, Term) :-
|
||||||
(
|
% first step at hook processing
|
||||||
(
|
'$message_to_lines'(Term, Lines),
|
||||||
'$oncenotrace'(user:generate_message_hook(Term, [], Lines)) ->
|
( nonvar(Term),
|
||||||
true
|
'$oncenotrace'(user:message_hook(Term, Severity, Lines))
|
||||||
;
|
->
|
||||||
'$oncenotrace'(prolog:message(Term, Lines, [])) ->
|
true
|
||||||
true
|
;
|
||||||
;
|
'$print_system_message'(Term, Severity, Lines)
|
||||||
'$messages':generate_message(Term, Lines, [])
|
), !.
|
||||||
)
|
|
||||||
-> ( nonvar(Term),
|
|
||||||
'$oncenotrace'(user:message_hook(Term, Severity, Lines))
|
|
||||||
-> !
|
|
||||||
; !, '$print_system_message'(Term, Severity, Lines)
|
|
||||||
)
|
|
||||||
).
|
|
||||||
print_message(silent, _) :- !.
|
print_message(silent, _) :- !.
|
||||||
print_message(_, error(syntax_error(syntax_error(_,between(_,L,_),_,_,_,_,StreamName)),_)) :- !,
|
print_message(_, error(syntax_error(syntax_error(_,between(_,L,_),_,_,_,_,StreamName)),_)) :- !,
|
||||||
format(user_error,'SYNTAX ERROR at ~a, close to ~d~n',[StreamName,L]).
|
format(user_error,'SYNTAX ERROR at ~a, close to ~d~n',[StreamName,L]).
|
||||||
@ -271,6 +264,14 @@ print_message(_, loaded(A, F, _, Time, Space)) :- !,
|
|||||||
print_message(_, Term) :-
|
print_message(_, Term) :-
|
||||||
format(user_error,'~q~n',[Term]).
|
format(user_error,'~q~n',[Term]).
|
||||||
|
|
||||||
|
'$message_to_lines'(Term, Lines) :-
|
||||||
|
'$oncenotrace'(user:generate_message_hook(Term, [], Lines)), !.
|
||||||
|
'$message_to_lines'(Term, Lines) :-
|
||||||
|
'$oncenotrace'(prolog:message(Term, Lines, [])), !.
|
||||||
|
'$message_to_lines'(Term, Lines) :-
|
||||||
|
'$messages':generate_message(Term, Lines, []), !.
|
||||||
|
|
||||||
|
|
||||||
% print_system_message(+Term, +Level, +Lines)
|
% print_system_message(+Term, +Level, +Lines)
|
||||||
%
|
%
|
||||||
% Print the message if the user did not intecept the message.
|
% Print the message if the user did not intecept the message.
|
||||||
|
@ -69,6 +69,10 @@ generate_message(debug) --> !,
|
|||||||
[ debug ].
|
[ debug ].
|
||||||
generate_message(trace) --> !,
|
generate_message(trace) --> !,
|
||||||
[ trace ].
|
[ trace ].
|
||||||
|
generate_message(error(Error,Context)) -->
|
||||||
|
{ Error = existence_error(procedure,_) }, !,
|
||||||
|
system_message(error(Error,Context)),
|
||||||
|
stack_dump(error(Error,Context)).
|
||||||
generate_message(error(Error,context(Cause,Extra))) -->
|
generate_message(error(Error,context(Cause,Extra))) -->
|
||||||
system_message(error(Error,Cause)),
|
system_message(error(Error,Cause)),
|
||||||
stack_dump(error(Error,context(Cause,Extra))).
|
stack_dump(error(Error,context(Cause,Extra))).
|
||||||
@ -130,8 +134,6 @@ system_message(no_match(P)) -->
|
|||||||
[ 'No matching predicate for ~w.' - [P] ].
|
[ 'No matching predicate for ~w.' - [P] ].
|
||||||
system_message(leash([A|B])) -->
|
system_message(leash([A|B])) -->
|
||||||
[ 'Leashing set to ~w.' - [[A|B]] ].
|
[ 'Leashing set to ~w.' - [[A|B]] ].
|
||||||
system_message(existence_error(prolog_flag,F)) -->
|
|
||||||
[ 'Prolog Flag ~w: new Prolog flags must be created using create_prolog_flag/3.' - [F] ].
|
|
||||||
system_message(singletons([SV],P)) -->
|
system_message(singletons([SV],P)) -->
|
||||||
[ 'Singleton variable ~s in ~q.' - [SV,P] ].
|
[ 'Singleton variable ~s in ~q.' - [SV,P] ].
|
||||||
system_message(singletons(SVs,P)) -->
|
system_message(singletons(SVs,P)) -->
|
||||||
@ -159,14 +161,18 @@ system_message(error(context_error(Goal,Who),Where)) -->
|
|||||||
system_message(error(domain_error(DomainType,Opt), Where)) -->
|
system_message(error(domain_error(DomainType,Opt), Where)) -->
|
||||||
[ 'DOMAIN ERROR- ~w: ' - Where],
|
[ 'DOMAIN ERROR- ~w: ' - Where],
|
||||||
domain_error(DomainType, Opt).
|
domain_error(DomainType, Opt).
|
||||||
|
system_message(error(existence_error(directory,Key), Where)) -->
|
||||||
|
[ 'EXISTENCE ERROR- ~w: ~w not an existing directory' - [Where,Key] ].
|
||||||
|
system_message(error(existence_error(key,Key), Where)) -->
|
||||||
|
[ 'EXISTENCE ERROR- ~w: ~w not an existing key' - [Where,Key] ].
|
||||||
|
system_message(existence_error(prolog_flag,F)) -->
|
||||||
|
[ 'Prolog Flag ~w: new Prolog flags must be created using create_prolog_flag/3.' - [F] ].
|
||||||
system_message(error(existence_error(prolog_flag,P), Where)) --> !,
|
system_message(error(existence_error(prolog_flag,P), Where)) --> !,
|
||||||
[ 'EXISTENCE ERROR- ~w: prolog flag ~w is undefined' - [Where,P] ].
|
[ 'EXISTENCE ERROR- ~w: prolog flag ~w is undefined' - [Where,P] ].
|
||||||
system_message(error(existence_error(procedure,P), context(Call,Parent))) --> !,
|
system_message(error(existence_error(procedure,P), context(Call,Parent))) --> !,
|
||||||
[ 'EXISTENCE ERROR- procedure ~w is undefined, called from context ~w~n Goal was ~w' - [P,Parent,Call] ].
|
[ 'EXISTENCE ERROR- procedure ~w is undefined, called from context ~w~n Goal was ~w' - [P,Parent,Call] ].
|
||||||
system_message(error(existence_error(stream,Stream), Where)) -->
|
system_message(error(existence_error(stream,Stream), Where)) -->
|
||||||
[ 'EXISTENCE ERROR- ~w: ~w not an open stream' - [Where,Stream] ].
|
[ 'EXISTENCE ERROR- ~w: ~w not an open stream' - [Where,Stream] ].
|
||||||
system_message(error(existence_error(key,Key), Where)) -->
|
|
||||||
[ 'EXISTENCE ERROR- ~w: ~w not an existing key' - [Where,Key] ].
|
|
||||||
system_message(error(existence_error(thread,Thread), Where)) -->
|
system_message(error(existence_error(thread,Thread), Where)) -->
|
||||||
[ 'EXISTENCE ERROR- ~w: ~w not a running thread' - [Where,Thread] ].
|
[ 'EXISTENCE ERROR- ~w: ~w not a running thread' - [Where,Thread] ].
|
||||||
system_message(error(existence_error(variable,Var), Where)) -->
|
system_message(error(existence_error(variable,Var), Where)) -->
|
||||||
|
Reference in New Issue
Block a user