Merge branch 'master' of /cygdrive/z/vitor/Yap/yap-6.3

This commit is contained in:
U-WIN-U2045GN0RNQ\Vítor Santos Costa 2011-07-25 17:09:43 +01:00
commit 4fe1833ece
81 changed files with 4239 additions and 2447 deletions

View File

@ -2081,8 +2081,13 @@ Yap_absmi(int inp)
goto failloop; goto failloop;
} else } else
#endif /* FROZEN_STACKS */ #endif /* FROZEN_STACKS */
if (IN_BETWEEN(H0,pt1,H) && IsAttVar(pt1)) if (IN_BETWEEN(H0,pt1,H)) {
goto failloop; if (IsAttVar(pt1)) {
goto failloop;
} else if (*pt1 == (CELL)FunctorBigInt) {
Yap_CleanOpaqueVariable(pt1);
}
}
#ifdef FROZEN_STACKS /* TRAIL */ #ifdef FROZEN_STACKS /* TRAIL */
/* don't reset frozen variables */ /* don't reset frozen variables */
if (pt0 < TR_FZ) if (pt0 < TR_FZ)

View File

@ -252,7 +252,7 @@ STATIC_PROTO(yamop *a_if, (op_numbers, union clause_obj *, int, yamop *, int, st
STATIC_PROTO(yamop *a_cut, (clause_info *,yamop *, int, struct intermediates *)); STATIC_PROTO(yamop *a_cut, (clause_info *,yamop *, int, struct intermediates *));
#ifdef YAPOR #ifdef YAPOR
STATIC_PROTO(yamop *a_try, (op_numbers, CELL, CELL, int, int, yamop *, int, struct intermediates *)); STATIC_PROTO(yamop *a_try, (op_numbers, CELL, CELL, int, int, yamop *, int, struct intermediates *));
STATIC_PROTO(yamop *a_either, (op_numbers, CELL, CELL, int, int, yamop *, int, struct intermediates *)); STATIC_PROTO(yamop *a_either, (op_numbers, CELL, CELL, int, yamop *, int, struct intermediates *));
#else #else
STATIC_PROTO(yamop *a_try, (op_numbers, CELL, CELL, yamop *, int, struct intermediates *)); STATIC_PROTO(yamop *a_try, (op_numbers, CELL, CELL, yamop *, int, struct intermediates *));
STATIC_PROTO(yamop *a_either, (op_numbers, CELL, CELL, yamop *, int, struct intermediates *)); STATIC_PROTO(yamop *a_either, (op_numbers, CELL, CELL, yamop *, int, struct intermediates *));
@ -2155,7 +2155,7 @@ a_try(op_numbers opcode, CELL lab, CELL opr, int nofalts, int hascut, yamop *cod
#endif #endif
#ifdef YAPOR #ifdef YAPOR
INIT_YAMOP_LTT(code_p, nofalts); INIT_YAMOP_LTT(code_p, nofalts);
if (hascut) if (cip->clause_has_cut)
PUT_YAMOP_CUT(code_p); PUT_YAMOP_CUT(code_p);
if (ap->PredFlags & SequentialPredFlag) if (ap->PredFlags & SequentialPredFlag)
PUT_YAMOP_SEQ(code_p); PUT_YAMOP_SEQ(code_p);
@ -2167,7 +2167,7 @@ a_try(op_numbers opcode, CELL lab, CELL opr, int nofalts, int hascut, yamop *cod
static yamop * static yamop *
#ifdef YAPOR #ifdef YAPOR
a_either(op_numbers opcode, CELL opr, CELL lab, int nofalts, int hascut, yamop *code_p, int pass_no, struct intermediates *cip) a_either(op_numbers opcode, CELL opr, CELL lab, int nofalts, yamop *code_p, int pass_no, struct intermediates *cip)
#else #else
a_either(op_numbers opcode, CELL opr, CELL lab, yamop *code_p, int pass_no, struct intermediates *cip) a_either(op_numbers opcode, CELL opr, CELL lab, yamop *code_p, int pass_no, struct intermediates *cip)
#endif /* YAPOR */ #endif /* YAPOR */
@ -2179,7 +2179,7 @@ a_either(op_numbers opcode, CELL opr, CELL lab, yamop *code_p, int pass_no, stru
code_p->u.Osblp.p0 = cip->CurrentPred; code_p->u.Osblp.p0 = cip->CurrentPred;
#ifdef YAPOR #ifdef YAPOR
INIT_YAMOP_LTT(code_p, nofalts); INIT_YAMOP_LTT(code_p, nofalts);
if (hascut) if (cip->clause_has_cut)
PUT_YAMOP_CUT(code_p); PUT_YAMOP_CUT(code_p);
if (cip->CurrentPred->PredFlags & SequentialPredFlag) if (cip->CurrentPred->PredFlags & SequentialPredFlag)
PUT_YAMOP_SEQ(code_p); PUT_YAMOP_SEQ(code_p);
@ -3583,7 +3583,7 @@ do_pass(int pass_no, yamop **entry_codep, int assembling, int *clause_has_blobsp
} }
code_p = a_either(_either, code_p = a_either(_either,
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2, -Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
Unsigned(cip->code_addr) + cip->label_offset[cip->cpc->rnd1], 0, 0, code_p, pass_no, cip); Unsigned(cip->code_addr) + cip->label_offset[cip->cpc->rnd1], 0, code_p, pass_no, cip);
#else #else
code_p = a_either(_either, code_p = a_either(_either,
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2, -Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
@ -3596,7 +3596,7 @@ do_pass(int pass_no, yamop **entry_codep, int assembling, int *clause_has_blobsp
either_inst[either_cont++] = code_p; either_inst[either_cont++] = code_p;
code_p = a_either(_or_else, code_p = a_either(_or_else,
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2, -Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
Unsigned(cip->code_addr) + cip->label_offset[cip->cpc->rnd1], 0, 0, code_p, pass_no, cip); Unsigned(cip->code_addr) + cip->label_offset[cip->cpc->rnd1], 0, code_p, pass_no, cip);
#else #else
code_p = a_either(_or_else, code_p = a_either(_or_else,
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2, -Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
@ -3608,7 +3608,7 @@ do_pass(int pass_no, yamop **entry_codep, int assembling, int *clause_has_blobsp
#ifdef YAPOR #ifdef YAPOR
if (pass_no) if (pass_no)
either_inst[either_cont++] = code_p; either_inst[either_cont++] = code_p;
code_p = a_either(_or_last, 0, 0, 0, 0, code_p, pass_no, cip); code_p = a_either(_or_last, 0, 0, 0, code_p, pass_no, cip);
if (pass_no) { if (pass_no) {
int cont = 1; int cont = 1;
do { do {

View File

@ -25,9 +25,10 @@ static char SccsId[] = "%W% %G%";
#include <string.h> #include <string.h>
#endif #endif
#include "YapHeap.h"
#ifdef USE_GMP #ifdef USE_GMP
#include "YapHeap.h"
#include "eval.h" #include "eval.h"
#include "alloc.h" #include "alloc.h"
@ -59,6 +60,7 @@ Yap_MkBigIntTerm(MP_INT *big)
return AbsAppl(ret); return AbsAppl(ret);
} }
MP_INT * MP_INT *
Yap_BigIntOfTerm(Term t) Yap_BigIntOfTerm(Term t)
{ {
@ -127,9 +129,60 @@ Yap_RatTermToApplTerm(Term t)
return Yap_MkApplTerm(FunctorRDiv,2,ts); return Yap_MkApplTerm(FunctorRDiv,2,ts);
} }
#endif #endif
Term
Yap_AllocExternalDataInStack(CELL tag, size_t bytes)
{
CACHE_REGS
Int nlimbs;
MP_INT *dst = (MP_INT *)(H+2);
CELL *ret = H;
nlimbs = ALIGN_YAPTYPE(bytes,CELL)/CellSize;
if (nlimbs > (ASP-ret)-1024) {
return TermNil;
}
H[0] = (CELL)FunctorBigInt;
H[1] = tag;
dst->_mp_size = 0;
dst->_mp_alloc = nlimbs;
H = (CELL *)(dst+1)+nlimbs;
H[0] = EndSpecials;
H++;
if (tag != EXTERNAL_BLOB) {
TrailTerm(TR) = AbsPair(ret);
TR++;
}
return AbsAppl(ret);
}
int Yap_CleanOpaqueVariable(CELL *pt)
{
CELL blob_info, blob_tag;
MP_INT *blobp;
#ifdef DEBUG
/* sanity checking */
if (pt[0] != (CELL)FunctorBigInt) {
Yap_Error(SYSTEM_ERROR, TermNil, "CleanOpaqueVariable bad call");
return FALSE;
}
#endif
blob_tag = pt[1];
if (blob_tag < USER_BLOB_START ||
blob_tag >= USER_BLOB_END) {
Yap_Error(SYSTEM_ERROR, AbsAppl(pt), "clean opaque: bad blob with tag " UInt_FORMAT ,blob_tag);
return FALSE;
}
blob_info = blob_tag - USER_BLOB_START;
if (!GLOBAL_OpaqueHandlers)
return FALSE;
blobp = (MP_INT *)(pt+2);
if (!GLOBAL_OpaqueHandlers[blob_info].fail_handler)
return TRUE;
return (GLOBAL_OpaqueHandlers[blob_info].fail_handler)((void *)(blobp+1));
}
Term Term
Yap_MkULLIntTerm(YAP_ULONG_LONG n) Yap_MkULLIntTerm(YAP_ULONG_LONG n)
{ {

View File

@ -390,6 +390,8 @@ X_API Bool STD_PROTO(YAP_IsDbRefTerm,(Term));
X_API Bool STD_PROTO(YAP_IsAtomTerm,(Term)); X_API Bool STD_PROTO(YAP_IsAtomTerm,(Term));
X_API Bool STD_PROTO(YAP_IsPairTerm,(Term)); X_API Bool STD_PROTO(YAP_IsPairTerm,(Term));
X_API Bool STD_PROTO(YAP_IsApplTerm,(Term)); X_API Bool STD_PROTO(YAP_IsApplTerm,(Term));
X_API Bool STD_PROTO(YAP_IsExternalDataInStackTerm,(Term));
X_API Bool STD_PROTO(YAP_IsOpaqueObjectTerm,(Term, int));
X_API Term STD_PROTO(YAP_MkIntTerm,(Int)); X_API Term STD_PROTO(YAP_MkIntTerm,(Int));
X_API Term STD_PROTO(YAP_MkBigNumTerm,(void *)); X_API Term STD_PROTO(YAP_MkBigNumTerm,(void *));
X_API Term STD_PROTO(YAP_MkRationalTerm,(void *)); X_API Term STD_PROTO(YAP_MkRationalTerm,(void *));
@ -463,7 +465,7 @@ X_API IOSTREAM *STD_PROTO(YAP_TermToStream,(Term));
X_API IOSTREAM *STD_PROTO(YAP_InitConsult,(int, char *)); X_API IOSTREAM *STD_PROTO(YAP_InitConsult,(int, char *));
X_API void STD_PROTO(YAP_EndConsult,(IOSTREAM *)); X_API void STD_PROTO(YAP_EndConsult,(IOSTREAM *));
X_API Term STD_PROTO(YAP_Read, (IOSTREAM *)); X_API Term STD_PROTO(YAP_Read, (IOSTREAM *));
X_API void STD_PROTO(YAP_Write, (Term, int (*)(wchar_t), int)); X_API void STD_PROTO(YAP_Write, (Term, IOSTREAM *, int));
X_API Term STD_PROTO(YAP_CopyTerm, (Term)); X_API Term STD_PROTO(YAP_CopyTerm, (Term));
X_API Term STD_PROTO(YAP_WriteBuffer, (Term, char *, unsigned int, int)); X_API Term STD_PROTO(YAP_WriteBuffer, (Term, char *, unsigned int, int));
X_API char *STD_PROTO(YAP_CompileClause, (Term)); X_API char *STD_PROTO(YAP_CompileClause, (Term));
@ -536,13 +538,11 @@ X_API Term STD_PROTO(YAP_ModuleUser,(void));
X_API Int STD_PROTO(YAP_NumberOfClausesForPredicate,(PredEntry *)); X_API Int STD_PROTO(YAP_NumberOfClausesForPredicate,(PredEntry *));
X_API int STD_PROTO(YAP_MaxOpPriority,(Atom, Term)); X_API int STD_PROTO(YAP_MaxOpPriority,(Atom, Term));
X_API int STD_PROTO(YAP_OpInfo,(Atom, Term, int, int *, int *)); X_API int STD_PROTO(YAP_OpInfo,(Atom, Term, int, int *, int *));
X_API Term STD_PROTO(YAP_AllocExternalDataInStack,(size_t));
static int (*do_putcf)(wchar_t); X_API void *STD_PROTO(YAP_ExternalDataInStackFromTerm,(Term));
X_API int STD_PROTO(YAP_NewOpaqueType,(void *));
static int do_yap_putc(int streamno,wchar_t ch) { X_API Term STD_PROTO(YAP_NewOpaqueObject,(int, size_t));
do_putcf(ch); X_API void *STD_PROTO(YAP_OpaqueObjectFromTerm,(Term));
return(ch);
}
static int static int
dogc(void) dogc(void)
@ -1677,7 +1677,6 @@ YAP_ExecuteOnCut(PredEntry *pe, CPredicate exec_code, struct cut_c_str *top)
if (pe->PredFlags & CArgsPredFlag) { if (pe->PredFlags & CArgsPredFlag) {
val = execute_cargs_back(pe, exec_code, ctx PASS_REGS); val = execute_cargs_back(pe, exec_code, ctx PASS_REGS);
} else { } else {
fprintf(stderr,"ctx=%p\n",ctx);
val = ((codev)(args-LCL0,0,ctx)); val = ((codev)(args-LCL0,0,ctx));
} }
/* make sure we clean up the frames left by the user */ /* make sure we clean up the frames left by the user */
@ -2314,6 +2313,65 @@ YAP_RunGoal(Term t)
return(out); return(out);
} }
X_API Term
YAP_AllocExternalDataInStack(size_t bytes)
{
Term t = Yap_AllocExternalDataInStack(EXTERNAL_BLOB, bytes);
if (t == TermNil)
return 0L;
return t;
}
X_API Bool
YAP_IsExternalDataInStackTerm(Term t)
{
return IsExternalBlobTerm(t, EXTERNAL_BLOB);
}
X_API void *
YAP_ExternalDataInStackFromTerm(Term t)
{
return ExternalBlobFromTerm (t);
}
int YAP_NewOpaqueType(void *f)
{
int i;
if (!GLOBAL_OpaqueHandlers) {
GLOBAL_OpaqueHandlers = malloc(sizeof(opaque_handler_t)*(USER_BLOB_END-USER_BLOB_START));
if (!GLOBAL_OpaqueHandlers) {
/* no room */
return -1;
}
} else if (GLOBAL_OpaqueHandlersCount == USER_BLOB_END-USER_BLOB_START) {
/* all types used */
return -1;
}
i = GLOBAL_OpaqueHandlersCount++;
memcpy(GLOBAL_OpaqueHandlers+i,f,sizeof(opaque_handler_t));
return i+USER_BLOB_START;
}
Term YAP_NewOpaqueObject(int tag, size_t bytes)
{
Term t = Yap_AllocExternalDataInStack((CELL)tag, bytes);
if (t == TermNil)
return 0L;
return t;
}
X_API Bool
YAP_IsOpaqueObjectTerm(Term t, int tag)
{
return IsExternalBlobTerm(t, (CELL)tag);
}
X_API void *
YAP_OpaqueObjectFromTerm(Term t)
{
return ExternalBlobFromTerm (t);
}
X_API Term X_API Term
YAP_RunGoalOnce(Term t) YAP_RunGoalOnce(Term t)
{ {
@ -2369,7 +2427,6 @@ YAP_RestartGoal(void)
BACKUP_MACHINE_REGS(); BACKUP_MACHINE_REGS();
if (LOCAL_AllowRestart) { if (LOCAL_AllowRestart) {
P = (yamop *)FAILCODE; P = (yamop *)FAILCODE;
do_putcf = myputc;
LOCAL_PrologMode = UserMode; LOCAL_PrologMode = UserMode;
out = Yap_exec_absmi(TRUE); out = Yap_exec_absmi(TRUE);
LOCAL_PrologMode = UserCCallMode; LOCAL_PrologMode = UserCCallMode;
@ -2589,12 +2646,11 @@ YAP_Read(IOSTREAM *inp)
} }
X_API void X_API void
YAP_Write(Term t, int (*myputc)(wchar_t), int flags) YAP_Write(Term t, IOSTREAM *stream, int flags)
{ {
BACKUP_MACHINE_REGS(); BACKUP_MACHINE_REGS();
do_putcf = myputc; /* */ Yap_dowrite (t, stream, flags, 1200);
Yap_plwrite (t, do_yap_putc, flags, 1200);
RECOVER_MACHINE_REGS(); RECOVER_MACHINE_REGS();
} }
@ -2774,6 +2830,7 @@ YAP_Init(YAP_init_args *yap_init)
Yap_init_yapor_global_local_memory(); Yap_init_yapor_global_local_memory();
LOCAL = REMOTE(0); LOCAL = REMOTE(0);
#endif /* YAPOR_COPY || YAPOR_COW || YAPOR_SBA */ #endif /* YAPOR_COPY || YAPOR_COW || YAPOR_SBA */
GLOBAL_PrologShouldHandleInterrupts = yap_init->PrologShouldHandleInterrupts;
Yap_InitSysbits(); /* init signal handling and time, required by later functions */ Yap_InitSysbits(); /* init signal handling and time, required by later functions */
GLOBAL_argv = yap_init->Argv; GLOBAL_argv = yap_init->Argv;
GLOBAL_argc = yap_init->Argc; GLOBAL_argc = yap_init->Argc;
@ -2821,7 +2878,6 @@ YAP_Init(YAP_init_args *yap_init)
} else { } else {
Heap = yap_init->HeapSize; Heap = yap_init->HeapSize;
} }
GLOBAL_PrologShouldHandleInterrupts = yap_init->PrologShouldHandleInterrupts;
Yap_InitWorkspace(Heap, Stack, Trail, Atts, Yap_InitWorkspace(Heap, Stack, Trail, Atts,
yap_init->MaxTableSpaceSize, yap_init->MaxTableSpaceSize,
yap_init->NumberWorkers, yap_init->NumberWorkers,

View File

@ -735,7 +735,7 @@ c_arg(Int argno, Term t, unsigned int arity, unsigned int level, compiler_struct
} else if (IsPairTerm(t)) { } else if (IsPairTerm(t)) {
cglobs->space_used += 2; cglobs->space_used += 2;
if (optimizer_on && level < 6) { if (optimizer_on && level < 6) {
#if !defined(THREADS) #if !defined(THREADS) && !defined(YAPOR)
/* discard code sharing because we cannot write on shared stuff */ /* discard code sharing because we cannot write on shared stuff */
if (!(cglobs->cint.CurrentPred->PredFlags & (DynamicPredFlag|LogUpdatePredFlag))) { if (!(cglobs->cint.CurrentPred->PredFlags & (DynamicPredFlag|LogUpdatePredFlag))) {
if (try_store_as_dbterm(t, argno, arity, level, cglobs)) if (try_store_as_dbterm(t, argno, arity, level, cglobs))

View File

@ -961,7 +961,7 @@ exec_absmi(int top USES_REGS)
restore_H(); restore_H();
/* set stack */ /* set stack */
ASP = (CELL *)PROTECT_FROZEN_B(B); ASP = (CELL *)PROTECT_FROZEN_B(B);
Yap_StartSlots( PASS_REGS1 ); Yap_PopSlots();
LOCK(LOCAL_SignalLock); LOCK(LOCAL_SignalLock);
/* forget any signals active, we're reborne */ /* forget any signals active, we're reborne */
LOCAL_ActiveSignals = 0; LOCAL_ActiveSignals = 0;
@ -991,9 +991,9 @@ exec_absmi(int top USES_REGS)
LOCAL_PrologMode = UserMode; LOCAL_PrologMode = UserMode;
} }
} else { } else {
Yap_CloseSlots( PASS_REGS1 );
LOCAL_PrologMode = UserMode; LOCAL_PrologMode = UserMode;
} }
Yap_CloseSlots( PASS_REGS1 );
YENV = ASP; YENV = ASP;
YENV[E_CB] = Unsigned (B); YENV[E_CB] = Unsigned (B);
out = Yap_absmi(0); out = Yap_absmi(0);

View File

@ -399,7 +399,7 @@ AdjustTrail(int adjusting_heap, int thread_copying USES_REGS)
#if defined(YAPOR_THREADS) #if defined(YAPOR_THREADS)
} }
#endif #endif
/* moving the trail is simple */ /* moving the trail is simple, yeaahhh! */
while (ptt != tr_base) { while (ptt != tr_base) {
register CELL reg = TrailTerm(ptt-1); register CELL reg = TrailTerm(ptt-1);
#ifdef FROZEN_STACKS #ifdef FROZEN_STACKS
@ -420,8 +420,6 @@ AdjustTrail(int adjusting_heap, int thread_copying USES_REGS)
} else if (IsPairTerm(reg)) { } else if (IsPairTerm(reg)) {
TrailTerm(ptt) = AdjustPair(reg PASS_REGS); TrailTerm(ptt) = AdjustPair(reg PASS_REGS);
#ifdef MULTI_ASSIGNMENT_VARIABLES /* does not work with new structures */ #ifdef MULTI_ASSIGNMENT_VARIABLES /* does not work with new structures */
/* check it whether we are protecting a
multi-assignment */
} else if (IsApplTerm(reg)) { } else if (IsApplTerm(reg)) {
TrailTerm(ptt) = AdjustAppl(reg PASS_REGS); TrailTerm(ptt) = AdjustAppl(reg PASS_REGS);
#endif #endif

View File

@ -1325,7 +1325,7 @@ mark_variable(CELL_PTR current USES_REGS)
sz++; sz++;
#if DEBUG #if DEBUG
if (next[sz] != EndSpecials) { if (next[sz] != EndSpecials) {
fprintf(stderr,"[ Error: could not find EndSpecials at blob %p type %lx ]\n", next, next[1]); fprintf(stderr,"[ Error: could not find EndSpecials at blob %p type " UInt_FORMAT " ]\n", next, next[1]);
} }
#endif #endif
MARK(next+sz); MARK(next+sz);
@ -1658,14 +1658,23 @@ mark_trail(tr_fr_ptr trail_ptr, tr_fr_ptr trail_base, CELL *gc_H, choiceptr gc_B
#endif #endif
} }
} else if (IsPairTerm(trail_cell)) { } else if (IsPairTerm(trail_cell)) {
/* can safely ignore this */ /* cannot safely ignore this */
CELL *cptr = RepPair(trail_cell); CELL *cptr = RepPair(trail_cell);
if (IN_BETWEEN(LOCAL_GlobalBase,cptr,H) && if (IN_BETWEEN(LOCAL_GlobalBase,cptr,H)) {
GlobalIsAttVar(cptr)) { if (GlobalIsAttVar(cptr)) {
TrailTerm(trail_base) = (CELL)cptr; TrailTerm(trail_base) = (CELL)cptr;
mark_external_reference(&TrailTerm(trail_base) PASS_REGS); mark_external_reference(&TrailTerm(trail_base) PASS_REGS);
TrailTerm(trail_base) = trail_cell; TrailTerm(trail_base) = trail_cell;
} } else if (*cptr == (CELL)FunctorBigInt) {
TrailTerm(trail_base) = AbsAppl(cptr);
mark_external_reference(&TrailTerm(trail_base) PASS_REGS);
TrailTerm(trail_base) = trail_cell;
}
#ifdef DEBUG
else
fprintf(GLOBAL_stderr,"OOPS in GC: weird trail entry at %p:" UInt_FORMAT "\n", &TrailTerm(trail_base), (CELL)cptr);
#endif
}
} }
#if MULTI_ASSIGNMENT_VARIABLES #if MULTI_ASSIGNMENT_VARIABLES
else { else {
@ -1904,11 +1913,11 @@ mark_choicepoints(register choiceptr gc_B, tr_fr_ptr saved_TR, int very_verbose
PredEntry *pe = Yap_PredForChoicePt(gc_B); PredEntry *pe = Yap_PredForChoicePt(gc_B);
#if defined(ANALYST) || defined(DEBUG) #if defined(ANALYST) || defined(DEBUG)
if (pe == NULL) { if (pe == NULL) {
fprintf(GLOBAL_stderr,"%% marked %ld (%s)\n", LOCAL_total_marked, Yap_op_names[opnum]); fprintf(GLOBAL_stderr,"%% marked " UInt_FORMAT " (%s)\n", LOCAL_total_marked, Yap_op_names[opnum]);
} else if (pe->ArityOfPE) { } else if (pe->ArityOfPE) {
fprintf(GLOBAL_stderr,"%% %s/%d marked %ld (%s)\n", RepAtom(NameOfFunctor(pe->FunctorOfPred))->StrOfAE, pe->ArityOfPE, LOCAL_total_marked, Yap_op_names[opnum]); fprintf(GLOBAL_stderr,"%% %s/%d marked " UInt_FORMAT " (%s)\n", RepAtom(NameOfFunctor(pe->FunctorOfPred))->StrOfAE, pe->ArityOfPE, LOCAL_total_marked, Yap_op_names[opnum]);
} else { } else {
fprintf(GLOBAL_stderr,"%% %s marked %ld (%s)\n", RepAtom((Atom)(pe->FunctorOfPred))->StrOfAE, LOCAL_total_marked, Yap_op_names[opnum]); fprintf(GLOBAL_stderr,"%% %s marked " UInt_FORMAT " (%s)\n", RepAtom((Atom)(pe->FunctorOfPred))->StrOfAE, LOCAL_total_marked, Yap_op_names[opnum]);
} }
#else #else
if (pe == NULL) { if (pe == NULL) {
@ -2450,11 +2459,19 @@ sweep_trail(choiceptr gc_B, tr_fr_ptr old_TR USES_REGS)
CELL *pt0 = RepPair(trail_cell); CELL *pt0 = RepPair(trail_cell);
CELL flags; CELL flags;
if (IN_BETWEEN(LOCAL_GlobalBase, pt0, H) && GlobalIsAttVar(pt0)) { if (IN_BETWEEN(LOCAL_GlobalBase, pt0, H)) {
TrailTerm(dest) = trail_cell; if (GlobalIsAttVar(pt0)) {
/* be careful with partial gc */ TrailTerm(dest) = trail_cell;
if (HEAP_PTR(TrailTerm(dest))) { /* be careful with partial gc */
into_relocation_chain(&TrailTerm(dest), GET_NEXT(trail_cell) PASS_REGS); if (HEAP_PTR(TrailTerm(dest))) {
into_relocation_chain(&TrailTerm(dest), GET_NEXT(trail_cell) PASS_REGS);
}
} else if (*pt0 == (CELL)FunctorBigInt) {
TrailTerm(dest) = trail_cell;
/* be careful with partial gc */
if (HEAP_PTR(TrailTerm(dest))) {
into_relocation_chain(&TrailTerm(dest), GET_NEXT(trail_cell) PASS_REGS);
}
} }
dest++; dest++;
trail_ptr++; trail_ptr++;

View File

@ -86,7 +86,9 @@ typedef enum
CLAUSE_LIST = 0x40, CLAUSE_LIST = 0x40,
BLOB_STRING = 0x80, /* SWI style strings */ BLOB_STRING = 0x80, /* SWI style strings */
BLOB_WIDE_STRING = 0x81, /* SWI style strings */ BLOB_WIDE_STRING = 0x81, /* SWI style strings */
EXTERNAL_BLOB = 0x100 /* for SWI emulation */ EXTERNAL_BLOB = 0x100, /* generic data */
USER_BLOB_START = 0x1000, /* user defined blob */
USER_BLOB_END = 0x1100 /* end of user defined blob */
} }
big_blob_type; big_blob_type;
@ -438,6 +440,25 @@ IsLargeNumTerm (Term t)
&& (FunctorOfTerm (t) >= FunctorLongInt))); && (FunctorOfTerm (t) >= FunctorLongInt)));
} }
inline EXTERN int IsExternalBlobTerm (Term, CELL);
inline EXTERN int
IsExternalBlobTerm (Term t, CELL tag)
{
return (int) (IsApplTerm (t) &&
FunctorOfTerm (t) == FunctorBigInt &&
RepAppl(t)[1] == tag);
}
inline EXTERN void *ExternalBlobFromTerm (Term);
inline EXTERN void *
ExternalBlobFromTerm (Term t)
{
MP_INT *base = (MP_INT *)(RepAppl(t)+2);
return (void *) (base+1);
}

View File

@ -33,6 +33,12 @@ typedef int (*SWI_PLGetStreamPositionFunction)(void *);
#include "../include/dswiatoms.h" #include "../include/dswiatoms.h"
typedef int (*Opaque_CallOnFail)(void *);
typedef struct opaque_handler_struct {
Opaque_CallOnFail fail_handler;
} opaque_handler_t;
#ifndef INT_KEYS_DEFAULT_SIZE #ifndef INT_KEYS_DEFAULT_SIZE
#define INT_KEYS_DEFAULT_SIZE 256 #define INT_KEYS_DEFAULT_SIZE 256
#endif #endif

View File

@ -120,6 +120,8 @@ int STD_PROTO(Yap_IsStringTerm, (Term));
int STD_PROTO(Yap_IsWideStringTerm, (Term)); int STD_PROTO(Yap_IsWideStringTerm, (Term));
Term STD_PROTO(Yap_RatTermToApplTerm, (Term)); Term STD_PROTO(Yap_RatTermToApplTerm, (Term));
void STD_PROTO(Yap_InitBigNums, (void)); void STD_PROTO(Yap_InitBigNums, (void));
Term STD_PROTO(Yap_AllocExternalDataInStack, (CELL, size_t));
int STD_PROTO(Yap_CleanOpaqueVariable, (CELL *));
/* c_interface.c */ /* c_interface.c */
Int STD_PROTO(YAP_Execute,(struct pred_entry *, CPredicate)); Int STD_PROTO(YAP_Execute,(struct pred_entry *, CPredicate));

View File

@ -102,6 +102,8 @@
#define GLOBAL_Executable Yap_global->Executable_ #define GLOBAL_Executable Yap_global->Executable_
#endif #endif
#define GLOBAL_OpaqueHandlersCount Yap_global->OpaqueHandlersCount_
#define GLOBAL_OpaqueHandlers Yap_global->OpaqueHandlers_
#if __simplescalar__ #if __simplescalar__
#define GLOBAL_pwd Yap_global->pwd_ #define GLOBAL_pwd Yap_global->pwd_
#endif #endif

View File

@ -102,6 +102,8 @@ typedef struct global_data {
char Executable_[YAP_FILENAME_MAX]; char Executable_[YAP_FILENAME_MAX];
#endif #endif
int OpaqueHandlersCount_;
struct opaque_handler_struct* OpaqueHandlers_;
#if __simplescalar__ #if __simplescalar__
char pwd_[YAP_FILENAME_MAX]; char pwd_[YAP_FILENAME_MAX];
#endif #endif

View File

@ -102,6 +102,8 @@ static void InitGlobal(void) {
#endif #endif
GLOBAL_OpaqueHandlersCount = 0;
GLOBAL_OpaqueHandlers = NULL;
#if __simplescalar__ #if __simplescalar__
#endif #endif

View File

@ -102,6 +102,8 @@ static void RestoreGlobal(void) {
#endif #endif
#if __simplescalar__ #if __simplescalar__
#endif #endif

View File

@ -67,6 +67,7 @@
#define YP_FILE FILE #define YP_FILE FILE
int STD_PROTO(YP_putc,(int, int)); int STD_PROTO(YP_putc,(int, int));
void STD_PROTO(Yap_dowrite, (Term, IOSTREAM *, int, int));
#else #else
@ -96,8 +97,6 @@ int STD_PROTO(YP_putc,(int, int));
#define fclose ERR_fclose #define fclose ERR_fclose
#define fflush ERR_fflush #define fflush ERR_fflush
/* flags for files in IOSTREAM struct */ /* flags for files in IOSTREAM struct */
#define _YP_IO_WRITE 1 #define _YP_IO_WRITE 1
#define _YP_IO_READ 2 #define _YP_IO_READ 2

View File

@ -154,18 +154,6 @@ static int SQLBINDCOL(SQLHSTMT sthandle,
return TRUE; return TRUE;
} }
static int SQLFREESTMT(SQLHSTMT sthandle,
SQLUSMALLINT opt,
char * print)
{
SQLRETURN retcode;
retcode = SQLFreeStmt(sthandle,opt);
if (retcode != SQL_SUCCESS && retcode != SQL_SUCCESS_WITH_INFO)
return odbc_error(SQL_HANDLE_STMT, sthandle, "SQLFreeStmt", print);
return TRUE;
}
static int SQLNUMRESULTCOLS(SQLHSTMT sthandle, static int SQLNUMRESULTCOLS(SQLHSTMT sthandle,
SQLSMALLINT * ncols, SQLSMALLINT * ncols,
char * print) char * print)
@ -421,8 +409,9 @@ c_db_odbc_query( USES_REGS1 ) {
/* +1 because of '\0' */ /* +1 because of '\0' */
bind_space = malloc(sizeof(char)*(ColumnSizePtr+1)); bind_space = malloc(sizeof(char)*(ColumnSizePtr+1));
data_info = malloc(sizeof(SQLINTEGER)); data_info = malloc(sizeof(SQLINTEGER));
if (!SQLBINDCOL(hstmt,i,SQL_C_CHAR,bind_space,(ColumnSizePtr+1),data_info,"db_query")) if (!SQLBINDCOL(hstmt,i,SQL_C_CHAR,bind_space,(ColumnSizePtr+1),data_info,"db_query")) {
return FALSE; return FALSE;
}
properties[0] = MkIntegerTerm((Int)bind_space); properties[0] = MkIntegerTerm((Int)bind_space);
properties[2] = MkIntegerTerm((Int)data_info); properties[2] = MkIntegerTerm((Int)data_info);
@ -444,7 +433,7 @@ c_db_odbc_query( USES_REGS1 ) {
{ {
if (!SQLCLOSECURSOR(hstmt,"db_query")) if (!SQLCLOSECURSOR(hstmt,"db_query"))
return FALSE; return FALSE;
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_query")) if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_query"))
return FALSE; return FALSE;
return FALSE; return FALSE;
} }
@ -466,7 +455,7 @@ c_db_odbc_number_of_fields( USES_REGS1 ) {
char sql[256]; char sql[256];
SQLSMALLINT number_fields; SQLSMALLINT number_fields;
sprintf(sql,"DESCRIBE %s",relation); sprintf(sql,"SELECT column_name from INFORMATION_SCHEMA.COLUMNS where table_name = \'%s\'",relation);
if (!SQLALLOCHANDLE(SQL_HANDLE_STMT, hdbc, &hstmt, "db_number_of_fields")) if (!SQLALLOCHANDLE(SQL_HANDLE_STMT, hdbc, &hstmt, "db_number_of_fields"))
return FALSE; return FALSE;
@ -482,7 +471,7 @@ c_db_odbc_number_of_fields( USES_REGS1 ) {
if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields")) if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields"))
return FALSE; return FALSE;
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_number_of_fields")) if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_number_of_fields"))
return FALSE; return FALSE;
if (!Yap_unify(arg_fields, MkIntegerTerm(number_fields))) if (!Yap_unify(arg_fields, MkIntegerTerm(number_fields)))
@ -506,7 +495,7 @@ c_db_odbc_get_attributes_types( USES_REGS1 ) {
Term head, list; Term head, list;
list = arg_types_list; list = arg_types_list;
sprintf(sql,"DESCRIBE %s",relation); sprintf(sql,"SELECT column_name,data_type FROM INFORMATION_SCHEMA.COLUMNS where table_name = \'%s\'",relation);
if (!SQLALLOCHANDLE(SQL_HANDLE_STMT, hdbc, &hstmt, "db_get_attributes_types")) if (!SQLALLOCHANDLE(SQL_HANDLE_STMT, hdbc, &hstmt, "db_get_attributes_types"))
return FALSE; return FALSE;
@ -547,7 +536,7 @@ c_db_odbc_get_attributes_types( USES_REGS1 ) {
if (!SQLCLOSECURSOR(hstmt,"db_get_attributes_types")) if (!SQLCLOSECURSOR(hstmt,"db_get_attributes_types"))
return FALSE; return FALSE;
if (!SQLFREESTMT(hstmt,SQL_CLOSE, "db_get_attributes_types")) if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_get_attributes_types"))
return FALSE; return FALSE;
return TRUE; return TRUE;
} }
@ -585,12 +574,31 @@ c_db_odbc_row_cut( USES_REGS1 ) {
if (!SQLCLOSECURSOR(hstmt,"db_row_cut")) if (!SQLCLOSECURSOR(hstmt,"db_row_cut"))
return FALSE; return FALSE;
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_row_cut")) if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_row_cut"))
return FALSE; return FALSE;
return TRUE; return TRUE;
} }
static int
release_list_args(Term arg_list_args, Term arg_bind_list, const char *error_msg)
{
Term list = arg_list_args;
Term list_bind = arg_bind_list;
while (IsPairTerm(list_bind))
{
Term head_bind = HeadOfTerm(list_bind);
list = TailOfTerm(list);
list_bind = TailOfTerm(list_bind);
free((char *)IntegerOfTerm(ArgOfTerm(1,head_bind)));
free((SQLINTEGER *)IntegerOfTerm(ArgOfTerm(3,head_bind)));
}
return TRUE;
}
/* db_row: ResultSet x BindList x ListOfArgs -> */ /* db_row: ResultSet x BindList x ListOfArgs -> */
static Int static Int
c_db_odbc_row( USES_REGS1 ) { c_db_odbc_row( USES_REGS1 ) {
@ -611,9 +619,12 @@ c_db_odbc_row( USES_REGS1 ) {
{ {
if (!SQLCLOSECURSOR(hstmt,"db_row")) if (!SQLCLOSECURSOR(hstmt,"db_row"))
return FALSE; return FALSE;
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_row")) if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_row"))
return FALSE; return FALSE;
if (!release_list_args(arg_list_args, arg_bind_list, "db_row")) {
return FALSE;
}
cut_fail(); cut_fail();
return FALSE; return FALSE;
} }
@ -699,7 +710,7 @@ c_db_odbc_number_of_fields_in_query( USES_REGS1 ) {
if (!Yap_unify(arg_fields, MkIntegerTerm(number_cols))){ if (!Yap_unify(arg_fields, MkIntegerTerm(number_cols))){
if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields_in_query")) if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields_in_query"))
return FALSE; return FALSE;
if (!SQLFREESTMT(hstmt,SQL_CLOSE, "db_number_of_fields_in_query")) if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_number_of_fields_in_query"))
return FALSE; return FALSE;
return FALSE; return FALSE;
@ -707,7 +718,7 @@ c_db_odbc_number_of_fields_in_query( USES_REGS1 ) {
if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields_in_query")) if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields_in_query"))
return FALSE; return FALSE;
if (!SQLFREESTMT(hstmt,SQL_CLOSE, "db_number_of_fields_in_query")) if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_number_of_fields_in_query"))
return FALSE; return FALSE;
return TRUE; return TRUE;
@ -775,7 +786,7 @@ c_db_odbc_get_fields_properties( USES_REGS1 ) {
if (!SQLCLOSECURSOR(hstmt2,"db_get_fields_properties")) if (!SQLCLOSECURSOR(hstmt2,"db_get_fields_properties"))
return FALSE; return FALSE;
if (!SQLFREESTMT(hstmt2,SQL_CLOSE,"db_get_fields_properties")) if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt2, "db_get_fields_properties"))
return FALSE; return FALSE;
for (i=1;i<=num_fields;i++) for (i=1;i<=num_fields;i++)
@ -816,9 +827,8 @@ c_db_odbc_get_fields_properties( USES_REGS1 ) {
if (!SQLCLOSECURSOR(hstmt,"db_get_fields_properties")) if (!SQLCLOSECURSOR(hstmt,"db_get_fields_properties"))
return FALSE; return FALSE;
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_get_fields_properties")) if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt2, "db_get_fields_properties"))
return FALSE; return FALSE;
return TRUE; return TRUE;
} }

7
configure vendored
View File

@ -9927,6 +9927,9 @@ mkdir -p packages/clib/maildrop
mkdir -p packages/clib/maildrop/rfc822 mkdir -p packages/clib/maildrop/rfc822
mkdir -p packages/clib/maildrop/rfc2045 mkdir -p packages/clib/maildrop/rfc2045
mkdir -p packages/CLPBN mkdir -p packages/CLPBN
mkdir -p packages/CLPBN/clpbn
mkdir -p packages/CLPBN/clpbn/bp
mkdir -p packages/CLPBN/clpbn/bp/xmlParser
mkdir -p packages/clpqr mkdir -p packages/clpqr
mkdir -p packages/cplint mkdir -p packages/cplint
mkdir -p packages/cplint/approx mkdir -p packages/cplint/approx
@ -10619,8 +10622,8 @@ esac
cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1
# Files that config.status was made for. # Files that config.status was made for.
config_files="`echo $ac_config_files`" config_files="$ac_config_files"
config_headers="`echo $ac_config_headers`" config_headers="$ac_config_headers"
_ACEOF _ACEOF

View File

@ -2103,6 +2103,9 @@ mkdir -p packages/clib/maildrop
mkdir -p packages/clib/maildrop/rfc822 mkdir -p packages/clib/maildrop/rfc822
mkdir -p packages/clib/maildrop/rfc2045 mkdir -p packages/clib/maildrop/rfc2045
mkdir -p packages/CLPBN mkdir -p packages/CLPBN
mkdir -p packages/CLPBN/clpbn
mkdir -p packages/CLPBN/clpbn/bp
mkdir -p packages/CLPBN/clpbn/bp/xmlParser
mkdir -p packages/clpqr mkdir -p packages/clpqr
mkdir -p packages/cplint mkdir -p packages/cplint
mkdir -p packages/cplint/approx mkdir -p packages/cplint/approx

View File

@ -16811,13 +16811,29 @@ only two boolean flags are accepted: @code{YAPC_ENABLE_GC} and
@code{YAPC_ENABLE_AGC}. The first enables/disables the standard garbage @code{YAPC_ENABLE_AGC}. The first enables/disables the standard garbage
collector, the second does the same for the atom garbage collector.` collector, the second does the same for the atom garbage collector.`
@item @code{YAP_TERM} YAP_AllocExternalDataInStack(@code{size_t bytes})
@item @code{void *} YAP_ExternalDataInStackFromTerm(@code{YAP_Term t})
@item @code{YAP_Bool} YAP_IsExternalDataInStackTerm(@code{YAP_Term t})
@findex YAP_AllocExternalDataInStack (C-Interface function)
The next routines allow one to store external data in the Prolog
execution stack. The first routine reserves space for @var{sz} bytes
and returns an opaque handle. The second routines receives the handle
and returns a pointer to the data. The last routine checks if a term
is an opaque handle.
Data will be automatically reclaimed during
backtracking. Also, this storage is opaque to the Prolog garbage compiler,
so it should not be used to store Prolog terms. On the other hand, it
may be useful to store arrays in a compact way, or pointers to external objects.
@item @code{int} YAP_HaltRegisterHook(@code{YAP_halt_hook f, void *closure}) @item @code{int} YAP_HaltRegisterHook(@code{YAP_halt_hook f, void *closure})
@findex YAP_HaltRegisterHook (C-Interface function) @findex YAP_HaltRegisterHook (C-Interface function)
Register the function @var{f} to be called if YAP is halted. The Register the function @var{f} to be called if YAP is halted. The
function is called with two arguments: the exit code of the process (@code{0} function is called with two arguments: the exit code of the process
if this cannot be determined on your operating system) and the closure (@code{0} if this cannot be determined on your operating system) and
argument @var{closure}. the closure argument @var{closure}.
@c See also @code{at_halt/1}. @c See also @code{at_halt/1}.
@end table @end table
@ -16850,6 +16866,7 @@ implementing the predicate and @var{arity} is its arity.
@findex YAP_UserBackCutCPredicate (C-Interface function) @findex YAP_UserBackCutCPredicate (C-Interface function)
@findex YAP_PRESERVE_DATA (C-Interface function) @findex YAP_PRESERVE_DATA (C-Interface function)
@findex YAP_PRESERVED_DATA (C-Interface function) @findex YAP_PRESERVED_DATA (C-Interface function)
@findex YAP_PRESERVED_DATA_CUT (C-Interface function)
@findex YAP_cutsucceed (C-Interface function) @findex YAP_cutsucceed (C-Interface function)
@findex YAP_cutfail (C-Interface function) @findex YAP_cutfail (C-Interface function)
For the second kind of predicates we need three C functions. The first one For the second kind of predicates we need three C functions. The first one
@ -16915,11 +16932,15 @@ static int start_n100(void)
@end example @end example
The routine starts by getting the dereference value of the argument. The routine starts by getting the dereference value of the argument.
The call to @code{YAP_PRESERVE_DATA} is used to initialize the memory which will The call to @code{YAP_PRESERVE_DATA} is used to initialize the memory
hold the information to be preserved across backtracking. The first which will hold the information to be preserved across
argument is the variable we shall use, and the second its type. Note backtracking. The first argument is the variable we shall use, and the
that we can only use @code{YAP_PRESERVE_DATA} once, so often we will second its type. Note that we can only use @code{YAP_PRESERVE_DATA}
want the variable to be a structure. once, so often we will want the variable to be a structure. This data
is visible to the garbage collector, so it should consist of Prolog
terms, as in the example. It is also correct to store pointers to
objects external to YAP stacks, as the garbage collector will ignore
such references.
If the argument of the predicate is a variable, the routine initializes the If the argument of the predicate is a variable, the routine initializes the
structure to be preserved across backtracking with the information structure to be preserved across backtracking with the information
@ -16988,6 +17009,34 @@ when pruning the execution of the predicate, @var{arity} is the
predicate arity, and @var{sizeof} is the size of the data to be predicate arity, and @var{sizeof} is the size of the data to be
preserved in the stack. In this example, we would have something like preserved in the stack. In this example, we would have something like
@example
void
init_n100(void)
@{
YAP_UserBackCutCPredicate("n100", start_n100, continue_n100, cut_n100, 1, 1);
@}
@end example
The argument before last is the predicate's arity. Notice again the
last argument to the call. function argument gives the extra space we
want to use for @code{PRESERVED_DATA}. Space is given in cells, where
a cell is the same size as a pointer. The garbage collector has access
to this space, hence users should use it either to store terms or to
store pointers to objects outside the stacks.
The code for @code{cut_n100} could be:
@example
static int cut_n100(void)
@{
YAP_PRESERVED_DATA_CUT(n100_data,n100_data_type*);
fprintf("n100 cut with counter %ld\n", YAP_IntOfTerm(n100_data->next_solution));
return TRUE;
@}
@end example
Notice that we have to use @code{YAP_PRESERVED_DATA_CUT}: this is
because the Prolog engine is at a different state during cut.
If no work is required at cut, we can use:
@example @example
void void
init_n100(void) init_n100(void)
@ -16995,8 +17044,7 @@ init_n100(void)
YAP_UserBackCutCPredicate("n100", start_n100, continue_n100, NULL, 1, 1); YAP_UserBackCutCPredicate("n100", start_n100, continue_n100, NULL, 1, 1);
@} @}
@end example @end example
Notice that we do not actually need to do anything on receiving a cut in in this case no code is executed at cut time.
this case.
@node Loading Objects, Save&Rest, Writing C, C-Interface @node Loading Objects, Save&Rest, Writing C, C-Interface
@section Loading Object Files @section Loading Object Files
@ -17272,9 +17320,9 @@ Associate the term @var{value} with the atom @var{at}. The term
@var{value} must be a constant. This functionality is used by YAP as a @var{value} must be a constant. This functionality is used by YAP as a
simple way for controlling and communicating with the Prolog run-time. simple way for controlling and communicating with the Prolog run-time.
@item @code{YAP_Term} YAP_Read(@code{int (*)(void)} @var{GetC}) @item @code{YAP_Term} YAP_Read(@code{IOSTREAM *Stream})
@findex YAP_Read/1 @findex YAP_Read/1
Parse a Term using the function @var{GetC} to input characters. Parse a @var{Term} from the stream @var{Stream}.
@item @code{YAP_Term} YAP_Write(@code{YAP_Term} @var{t}) @item @code{YAP_Term} YAP_Write(@code{YAP_Term} @var{t})
@findex YAP_CopyTerm/1 @findex YAP_CopyTerm/1
@ -17282,13 +17330,13 @@ Copy a Term @var{t} and all associated constraints. May call the garbage
collector and returns @code{0L} on error (such as no space being collector and returns @code{0L} on error (such as no space being
available). available).
@item @code{void} YAP_Write(@code{YAP_Term} @var{t}, @code{void (*)(int)} @item @code{void} YAP_Write(@code{YAP_Term} @var{t}, @code{IOSTREAM}
@var{PutC}, @code{int} @var{flags}) @var{stream}, @code{int} @var{flags})
@findex YAP_Write/3 @findex YAP_Write/3
Write a Term @var{t} using the function @var{PutC} to output Write a Term @var{t} using the stream @var{stream} to output
characters. The term is written according to a mask of the following characters. The term is written according to a mask of the following
flags in the @code{flag} argument: @code{YAP_WRITE_QUOTED}, flags in the @code{flag} argument: @code{YAP_WRITE_QUOTED},
@code{YAP_WRITE_HANDLE_VARS}, and @code{YAP_WRITE_IGNORE_OPS}. @code{YAP_WRITE_HANDLE_VARS}, @code{YAP_WRITE_USE_PORTRAY}, and @code{YAP_WRITE_IGNORE_OPS}.
@item @code{void} YAP_WriteBuffer(@code{YAP_Term} @var{t}, @code{char *} @item @code{void} YAP_WriteBuffer(@code{YAP_Term} @var{t}, @code{char *}
@var{buff}, @code{unsigned int} @var{buff}, @code{unsigned int}

View File

@ -237,7 +237,7 @@ extern X_API void PROTO(YAP_UserBackCPredicate,(CONST char *, YAP_Bool (*)(void)
/* void UserBackCPredicate(char *name, int *init(), int *cont(), int *cut(), int /* void UserBackCPredicate(char *name, int *init(), int *cont(), int *cut(), int
arity, int extra) */ arity, int extra) */
extern X_API void PROTO(YAP_UserBackCutCPredicate,(char *, YAP_Bool (*)(void), YAP_Bool (*)(void), YAP_Bool (*)(void), YAP_Arity, unsigned int)); extern X_API void PROTO(YAP_UserBackCutCPredicate,(CONST char *, YAP_Bool (*)(void), YAP_Bool (*)(void), YAP_Bool (*)(void), YAP_Arity, unsigned int));
/* void CallProlog(YAP_Term t) */ /* void CallProlog(YAP_Term t) */
extern X_API YAP_Bool PROTO(YAP_CallProlog,(YAP_Term t)); extern X_API YAP_Bool PROTO(YAP_CallProlog,(YAP_Term t));
@ -245,9 +245,9 @@ extern X_API YAP_Bool PROTO(YAP_CallProlog,(YAP_Term t));
/* void cut_fail(void) */ /* void cut_fail(void) */
extern X_API void PROTO(YAP_cut_up,(void)); extern X_API void PROTO(YAP_cut_up,(void));
#define YAP_cut_succeed() { YAP_cut_up(); return TRUE; } #define YAP_cut_succeed() do { YAP_cut_up(); return TRUE; } while(0)
#define YAP_cut_fail() { YAP_cut_up(); return FALSE; } #define YAP_cut_fail() do { YAP_cut_up(); return FALSE; } while(0)
/* void *AllocSpaceFromYAP_(int) */ /* void *AllocSpaceFromYAP_(int) */
extern X_API void *PROTO(YAP_AllocSpaceFromYap,(unsigned int)); extern X_API void *PROTO(YAP_AllocSpaceFromYap,(unsigned int));
@ -555,6 +555,17 @@ extern X_API int PROTO(YAP_MaxOpPriority,(YAP_Atom, YAP_Term));
/* int YAP_OpInfo(Atom, Term, int, int *, int *) */ /* int YAP_OpInfo(Atom, Term, int, int *, int *) */
extern X_API int PROTO(YAP_OpInfo,(YAP_Atom, YAP_Term, int, int *, int *)); extern X_API int PROTO(YAP_OpInfo,(YAP_Atom, YAP_Term, int, int *, int *));
/* YAP_Bool YAP_IsExternalDataInStackTerm(YAP_Term) */
extern X_API YAP_Bool PROTO(YAP_IsExternalDataInStackTerm,(YAP_Term));
extern X_API YAP_opaque_tag_t PROTO(YAP_NewOpaqueType,(struct YAP_opaque_handler_struct *));
extern X_API YAP_Bool PROTO(YAP_IsOpaqueObjectTerm,(YAP_Term, YAP_opaque_tag_t));
extern X_API YAP_Term PROTO(YAP_NewOpaqueObject,(YAP_opaque_tag_t, size_t));
extern X_API void *PROTO(YAP_OpaqueObjectFromTerm,(YAP_Term));
#define YAP_InitCPred(N,A,F) YAP_UserCPredicate(N,F,A) #define YAP_InitCPred(N,A,F) YAP_UserCPredicate(N,F,A)
__END_DECLS __END_DECLS

View File

@ -101,9 +101,10 @@ typedef double YAP_Float;
#define YAP_FULL_BOOT_FROM_PROLOG 4 #define YAP_FULL_BOOT_FROM_PROLOG 4
#define YAP_BOOT_ERROR -1 #define YAP_BOOT_ERROR -1
#define YAP_WRITE_QUOTED 0 #define YAP_WRITE_QUOTED 1
#define YAP_WRITE_HANDLE_VARS 1
#define YAP_WRITE_IGNORE_OPS 2 #define YAP_WRITE_IGNORE_OPS 2
#define YAP_WRITE_HANDLE_VARS 2
#define YAP_WRITE_USE_PORTRAY 8
#define YAP_CONSULT_MODE 0 #define YAP_CONSULT_MODE 0
#define YAP_RECONSULT_MODE 1 #define YAP_RECONSULT_MODE 1
@ -204,6 +205,14 @@ typedef int (*YAP_agc_hook)(void *_Atom);
typedef void (*YAP_halt_hook)(int exit_code, void *closure); typedef void (*YAP_halt_hook)(int exit_code, void *closure);
typedef int YAP_opaque_tag_t;
typedef int (*YAP_Opaque_CallOnFail)(void *);
typedef struct YAP_opaque_handler_struct {
YAP_Opaque_CallOnFail fail_handler;
} YAP_opaque_handler_t;
/********* execution mode ***********************/ /********* execution mode ***********************/
typedef enum typedef enum

View File

@ -241,6 +241,8 @@ tokenize_arguments([FirstArg|RestArgs],[TokFirstArg|TokRestArgs]):-
% %
% -------------------------------------------------------------------------------------- % --------------------------------------------------------------------------------------
:- dynamic attribute/4.
query_generation([],_,[]). query_generation([],_,[]).
query_generation([Conjunction|Conjunctions],ProjectionTerm,[Query|Queries]):- query_generation([Conjunction|Conjunctions],ProjectionTerm,[Query|Queries]):-
@ -1157,9 +1159,9 @@ column_atom(att(RangeVar,Attribute),QueryList,Diff):-
column_atom(Attribute,X2,Diff). column_atom(Attribute,X2,Diff).
column_atom(rel(Relation,RangeVar),QueryList,Diff):- column_atom(rel(Relation,RangeVar),QueryList,Diff):-
column_atom('`',QueryList,X0), column_atom('',QueryList,X0),
column_atom(Relation,X0,X1), column_atom(Relation,X0,X1),
column_atom('` ',X1,X2), column_atom(' ',X1,X2),
column_atom(RangeVar,X2,Diff). column_atom(RangeVar,X2,Diff).
column_atom('$const$'(String),QueryList,Diff):- column_atom('$const$'(String),QueryList,Diff):-

View File

@ -142,12 +142,6 @@ p2c_putc(const int c) {
/* /*
* Function used by YAP to read a char from a string * Function used by YAP to read a char from a string
*/ */
static int
p2c_getc(void) {
if( BUFFER_POS < BUFFER_LEN )
return BUFFER_PTR[BUFFER_POS++];
return -1;
}
/* /*
* Writes a term to a stream. * Writes a term to a stream.
*/ */
@ -177,7 +171,7 @@ read_term_from_stream(const int fd) {
if ( size> BUFFER_SIZE) if ( size> BUFFER_SIZE)
expand_buffer(size-BUFFER_SIZE); expand_buffer(size-BUFFER_SIZE);
read(fd,BUFFER_PTR,size); // read term from stream read(fd,BUFFER_PTR,size); // read term from stream
return YAP_Read( p2c_getc ); return YAP_ReadBuffer( BUFFER_PTR , NULL);
} }
/********************************************************************************************* /*********************************************************************************************
* Conversion: Prolog Term->char[] and char[]->Prolog Term * Conversion: Prolog Term->char[] and char[]->Prolog Term
@ -229,7 +223,7 @@ string2term(char *const ptr,const size_t *size) {
} }
BUFFER_POS=0; BUFFER_POS=0;
LOCAL_ErrorMessage=NULL; LOCAL_ErrorMessage=NULL;
t = YAP_Read(p2c_getc); t = YAP_ReadBuffer( BUFFER_PTR , NULL );
if ( t==FALSE ) { if ( t==FALSE ) {
write_msg(__FUNCTION__,__FILE__,__LINE__,"FAILED string2term>>>>size:%d %d %s\n",BUFFER_SIZE,strlen(BUFFER_PTR),LOCAL_ErrorMessage); write_msg(__FUNCTION__,__FILE__,__LINE__,"FAILED string2term>>>>size:%d %d %s\n",BUFFER_SIZE,strlen(BUFFER_PTR),LOCAL_ErrorMessage);
exit(1); exit(1);

View File

@ -30,7 +30,6 @@ static char *rcsid = "$Header: /Users/vitor/Yap/yap-cvsbackup/library/mpi/mpi.c,
#include <string.h> #include <string.h>
#include <mpi.h> #include <mpi.h>
Term STD_PROTO(YAP_Read, (int (*)(void)));
void STD_PROTO(YAP_Write, (Term, void (*)(int), int)); void STD_PROTO(YAP_Write, (Term, void (*)(int), int));
STATIC_PROTO (Int p_mpi_open, (void)); STATIC_PROTO (Int p_mpi_open, (void));
@ -105,15 +104,6 @@ mpi_putc(Int ch)
} }
} }
static Int
mpi_getc(void)
{
if( bufptr < bufsize ) return buf[bufptr++];
else return -1;
}
/* /*
* C Predicates * C Predicates
@ -301,7 +291,7 @@ p_mpi_receive() /* mpi_receive(-data, ?orig, ?tag) */
/* parse received string into a Prolog term */ /* parse received string into a Prolog term */
bufptr = 0; bufptr = 0;
t = YAP_Read( mpi_getc ); t = YAP_ReadBuffer( buf, NULL );
if( t == TermNil ) { if( t == TermNil ) {
retv = FALSE; retv = FALSE;
@ -384,7 +374,7 @@ p_mpi_bcast3() /* mpi_bcast( ?data, +root, +max_size ) */
bufptr = 0; bufptr = 0;
/* parse received string into a Prolog term */ /* parse received string into a Prolog term */
return Yap_unify( YAP_Read(mpi_getc), ARG1 ); return Yap_unify( YAP_ReadBuffer( buf, NULL ), ARG1 );
} }
} }
@ -464,7 +454,7 @@ p_mpi_bcast2() /* mpi_bcast( ?data, +root ) */
bufstrlen = strlen(buf); bufstrlen = strlen(buf);
bufptr = 0; bufptr = 0;
return Yap_unify(YAP_Read( mpi_getc ), ARG1); return Yap_unify(YAP_ReadBuffer( buf, NULL ), ARG1);
} }
} }

View File

@ -120,6 +120,8 @@ char* DIRNAME =NULL
char Executable[YAP_FILENAME_MAX] void char Executable[YAP_FILENAME_MAX] void
#endif #endif
int OpaqueHandlersCount =0
struct opaque_handler_struct* OpaqueHandlers =NULL
#if __simplescalar__ #if __simplescalar__
char pwd[YAP_FILENAME_MAX] void char pwd[YAP_FILENAME_MAX] void

View File

@ -0,0 +1,149 @@
#include <cassert>
#include <cmath>
#include <iostream>
#include "BPNodeInfo.h"
#include "BPSolver.h"
BPNodeInfo::BPNodeInfo (BayesNode* node)
{
node_ = node;
ds_ = node->getDomainSize();
piValsCalc_ = false;
ldValsCalc_ = false;
nPiMsgsRcv_ = 0;
nLdMsgsRcv_ = 0;
piVals_.resize (ds_, 1);
ldVals_.resize (ds_, 1);
const BnNodeSet& childs = node->getChilds();
for (unsigned i = 0; i < childs.size(); i++) {
cmsgs_.insert (make_pair (childs[i], false));
}
const BnNodeSet& parents = node->getParents();
for (unsigned i = 0; i < parents.size(); i++) {
pmsgs_.insert (make_pair (parents[i], false));
}
}
ParamSet
BPNodeInfo::getBeliefs (void) const
{
double sum = 0.0;
ParamSet beliefs (ds_);
for (unsigned xi = 0; xi < ds_; xi++) {
double prod = piVals_[xi] * ldVals_[xi];
beliefs[xi] = prod;
sum += prod;
}
assert (sum);
//normalize the beliefs
for (unsigned xi = 0; xi < ds_; xi++) {
beliefs[xi] /= sum;
}
return beliefs;
}
bool
BPNodeInfo::readyToSendPiMsgTo (const BayesNode* child) const
{
for (unsigned i = 0; i < inChildLinks_.size(); i++) {
if (inChildLinks_[i]->getSource() != child
&& !inChildLinks_[i]->messageWasSended()) {
return false;
}
}
return true;
}
bool
BPNodeInfo::readyToSendLambdaMsgTo (const BayesNode* parent) const
{
for (unsigned i = 0; i < inParentLinks_.size(); i++) {
if (inParentLinks_[i]->getSource() != parent
&& !inParentLinks_[i]->messageWasSended()) {
return false;
}
}
return true;
}
double
BPNodeInfo::getPiValue (unsigned idx) const
{
assert (idx >=0 && idx < ds_);
return piVals_[idx];
}
void
BPNodeInfo::setPiValue (unsigned idx, Param value)
{
assert (idx >=0 && idx < ds_);
piVals_[idx] = value;
}
double
BPNodeInfo::getLambdaValue (unsigned idx) const
{
assert (idx >=0 && idx < ds_);
return ldVals_[idx];
}
void
BPNodeInfo::setLambdaValue (unsigned idx, Param value)
{
assert (idx >=0 && idx < ds_);
ldVals_[idx] = value;
}
double
BPNodeInfo::getBeliefChange (void)
{
double change = 0.0;
if (oldBeliefs_.size() == 0) {
oldBeliefs_ = getBeliefs();
change = 9999999999.0;
} else {
ParamSet currentBeliefs = getBeliefs();
for (unsigned xi = 0; xi < ds_; xi++) {
change += abs (currentBeliefs[xi] - oldBeliefs_[xi]);
}
oldBeliefs_ = currentBeliefs;
}
return change;
}
bool
BPNodeInfo::receivedBottomInfluence (void) const
{
// if all lambda values are equal, then neither
// this node neither its descendents have evidence,
// we can use this to don't send lambda messages his parents
bool childInfluenced = false;
for (unsigned xi = 1; xi < ds_; xi++) {
if (ldVals_[xi] != ldVals_[0]) {
childInfluenced = true;
break;
}
}
return childInfluenced;
}

View File

@ -0,0 +1,82 @@
#ifndef BP_BP_NODE_H
#define BP_BP_NODE_H
#include <vector>
#include <map>
#include "BPSolver.h"
#include "BayesNode.h"
#include "Shared.h"
//class Edge;
using namespace std;
class BPNodeInfo
{
public:
BPNodeInfo (int);
BPNodeInfo (BayesNode*);
ParamSet getBeliefs (void) const;
double getPiValue (unsigned) const;
void setPiValue (unsigned, Param);
double getLambdaValue (unsigned) const;
void setLambdaValue (unsigned, Param);
double getBeliefChange (void);
bool receivedBottomInfluence (void) const;
ParamSet& getPiValues (void) { return piVals_; }
ParamSet& getLambdaValues (void) { return ldVals_; }
bool arePiValuesCalculated (void) { return piValsCalc_; }
bool areLambdaValuesCalculated (void) { return ldValsCalc_; }
void markPiValuesAsCalculated (void) { piValsCalc_ = true; }
void markLambdaValuesAsCalculated (void) { ldValsCalc_ = true; }
void incNumPiMsgsRcv (void) { nPiMsgsRcv_ ++; }
void incNumLambdaMsgsRcv (void) { nLdMsgsRcv_ ++; }
bool receivedAllPiMessages (void)
{
return node_->getParents().size() == nPiMsgsRcv_;
}
bool receivedAllLambdaMessages (void)
{
return node_->getChilds().size() == nLdMsgsRcv_;
}
bool readyToSendPiMsgTo (const BayesNode*) const ;
bool readyToSendLambdaMsgTo (const BayesNode*) const;
CEdgeSet getIncomingParentLinks (void) { return inParentLinks_; }
CEdgeSet getIncomingChildLinks (void) { return inChildLinks_; }
CEdgeSet getOutcomingParentLinks (void) { return outParentLinks_; }
CEdgeSet getOutcomingChildLinks (void) { return outChildLinks_; }
void addIncomingParentLink (Edge* l) { inParentLinks_.push_back (l); }
void addIncomingChildLink (Edge* l) { inChildLinks_.push_back (l); }
void addOutcomingParentLink (Edge* l) { outParentLinks_.push_back (l); }
void addOutcomingChildLink (Edge* l) { outChildLinks_.push_back (l); }
private:
DISALLOW_COPY_AND_ASSIGN (BPNodeInfo);
ParamSet piVals_; // pi values
ParamSet ldVals_; // lambda values
ParamSet oldBeliefs_;
unsigned nPiMsgsRcv_;
unsigned nLdMsgsRcv_;
bool piValsCalc_;
bool ldValsCalc_;
EdgeSet inParentLinks_;
EdgeSet inChildLinks_;
EdgeSet outParentLinks_;
EdgeSet outChildLinks_;
unsigned ds_;
const BayesNode* node_;
map<const BayesNode*, bool> pmsgs_;
map<const BayesNode*, bool> cmsgs_;
};
#endif //BP_BP_NODE_H

File diff suppressed because it is too large Load Diff

View File

@ -1,259 +1,106 @@
#ifndef BP_BPSOLVER_H #ifndef BP_BP_SOLVER_H
#define BP_BPSOLVER_H #define BP_BP_SOLVER_H
#include <vector> #include <vector>
#include <string>
#include <set> #include <set>
#include "Solver.h" #include "Solver.h"
#include "BayesNet.h" #include "BayesNet.h"
#include "BpNode.h" #include "BPNodeInfo.h"
#include "Shared.h" #include "Shared.h"
using namespace std; using namespace std;
class BPSolver; class BPNodeInfo;
static const string PI = "pi" ; static const string PI = "pi" ;
static const string LD = "ld" ; static const string LD = "ld" ;
enum MessageType {PI_MSG, LAMBDA_MSG}; enum MessageType {PI_MSG, LAMBDA_MSG};
enum JointCalcType {CHAIN_RULE, JUNCTION_NODE};
class BPSolver; class Edge
struct Edge
{ {
Edge (BayesNode* s, BayesNode* d, MessageType t) public:
{ Edge (BayesNode* s, BayesNode* d, MessageType t)
source = s; {
destination = d; source_ = s;
type = t; destin_ = d;
} type_ = t;
string getId (void) const if (type_ == PI_MSG) {
{ currMsg_.resize (s->getDomainSize(), 1);
stringstream ss; nextMsg_.resize (s->getDomainSize(), 1);
type == PI_MSG ? ss << PI : ss << LD; } else {
ss << source->getVarId() << "." << destination->getVarId(); currMsg_.resize (d->getDomainSize(), 1);
return ss.str(); nextMsg_.resize (d->getDomainSize(), 1);
}
string toString (void) const
{
stringstream ss;
type == PI_MSG ? ss << PI << "(" : ss << LD << "(" ;
ss << source->getLabel() << " --> " ;
ss << destination->getLabel();
ss << ")" ;
return ss.str();
}
BayesNode* source;
BayesNode* destination;
MessageType type;
static BPSolver* klass;
};
/*
class BPMessage
{
BPMessage (BayesNode* parent, BayesNode* child)
{
parent_ = parent;
child_ = child;
currPiMsg_.resize (child->getDomainSize(), 1);
currLdMsg_.resize (parent->getDomainSize(), 1);
nextLdMsg_.resize (parent->getDomainSize(), 1);
nextPiMsg_.resize (child->getDomainSize(), 1);
piResidual_ = 1.0;
ldResidual_ = 1.0;
}
Param getPiMessageValue (int idx) const
{
assert (idx >=0 && idx < child->getDomainSize());
return currPiMsg_[idx];
}
Param getLambdaMessageValue (int idx) const
{
assert (idx >=0 && idx < parent->getDomainSize());
return currLdMsg_[idx];
}
const ParamSet& getPiMessage (void) const
{
return currPiMsg_;
}
const ParamSet& getLambdaMessage (void) const
{
return currLdMsg_;
}
ParamSet& piNextMessageReference (void)
{
return nextPiMsg_;
}
ParamSet& lambdaNextMessageReference (const BayesNode* source)
{
return nextLdMsg_;
}
void updatePiMessage (void)
{
currPiMsg_ = nextPiMsg_;
Util::normalize (currPiMsg_);
}
void updateLambdaMessage (void)
{
currLdMsg_ = nextLdMsg_;
Util::normalize (currLdMsg_);
}
double getPiResidual (void)
{
return piResidual_;
}
double getLambdaResidual (void)
{
return ldResidual_;
}
void updatePiResidual (void)
{
piResidual_ = Util::getL1dist (currPiMsg_, nextPiMsg_);
}
void updateLambdaResidual (void)
{
ldResidual_ = Util::getL1dist (currLdMsg_, nextLdMsg_);
}
void clearPiResidual (void)
{
piResidual_ = 0.0;
}
void clearLambdaResidual (void)
{
ldResidual_ = 0.0;
}
BayesNode* parent_;
BayesNode* child_;
ParamSet currPiMsg_; // current pi messages
ParamSet currLdMsg_; // current lambda messages
ParamSet nextPiMsg_;
ParamSet nextLdMsg_;
Param piResidual_;
Param ldResidual_;
};
class NodeInfo
{
NodeInfo (BayesNode* node)
{
node_ = node;
piVals_.resize (node->getDomainSize(), 1);
ldVals_.resize (node->getDomainSize(), 1);
}
ParamSet getBeliefs (void) const
{
double sum = 0.0;
ParamSet beliefs (node_->getDomainSize());
for (int xi = 0; xi < node_->getDomainSize(); xi++) {
double prod = piVals_[xi] * ldVals_[xi];
beliefs[xi] = prod;
sum += prod;
}
assert (sum);
//normalize the beliefs
for (int xi = 0; xi < node_->getDomainSize(); xi++) {
beliefs[xi] /= sum;
}
return beliefs;
}
double getPiValue (int idx) const
{
assert (idx >=0 && idx < node_->getDomainSize());
return piVals_[idx];
}
void setPiValue (int idx, double value)
{
assert (idx >=0 && idx < node_->getDomainSize());
piVals_[idx] = value;
}
double getLambdaValue (int idx) const
{
assert (idx >=0 && idx < node_->getDomainSize());
return ldVals_[idx];
}
void setLambdaValue (int idx, double value)
{
assert (idx >=0 && idx < node_->getDomainSize());
ldVals_[idx] = value;
}
ParamSet& getPiValues (void)
{
return piVals_;
}
ParamSet& getLambdaValues (void)
{
return ldVals_;
}
double getBeliefChange (void)
{
double change = 0.0;
if (oldBeliefs_.size() == 0) {
oldBeliefs_ = getBeliefs();
change = MAX_CHANGE_;
} else {
ParamSet currentBeliefs = getBeliefs();
for (int xi = 0; xi < node_->getDomainSize(); xi++) {
change += abs (currentBeliefs[xi] - oldBeliefs_[xi]);
} }
oldBeliefs_ = currentBeliefs; msgSended_ = false;
residual_ = 0.0;
} }
return change;
} //void setMessage (ParamSet msg)
//{
// Util::normalize (msg);
// residual_ = Util::getMaxNorm (currMsg_, msg);
// currMsg_ = msg;
//}
bool hasReceivedChildInfluence (void) const void setNextMessage (CParamSet msg)
{ {
// if all lambda values are equal, then neither nextMsg_ = msg;
// this node neither its descendents have evidence, Util::normalize (nextMsg_);
// we can use this to don't send lambda messages his parents residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
bool childInfluenced = false; }
for (int xi = 1; xi < node_->getDomainSize(); xi++) {
if (ldVals_[xi] != ldVals_[0]) { void updateMessage (void)
childInfluenced = true; {
break; currMsg_ = nextMsg_;
if (DL >= 3) {
cout << "updating " << toString() << endl;
} }
msgSended_ = true;
}
void updateResidual (void)
{
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
} }
return childInfluenced;
}
BayesNode* node_; string toString (void) const
ParamSet piVals_; // pi values {
ParamSet ldVals_; // lambda values stringstream ss;
ParamSet oldBeliefs_; if (type_ == PI_MSG) {
ss << PI;
} else if (type_ == LAMBDA_MSG) {
ss << LD;
} else {
abort();
}
ss << "(" << source_->getLabel();
ss << " --> " << destin_->getLabel() << ")" ;
return ss.str();
}
BayesNode* getSource (void) const { return source_; }
BayesNode* getDestination (void) const { return destin_; }
MessageType getMessageType (void) const { return type_; }
CParamSet getMessage (void) const { return currMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; }
private:
BayesNode* source_;
BayesNode* destin_;
MessageType type_;
ParamSet currMsg_;
ParamSet nextMsg_;
bool msgSended_;
double residual_;
}; };
*/
bool compareResidual (const Edge&, const Edge&);
class BPSolver : public Solver class BPSolver : public Solver
{ {
public: public:
@ -261,190 +108,85 @@ class BPSolver : public Solver
~BPSolver (void); ~BPSolver (void);
void runSolver (void); void runSolver (void);
ParamSet getPosterioriOf (const Variable* var) const; ParamSet getPosterioriOf (Vid) const;
ParamSet getJointDistribution (const NodeSet&) const; ParamSet getJointDistributionOf (const VidSet&);
private: private:
DISALLOW_COPY_AND_ASSIGN (BPSolver); DISALLOW_COPY_AND_ASSIGN (BPSolver);
void initializeSolver (void); void initializeSolver (void);
void incorporateEvidence (BayesNode*);
void runPolyTreeSolver (void); void runPolyTreeSolver (void);
void polyTreePiMessage (BayesNode*, BayesNode*); void runLoopySolver (void);
void polyTreeLambdaMessage (BayesNode*, BayesNode*);
void runGenericSolver (void);
void maxResidualSchedule (void); void maxResidualSchedule (void);
bool converged (void) const; bool converged (void) const;
void updatePiValues (BayesNode*); void updatePiValues (BayesNode*);
void updateLambdaValues (BayesNode*); void updateLambdaValues (BayesNode*);
void calculateNextPiMessage (BayesNode*, BayesNode*); ParamSet calculateNextLambdaMessage (Edge* edge);
void calculateNextLambdaMessage (BayesNode*, BayesNode*); ParamSet calculateNextPiMessage (Edge* edge);
ParamSet getJointByJunctionNode (const VidSet&) const;
ParamSet getJointByChainRule (const VidSet&) const;
void printMessageStatusOf (const BayesNode*) const; void printMessageStatusOf (const BayesNode*) const;
void printAllMessageStatus (void) const; void printAllMessageStatus (void) const;
// inlines
void updatePiMessage (BayesNode*, BayesNode*); ParamSet getMessage (Edge* edge)
void updateLambdaMessage (BayesNode*, BayesNode*); {
void calculateNextMessage (const Edge&); if (DL >= 3) {
void updateMessage (const Edge&); cout << " calculating " << edge->toString() << endl;
void updateValues (const Edge&); }
double getResidual (const Edge&) const; if (edge->getMessageType() == PI_MSG) {
void updateResidual (const Edge&); return calculateNextPiMessage (edge);
void clearResidual (const Edge&); } else if (edge->getMessageType() == LAMBDA_MSG) {
BpNode* M (const BayesNode*) const; return calculateNextLambdaMessage (edge);
friend bool compareResidual (const Edge&, const Edge&); } else {
abort();
}
return ParamSet();
}
void updateValues (Edge* edge)
{
if (!edge->getDestination()->hasEvidence()) {
if (edge->getMessageType() == PI_MSG) {
updatePiValues (edge->getDestination());
} else if (edge->getMessageType() == LAMBDA_MSG) {
updateLambdaValues (edge->getDestination());
} else {
abort();
}
}
}
BPNodeInfo* M (const BayesNode* node) const
{
assert (node);
assert (node == bn_->getBayesNode (node->getVarId()));
assert (node->getIndex() < nodesI_.size());
return nodesI_[node->getIndex()];
}
const BayesNet* bn_; const BayesNet* bn_;
vector<BpNode*> msgs_; vector<BPNodeInfo*> nodesI_;
Schedule schedule_; unsigned nIter_;
int nIter_; vector<Edge*> links_;
int maxIter_; bool useAlwaysLoopySolver_;
double accuracy_; JointCalcType jointCalcType_;
vector<Edge> updateOrder_;
bool forceGenericSolver_;
struct compare struct compare
{ {
inline bool operator() (const Edge& e1, const Edge& e2) inline bool operator() (const Edge* e1, const Edge* e2)
{ {
return compareResidual (e1, e2); return e1->getResidual() > e2->getResidual();
} }
}; };
typedef multiset<Edge, compare> SortedOrder; typedef multiset<Edge*, compare> SortedOrder;
SortedOrder sortedOrder_; SortedOrder sortedOrder_;
typedef unordered_map<string, SortedOrder::iterator> EdgeMap; typedef map<Edge*, SortedOrder::iterator> EdgeMap;
EdgeMap edgeMap_; EdgeMap edgeMap_;
}; };
#endif //BP_BP_SOLVER_H
inline void
BPSolver::updatePiMessage (BayesNode* source, BayesNode* destination)
{
M(source)->updatePiMessage(destination);
}
inline void
BPSolver::updateLambdaMessage (BayesNode* source, BayesNode* destination)
{
M(destination)->updateLambdaMessage(source);
}
inline void
BPSolver::calculateNextMessage (const Edge& e)
{
if (DL >= 1) {
cout << "calculating " << e.toString() << endl;
}
if (e.type == PI_MSG) {
calculateNextPiMessage (e.source, e.destination);
} else {
calculateNextLambdaMessage (e.source, e.destination);
}
}
inline void
BPSolver::updateMessage (const Edge& e)
{
if (DL >= 1) {
cout << "updating " << e.toString() << endl;
}
if (e.type == PI_MSG) {
M(e.source)->updatePiMessage(e.destination);
} else {
M(e.destination)->updateLambdaMessage(e.source);
}
}
inline void
BPSolver::updateValues (const Edge& e)
{
if (!e.destination->hasEvidence()) {
if (e.type == PI_MSG) {
updatePiValues (e.destination);
} else {
updateLambdaValues (e.destination);
}
}
}
inline double
BPSolver::getResidual (const Edge& e) const
{
if (e.type == PI_MSG) {
return M(e.source)->getPiResidual(e.destination);
} else {
return M(e.destination)->getLambdaResidual(e.source);
}
}
inline void
BPSolver::updateResidual (const Edge& e)
{
if (e.type == PI_MSG) {
M(e.source)->updatePiResidual(e.destination);
} else {
M(e.destination)->updateLambdaResidual(e.source);
}
}
inline void
BPSolver::clearResidual (const Edge& e)
{
if (e.type == PI_MSG) {
M(e.source)->clearPiResidual(e.destination);
} else {
M(e.destination)->clearLambdaResidual(e.source);
}
}
inline bool
compareResidual (const Edge& e1, const Edge& e2)
{
double residual1;
double residual2;
if (e1.type == PI_MSG) {
residual1 = Edge::klass->M(e1.source)->getPiResidual(e1.destination);
} else {
residual1 = Edge::klass->M(e1.destination)->getLambdaResidual(e1.source);
}
if (e2.type == PI_MSG) {
residual2 = Edge::klass->M(e2.source)->getPiResidual(e2.destination);
} else {
residual2 = Edge::klass->M(e2.destination)->getLambdaResidual(e2.source);
}
return residual1 > residual2;
}
inline BpNode*
BPSolver::M (const BayesNode* node) const
{
assert (node);
assert (node == bn_->getNode (node->getVarId()));
assert (node->getIndex() < msgs_.size());
return msgs_[node->getIndex()];
}
#endif

View File

@ -1,30 +1,24 @@
#include <cstdlib>
#include <cassert>
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include <iomanip> #include <iomanip>
#include <cassert>
#include <cstdlib>
#include <map>
#include "xmlParser/xmlParser.h" #include "xmlParser/xmlParser.h"
#include "BayesNet.h" #include "BayesNet.h"
BayesNet::BayesNet (void)
{
}
BayesNet::BayesNet (const char* fileName) BayesNet::BayesNet (const char* fileName)
{ {
map<string, Domain> domains; map<string, Domain> domains;
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF"); XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
// only the first network is parsed, others are ignored // only the first network is parsed, others are ignored
XMLNode xNode = xMainNode.getChildNode ("NETWORK"); XMLNode xNode = xMainNode.getChildNode ("NETWORK");
int nVars = xNode.nChildNode ("VARIABLE"); unsigned nVars = xNode.nChildNode ("VARIABLE");
for (int i = 0; i < nVars; i++) { for (unsigned i = 0; i < nVars; i++) {
XMLNode var = xNode.getChildNode ("VARIABLE", i); XMLNode var = xNode.getChildNode ("VARIABLE", i);
string type = var.getAttribute ("TYPE"); string type = var.getAttribute ("TYPE");
if (type != "nature") { if (type != "nature") {
@ -32,9 +26,9 @@ BayesNet::BayesNet (const char* fileName)
abort(); abort();
} }
Domain domain; Domain domain;
string label = var.getChildNode("NAME").getText(); string varLabel = var.getChildNode("NAME").getText();
int domainSize = var.nChildNode ("OUTCOME"); unsigned dsize = var.nChildNode ("OUTCOME");
for (int j = 0; j < domainSize; j++) { for (unsigned j = 0; j < dsize; j++) {
if (var.getChildNode("OUTCOME", j).getText() == 0) { if (var.getChildNode("OUTCOME", j).getText() == 0) {
stringstream ss; stringstream ss;
ss << j + 1; ss << j + 1;
@ -43,37 +37,37 @@ BayesNet::BayesNet (const char* fileName)
domain.push_back (var.getChildNode("OUTCOME", j).getText()); domain.push_back (var.getChildNode("OUTCOME", j).getText());
} }
} }
domains.insert (make_pair (label, domain)); domains.insert (make_pair (varLabel, domain));
} }
int nDefs = xNode.nChildNode ("DEFINITION"); unsigned nDefs = xNode.nChildNode ("DEFINITION");
if (nVars != nDefs) { if (nVars != nDefs) {
cerr << "error: different number of variables and definitions"; cerr << "error: different number of variables and definitions" << endl;
cerr << endl; abort();
} }
queue<int> indexes; queue<unsigned> indexes;
for (int i = 0; i < nDefs; i++) { for (unsigned i = 0; i < nDefs; i++) {
indexes.push (i); indexes.push (i);
} }
while (!indexes.empty()) { while (!indexes.empty()) {
int index = indexes.front(); unsigned index = indexes.front();
indexes.pop(); indexes.pop();
XMLNode def = xNode.getChildNode ("DEFINITION", index); XMLNode def = xNode.getChildNode ("DEFINITION", index);
string label = def.getChildNode("FOR").getText(); string varLabel = def.getChildNode("FOR").getText();
map<string, Domain>::const_iterator iter; map<string, Domain>::const_iterator iter;
iter = domains.find (label); iter = domains.find (varLabel);
if (iter == domains.end()) { if (iter == domains.end()) {
cerr << "error: unknow variable `" << label << "'" << endl; cerr << "error: unknow variable `" << varLabel << "'" << endl;
abort(); abort();
} }
bool processItLatter = false; bool processItLatter = false;
NodeSet parents; BnNodeSet parents;
int nParams = iter->second.size(); unsigned nParams = iter->second.size();
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) { for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
string parentLabel = def.getChildNode("GIVEN", j).getText(); string parentLabel = def.getChildNode("GIVEN", j).getText();
BayesNode* parentNode = getNode (parentLabel); BayesNode* parentNode = getBayesNode (parentLabel);
if (parentNode) { if (parentNode) {
nParams *= parentNode->getDomainSize(); nParams *= parentNode->getDomainSize();
parents.push_back (parentNode); parents.push_back (parentNode);
@ -95,7 +89,7 @@ BayesNet::BayesNet (const char* fileName)
} }
if (!processItLatter) { if (!processItLatter) {
int count = 0; unsigned count = 0;
ParamSet params (nParams); ParamSet params (nParams);
stringstream s (def.getChildNode("TABLE").getText()); stringstream s (def.getChildNode("TABLE").getText());
while (!s.eof() && count < nParams) { while (!s.eof() && count < nParams) {
@ -104,11 +98,11 @@ BayesNet::BayesNet (const char* fileName)
} }
if (count != nParams) { if (count != nParams) {
cerr << "error: invalid number of parameters " ; cerr << "error: invalid number of parameters " ;
cerr << "for variable `" << label << "'" << endl; cerr << "for variable `" << varLabel << "'" << endl;
abort(); abort();
} }
params = reorderParameters (params, iter->second.size()); params = reorderParameters (params, iter->second.size());
addNode (label, iter->second, parents, params); addNode (varLabel, iter->second, parents, params);
} }
} }
setIndexes(); setIndexes();
@ -118,7 +112,6 @@ BayesNet::BayesNet (const char* fileName)
BayesNet::~BayesNet (void) BayesNet::~BayesNet (void)
{ {
Statistics::writeStats();
for (unsigned i = 0; i < nodes_.size(); i++) { for (unsigned i = 0; i < nodes_.size(); i++) {
delete nodes_[i]; delete nodes_[i];
} }
@ -127,25 +120,25 @@ BayesNet::~BayesNet (void)
BayesNode* BayesNode*
BayesNet::addNode (unsigned varId) BayesNet::addNode (Vid vid)
{ {
indexMap_.insert (make_pair (varId, nodes_.size())); indexMap_.insert (make_pair (vid, nodes_.size()));
nodes_.push_back (new BayesNode (varId)); nodes_.push_back (new BayesNode (vid));
return nodes_.back(); return nodes_.back();
} }
BayesNode* BayesNode*
BayesNet::addNode (unsigned varId, BayesNet::addNode (Vid vid,
unsigned dsize, unsigned dsize,
int evidence, int evidence,
NodeSet& parents, BnNodeSet& parents,
Distribution* dist) Distribution* dist)
{ {
indexMap_.insert (make_pair (varId, nodes_.size())); indexMap_.insert (make_pair (vid, nodes_.size()));
nodes_.push_back (new BayesNode ( nodes_.push_back (new BayesNode (
varId, dsize, evidence, parents, dist)); vid, dsize, evidence, parents, dist));
return nodes_.back(); return nodes_.back();
} }
@ -154,7 +147,7 @@ BayesNet::addNode (unsigned varId,
BayesNode* BayesNode*
BayesNet::addNode (string label, BayesNet::addNode (string label,
Domain domain, Domain domain,
NodeSet& parents, BnNodeSet& parents,
ParamSet& params) ParamSet& params)
{ {
indexMap_.insert (make_pair (nodes_.size(), nodes_.size())); indexMap_.insert (make_pair (nodes_.size(), nodes_.size()));
@ -169,9 +162,9 @@ BayesNet::addNode (string label,
BayesNode* BayesNode*
BayesNet::getNode (unsigned varId) const BayesNet::getBayesNode (Vid vid) const
{ {
IndexMap::const_iterator it = indexMap_.find(varId); IndexMap::const_iterator it = indexMap_.find (vid);
if (it == indexMap_.end()) { if (it == indexMap_.end()) {
return 0; return 0;
} else { } else {
@ -182,7 +175,7 @@ BayesNet::getNode (unsigned varId) const
BayesNode* BayesNode*
BayesNet::getNode (string label) const BayesNet::getBayesNode (string label) const
{ {
BayesNode* node = 0; BayesNode* node = 0;
for (unsigned i = 0; i < nodes_.size(); i++) { for (unsigned i = 0; i < nodes_.size(); i++) {
@ -196,6 +189,15 @@ BayesNet::getNode (string label) const
Variable*
BayesNet::getVariable (Vid vid) const
{
return getBayesNode (vid);
}
void void
BayesNet::addDistribution (Distribution* dist) BayesNet::addDistribution (Distribution* dist)
{ {
@ -219,15 +221,15 @@ BayesNet::getDistribution (unsigned distId) const
const NodeSet& const BnNodeSet&
BayesNet::getNodes (void) const BayesNet::getBayesNodes (void) const
{ {
return nodes_; return nodes_;
} }
int unsigned
BayesNet::getNumberOfNodes (void) const BayesNet::getNumberOfNodes (void) const
{ {
return nodes_.size(); return nodes_.size();
@ -235,10 +237,10 @@ BayesNet::getNumberOfNodes (void) const
NodeSet BnNodeSet
BayesNet::getRootNodes (void) const BayesNet::getRootNodes (void) const
{ {
NodeSet roots; BnNodeSet roots;
for (unsigned i = 0; i < nodes_.size(); i++) { for (unsigned i = 0; i < nodes_.size(); i++) {
if (nodes_[i]->isRoot()) { if (nodes_[i]->isRoot()) {
roots.push_back (nodes_[i]); roots.push_back (nodes_[i]);
@ -249,10 +251,10 @@ BayesNet::getRootNodes (void) const
NodeSet BnNodeSet
BayesNet::getLeafNodes (void) const BayesNet::getLeafNodes (void) const
{ {
NodeSet leafs; BnNodeSet leafs;
for (unsigned i = 0; i < nodes_.size(); i++) { for (unsigned i = 0; i < nodes_.size(); i++) {
if (nodes_[i]->isLeaf()) { if (nodes_[i]->isLeaf()) {
leafs.push_back (nodes_[i]); leafs.push_back (nodes_[i]);
@ -276,30 +278,32 @@ BayesNet::getVariables (void) const
BayesNet* BayesNet*
BayesNet::pruneNetwork (BayesNode* queryNode) const BayesNet::getMinimalRequesiteNetwork (Vid vid) const
{ {
NodeSet queryNodes; return getMinimalRequesiteNetwork (VidSet() = {vid});
queryNodes.push_back (queryNode);
return pruneNetwork (queryNodes);
} }
BayesNet* BayesNet*
BayesNet::pruneNetwork (const NodeSet& interestedVars) const BayesNet::getMinimalRequesiteNetwork (const VidSet& queryVids) const
{ {
/* BnNodeSet queryVars;
cout << "interested vars: " ; for (unsigned i = 0; i < queryVids.size(); i++) {
for (unsigned i = 0; i < interestedVars.size(); i++) { assert (getBayesNode (queryVids[i]));
cout << interestedVars[i]->getLabel() << " " ; queryVars.push_back (getBayesNode (queryVids[i]));
} }
cout << endl; // cout << "query vars: " ;
*/ // for (unsigned i = 0; i < queryVars.size(); i++) {
// cout << queryVars[i]->getLabel() << " " ;
// }
// cout << endl;
vector<StateInfo*> states (nodes_.size(), 0); vector<StateInfo*> states (nodes_.size(), 0);
Scheduling scheduling; Scheduling scheduling;
for (NodeSet::const_iterator it = interestedVars.begin(); for (BnNodeSet::const_iterator it = queryVars.begin();
it != interestedVars.end(); it++) { it != queryVars.end(); it++) {
scheduling.push (ScheduleInfo (*it, false, true)); scheduling.push (ScheduleInfo (*it, false, true));
} }
@ -378,18 +382,18 @@ BayesNet::constructGraph (BayesNet* bn,
states[i]->markedOnTop; states[i]->markedOnTop;
} }
if (isRequired) { if (isRequired) {
NodeSet parents; BnNodeSet parents;
if (states[i]->markedOnTop) { if (states[i]->markedOnTop) {
const NodeSet& ps = nodes_[i]->getParents(); const BnNodeSet& ps = nodes_[i]->getParents();
for (unsigned j = 0; j < ps.size(); j++) { for (unsigned j = 0; j < ps.size(); j++) {
BayesNode* parent = bn->getNode (ps[j]->getVarId()); BayesNode* parent = bn->getBayesNode (ps[j]->getVarId());
if (!parent) { if (!parent) {
parent = bn->addNode (ps[j]->getVarId()); parent = bn->addNode (ps[j]->getVarId());
} }
parents.push_back (parent); parents.push_back (parent);
} }
} }
BayesNode* node = bn->getNode (nodes_[i]->getVarId()); BayesNode* node = bn->getBayesNode (nodes_[i]->getVarId());
if (node) { if (node) {
node->setData (nodes_[i]->getDomainSize(), node->setData (nodes_[i]->getDomainSize(),
nodes_[i]->getEvidence(), parents, nodes_[i]->getEvidence(), parents,
@ -411,65 +415,6 @@ BayesNet::constructGraph (BayesNet* bn,
bn->setIndexes(); bn->setIndexes();
} }
/*
void
BayesNet::constructGraph (BayesNet* bn,
const vector<StateInfo*>& states) const
{
for (unsigned i = 0; i < nodes_.size(); i++) {
if (states[i]) {
if (nodes_[i]->hasEvidence() && states[i]->visited) {
NodeSet parents;
if (states[i]->markedOnTop) {
const NodeSet& ps = nodes_[i]->getParents();
for (unsigned j = 0; j < ps.size(); j++) {
BayesNode* parent = bn->getNode (ps[j]->getVarId());
if (parent == 0) {
parent = bn->addNode (ps[j]->getVarId());
}
parents.push_back (parent);
}
}
BayesNode* n = bn->getNode (nodes_[i]->getVarId());
if (n) {
n->setData (nodes_[i]->getDomainSize(),
nodes_[i]->getEvidence(), parents,
nodes_[i]->getDistribution());
} else {
bn->addNode (nodes_[i]->getVarId(),
nodes_[i]->getDomainSize(),
nodes_[i]->getEvidence(), parents,
nodes_[i]->getDistribution());
}
} else if (states[i]->markedOnTop) {
NodeSet parents;
const NodeSet& ps = nodes_[i]->getParents();
for (unsigned j = 0; j < ps.size(); j++) {
BayesNode* parent = bn->getNode (ps[j]->getVarId());
if (parent == 0) {
parent = bn->addNode (ps[j]->getVarId());
}
parents.push_back (parent);
}
BayesNode* n = bn->getNode (nodes_[i]->getVarId());
if (n) {
n->setData (nodes_[i]->getDomainSize(),
nodes_[i]->getEvidence(), parents,
nodes_[i]->getDistribution());
} else {
bn->addNode (nodes_[i]->getVarId(),
nodes_[i]->getDomainSize(),
nodes_[i]->getEvidence(), parents,
nodes_[i]->getDistribution());
}
}
}
}
}*/
bool bool
@ -480,70 +425,6 @@ BayesNet::isSingleConnected (void) const
vector<DomainConf>
BayesNet::getDomainConfigurationsOf (const NodeSet& nodes)
{
int nConfs = 1;
for (unsigned i = 0; i < nodes.size(); i++) {
nConfs *= nodes[i]->getDomainSize();
}
vector<DomainConf> confs (nConfs);
for (int i = 0; i < nConfs; i++) {
confs[i].resize (nodes.size());
}
int nReps = 1;
for (int i = nodes.size() - 1; i >= 0; i--) {
int index = 0;
while (index < nConfs) {
for (int j = 0; j < nodes[i]->getDomainSize(); j++) {
for (int r = 0; r < nReps; r++) {
confs[index][i] = j;
index++;
}
}
}
nReps *= nodes[i]->getDomainSize();
}
return confs;
}
vector<string>
BayesNet::getInstantiations (const NodeSet& parents_)
{
int nParents = parents_.size();
int rowSize = 1;
for (unsigned i = 0; i < parents_.size(); i++) {
rowSize *= parents_[i]->getDomainSize();
}
int nReps = 1;
vector<string> headers (rowSize);
for (int i = nParents - 1; i >= 0; i--) {
Domain domain = parents_[i]->getDomain();
int index = 0;
while (index < rowSize) {
for (int j = 0; j < parents_[i]->getDomainSize(); j++) {
for (int r = 0; r < nReps; r++) {
if (headers[index] != "") {
headers[index] = domain[j] + "," + headers[index];
} else {
headers[index] = domain[j];
}
index++;
}
}
}
nReps *= parents_[i]->getDomainSize();
}
return headers;
}
void void
BayesNet::setIndexes (void) BayesNet::setIndexes (void)
{ {
@ -565,7 +446,7 @@ BayesNet::freeDistributions (void)
void void
BayesNet::printNetwork (void) const BayesNet::printGraphicalModel (void) const
{ {
for (unsigned i = 0; i < nodes_.size(); i++) { for (unsigned i = 0; i < nodes_.size(); i++) {
cout << *nodes_[i]; cout << *nodes_[i];
@ -575,32 +456,11 @@ BayesNet::printNetwork (void) const
void void
BayesNet::printNetworkToFile (const char* fileName) const BayesNet::exportToDotFormat (const char* fileName,
bool showNeighborless,
CVidSet& highlightVids) const
{ {
string s = "../../" ; ofstream out (fileName);
s += fileName;
ofstream out (s.c_str());
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "BayesNet::printToFile()" << endl;
abort();
}
for (unsigned i = 0; i < nodes_.size(); i++) {
out << *nodes_[i];
}
out.close();
}
void
BayesNet::exportToDotFile (const char* fileName,
bool showNeighborless,
const NodeSet& highlightNodes) const
{
string s = "../../" ;
s+= fileName;
ofstream out (s.c_str());
if (!out.is_open()) { if (!out.is_open()) {
cerr << "error: cannot open file to write at " ; cerr << "error: cannot open file to write at " ;
cerr << "BayesNet::exportToDotFile()" << endl; cerr << "BayesNet::exportToDotFile()" << endl;
@ -608,13 +468,6 @@ BayesNet::exportToDotFile (const char* fileName,
} }
out << "digraph \"" << fileName << "\" {" << endl; out << "digraph \"" << fileName << "\" {" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
const NodeSet& childs = nodes_[i]->getChilds();
for (unsigned j = 0; j < childs.size(); j++) {
out << '"' << nodes_[i]->getLabel() << '"' << " -> " ;
out << '"' << childs[j]->getLabel() << '"' << endl;
}
}
for (unsigned i = 0; i < nodes_.size(); i++) { for (unsigned i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->hasNeighbors()) { if (showNeighborless || nodes_[i]->hasNeighbors()) {
@ -627,9 +480,24 @@ BayesNet::exportToDotFile (const char* fileName,
} }
} }
for (unsigned i = 0; i < highlightNodes.size(); i++) { for (unsigned i = 0; i < highlightVids.size(); i++) {
out << '"' << highlightNodes[i]->getLabel() << '"' ; BayesNode* node = getBayesNode (highlightVids[i]);
out << " [shape=box]" << endl; if (node) {
out << '"' << node->getLabel() << '"' ;
// out << " [shape=polygon, sides=6]" << endl;
out << " [shape=box3d]" << endl;
} else {
cout << "error: invalid variable id: " << highlightVids[i] << endl;
abort();
}
}
for (unsigned i = 0; i < nodes_.size(); i++) {
const BnNodeSet& childs = nodes_[i]->getChilds();
for (unsigned j = 0; j < childs.size(); j++) {
out << '"' << nodes_[i]->getLabel() << '"' << " -> " ;
out << '"' << childs[j]->getLabel() << '"' << endl;
}
} }
out << "}" << endl; out << "}" << endl;
@ -639,11 +507,9 @@ BayesNet::exportToDotFile (const char* fileName,
void void
BayesNet::exportToBifFile (const char* fileName) const BayesNet::exportToBifFormat (const char* fileName) const
{ {
string s = "../../" ; ofstream out (fileName);
s += fileName;
ofstream out (s.c_str());
if(!out.is_open()) { if(!out.is_open()) {
cerr << "error: cannot open file to write at " ; cerr << "error: cannot open file to write at " ;
cerr << "BayesNet::exportToBifFile()" << endl; cerr << "BayesNet::exportToBifFile()" << endl;
@ -666,7 +532,7 @@ BayesNet::exportToBifFile (const char* fileName) const
for (unsigned i = 0; i < nodes_.size(); i++) { for (unsigned i = 0; i < nodes_.size(); i++) {
out << "<DEFINITION>" << endl; out << "<DEFINITION>" << endl;
out << "\t<FOR>" << nodes_[i]->getLabel() << "</FOR>" << endl; out << "\t<FOR>" << nodes_[i]->getLabel() << "</FOR>" << endl;
const NodeSet& parents = nodes_[i]->getParents(); const BnNodeSet& parents = nodes_[i]->getParents();
for (unsigned j = 0; j < parents.size(); j++) { for (unsigned j = 0; j < parents.size(); j++) {
out << "\t<GIVEN>" << parents[j]->getLabel(); out << "\t<GIVEN>" << parents[j]->getLabel();
out << "</GIVEN>" << endl; out << "</GIVEN>" << endl;
@ -682,7 +548,7 @@ BayesNet::exportToBifFile (const char* fileName) const
} }
out << "</NETWORK>" << endl; out << "</NETWORK>" << endl;
out << "</BIF>" << endl << endl; out << "</BIF>" << endl << endl;
out.close(); out.close();
} }
@ -731,8 +597,8 @@ vector<int>
BayesNet::getAdjacentNodes (int v) const BayesNet::getAdjacentNodes (int v) const
{ {
vector<int> adjacencies; vector<int> adjacencies;
const NodeSet& parents = nodes_[v]->getParents(); const BnNodeSet& parents = nodes_[v]->getParents();
const NodeSet& childs = nodes_[v]->getChilds(); const BnNodeSet& childs = nodes_[v]->getChilds();
for (unsigned i = 0; i < parents.size(); i++) { for (unsigned i = 0; i < parents.size(); i++) {
adjacencies.push_back (parents[i]->getIndex()); adjacencies.push_back (parents[i]->getIndex());
} }
@ -745,8 +611,8 @@ BayesNet::getAdjacentNodes (int v) const
ParamSet ParamSet
BayesNet::reorderParameters (const ParamSet& params, BayesNet::reorderParameters (CParamSet params,
int domainSize) const unsigned domainSize) const
{ {
// the interchange format for bayesian networks keeps the probabilities // the interchange format for bayesian networks keeps the probabilities
// in the following order: // in the following order:
@ -773,15 +639,15 @@ BayesNet::reorderParameters (const ParamSet& params,
ParamSet ParamSet
BayesNet::revertParameterReorder (const ParamSet& params, BayesNet::revertParameterReorder (CParamSet params,
int domainSize) const unsigned domainSize) const
{ {
unsigned count = 0; unsigned count = 0;
unsigned rowSize = params.size() / domainSize; unsigned rowSize = params.size() / domainSize;
ParamSet reordered; ParamSet reordered;
while (reordered.size() < params.size()) { while (reordered.size() < params.size()) {
unsigned idx = count; unsigned idx = count;
for (int i = 0; i < domainSize; i++) { for (unsigned i = 0; i < domainSize; i++) {
reordered.push_back (params[idx]); reordered.push_back (params[idx]);
idx += rowSize; idx += rowSize;
} }

View File

@ -4,8 +4,6 @@
#include <vector> #include <vector>
#include <queue> #include <queue>
#include <list> #include <list>
#include <string>
#include <unordered_map>
#include <map> #include <map>
#include "GraphicalModel.h" #include "GraphicalModel.h"
@ -46,42 +44,42 @@ struct StateInfo
typedef vector<Distribution*> DistSet; typedef vector<Distribution*> DistSet;
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling; typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
typedef unordered_map<unsigned, unsigned> Histogram; typedef map<unsigned, unsigned> Histogram;
typedef unordered_map<unsigned, double> Times; typedef map<unsigned, double> Times;
class BayesNet : public GraphicalModel class BayesNet : public GraphicalModel
{ {
public: public:
BayesNet (void); BayesNet (void) {};
BayesNet (const char*); BayesNet (const char*);
~BayesNet (void); ~BayesNet (void);
BayesNode* addNode (unsigned); BayesNode* addNode (unsigned);
BayesNode* addNode (unsigned, unsigned, int, NodeSet&, Distribution*); BayesNode* addNode (unsigned, unsigned, int, BnNodeSet&,
BayesNode* addNode (string, Domain, NodeSet&, ParamSet&); Distribution*);
BayesNode* getNode (unsigned) const; BayesNode* addNode (string, Domain, BnNodeSet&, ParamSet&);
BayesNode* getNode (string) const; BayesNode* getBayesNode (Vid) const;
BayesNode* getBayesNode (string) const;
Variable* getVariable (Vid) const;
void addDistribution (Distribution*); void addDistribution (Distribution*);
Distribution* getDistribution (unsigned) const; Distribution* getDistribution (unsigned) const;
const NodeSet& getNodes (void) const; const BnNodeSet& getBayesNodes (void) const;
int getNumberOfNodes (void) const; unsigned getNumberOfNodes (void) const;
NodeSet getRootNodes (void) const; BnNodeSet getRootNodes (void) const;
NodeSet getLeafNodes (void) const; BnNodeSet getLeafNodes (void) const;
VarSet getVariables (void) const; VarSet getVariables (void) const;
BayesNet* pruneNetwork (BayesNode*) const; BayesNet* getMinimalRequesiteNetwork (Vid) const;
BayesNet* pruneNetwork (const NodeSet& queryNodes) const; BayesNet* getMinimalRequesiteNetwork (const VidSet&) const;
void constructGraph (BayesNet*, const vector<StateInfo*>&) const; void constructGraph (BayesNet*,
const vector<StateInfo*>&) const;
bool isSingleConnected (void) const; bool isSingleConnected (void) const;
static vector<DomainConf> getDomainConfigurationsOf (const NodeSet&);
static vector<string> getInstantiations (const NodeSet& nodes);
void setIndexes (void); void setIndexes (void);
void freeDistributions (void); void freeDistributions (void);
void printNetwork (void) const; void printGraphicalModel (void) const;
void printNetworkToFile (const char*) const; void exportToDotFormat (const char*, bool = true,
void exportToDotFile (const char*, bool = true, CVidSet = VidSet()) const;
const NodeSet& = NodeSet()) const; void exportToBifFormat (const char*) const;
void exportToBifFile (const char*) const;
static Histogram histogram_; static Histogram histogram_;
static Times times_; static Times times_;
@ -93,12 +91,12 @@ class BayesNet : public GraphicalModel
bool containsUndirectedCycle (int, int, bool containsUndirectedCycle (int, int,
vector<bool>&)const; vector<bool>&)const;
vector<int> getAdjacentNodes (int) const ; vector<int> getAdjacentNodes (int) const ;
ParamSet reorderParameters (const ParamSet&, int) const; ParamSet reorderParameters (CParamSet, unsigned) const;
ParamSet revertParameterReorder (const ParamSet&, int) const; ParamSet revertParameterReorder (CParamSet, unsigned) const;
void scheduleParents (const BayesNode*, Scheduling&) const; void scheduleParents (const BayesNode*, Scheduling&) const;
void scheduleChilds (const BayesNode*, Scheduling&) const; void scheduleChilds (const BayesNode*, Scheduling&) const;
NodeSet nodes_; BnNodeSet nodes_;
DistSet dists_; DistSet dists_;
IndexMap indexMap_; IndexMap indexMap_;
}; };
@ -108,8 +106,8 @@ class BayesNet : public GraphicalModel
inline void inline void
BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
{ {
const NodeSet& ps = n->getParents(); const BnNodeSet& ps = n->getParents();
for (NodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) { for (BnNodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) {
sch.push (ScheduleInfo (*it, false, true)); sch.push (ScheduleInfo (*it, false, true));
} }
} }
@ -119,11 +117,11 @@ BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
inline void inline void
BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
{ {
const NodeSet& cs = n->getChilds(); const BnNodeSet& cs = n->getChilds();
for (NodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) { for (BnNodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) {
sch.push (ScheduleInfo (*it, true, false)); sch.push (ScheduleInfo (*it, true, false));
} }
} }
#endif #endif //BP_BAYES_NET_H

View File

@ -1,26 +1,21 @@
#include <cstdlib>
#include <cassert>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <iomanip> #include <iomanip>
#include <cassert>
#include <cstdlib>
#include "BayesNode.h" #include "BayesNode.h"
BayesNode::BayesNode (unsigned varId) : Variable (varId) BayesNode::BayesNode (Vid vid,
{
}
BayesNode::BayesNode (unsigned varId,
unsigned dsize, unsigned dsize,
int evidence, int evidence,
const NodeSet& parents, const BnNodeSet& parents,
Distribution* dist) : Variable(varId, dsize, evidence) Distribution* dist) : Variable (vid, dsize, evidence)
{ {
parents_ = parents; parents_ = parents;
dist_ = dist; dist_ = dist;
for (unsigned int i = 0; i < parents.size(); i++) { for (unsigned int i = 0; i < parents.size(); i++) {
parents[i]->addChild (this); parents[i]->addChild (this);
} }
@ -28,15 +23,15 @@ BayesNode::BayesNode (unsigned varId,
BayesNode::BayesNode (unsigned varId, BayesNode::BayesNode (Vid vid,
string label, string label,
const Domain& domain, const Domain& domain,
const NodeSet& parents, const BnNodeSet& parents,
Distribution* dist) : Variable(varId, domain) Distribution* dist) : Variable (vid, domain,
NO_EVIDENCE, label)
{ {
label_ = new string (label); parents_ = parents;
parents_ = parents; dist_ = dist;
dist_ = dist;
for (unsigned int i = 0; i < parents.size(); i++) { for (unsigned int i = 0; i < parents.size(); i++) {
parents[i]->addChild (this); parents[i]->addChild (this);
} }
@ -47,11 +42,11 @@ BayesNode::BayesNode (unsigned varId,
void void
BayesNode::setData (unsigned dsize, BayesNode::setData (unsigned dsize,
int evidence, int evidence,
const NodeSet& parents, const BnNodeSet& parents,
Distribution* dist) Distribution* dist)
{ {
setDomainSize (dsize); setDomainSize (dsize);
evidence_ = evidence; setEvidence (evidence);
parents_ = parents; parents_ = parents;
dist_ = dist; dist_ = dist;
for (unsigned int i = 0; i < parents.size(); i++) { for (unsigned int i = 0; i < parents.size(); i++) {
@ -135,19 +130,18 @@ BayesNode::getCptEntries (void)
{ {
if (dist_->entries.size() == 0) { if (dist_->entries.size() == 0) {
unsigned rowSize = getRowSize(); unsigned rowSize = getRowSize();
unsigned nParents = parents_.size(); vector<DConf> confs (rowSize);
vector<DomainConf> confs (rowSize);
for (unsigned i = 0; i < rowSize; i++) { for (unsigned i = 0; i < rowSize; i++) {
confs[i].resize (nParents); confs[i].resize (parents_.size());
} }
int nReps = 1; unsigned nReps = 1;
for (int i = nParents - 1; i >= 0; i--) { for (int i = parents_.size() - 1; i >= 0; i--) {
unsigned index = 0; unsigned index = 0;
while (index < rowSize) { while (index < rowSize) {
for (int j = 0; j < parents_[i]->getDomainSize(); j++) { for (unsigned j = 0; j < parents_[i]->getDomainSize(); j++) {
for (int r = 0; r < nReps; r++) { for (unsigned r = 0; r < nReps; r++) {
confs[index][i] = j; confs[index][i] = j;
index++; index++;
} }
@ -184,7 +178,7 @@ BayesNode::cptEntryToString (const CptEntry& entry) const
{ {
stringstream ss; stringstream ss;
ss << "p(" ; ss << "p(" ;
const DomainConf& conf = entry.getParentConfigurations(); const DConf& conf = entry.getDomainConfiguration();
int row = entry.getParameterIndex() / getRowSize(); int row = entry.getParameterIndex() / getRowSize();
ss << getDomain()[row]; ss << getDomain()[row];
if (parents_.size() > 0) { if (parents_.size() > 0) {
@ -207,7 +201,7 @@ BayesNode::cptEntryToString (int row, const CptEntry& entry) const
{ {
stringstream ss; stringstream ss;
ss << "p(" ; ss << "p(" ;
const DomainConf& conf = entry.getParentConfigurations(); const DConf& conf = entry.getDomainConfiguration();
ss << getDomain()[row]; ss << getDomain()[row];
if (parents_.size() > 0) { if (parents_.size() > 0) {
ss << "|" ; ss << "|" ;
@ -227,16 +221,16 @@ BayesNode::cptEntryToString (int row, const CptEntry& entry) const
vector<string> vector<string>
BayesNode::getDomainHeaders (void) const BayesNode::getDomainHeaders (void) const
{ {
int nParents = parents_.size(); unsigned nParents = parents_.size();
int rowSize = getRowSize(); unsigned rowSize = getRowSize();
int nReps = 1; unsigned nReps = 1;
vector<string> headers (rowSize); vector<string> headers (rowSize);
for (int i = nParents - 1; i >= 0; i--) { for (int i = nParents - 1; i >= 0; i--) {
Domain domain = parents_[i]->getDomain(); Domain domain = parents_[i]->getDomain();
int index = 0; unsigned index = 0;
while (index < rowSize) { while (index < rowSize) {
for (int j = 0; j < parents_[i]->getDomainSize(); j++) { for (unsigned j = 0; j < parents_[i]->getDomainSize(); j++) {
for (int r = 0; r < nReps; r++) { for (unsigned r = 0; r < nReps; r++) {
if (headers[index] != "") { if (headers[index] != "") {
headers[index] = domain[j] + "," + headers[index]; headers[index] = domain[j] + "," + headers[index];
} else { } else {
@ -270,7 +264,7 @@ operator << (ostream& o, const BayesNode& node)
o << endl; o << endl;
o << "Parents: " ; o << "Parents: " ;
const NodeSet& parents = node.getParents(); const BnNodeSet& parents = node.getParents();
if (parents.size() != 0) { if (parents.size() != 0) {
for (unsigned int i = 0; i < parents.size() - 1; i++) { for (unsigned int i = 0; i < parents.size() - 1; i++) {
o << parents[i]->getLabel() << ", " ; o << parents[i]->getLabel() << ", " ;
@ -280,7 +274,7 @@ operator << (ostream& o, const BayesNode& node)
o << endl; o << endl;
o << "Childs: " ; o << "Childs: " ;
const NodeSet& childs = node.getChilds(); const BnNodeSet& childs = node.getChilds();
if (childs.size() != 0) { if (childs.size() != 0) {
for (unsigned int i = 0; i < childs.size() - 1; i++) { for (unsigned int i = 0; i < childs.size() - 1; i++) {
o << childs[i]->getLabel() << ", " ; o << childs[i]->getLabel() << ", " ;

View File

@ -1,9 +1,7 @@
#ifndef BP_BAYESNODE_H #ifndef BP_BAYES_NODE_H
#define BP_BAYESNODE_H #define BP_BAYES_NODE_H
#include <vector> #include <vector>
#include <string>
#include <sstream>
#include "Variable.h" #include "Variable.h"
#include "CptEntry.h" #include "CptEntry.h"
@ -16,11 +14,12 @@ using namespace std;
class BayesNode : public Variable class BayesNode : public Variable
{ {
public: public:
BayesNode (unsigned); BayesNode (Vid vid) : Variable (vid) {}
BayesNode (unsigned, unsigned, int, const NodeSet&, Distribution*); BayesNode (Vid, unsigned, int, const BnNodeSet&, Distribution*);
BayesNode (unsigned, string, const Domain&, const NodeSet&, Distribution*); BayesNode (Vid, string, const Domain&, const BnNodeSet&, Distribution*);
void setData (unsigned, int, const NodeSet&, Distribution*); void setData (unsigned, int, const BnNodeSet&,
Distribution*);
void addChild (BayesNode*); void addChild (BayesNode*);
Distribution* getDistribution (void); Distribution* getDistribution (void);
const ParamSet& getParameters (void); const ParamSet& getParameters (void);
@ -34,11 +33,21 @@ class BayesNode : public Variable
int getIndexOfParent (const BayesNode*) const; int getIndexOfParent (const BayesNode*) const;
string cptEntryToString (const CptEntry&) const; string cptEntryToString (const CptEntry&) const;
string cptEntryToString (int, const CptEntry&) const; string cptEntryToString (int, const CptEntry&) const;
// inlines
const NodeSet& getParents (void) const; const BnNodeSet& getParents (void) const { return parents_; }
const NodeSet& getChilds (void) const; const BnNodeSet& getChilds (void) const { return childs_; }
double getProbability (int, const CptEntry& entry);
unsigned getRowSize (void) const; unsigned getRowSize (void) const
{
return dist_->params.size() / getDomainSize();
}
double getProbability (int row, const CptEntry& entry)
{
int col = entry.getParameterIndex();
int idx = (row * getRowSize()) + col;
return dist_->params[idx];
}
private: private:
DISALLOW_COPY_AND_ASSIGN (BayesNode); DISALLOW_COPY_AND_ASSIGN (BayesNode);
@ -46,46 +55,12 @@ class BayesNode : public Variable
Domain getDomainHeaders (void) const; Domain getDomainHeaders (void) const;
friend ostream& operator << (ostream&, const BayesNode&); friend ostream& operator << (ostream&, const BayesNode&);
NodeSet parents_; BnNodeSet parents_;
NodeSet childs_; BnNodeSet childs_;
Distribution* dist_; Distribution* dist_;
}; };
ostream& operator << (ostream&, const BayesNode&); ostream& operator << (ostream&, const BayesNode&);
#endif //BP_BAYES_NODE_H
inline const NodeSet&
BayesNode::getParents (void) const
{
return parents_;
}
inline const NodeSet&
BayesNode::getChilds (void) const
{
return childs_;
}
inline double
BayesNode::getProbability (int row, const CptEntry& entry)
{
int col = entry.getParameterIndex();
int idx = (row * getRowSize()) + col;
return dist_->params[idx];
}
inline unsigned
BayesNode::getRowSize (void) const
{
return dist_->params.size() / getDomainSize();
}
#endif

View File

@ -0,0 +1,198 @@
#include "CountingBP.h"
CountingBP::~CountingBP (void)
{
delete lfg_;
delete fg_;
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
links_.clear();
}
ParamSet
CountingBP::getPosterioriOf (Vid vid) const
{
FgVarNode* var = lfg_->getEquivalentVariable (vid);
ParamSet probs;
if (var->hasEvidence()) {
probs.resize (var->getDomainSize(), 0.0);
probs[var->getEvidence()] = 1.0;
} else {
probs.resize (var->getDomainSize(), 1.0);
CLinkSet links = varsI_[var->getIndex()]->getLinks();
for (unsigned i = 0; i < links.size(); i++) {
ParamSet msg = links[i]->getMessage();
CountingBPLink* l = static_cast<CountingBPLink*> (links[i]);
Util::pow (msg, l->getNumberOfEdges());
for (unsigned j = 0; j < msg.size(); j++) {
probs[j] *= msg[j];
}
}
Util::normalize (probs);
}
return probs;
}
void
CountingBP::initializeSolver (void)
{
lfg_ = new LiftedFG (*fg_);
unsigned nUncVars = fg_->getFgVarNodes().size();
unsigned nUncFactors = fg_->getFactors().size();
CFgVarSet vars = fg_->getFgVarNodes();
unsigned nNeighborLessVars = 0;
for (unsigned i = 0; i < vars.size(); i++) {
CFactorSet factors = vars[i]->getFactors();
if (factors.size() == 1 && factors[0]->getFgVarNodes().size() == 1) {
nNeighborLessVars ++;
}
}
// cout << "UNCOMPRESSED FACTOR GRAPH" << endl;
// fg_->printGraphicalModel();
fg_->exportToDotFormat ("uncompress.dot");
FactorGraph *temp;
temp = fg_;
fg_ = lfg_->getCompressedFactorGraph();
unsigned nCompVars = fg_->getFgVarNodes().size();
unsigned nCompFactors = fg_->getFactors().size();
Statistics::updateCompressingStats (nUncVars,
nUncFactors,
nCompVars,
nCompFactors,
nNeighborLessVars);
cout << "COMPRESSED FACTOR GRAPH" << endl;
fg_->printGraphicalModel();
//fg_->exportToDotFormat ("compress.dot");
SPSolver::initializeSolver();
}
void
CountingBP::createLinks (void)
{
const FactorClusterSet fcs = lfg_->getFactorClusters();
for (unsigned i = 0; i < fcs.size(); i++) {
const VarClusterSet vcs = fcs[i]->getVarClusters();
for (unsigned j = 0; j < vcs.size(); j++) {
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
links_.push_back (
new CountingBPLink (fcs[i]->getRepresentativeFactor(),
vcs[j]->getRepresentativeVariable(), c));
//cout << (links_.back())->toString() << " edge count =" << c << endl;
}
}
return;
}
void
CountingBP::deleteJunction (Factor* f, FgVarNode*)
{
f->freeDistribution();
}
void
CountingBP::maxResidualSchedule (void)
{
if (nIter_ == 1) {
for (unsigned i = 0; i < links_.size(); i++) {
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it));
if (DL >= 2 && DL < 5) {
cout << "calculating " << links_[i]->toString() << endl;
}
}
return;
}
for (unsigned c = 0; c < links_.size(); c++) {
if (DL >= 2) {
cout << endl << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
cout << " " << setw (30) << left << (*it)->toString();
cout << "residual = " << (*it)->getResidual() << endl;
}
}
SortedOrder::iterator it = sortedOrder_.begin();
Link* link = *it;
if (DL >= 2) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
}
if (link->getResidual() < SolverOptions::accuracy) {
return;
}
link->updateMessage();
link->clearResidual();
sortedOrder_.erase (it);
linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin
CFactorSet factorNeighbors = link->getVariable()->getFactors();
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
CLinkSet links = factorsI_[factorNeighbors[i]->getIndex()]->getLinks();
for (unsigned j = 0; j < links.size(); j++) {
if (links[j]->getVariable() != link->getVariable()) { //FIXMEFIXME
if (DL >= 2 && DL < 5) {
cout << " calculating " << links[j]->toString() << endl;
}
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
LinkMap::iterator iter = linkMap_.find (links[j]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[j]);
}
}
}
}
}
ParamSet
CountingBP::getVar2FactorMsg (const Link* link) const
{
const FgVarNode* src = link->getVariable();
const Factor* dest = link->getFactor();
ParamSet msg;
if (src->hasEvidence()) {
cout << "has evidence" << endl;
msg.resize (src->getDomainSize(), 0.0);
msg[src->getEvidence()] = link->getMessage()[src->getEvidence()];
cout << "-> " << link->getVariable()->getLabel() << " " << link->getFactor()->getLabel() << endl;
cout << "-> p2s " << Util::parametersToString (msg) << endl;
} else {
msg = link->getMessage();
}
const CountingBPLink* l = static_cast<const CountingBPLink*> (link);
Util::pow (msg, l->getNumberOfEdges() - 1);
CLinkSet links = varsI_[src->getIndex()]->getLinks();
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dest) {
ParamSet msgFromFactor = links[i]->getMessage();
CountingBPLink* l = static_cast<CountingBPLink*> (links[i]);
Util::pow (msgFromFactor, l->getNumberOfEdges());
for (unsigned j = 0; j < msgFromFactor.size(); j++) {
msg[j] *= msgFromFactor[j];
}
}
}
return msg;
}

View File

@ -0,0 +1,45 @@
#ifndef BP_COUNTING_BP_H
#define BP_COUNTING_BP_H
#include "SPSolver.h"
#include "LiftedFG.h"
class Factor;
class FgVarNode;
class CountingBPLink : public Link
{
public:
CountingBPLink (Factor* f, FgVarNode* v, unsigned c) : Link (f, v)
{
edgeCount_ = c;
}
unsigned getNumberOfEdges (void) const { return edgeCount_; }
private:
unsigned edgeCount_;
};
class CountingBP : public SPSolver
{
public:
CountingBP (FactorGraph& fg) : SPSolver (fg) { }
~CountingBP (void);
ParamSet getPosterioriOf (Vid) const;
private:
void initializeSolver (void);
void createLinks (void);
void deleteJunction (Factor*, FgVarNode*);
void maxResidualSchedule (void);
ParamSet getVar2FactorMsg (const Link*) const;
LiftedFG* lfg_;
};
#endif // BP_COUNTING_BP_H

View File

@ -1,5 +1,5 @@
#ifndef BP_CPTENTRY_H #ifndef BP_CPT_ENTRY_H
#define BP_CPTENTRY_H #define BP_CPT_ENTRY_H
#include <vector> #include <vector>
@ -10,62 +10,34 @@ using namespace std;
class CptEntry class CptEntry
{ {
public: public:
CptEntry (unsigned, const vector<unsigned>&); CptEntry (unsigned index, const DConf& conf)
{
index_ = index;
conf_ = conf;
}
unsigned getParameterIndex (void) const; unsigned getParameterIndex (void) const { return index_; }
const vector<unsigned>& getParentConfigurations (void) const; const DConf& getDomainConfiguration (void) const { return conf_; }
bool matchConstraints (const DomainConstr&) const;
bool matchConstraints (const vector<DomainConstr>&) const; bool matchConstraints (const DConstraint& constr) const
{
return conf_[constr.first] == constr.second;
}
bool matchConstraints (const vector<DConstraint>& constrs) const
{
for (unsigned j = 0; j < constrs.size(); j++) {
if (conf_[constrs[j].first] != constrs[j].second) {
return false;
}
}
return true;
}
private: private:
unsigned index_; unsigned index_;
vector<unsigned> confs_; DConf conf_;
}; };
#endif //BP_CPT_ENTRY_H
inline
CptEntry::CptEntry (unsigned index, const vector<unsigned>& confs)
{
index_ = index;
confs_ = confs;
}
inline unsigned
CptEntry::getParameterIndex (void) const
{
return index_;
}
inline const vector<unsigned>&
CptEntry::getParentConfigurations (void) const
{
return confs_;
}
inline bool
CptEntry::matchConstraints (const DomainConstr& constr) const
{
return confs_[constr.first] == constr.second;
}
inline bool
CptEntry::matchConstraints (const vector<DomainConstr>& constrs) const
{
for (unsigned j = 0; j < constrs.size(); j++) {
if (confs_[constrs[j].first] != constrs[j].second) {
return false;
}
}
return true;
}
#endif

View File

@ -2,8 +2,8 @@
#define BP_DISTRIBUTION_H #define BP_DISTRIBUTION_H
#include <vector> #include <vector>
#include <string>
#include "CptEntry.h"
#include "Shared.h" #include "Shared.h"
using namespace std; using namespace std;
@ -11,16 +11,18 @@ using namespace std;
struct Distribution struct Distribution
{ {
public: public:
Distribution (unsigned id) Distribution (unsigned id, bool shared = false)
{ {
this->id = id; this->id = id;
this->params = params; this->params = params;
this->shared = shared;
} }
Distribution (const ParamSet& params) Distribution (const ParamSet& params, bool shared = false)
{ {
this->id = -1; this->id = -1;
this->params = params; this->params = params;
this->shared = shared;
} }
void updateParameters (const ParamSet& params) void updateParameters (const ParamSet& params)
@ -31,10 +33,11 @@ struct Distribution
unsigned id; unsigned id;
ParamSet params; ParamSet params;
vector<CptEntry> entries; vector<CptEntry> entries;
bool shared;
private: private:
DISALLOW_COPY_AND_ASSIGN (Distribution); DISALLOW_COPY_AND_ASSIGN (Distribution);
}; };
#endif #endif //BP_DISTRIBUTION_H

View File

@ -1,37 +1,37 @@
#include <iostream>
#include <sstream>
#include <cstdlib> #include <cstdlib>
#include <cassert> #include <cassert>
#include <iostream>
#include <sstream>
#include "Factor.h" #include "Factor.h"
#include "FgVarNode.h" #include "FgVarNode.h"
int Factor::indexCount_ = 0; Factor::Factor (const Factor& g)
{
Factor::Factor (FgVarNode* var) { copyFactor (g);
vs_.push_back (var);
int nParams = var->getDomainSize();
// create a uniform distribution
double val = 1.0 / nParams;
ps_ = ParamSet (nParams, val);
id_ = indexCount_;
indexCount_ ++;
} }
Factor::Factor (const FgVarSet& vars) { Factor::Factor (FgVarNode* var)
vs_ = vars; {
Factor (FgVarSet() = {var});
}
Factor::Factor (const FgVarSet& vars)
{
vars_ = vars;
int nParams = 1; int nParams = 1;
for (unsigned i = 0; i < vs_.size(); i++) { for (unsigned i = 0; i < vars_.size(); i++) {
nParams *= vs_[i]->getDomainSize(); nParams *= vars_[i]->getDomainSize();
} }
// create a uniform distribution // create a uniform distribution
double val = 1.0 / nParams; double val = 1.0 / nParams;
ps_ = ParamSet (nParams, val); dist_ = new Distribution (ParamSet (nParams, val));
id_ = indexCount_;
indexCount_ ++;
} }
@ -39,10 +39,17 @@ Factor::Factor (const FgVarSet& vars) {
Factor::Factor (FgVarNode* var, Factor::Factor (FgVarNode* var,
const ParamSet& params) const ParamSet& params)
{ {
vs_.push_back (var); vars_.push_back (var);
ps_ = params; dist_ = new Distribution (params);
id_ = indexCount_; }
indexCount_ ++;
Factor::Factor (FgVarSet& vars,
Distribution* dist)
{
vars_ = vars;
dist_ = dist;
} }
@ -50,42 +57,8 @@ Factor::Factor (FgVarNode* var,
Factor::Factor (const FgVarSet& vars, Factor::Factor (const FgVarSet& vars,
const ParamSet& params) const ParamSet& params)
{ {
vs_ = vars; vars_ = vars;
ps_ = params; dist_ = new Distribution (params);
id_ = indexCount_;
indexCount_ ++;
}
const FgVarSet&
Factor::getFgVarNodes (void) const
{
return vs_;
}
FgVarSet&
Factor::getFgVarNodes (void)
{
return vs_;
}
const ParamSet&
Factor::getParameters (void) const
{
return ps_;
}
ParamSet&
Factor::getParameters (void)
{
return ps_;
} }
@ -93,75 +66,95 @@ Factor::getParameters (void)
void void
Factor::setParameters (const ParamSet& params) Factor::setParameters (const ParamSet& params)
{ {
//cout << "ps size: " << ps_.size() << endl; assert (dist_->params.size() == params.size());
//cout << "params size: " << params.size() << endl; dist_->updateParameters (params);
assert (ps_.size() == params.size());
ps_ = params;
} }
Factor& void
Factor::operator= (const Factor& g) Factor::copyFactor (const Factor& g)
{ {
FgVarSet vars = g.getFgVarNodes(); vars_ = g.getFgVarNodes();
ParamSet params = g.getParameters(); dist_ = new Distribution (g.getDistribution()->params);
return *this;
} }
Factor& void
Factor::operator*= (const Factor& g) Factor::multiplyByFactor (const Factor& g, const vector<CptEntry>* entries)
{ {
FgVarSet gVs = g.getFgVarNodes(); if (vars_.size() == 0) {
copyFactor (g);
return;
}
const FgVarSet& gVs = g.getFgVarNodes();
const ParamSet& gPs = g.getParameters(); const ParamSet& gPs = g.getParameters();
bool hasCommonVars = false; bool factorsAreEqual = true;
vector<int> varIndexes; if (gVs.size() == vars_.size()) {
for (unsigned i = 0; i < gVs.size(); i++) { for (unsigned i = 0; i < vars_.size(); i++) {
int idx = getIndexOf (gVs[i]); if (gVs[i] != vars_[i]) {
if (idx == -1) { factorsAreEqual = false;
insertVariable (gVs[i]); break;
varIndexes.push_back (vs_.size() - 1);
} else {
hasCommonVars = true;
varIndexes.push_back (idx);
}
}
if (hasCommonVars) {
vector<int> offsets (gVs.size());
offsets[gVs.size() - 1] = 1;
for (int i = gVs.size() - 2; i >= 0; i--) {
offsets[i] = offsets[i + 1] * gVs[i + 1]->getDomainSize();
}
vector<CptEntry> entries = getCptEntries();
for (unsigned i = 0; i < entries.size(); i++) {
int idx = 0;
const DomainConf conf = entries[i].getParentConfigurations();
for (unsigned j = 0; j < varIndexes.size(); j++) {
idx += offsets[j] * conf[varIndexes[j]];
} }
//cout << "ps_[" << i << "] = " << ps_[i] << " * " ;
//cout << gPs[idx] << " , idx = " << idx << endl;
ps_[i] = ps_[i] * gPs[idx];
} }
} else { } else {
// if the originally factors doesn't have common factors. factorsAreEqual = false;
// we don't have to make domain comparations }
unsigned idx = 0;
for (unsigned i = 0; i < ps_.size(); i++) { if (factorsAreEqual) {
//cout << "ps_[" << i << "] = " << ps_[i] << " * " ; // optimization: if the factors contain the same set of variables,
//cout << gPs[idx] << " , idx = " << idx << endl; // we can do 1 to 1 operations on the parameteres
ps_[i] = ps_[i] * gPs[idx]; for (unsigned i = 0; i < dist_->params.size(); i++) {
idx ++; dist_->params[i] *= gPs[i];
if (idx >= gPs.size()) { }
idx = 0; } else {
bool hasCommonVars = false;
vector<unsigned> gVsIndexes;
for (unsigned i = 0; i < gVs.size(); i++) {
int idx = getIndexOf (gVs[i]);
if (idx == -1) {
insertVariable (gVs[i]);
gVsIndexes.push_back (vars_.size() - 1);
} else {
hasCommonVars = true;
gVsIndexes.push_back (idx);
}
}
if (hasCommonVars) {
vector<unsigned> gVsOffsets (gVs.size());
gVsOffsets[gVs.size() - 1] = 1;
for (int i = gVs.size() - 2; i >= 0; i--) {
gVsOffsets[i] = gVsOffsets[i + 1] * gVs[i + 1]->getDomainSize();
}
if (entries == 0) {
entries = &getCptEntries();
}
for (unsigned i = 0; i < entries->size(); i++) {
unsigned idx = 0;
const DConf& conf = (*entries)[i].getDomainConfiguration();
for (unsigned j = 0; j < gVsIndexes.size(); j++) {
idx += gVsOffsets[j] * conf[ gVsIndexes[j] ];
}
dist_->params[i] = dist_->params[i] * gPs[idx];
}
} else {
// optimization: if the original factors doesn't have common variables,
// we don't need to marry the states of the common variables
unsigned count = 0;
for (unsigned i = 0; i < dist_->params.size(); i++) {
dist_->params[i] *= gPs[count];
count ++;
if (count >= gPs.size()) {
count = 0;
}
} }
} }
} }
return *this;
} }
@ -169,81 +162,109 @@ Factor::operator*= (const Factor& g)
void void
Factor::insertVariable (FgVarNode* var) Factor::insertVariable (FgVarNode* var)
{ {
int c = 0; assert (getIndexOf (var) == -1);
ParamSet newPs (ps_.size() * var->getDomainSize()); ParamSet newPs;
for (unsigned i = 0; i < ps_.size(); i++) { newPs.reserve (dist_->params.size() * var->getDomainSize());
for (int j = 0; j < var->getDomainSize(); j++) { for (unsigned i = 0; i < dist_->params.size(); i++) {
newPs[c] = ps_[i]; for (unsigned j = 0; j < var->getDomainSize(); j++) {
c ++; newPs.push_back (dist_->params[i]);
} }
} }
vs_.push_back (var); vars_.push_back (var);
ps_ = newPs; dist_->updateParameters (newPs);
} }
void void
Factor::marginalizeVariable (const FgVarNode* var) { Factor::removeVariable (const FgVarNode* var)
int varIndex = getIndexOf (var);
marginalizeVariable (varIndex);
}
void
Factor::marginalizeVariable (unsigned varIndex)
{ {
assert (varIndex >= 0 && varIndex < vs_.size()); int varIndex = getIndexOf (var);
int distOffset = 1; assert (varIndex >= 0 && varIndex < (int)vars_.size());
int leftVarOffset = 1;
for (unsigned i = vs_.size() - 1; i > varIndex; i--) { // number of parameters separating a different state of `var',
distOffset *= vs_[i]->getDomainSize(); // with the states of the remaining variables fixed
leftVarOffset *= vs_[i]->getDomainSize(); unsigned varOffset = 1;
}
leftVarOffset *= vs_[varIndex]->getDomainSize(); // number of parameters separating a different state of the variable
// on the left of `var', with the states of the remaining vars fixed
unsigned leftVarOffset = 1;
for (int i = vars_.size() - 1; i > varIndex; i--) {
varOffset *= vars_[i]->getDomainSize();
leftVarOffset *= vars_[i]->getDomainSize();
}
leftVarOffset *= vars_[varIndex]->getDomainSize();
unsigned offset = 0;
unsigned count1 = 0;
unsigned count2 = 0;
unsigned newPsSize = dist_->params.size() / vars_[varIndex]->getDomainSize();
int ds = vs_[varIndex]->getDomainSize();
int count = 0;
int offset = 0;
int startIndex = 0;
int currDomainIdx = 0;
unsigned newPsSize = ps_.size() / ds;
ParamSet newPs; ParamSet newPs;
newPs.reserve (newPsSize); newPs.reserve (newPsSize);
stringstream ss; // stringstream ss;
ss << "marginalizing " << vs_[varIndex]->getLabel(); // ss << "marginalizing " << vars_[varIndex]->getLabel();
ss << " from factor " << getLabel() << endl; // ss << " from factor " << getLabel() << endl;
while (newPs.size() < newPsSize) { while (newPs.size() < newPsSize) {
ss << " sum = "; // ss << " sum = ";
double sum = 0.0; double sum = 0.0;
for (int j = 0; j < ds; j++) { for (unsigned i = 0; i < vars_[varIndex]->getDomainSize(); i++) {
if (j != 0) ss << " + "; // if (i != 0) ss << " + ";
ss << ps_[offset]; // ss << dist_->params[offset];
sum = sum + ps_[offset]; sum += dist_->params[offset];
offset = offset + distOffset; offset += varOffset;
} }
newPs.push_back (sum); newPs.push_back (sum);
count ++; count1 ++;
if (varIndex == vs_.size() - 1) { if (varIndex == (int)vars_.size() - 1) {
offset = count * ds; offset = count1 * vars_[varIndex]->getDomainSize();
} else { } else {
offset = offset - distOffset + 1; if (((offset - varOffset + 1) % leftVarOffset) == 0) {
if ((offset % leftVarOffset) == 0) { count1 = 0;
currDomainIdx ++; count2 ++;
startIndex = leftVarOffset * currDomainIdx;
offset = startIndex;
count = 0;
} else {
offset = startIndex + count;
} }
offset = (leftVarOffset * count2) + count1;
} }
ss << " = " << sum << endl; // ss << " = " << sum << endl;
} }
//cout << ss.str() << endl; // cout << ss.str() << endl;
ps_ = newPs; vars_.erase (vars_.begin() + varIndex);
vs_.erase (vs_.begin() + varIndex); dist_->updateParameters (newPs);
}
const vector<CptEntry>&
Factor::getCptEntries (void) const
{
if (dist_->entries.size() == 0) {
vector<DConf> confs (dist_->params.size());
for (unsigned i = 0; i < dist_->params.size(); i++) {
confs[i].resize (vars_.size());
}
unsigned nReps = 1;
for (int i = vars_.size() - 1; i >= 0; i--) {
unsigned index = 0;
while (index < dist_->params.size()) {
for (unsigned j = 0; j < vars_[i]->getDomainSize(); j++) {
for (unsigned r = 0; r < nReps; r++) {
confs[index][i] = j;
index++;
}
}
}
nReps *= vars_[i]->getDomainSize();
}
dist_->entries.clear();
dist_->entries.reserve (dist_->params.size());
for (unsigned i = 0; i < dist_->params.size(); i++) {
dist_->entries.push_back (CptEntry (i, confs[i]));
}
}
return dist_->entries;
} }
@ -252,11 +273,10 @@ string
Factor::getLabel (void) const Factor::getLabel (void) const
{ {
stringstream ss; stringstream ss;
ss << "f(" ; ss << "Φ(" ;
// ss << "Φ(" ; for (unsigned i = 0; i < vars_.size(); i++) {
for (unsigned i = 0; i < vs_.size(); i++) { if (i != 0) ss << "," ;
if (i != 0) ss << ", " ; ss << vars_[i]->getLabel();
ss << "v" << vs_[i]->getVarId();
} }
ss << ")" ; ss << ")" ;
return ss.str(); return ss.str();
@ -264,62 +284,24 @@ Factor::getLabel (void) const
string void
Factor::toString (void) const Factor::printFactor (void)
{ {
stringstream ss; stringstream ss;
ss << "vars: " ; ss << getLabel() << endl;
for (unsigned i = 0; i < vs_.size(); i++) { ss << "--------------------" << endl;
if (i != 0) ss << ", " ; VarSet vs;
ss << "v" << vs_[i]->getVarId(); for (unsigned i = 0; i < vars_.size(); i++) {
vs.push_back (vars_[i]);
} }
ss << endl; vector<string> domainConfs = Util::getInstantiations (vs);
vector<CptEntry> entries = getCptEntries(); const vector<CptEntry>& entries = getCptEntries();
for (unsigned i = 0; i < entries.size(); i++) { for (unsigned i = 0; i < entries.size(); i++) {
ss << "Φ(" ; ss << "Φ(" << domainConfs[i] << ")" ;
char s = 'a' ; unsigned idx = entries[i].getParameterIndex();
const DomainConf& conf = entries[i].getParentConfigurations(); ss << " = " << dist_->params[idx] << endl;
for (unsigned j = 0; j < conf.size(); j++) {
if (j != 0) ss << "," ;
ss << s << conf[j] + 1;
s++;
}
ss << ") = " << ps_[entries[i].getParameterIndex()] << endl;
} }
return ss.str(); cout << ss.str();
}
vector<CptEntry>
Factor::getCptEntries (void) const
{
vector<DomainConf> confs (ps_.size());
for (unsigned i = 0; i < ps_.size(); i++) {
confs[i].resize (vs_.size());
}
int nReps = 1;
for (int i = vs_.size() - 1; i >= 0; i--) {
unsigned index = 0;
while (index < ps_.size()) {
for (int j = 0; j < vs_[i]->getDomainSize(); j++) {
for (int r = 0; r < nReps; r++) {
confs[index][i] = j;
index++;
}
}
}
nReps *= vs_[i]->getDomainSize();
}
vector<CptEntry> entries;
for (unsigned i = 0; i < ps_.size(); i++) {
for (unsigned j = 0; j < vs_.size(); j++) {
}
entries.push_back (CptEntry (i, confs[i]));
}
return entries;
} }
@ -327,20 +309,11 @@ Factor::getCptEntries (void) const
int int
Factor::getIndexOf (const FgVarNode* var) const Factor::getIndexOf (const FgVarNode* var) const
{ {
for (unsigned i = 0; i < vs_.size(); i++) { for (unsigned i = 0; i < vars_.size(); i++) {
if (vs_[i] == var) { if (vars_[i] == var) {
return i; return i;
} }
} }
return -1; return -1;
} }
Factor operator* (const Factor& f, const Factor& g)
{
Factor r = f;
r *= g;
return r;
}

View File

@ -3,43 +3,46 @@
#include <vector> #include <vector>
#include "Distribution.h"
#include "CptEntry.h" #include "CptEntry.h"
using namespace std; using namespace std;
class FgVarNode; class FgVarNode;
class Distribution;
class Factor class Factor
{ {
public: public:
Factor (void) { }
Factor (const Factor&);
Factor (FgVarNode*); Factor (FgVarNode*);
Factor (const FgVarSet&); Factor (CFgVarSet);
Factor (FgVarNode*, const ParamSet&); Factor (FgVarNode*, const ParamSet&);
Factor (const FgVarSet&, const ParamSet&); Factor (FgVarSet&, Distribution*);
Factor (CFgVarSet, CParamSet);
const FgVarSet& getFgVarNodes (void) const; void setParameters (CParamSet);
FgVarSet& getFgVarNodes (void); void copyFactor (const Factor& f);
const ParamSet& getParameters (void) const; void multiplyByFactor (const Factor& f, const vector<CptEntry>* = 0);
ParamSet& getParameters (void); void insertVariable (FgVarNode* index);
void setParameters (const ParamSet&); void removeVariable (const FgVarNode* var);
Factor& operator= (const Factor& f); const vector<CptEntry>& getCptEntries (void) const;
Factor& operator*= (const Factor& f); string getLabel (void) const;
void insertVariable (FgVarNode* index); void printFactor (void);
void marginalizeVariable (const FgVarNode* var);
void marginalizeVariable (unsigned); CFgVarSet getFgVarNodes (void) const { return vars_; }
string getLabel (void) const; CParamSet getParameters (void) const { return dist_->params; }
string toString (void) const; Distribution* getDistribution (void) const { return dist_; }
unsigned getIndex (void) const { return index_; }
void setIndex (unsigned index) { index_ = index; }
void freeDistribution (void) { delete dist_; dist_ = 0;}
int getIndexOf (const FgVarNode*) const;
private: private:
vector<CptEntry> getCptEntries() const; FgVarSet vars_;
int getIndexOf (const FgVarNode*) const; Distribution* dist_;
unsigned index_;
FgVarSet vs_;
ParamSet ps_;
int id_;
static int indexCount_;
}; };
Factor operator* (const Factor&, const Factor&); #endif //BP_FACTOR_H
#endif

View File

@ -1,23 +1,26 @@
#include <cstdlib>
#include <vector>
#include <set>
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include <vector>
#include <cstdlib>
#include "FactorGraph.h" #include "FactorGraph.h"
#include "FgVarNode.h" #include "FgVarNode.h"
#include "Factor.h" #include "Factor.h"
#include "BayesNet.h"
FactorGraph::FactorGraph (const char* fileName) FactorGraph::FactorGraph (const char* fileName)
{ {
string line;
ifstream is (fileName); ifstream is (fileName);
if (!is.is_open()) { if (!is.is_open()) {
cerr << "error: cannot read from file " + std::string (fileName) << endl; cerr << "error: cannot read from file " + std::string (fileName) << endl;
abort(); abort();
} }
string line;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line); while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
getline (is, line); getline (is, line);
if (line != "MARKOV") { if (line != "MARKOV") {
@ -39,7 +42,7 @@ FactorGraph::FactorGraph (const char* fileName)
while (is.peek() == '#' || is.peek() == '\n') getline (is, line); while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
for (int i = 0; i < nVars; i++) { for (int i = 0; i < nVars; i++) {
varNodes_.push_back (new FgVarNode (i, domainSizes[i])); addVariable (new FgVarNode (i, domainSizes[i]));
} }
int nFactors; int nFactors;
@ -50,11 +53,11 @@ FactorGraph::FactorGraph (const char* fileName)
is >> nFactorVars; is >> nFactorVars;
FgVarSet factorVars; FgVarSet factorVars;
for (int j = 0; j < nFactorVars; j++) { for (int j = 0; j < nFactorVars; j++) {
int varId; int vid;
is >> varId; is >> vid;
FgVarNode* var = getVariableById (varId); FgVarNode* var = getFgVarNode (vid);
if (var == 0) { if (!var) {
cerr << "error: invalid variable identifier (" << varId << ")" << endl; cerr << "error: invalid variable identifier (" << vid << ")" << endl;
abort(); abort();
} }
factorVars.push_back (var); factorVars.push_back (var);
@ -87,6 +90,33 @@ FactorGraph::FactorGraph (const char* fileName)
FactorGraph::FactorGraph (const BayesNet& bn)
{
const BnNodeSet& nodes = bn.getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
FgVarNode* varNode = new FgVarNode (nodes[i]);
varNode->setIndex (i);
addVariable (varNode);
}
for (unsigned i = 0; i < nodes.size(); i++) {
const BnNodeSet& parents = nodes[i]->getParents();
if (!(nodes[i]->hasEvidence() && parents.size() == 0)) {
FgVarSet factorVars = { varNodes_[nodes[i]->getIndex()] };
for (unsigned j = 0; j < parents.size(); j++) {
factorVars.push_back (varNodes_[parents[j]->getIndex()]);
}
Factor* f = new Factor (factorVars, nodes[i]->getDistribution());
factors_.push_back (f);
for (unsigned j = 0; j < factorVars.size(); j++) {
factorVars[j]->addFactor (f);
}
}
}
}
FactorGraph::~FactorGraph (void) FactorGraph::~FactorGraph (void)
{ {
for (unsigned i = 0; i < varNodes_.size(); i++) { for (unsigned i = 0; i < varNodes_.size(); i++) {
@ -99,18 +129,67 @@ FactorGraph::~FactorGraph (void)
FgVarSet void
FactorGraph::getFgVarNodes (void) const FactorGraph::addVariable (FgVarNode* varNode)
{ {
return varNodes_; varNodes_.push_back (varNode);
varNode->setIndex (varNodes_.size() - 1);
indexMap_.insert (make_pair (varNode->getVarId(), varNodes_.size() - 1));
} }
vector<Factor*> void
FactorGraph::getFactors (void) const FactorGraph::removeVariable (const FgVarNode* var)
{ {
return factors_; if (varNodes_[varNodes_.size() - 1] == var) {
varNodes_.pop_back();
} else {
for (unsigned i = 0; i < varNodes_.size(); i++) {
if (varNodes_[i] == var) {
varNodes_.erase (varNodes_.begin() + i);
return;
}
}
assert (false);
}
indexMap_.erase (indexMap_.find (var->getVarId()));
}
void
FactorGraph::addFactor (Factor* f)
{
factors_.push_back (f);
const FgVarSet& factorVars = f->getFgVarNodes();
for (unsigned i = 0; i < factorVars.size(); i++) {
factorVars[i]->addFactor (f);
}
}
void
FactorGraph::removeFactor (const Factor* f)
{
const FgVarSet& factorVars = f->getFgVarNodes();
for (unsigned i = 0; i < factorVars.size(); i++) {
if (factorVars[i]) {
factorVars[i]->removeFactor (f);
}
}
if (factors_[factors_.size() - 1] == f) {
factors_.pop_back();
} else {
for (unsigned i = 0; i < factors_.size(); i++) {
if (factors_[i] == f) {
factors_.erase (factors_.begin() + i);
return;
}
}
assert (false);
}
} }
@ -127,47 +206,142 @@ FactorGraph::getVariables (void) const
FgVarNode* Variable*
FactorGraph::getVariableById (unsigned id) const FactorGraph::getVariable (Vid vid) const
{ {
for (unsigned i = 0; i < varNodes_.size(); i++) { return getFgVarNode (vid);
if (varNodes_[i]->getVarId() == id) {
return varNodes_[i];
}
}
return 0;
}
FgVarNode*
FactorGraph::getVariableByLabel (string label) const
{
for (unsigned i = 0; i < varNodes_.size(); i++) {
stringstream ss;
ss << "v" << varNodes_[i]->getVarId();
if (ss.str() == label) {
return varNodes_[i];
}
}
return 0;
} }
void void
FactorGraph::printFactorGraph (void) const FactorGraph::setIndexes (void)
{
for (unsigned i = 0; i < varNodes_.size(); i++) {
varNodes_[i]->setIndex (i);
}
for (unsigned i = 0; i < factors_.size(); i++) {
factors_[i]->setIndex (i);
}
}
void
FactorGraph::freeDistributions (void)
{
set<Distribution*> dists;
for (unsigned i = 0; i < factors_.size(); i++) {
dists.insert (factors_[i]->getDistribution());
}
for (set<Distribution*>::iterator it = dists.begin();
it != dists.end(); it++) {
delete *it;
}
}
void
FactorGraph::printGraphicalModel (void) const
{ {
for (unsigned i = 0; i < varNodes_.size(); i++) { for (unsigned i = 0; i < varNodes_.size(); i++) {
cout << "variable number " << varNodes_[i]->getIndex() << endl; cout << "variable number " << varNodes_[i]->getIndex() << endl;
cout << "Id = " << varNodes_[i]->getVarId() << endl; cout << "Id = " << varNodes_[i]->getVarId() << endl;
cout << "Label = " << varNodes_[i]->getLabel() << endl;
cout << "Domain size = " << varNodes_[i]->getDomainSize() << endl; cout << "Domain size = " << varNodes_[i]->getDomainSize() << endl;
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl; cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
cout << endl; cout << "Factors = " ;
for (unsigned j = 0; j < varNodes_[i]->getFactors().size(); j++) {
cout << varNodes_[i]->getFactors()[j]->getLabel() << " " ;
}
cout << endl << endl;
} }
cout << endl;
for (unsigned i = 0; i < factors_.size(); i++) { for (unsigned i = 0; i < factors_.size(); i++) {
cout << factors_[i]->toString() << endl; factors_[i]->printFactor();
cout << endl;
} }
} }
void
FactorGraph::exportToDotFormat (const char* fileName) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "FactorGraph::exportToDotFile()" << endl;
abort();
}
out << "graph \"" << fileName << "\" {" << endl;
for (unsigned i = 0; i < varNodes_.size(); i++) {
if (varNodes_[i]->hasEvidence()) {
out << '"' << varNodes_[i]->getLabel() << '"' ;
out << " [style=filled, fillcolor=yellow]" << endl;
}
}
for (unsigned i = 0; i < factors_.size(); i++) {
out << '"' << factors_[i]->getLabel() << '"' ;
out << " [label=\"" << factors_[i]->getLabel() << "\\n(";
out << factors_[i]->getDistribution()->id << ")" << "\"" ;
out << ", shape=box]" << endl;
}
for (unsigned i = 0; i < factors_.size(); i++) {
CFgVarSet myVars = factors_[i]->getFgVarNodes();
for (unsigned j = 0; j < myVars.size(); j++) {
out << '"' << factors_[i]->getLabel() << '"' ;
out << " -- " ;
out << '"' << myVars[j]->getLabel() << '"' << endl;
}
}
out << "}" << endl;
out.close();
}
void
FactorGraph::exportToUaiFormat (const char* fileName) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "FactorGraph::exportToUaiFormat()" << endl;
abort();
}
out << "MARKOV" << endl;
out << varNodes_.size() << endl;
for (unsigned i = 0; i < varNodes_.size(); i++) {
out << varNodes_[i]->getDomainSize() << " " ;
}
out << endl;
out << factors_.size() << endl;
for (unsigned i = 0; i < factors_.size(); i++) {
CFgVarSet factorVars = factors_[i]->getFgVarNodes();
out << factorVars.size();
for (unsigned j = 0; j < factorVars.size(); j++) {
out << " " << factorVars[j]->getIndex();
}
out << endl;
}
for (unsigned i = 0; i < factors_.size(); i++) {
CParamSet params = factors_[i]->getParameters();
out << endl << params.size() << endl << " " ;
for (unsigned j = 0; j < params.size(); j++) {
out << params[j] << " " ;
}
out << endl;
}
out.close();
}

View File

@ -1,8 +1,7 @@
#ifndef BP_FACTORGRAPH_H #ifndef BP_FACTOR_GRAPH_H
#define BP_FACTORGRAPH_H #define BP_FACTOR_GRAPH_H
#include <vector> #include <vector>
#include <string>
#include "GraphicalModel.h" #include "GraphicalModel.h"
#include "Shared.h" #include "Shared.h"
@ -11,25 +10,48 @@ using namespace std;
class FgVarNode; class FgVarNode;
class Factor; class Factor;
class BayesNet;
class FactorGraph : public GraphicalModel class FactorGraph : public GraphicalModel
{ {
public: public:
FactorGraph (const char* fileName); FactorGraph (void) {};
FactorGraph (const char*);
FactorGraph (const BayesNet&);
~FactorGraph (void); ~FactorGraph (void);
FgVarSet getFgVarNodes (void) const; void addVariable (FgVarNode*);
vector<Factor*> getFactors (void) const; void removeVariable (const FgVarNode*);
void addFactor (Factor*);
void removeFactor (const Factor*);
VarSet getVariables (void) const; VarSet getVariables (void) const;
FgVarNode* getVariableById (unsigned) const; Variable* getVariable (unsigned) const;
FgVarNode* getVariableByLabel (string) const; void setIndexes (void);
void printFactorGraph (void) const; void freeDistributions (void);
void printGraphicalModel (void) const;
void exportToDotFormat (const char*) const;
void exportToUaiFormat (const char*) const;
const FgVarSet& getFgVarNodes (void) const { return varNodes_; }
const FactorSet& getFactors (void) const { return factors_; }
FgVarNode* getFgVarNode (Vid vid) const
{
IndexMap::const_iterator it = indexMap_.find (vid);
if (it == indexMap_.end()) {
return 0;
} else {
return varNodes_[it->second];
}
}
private: private:
DISALLOW_COPY_AND_ASSIGN (FactorGraph); DISALLOW_COPY_AND_ASSIGN (FactorGraph);
FgVarSet varNodes_; FgVarSet varNodes_;
vector<Factor*> factors_; FactorSet factors_;
IndexMap indexMap_;
}; };
#endif #endif // BP_FACTOR_GRAPH_H

View File

@ -1,8 +1,7 @@
#ifndef BP_VARIABLE_H #ifndef BP_FG_VAR_NODE_H
#define BP_VARIABLE_H #define BP_FG_VAR_NODE_H
#include <vector> #include <vector>
#include <string>
#include "Variable.h" #include "Variable.h"
#include "Shared.h" #include "Shared.h"
@ -14,15 +13,31 @@ class Factor;
class FgVarNode : public Variable class FgVarNode : public Variable
{ {
public: public:
FgVarNode (int varId, int dsize) : Variable (varId, dsize) { } FgVarNode (unsigned vid, unsigned dsize) : Variable (vid, dsize) { }
FgVarNode (const Variable* v) : Variable (v) { }
void addFactor (Factor* f) { factors_.push_back (f); } void addFactor (Factor* f) { factors_.push_back (f); }
vector<Factor*> getFactors (void) const { return factors_; } CFactorSet getFactors (void) const { return factors_; }
void removeFactor (const Factor* f)
{
if (factors_[factors_.size() -1] == f) {
factors_.pop_back();
} else {
for (unsigned i = 0; i < factors_.size(); i++) {
if (factors_[i] == f) {
factors_.erase (factors_.begin() + i);
return;
}
}
assert (false);
}
}
private: private:
DISALLOW_COPY_AND_ASSIGN (FgVarNode); DISALLOW_COPY_AND_ASSIGN (FgVarNode);
// members // members
vector<Factor*> factors_; FactorSet factors_;
}; };
#endif // BP_VARIABLE_H #endif // BP_FG_VAR_NODE_H

View File

@ -1,5 +1,5 @@
#ifndef BP_GRAPHICALMODEL_H #ifndef BP_GRAPHICAL_MODEL_H
#define BP_GRAPHICALMODEL_H #define BP_GRAPHICAL_MODEL_H
#include "Variable.h" #include "Variable.h"
#include "Shared.h" #include "Shared.h"
@ -9,9 +9,10 @@ using namespace std;
class GraphicalModel class GraphicalModel
{ {
public: public:
virtual VarSet getVariables (void) const = 0; virtual ~GraphicalModel (void) {};
virtual Variable* getVariable (Vid) const = 0;
private: virtual VarSet getVariables (void) const = 0;
virtual void printGraphicalModel (void) const = 0;
}; };
#endif #endif // BP_GRAPHICAL_MODEL_H

View File

@ -1,17 +1,19 @@
#include <iostream>
#include <cstdlib> #include <cstdlib>
#include <iostream>
#include <sstream> #include <sstream>
#include "BayesNet.h" #include "BayesNet.h"
#include "BPSolver.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "SPSolver.h" #include "SPSolver.h"
#include "BPSolver.h"
#include "CountingBP.h"
using namespace std; using namespace std;
void BayesianNetwork (int, const char* []); void BayesianNetwork (int, const char* []);
void markovNetwork (int, const char* []); void markovNetwork (int, const char* []);
void runSolver (Solver*, const VarSet&);
const string USAGE = "usage: \ const string USAGE = "usage: \
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ; ./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
@ -19,14 +21,40 @@ const string USAGE = "usage: \
int int
main (int argc, const char* argv[]) main (int argc, const char* argv[])
{ {
/*
FactorGraph fg;
FgVarNode* varNode1 = new FgVarNode (0, 2);
FgVarNode* varNode2 = new FgVarNode (1, 2);
FgVarNode* varNode3 = new FgVarNode (2, 2);
fg.addVariable (varNode1);
fg.addVariable (varNode2);
fg.addVariable (varNode3);
Distribution* dist = new Distribution (ParamSet() = {1.2, 1.4, 2.0, 0.4});
fg.addFactor (new Factor (FgVarSet() = {varNode1, varNode2}, dist));
fg.addFactor (new Factor (FgVarSet() = {varNode3, varNode2}, dist));
//fg.printGraphicalModel();
//SPSolver sp (fg);
//sp.runSolver();
//sp.printAllPosterioris();
//ParamSet p = sp.getJointDistributionOf (VidSet() = {0, 1, 2});
//cout << Util::parametersToString (p) << endl;
CountingBP cbp (fg);
//cbp.runSolver();
//cbp.printAllPosterioris();
ParamSet p2 = cbp.getJointDistributionOf (VidSet() = {0, 1, 2});
cout << Util::parametersToString (p2) << endl;
fg.freeDistributions();
Statistics::printCompressingStats ("compressing.stats");
return 0;
*/
if (!argv[1]) { if (!argv[1]) {
cerr << "error: no graphical model specified" << endl; cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl; cerr << USAGE << endl;
exit (0); exit (0);
} }
string fileName = argv[1]; const string& fileName = argv[1];
string extension = fileName.substr (fileName.find_last_of ('.') + 1); const string& extension = fileName.substr (fileName.find_last_of ('.') + 1);
if (extension == "xml") { if (extension == "xml") {
BayesianNetwork (argc, argv); BayesianNetwork (argc, argv);
} else if (extension == "uai") { } else if (extension == "uai") {
@ -45,13 +73,13 @@ void
BayesianNetwork (int argc, const char* argv[]) BayesianNetwork (int argc, const char* argv[])
{ {
BayesNet bn (argv[1]); BayesNet bn (argv[1]);
//bn.printNetwork(); //bn.printGraphicalModel();
NodeSet queryVars; VarSet queryVars;
for (int i = 2; i < argc; i++) { for (int i = 2; i < argc; i++) {
string arg = argv[i]; const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) { if (arg.find ('=') == std::string::npos) {
BayesNode* queryVar = bn.getNode (arg); BayesNode* queryVar = bn.getBayesNode (arg);
if (queryVar) { if (queryVar) {
queryVars.push_back (queryVar); queryVars.push_back (queryVar);
} else { } else {
@ -61,9 +89,9 @@ BayesianNetwork (int argc, const char* argv[])
exit (0); exit (0);
} }
} else { } else {
size_t pos = arg.find ('='); size_t pos = arg.find ('=');
string label = arg.substr (0, pos); const string& label = arg.substr (0, pos);
string state = arg.substr (pos + 1); const string& state = arg.substr (pos + 1);
if (label.empty()) { if (label.empty()) {
cerr << "error: missing left argument" << endl; cerr << "error: missing left argument" << endl;
cerr << USAGE << endl; cerr << USAGE << endl;
@ -74,7 +102,7 @@ BayesianNetwork (int argc, const char* argv[])
cerr << USAGE << endl; cerr << USAGE << endl;
exit (0); exit (0);
} }
BayesNode* node = bn.getNode (label); BayesNode* node = bn.getBayesNode (label);
if (node) { if (node) {
if (node->isValidState (state)) { if (node->isValidState (state)) {
node->setEvidence (state); node->setEvidence (state);
@ -94,19 +122,16 @@ BayesianNetwork (int argc, const char* argv[])
} }
} }
BPSolver solver (bn); Solver* solver;
if (queryVars.size() == 0) { if (SolverOptions::convertBn2Fg) {
solver.runSolver(); FactorGraph* fg = new FactorGraph (bn);
solver.printAllPosterioris(); fg->printGraphicalModel();
} else if (queryVars.size() == 1) { solver = new SPSolver (*fg);
solver.runSolver(); runSolver (solver, queryVars);
solver.printPosterioriOf (queryVars[0]); delete fg;
} else { } else {
Domain domain = BayesNet::getInstantiations(queryVars); solver = new BPSolver (bn);
ParamSet params = solver.getJointDistribution (queryVars); runSolver (solver, queryVars);
for (unsigned i = 0; i < params.size(); i++) {
cout << domain[i] << "\t" << params[i] << endl;
}
} }
bn.freeDistributions(); bn.freeDistributions();
} }
@ -117,11 +142,11 @@ void
markovNetwork (int argc, const char* argv[]) markovNetwork (int argc, const char* argv[])
{ {
FactorGraph fg (argv[1]); FactorGraph fg (argv[1]);
//fg.printFactorGraph(); //fg.printGraphicalModel();
VarSet queryVars; VarSet queryVars;
for (int i = 2; i < argc; i++) { for (int i = 2; i < argc; i++) {
string arg = argv[i]; const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) { if (arg.find ('=') == std::string::npos) {
if (!Util::isInteger (arg)) { if (!Util::isInteger (arg)) {
cerr << "error: `" << arg << "' " ; cerr << "error: `" << arg << "' " ;
@ -129,16 +154,16 @@ markovNetwork (int argc, const char* argv[])
cerr << endl; cerr << endl;
exit (0); exit (0);
} }
unsigned varId; Vid vid;
stringstream ss; stringstream ss;
ss << arg; ss << arg;
ss >> varId; ss >> vid;
Variable* queryVar = fg.getVariableById (varId); Variable* queryVar = fg.getFgVarNode (vid);
if (queryVar) { if (queryVar) {
queryVars.push_back (queryVar); queryVars.push_back (queryVar);
} else { } else {
cerr << "error: there isn't a variable with " ; cerr << "error: there isn't a variable with " ;
cerr << "`" << varId << "' as id" ; cerr << "`" << vid << "' as id" ;
cerr << endl; cerr << endl;
exit (0); exit (0);
} }
@ -160,11 +185,11 @@ markovNetwork (int argc, const char* argv[])
cerr << endl; cerr << endl;
exit (0); exit (0);
} }
unsigned varId; Vid vid;
stringstream ss; stringstream ss;
ss << arg.substr (0, pos); ss << arg.substr (0, pos);
ss >> varId; ss >> vid;
Variable* var = fg.getVariableById (varId); Variable* var = fg.getFgVarNode (vid);
if (var) { if (var) {
if (!Util::isInteger (arg.substr (pos + 1))) { if (!Util::isInteger (arg.substr (pos + 1))) {
cerr << "error: `" << arg.substr (pos + 1) << "' " ; cerr << "error: `" << arg.substr (pos + 1) << "' " ;
@ -176,7 +201,6 @@ markovNetwork (int argc, const char* argv[])
stringstream ss; stringstream ss;
ss << arg.substr (pos + 1); ss << arg.substr (pos + 1);
ss >> stateIndex; ss >> stateIndex;
cout << "si: " << stateIndex << endl;
if (var->isValidStateIndex (stateIndex)) { if (var->isValidStateIndex (stateIndex)) {
var->setEvidence (stateIndex); var->setEvidence (stateIndex);
} else { } else {
@ -188,27 +212,35 @@ markovNetwork (int argc, const char* argv[])
} }
} else { } else {
cerr << "error: there isn't a variable with " ; cerr << "error: there isn't a variable with " ;
cerr << "`" << varId << "' as id" ; cerr << "`" << vid << "' as id" ;
cerr << endl; cerr << endl;
exit (0); exit (0);
} }
} }
} }
Solver* solver = new SPSolver (fg);
SPSolver solver (fg); runSolver (solver, queryVars);
if (queryVars.size() == 0) { fg.freeDistributions();
solver.runSolver(); }
solver.printAllPosterioris();
} else if (queryVars.size() == 1) {
solver.runSolver();
solver.printPosterioriOf (queryVars[0]); void
} else { runSolver (Solver* solver, const VarSet& queryVars)
assert (false); //FIXME {
//Domain domain = BayesNet::getInstantiations(queryVars); VidSet vids;
//ParamSet params = solver.getJointDistribution (queryVars); for (unsigned i = 0; i < queryVars.size(); i++) {
//for (unsigned i = 0; i < params.size(); i++) { vids.push_back (queryVars[i]->getVarId());
// cout << domain[i] << "\t" << params[i] << endl; }
//} if (queryVars.size() == 0) {
} solver->runSolver();
solver->printAllPosterioris();
} else if (queryVars.size() == 1) {
solver->runSolver();
solver->printPosterioriOf (vids[0]);
} else {
solver->printJointDistributionOf (vids);
}
delete solver;
} }

View File

@ -1,41 +1,39 @@
#include <cstdlib> #include <cstdlib>
#include <vector>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <vector>
#include <string>
#include <YapInterface.h> #include <YapInterface.h>
#include "callgrind.h"
#include "BayesNet.h" #include "BayesNet.h"
#include "BayesNode.h" #include "FactorGraph.h"
#include "BPSolver.h" #include "BPSolver.h"
#include "SPSolver.h"
#include "CountingBP.h"
using namespace std; using namespace std;
int int
createNetwork (void) createNetwork (void)
{ {
Statistics::numCreatedNets ++; //Statistics::numCreatedNets ++;
cout << "creating network number " << Statistics::numCreatedNets << endl; //cout << "creating network number " << Statistics::numCreatedNets << endl;
if (Statistics::numCreatedNets == 1) {
//CALLGRIND_START_INSTRUMENTATION;
}
BayesNet* bn = new BayesNet();
BayesNet* bn = new BayesNet();
YAP_Term varList = YAP_ARG1; YAP_Term varList = YAP_ARG1;
while (varList != YAP_TermNil()) { while (varList != YAP_TermNil()) {
YAP_Term var = YAP_HeadOfTerm (varList); YAP_Term var = YAP_HeadOfTerm (varList);
unsigned varId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, var)); Vid vid = (Vid) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var)); unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var));
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var)); int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
YAP_Term parentL = YAP_ArgOfTerm (4, var); YAP_Term parentL = YAP_ArgOfTerm (4, var);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var)); unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
NodeSet parents; BnNodeSet parents;
while (parentL != YAP_TermNil()) { while (parentL != YAP_TermNil()) {
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL)); unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
BayesNode* parent = bn->getNode (parentId); BayesNode* parent = bn->getBayesNode (parentId);
if (!parent) { if (!parent) {
parent = bn->addNode (parentId); parent = bn->addNode (parentId);
} }
@ -47,23 +45,20 @@ createNetwork (void)
dist = new Distribution (distId); dist = new Distribution (distId);
bn->addDistribution (dist); bn->addDistribution (dist);
} }
BayesNode* node = bn->getNode (varId); BayesNode* node = bn->getBayesNode (vid);
if (node) { if (node) {
node->setData (dsize, evidence, parents, dist); node->setData (dsize, evidence, parents, dist);
} else { } else {
bn->addNode (varId, dsize, evidence, parents, dist); bn->addNode (vid, dsize, evidence, parents, dist);
} }
varList = YAP_TailOfTerm (varList); varList = YAP_TailOfTerm (varList);
} }
bn->setIndexes(); bn->setIndexes();
if (Statistics::numCreatedNets == 1688) { // if (Statistics::numCreatedNets == 1688) {
Statistics::writeStats(); // Statistics::writeStats();
//Statistics::writeStats(); // exit (0);
//CALLGRIND_STOP_INSTRUMENTATION; // }
//CALLGRIND_DUMP_STATS;
//exit (0);
}
YAP_Int p = (YAP_Int) (bn); YAP_Int p = (YAP_Int) (bn);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2); return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
} }
@ -73,20 +68,20 @@ createNetwork (void)
int int
setExtraVarsInfo (void) setExtraVarsInfo (void)
{ {
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1); BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term varsInfoL = YAP_ARG2; YAP_Term varsInfoL = YAP_ARG2;
while (varsInfoL != YAP_TermNil()) { while (varsInfoL != YAP_TermNil()) {
YAP_Term head = YAP_HeadOfTerm (varsInfoL); YAP_Term head = YAP_HeadOfTerm (varsInfoL);
unsigned varId = YAP_IntOfTerm (YAP_ArgOfTerm (1, head)); Vid vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head)); YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
YAP_Term domainL = YAP_ArgOfTerm (3, head); YAP_Term domainL = YAP_ArgOfTerm (3, head);
Domain domain; Domain domain;
while (domainL != YAP_TermNil()) { while (domainL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (domainL)); YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (domainL));
domain.push_back ((char*) YAP_AtomName (atom)); domain.push_back ((char*) YAP_AtomName (atom));
domainL = YAP_TailOfTerm (domainL); domainL = YAP_TailOfTerm (domainL);
} }
BayesNode* node = bn->getNode (varId); BayesNode* node = bn->getBayesNode (vid);
assert (node); assert (node);
node->setLabel ((char*) YAP_AtomName (label)); node->setLabel ((char*) YAP_AtomName (label));
node->setDomain (domain); node->setDomain (domain);
@ -100,8 +95,8 @@ setExtraVarsInfo (void)
int int
setParameters (void) setParameters (void)
{ {
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1); BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distList = YAP_ARG2; YAP_Term distList = YAP_ARG2;
while (distList != YAP_TermNil()) { while (distList != YAP_TermNil()) {
YAP_Term dist = YAP_HeadOfTerm (distList); YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist)); unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
@ -112,6 +107,11 @@ setParameters (void)
paramL = YAP_TailOfTerm (paramL); paramL = YAP_TailOfTerm (paramL);
} }
bn->getDistribution(distId)->updateParameters(params); bn->getDistribution(distId)->updateParameters(params);
if (Statistics::numCreatedNets == 4) {
cout << "dist " << distId << " parameters:" ;
cout << Util::parametersToString (params);
cout << endl;
}
distList = YAP_TailOfTerm (distList); distList = YAP_TailOfTerm (distList);
} }
return TRUE; return TRUE;
@ -122,84 +122,126 @@ setParameters (void)
int int
runSolver (void) runSolver (void)
{ {
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1); BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term taskList = YAP_ARG2; YAP_Term taskList = YAP_ARG2;
vector<VidSet> tasks;
vector<NodeSet> tasks; VidSet marginalVids;
NodeSet marginalVars;
while (taskList != YAP_TermNil()) { while (taskList != YAP_TermNil()) {
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) { if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) {
NodeSet jointVars; VidSet jointVids;
YAP_Term jointList = YAP_HeadOfTerm (taskList); YAP_Term jointList = YAP_HeadOfTerm (taskList);
while (jointList != YAP_TermNil()) { while (jointList != YAP_TermNil()) {
unsigned varId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList)); Vid vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
assert (bn->getNode (varId)); assert (bn->getBayesNode (vid));
jointVars.push_back (bn->getNode (varId)); jointVids.push_back (vid);
jointList = YAP_TailOfTerm (jointList); jointList = YAP_TailOfTerm (jointList);
} }
tasks.push_back (jointVars); tasks.push_back (jointVids);
} else { } else {
unsigned varId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList)); Vid vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
BayesNode* node = bn->getNode (varId); assert (bn->getBayesNode (vid));
assert (node); tasks.push_back (VidSet() = {vid});
tasks.push_back (NodeSet() = {node}); marginalVids.push_back (vid);
marginalVars.push_back (node);
} }
taskList = YAP_TailOfTerm (taskList); taskList = YAP_TailOfTerm (taskList);
} }
/*
cout << "tasks to resolve:" << endl;
for (unsigned i = 0; i < tasks.size(); i++) {
cout << "i" << ": " ;
if (tasks[i].size() == 1) {
cout << tasks[i][0]->getVarId() << endl;
} else {
for (unsigned j = 0; j < tasks[i].size(); j++) {
cout << tasks[i][j]->getVarId() << " " ;
}
cout << endl;
}
}
*/
cerr << "prunning now..." << endl; // cout << "inference tasks:" << endl;
BayesNet* prunedNet = bn->pruneNetwork (marginalVars); // for (unsigned i = 0; i < tasks.size(); i++) {
bn->printNetworkToFile ("net.txt"); // cout << "i" << ": " ;
BPSolver solver (*prunedNet); // if (tasks[i].size() == 1) {
cerr << "solving marginals now..." << endl; // cout << tasks[i][0] << endl;
solver.runSolver(); // } else {
cerr << "calculating joints now ..." << endl; // for (unsigned j = 0; j < tasks[i].size(); j++) {
// cout << tasks[i][j] << " " ;
// }
// cout << endl;
// }
// }
Solver* solver = 0;
GraphicalModel* gm = 0;
VidSet vids;
const BnNodeSet& nodes = bn->getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
vids.push_back (nodes[i]->getVarId());
}
if (marginalVids.size() != 0) {
bn->exportToDotFormat ("bn unbayes.dot");
BayesNet* mrn = bn->getMinimalRequesiteNetwork (marginalVids);
mrn->exportToDotFormat ("bn bayes.dot");
//BayesNet* mrn = bn->getMinimalRequesiteNetwork (vids);
if (SolverOptions::convertBn2Fg) {
gm = new FactorGraph (*mrn);
if (SolverOptions::compressFactorGraph) {
solver = new CountingBP (*static_cast<FactorGraph*> (gm));
} else {
solver = new SPSolver (*static_cast<FactorGraph*> (gm));
}
if (SolverOptions::runBayesBall) {
delete mrn;
}
} else {
gm = mrn;
solver = new BPSolver (*static_cast<BayesNet*> (gm));
}
solver->runSolver();
}
vector<ParamSet> results; vector<ParamSet> results;
results.reserve (tasks.size()); results.reserve (tasks.size());
for (unsigned i = 0; i < tasks.size(); i++) { for (unsigned i = 0; i < tasks.size(); i++) {
if (tasks[i].size() == 1) { if (tasks[i].size() == 1) {
BayesNode* node = prunedNet->getNode (tasks[i][0]->getVarId()); results.push_back (solver->getPosterioriOf (tasks[i][0]));
results.push_back (solver.getPosterioriOf (node));
} else { } else {
BPSolver solver2 (*bn); static int count = 0;
cout << "calculating an join dist on: " ; cout << "calculating joint... " << count ++ << endl;
for (unsigned j = 0; j < tasks[i].size(); j++) { //if (count == 5225) {
cout << tasks[i][j]->getVarId() << " " ; // Statistics::printCompressingStats ("compressing.stats");
//}
Solver* solver2 = 0;
GraphicalModel* gm2 = 0;
bn->exportToDotFormat ("joint.dot");
BayesNet* mrn2;
if (SolverOptions::runBayesBall) {
mrn2 = bn->getMinimalRequesiteNetwork (tasks[i]);
} else {
mrn2 = bn;
} }
cout << "..." << endl; if (SolverOptions::convertBn2Fg) {
results.push_back (solver2.getJointDistribution (tasks[i])); gm2 = new FactorGraph (*mrn2);
if (SolverOptions::compressFactorGraph) {
solver2 = new CountingBP (*static_cast<FactorGraph*> (gm2));
} else {
solver2 = new SPSolver (*static_cast<FactorGraph*> (gm2));
}
if (SolverOptions::runBayesBall) {
delete mrn2;
}
} else {
gm2 = mrn2;
solver2 = new BPSolver (*static_cast<BayesNet*> (gm2));
}
results.push_back (solver2->getJointDistributionOf (tasks[i]));
delete solver2;
delete gm2;
} }
} }
delete prunedNet; delete solver;
delete gm;
YAP_Term list = YAP_TermNil(); YAP_Term list = YAP_TermNil();
for (int i = results.size() - 1; i >= 0; i--) { for (int i = results.size() - 1; i >= 0; i--) {
const ParamSet& beliefs = results[i]; const ParamSet& beliefs = results[i];
YAP_Term queryBeliefsL = YAP_TermNil(); YAP_Term queryBeliefsL = YAP_TermNil();
for (int j = beliefs.size() - 1; j >= 0; j--) { for (int j = beliefs.size() - 1; j >= 0; j--) {
YAP_Int sl1 = YAP_InitSlot(list); YAP_Int sl1 = YAP_InitSlot (list);
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]); YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL); queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
list = YAP_GetFromSlot(sl1); list = YAP_GetFromSlot (sl1);
YAP_RecoverSlots(1); YAP_RecoverSlots (1);
} }
list = YAP_MkPairTerm (queryBeliefsL, list); list = YAP_MkPairTerm (queryBeliefsL, list);
} }
@ -210,8 +252,9 @@ runSolver (void)
int int
deleteBayesNet (void) freeBayesNetwork (void)
{ {
//Statistics::printCompressingStats ("../../compressing.stats");
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1); BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
bn->freeDistributions(); bn->freeDistributions();
delete bn; delete bn;
@ -223,10 +266,10 @@ deleteBayesNet (void)
extern "C" void extern "C" void
init_predicates (void) init_predicates (void)
{ {
YAP_UserCPredicate ("create_network", createNetwork, 2); YAP_UserCPredicate ("create_network", createNetwork, 2);
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2); YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
YAP_UserCPredicate ("set_parameters", setParameters, 2); YAP_UserCPredicate ("set_parameters", setParameters, 2);
YAP_UserCPredicate ("run_solver", runSolver, 3); YAP_UserCPredicate ("run_solver", runSolver, 3);
YAP_UserCPredicate ("delete_bayes_net", deleteBayesNet, 1); YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
} }

View File

@ -0,0 +1,278 @@
#include "LiftedFG.h"
#include "FgVarNode.h"
#include "Factor.h"
#include "Distribution.h"
LiftedFG::LiftedFG (const FactorGraph& fg)
{
groundFg_ = &fg;
freeColor_ = 0;
const FgVarSet& varNodes = fg.getFgVarNodes();
const FactorSet& factors = fg.getFactors();
varColors_.resize (varNodes.size());
factorColors_.resize (factors.size());
for (unsigned i = 0; i < factors.size(); i++) {
factors[i]->setIndex (i);
}
// create the initial variable colors
VarColorMap colorMap;
for (unsigned i = 0; i < varNodes.size(); i++) {
unsigned dsize = varNodes[i]->getDomainSize();
VarColorMap::iterator it = colorMap.find (dsize);
if (it == colorMap.end()) {
it = colorMap.insert (make_pair (
dsize, vector<Color> (dsize + 1,-1))).first;
}
unsigned idx;
if (varNodes[i]->hasEvidence()) {
idx = varNodes[i]->getEvidence();
} else {
idx = dsize;
}
vector<Color>& stateColors = it->second;
if (stateColors[idx] == -1) {
stateColors[idx] = getFreeColor();
}
setColor (varNodes[i], stateColors[idx]);
}
// create the initial factor colors
DistColorMap distColors;
for (unsigned i = 0; i < factors.size(); i++) {
Distribution* dist = factors[i]->getDistribution();
DistColorMap::iterator it = distColors.find (dist);
if (it == distColors.end()) {
it = distColors.insert (make_pair (dist, getFreeColor())).first;
}
setColor (factors[i], it->second);
}
VarSignMap varGroups;
FactorSignMap factorGroups;
bool groupsHaveChanged = true;
unsigned nIter = 0;
while (groupsHaveChanged || nIter == 1) {
nIter ++;
if (Statistics::numCreatedNets == 4) {
cout << "--------------------------------------------" << endl;
cout << "Iteration " << nIter << endl;
cout << "--------------------------------------------" << endl;
}
unsigned prevFactorGroupsSize = factorGroups.size();
factorGroups.clear();
// set a new color to the factors with the same signature
for (unsigned i = 0; i < factors.size(); i++) {
const string& signatureId = getSignatureId (factors[i]);
// cout << factors[i]->getLabel() << " signature: " ;
// cout<< signatureId << endl;
FactorSignMap::iterator it = factorGroups.find (signatureId);
if (it == factorGroups.end()) {
it = factorGroups.insert (make_pair (signatureId, FactorSet())).first;
}
it->second.push_back (factors[i]);
}
if (nIter > 0)
for (FactorSignMap::iterator it = factorGroups.begin();
it != factorGroups.end(); it++) {
Color newColor = getFreeColor();
FactorSet& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
// set a new color to the variables with the same signature
unsigned prevVarGroupsSize = varGroups.size();
varGroups.clear();
for (unsigned i = 0; i < varNodes.size(); i++) {
const string& signatureId = getSignatureId (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signatureId);
// cout << varNodes[i]->getLabel() << " signature: " ;
// cout << signatureId << endl;
if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signatureId, FgVarSet())).first;
}
it->second.push_back (varNodes[i]);
}
if (nIter > 0)
for (VarSignMap::iterator it = varGroups.begin();
it != varGroups.end(); it++) {
Color newColor = getFreeColor();
FgVarSet& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
//if (nIter >= 3) cout << "bigger than three: " << nIter << endl;
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != factorGroups.size();
}
printGroups (varGroups, factorGroups);
for (VarSignMap::iterator it = varGroups.begin();
it != varGroups.end(); it++) {
CFgVarSet vars = it->second;
VarCluster* vc = new VarCluster (vars);
for (unsigned i = 0; i < vars.size(); i++) {
vid2VarCluster_.insert (make_pair (vars[i]->getVarId(), vc));
}
varClusters_.push_back (vc);
}
for (FactorSignMap::iterator it = factorGroups.begin();
it != factorGroups.end(); it++) {
VarClusterSet varClusters;
Factor* groundFactor = it->second[0];
FgVarSet groundVars = groundFactor->getFgVarNodes();
for (unsigned i = 0; i < groundVars.size(); i++) {
Vid vid = groundVars[i]->getVarId();
varClusters.push_back (vid2VarCluster_.find (vid)->second);
}
factorClusters_.push_back (new FactorCluster (it->second, varClusters));
}
}
LiftedFG::~LiftedFG (void)
{
for (unsigned i = 0; i < varClusters_.size(); i++) {
delete varClusters_[i];
}
for (unsigned i = 0; i < factorClusters_.size(); i++) {
delete factorClusters_[i];
}
}
string
LiftedFG::getSignatureId (FgVarNode* var) const
{
stringstream ss;
CFactorSet myFactors = var->getFactors();
ss << myFactors.size();
for (unsigned i = 0; i < myFactors.size(); i++) {
ss << "." << getColor (myFactors[i]);
ss << "." << myFactors[i]->getIndexOf(var);
}
ss << "." << getColor (var);
return ss.str();
}
string
LiftedFG::getSignatureId (Factor* factor) const
{
stringstream ss;
CFgVarSet myVars = factor->getFgVarNodes();
ss << myVars.size();
for (unsigned i = 0; i < myVars.size(); i++) {
ss << "." << getColor (myVars[i]);
}
ss << "." << getColor (factor);
return ss.str();
}
FactorGraph*
LiftedFG::getCompressedFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
for (unsigned i = 0; i < varClusters_.size(); i++) {
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0];
FgVarNode* newVar = new FgVarNode (var);
newVar->setIndex (i);
varClusters_[i]->setRepresentativeVariable (newVar);
fg->addVariable (newVar);
}
for (unsigned i = 0; i < factorClusters_.size(); i++) {
FgVarSet myGroundVars;
const VarClusterSet& myVarClusters = factorClusters_[i]->getVarClusters();
for (unsigned j = 0; j < myVarClusters.size(); j++) {
myGroundVars.push_back (myVarClusters[j]->getRepresentativeVariable());
}
Factor* newFactor = new Factor (myGroundVars,
factorClusters_[i]->getGroundFactors()[0]->getDistribution());
factorClusters_[i]->setRepresentativeFactor (newFactor);
fg->addFactor (newFactor);
}
return fg;
}
unsigned
LiftedFG::getGroundEdgeCount (FactorCluster* fc, VarCluster* vc) const
{
CFactorSet clusterGroundFactors = fc->getGroundFactors();
FgVarNode* var = vc->getGroundFgVarNodes()[0];
unsigned count = 0;
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
if (clusterGroundFactors[i]->getIndexOf (var) != -1) {
count ++;
}
}
/*
CFgVarSet vars = vc->getGroundFgVarNodes();
for (unsigned i = 1; i < vars.size(); i++) {
FgVarNode* var = vc->getGroundFgVarNodes()[i];
unsigned count2 = 0;
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
if (clusterGroundFactors[i]->getIndexOf (var) != -1) {
count2 ++;
}
}
if (count != count2) { cout << "oops!" << endl; abort(); }
}
*/
return count;
}
void
LiftedFG::printGroups (const VarSignMap& varGroups,
const FactorSignMap& factorGroups) const
{
cout << "variable groups:" << endl;
unsigned count = 0;
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); it++) {
const FgVarSet& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << ++count << ": " ;
//if (groupMembers.size() > 1) {
for (unsigned i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->getLabel() << " " ;
}
//}
cout << endl;
}
}
cout << endl;
cout << "factor groups:" << endl;
count = 0;
for (FactorSignMap::const_iterator it = factorGroups.begin();
it != factorGroups.end(); it++) {
const FactorSet& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << ++count << ": " ;
//if (groupMembers.size() > 1) {
for (unsigned i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->getLabel() << " " ;
}
//}
cout << endl;
}
}
}

View File

@ -0,0 +1,152 @@
#ifndef BP_LIFTED_FG_H
#define BP_LIFTED_FG_H
#include <unordered_map>
#include "FactorGraph.h"
#include "FgVarNode.h"
#include "Factor.h"
#include "Shared.h"
class VarCluster;
class FactorCluster;
class Distribution;
typedef long Color;
typedef vector<Color> Signature;
typedef vector<VarCluster*> VarClusterSet;
typedef vector<FactorCluster*> FactorClusterSet;
typedef map<string, FgVarSet> VarSignMap;
typedef map<string, FactorSet> FactorSignMap;
typedef map<unsigned, vector<Color> > VarColorMap;
typedef map<Distribution*, Color> DistColorMap;
typedef map<Vid, VarCluster*> Vid2VarCluster;
class VarCluster
{
public:
VarCluster (CFgVarSet vs)
{
for (unsigned i = 0; i < vs.size(); i++) {
groundVars_.push_back (vs[i]);
}
}
void addFactorCluster (FactorCluster* fc)
{
factorClusters_.push_back (fc);
}
const FactorClusterSet& getFactorClusters (void) const
{
return factorClusters_;
}
FgVarNode* getRepresentativeVariable (void) const { return representVar_; }
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
CFgVarSet getGroundFgVarNodes (void) const { return groundVars_; }
private:
FgVarSet groundVars_;
FactorClusterSet factorClusters_;
FgVarNode* representVar_;
};
class FactorCluster
{
public:
FactorCluster (CFactorSet groundFactors, const VarClusterSet& vcs)
{
groundFactors_ = groundFactors;
varClusters_ = vcs;
for (unsigned i = 0; i < varClusters_.size(); i++) {
varClusters_[i]->addFactorCluster (this);
}
}
const VarClusterSet& getVarClusters (void) const
{
return varClusters_;
}
bool containsGround (const Factor* f)
{
for (unsigned i = 0; i < groundFactors_.size(); i++) {
if (groundFactors_[i] == f) {
return true;
}
}
return false;
}
Factor* getRepresentativeFactor (void) const { return representFactor_; }
void setRepresentativeFactor (Factor* f) { representFactor_ = f; }
CFactorSet getGroundFactors (void) const { return groundFactors_; }
private:
FactorSet groundFactors_;
VarClusterSet varClusters_;
Factor* representFactor_;
};
class LiftedFG
{
public:
LiftedFG (const FactorGraph&);
~LiftedFG (void);
FactorGraph* getCompressedFactorGraph (void);
unsigned getGroundEdgeCount (FactorCluster*, VarCluster*) const;
void printGroups (const VarSignMap& varGroups,
const FactorSignMap& factorGroups) const;
FgVarNode* getEquivalentVariable (Vid vid)
{
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->getRepresentativeVariable();
}
const VarClusterSet& getVariableClusters (void) { return varClusters_; }
const FactorClusterSet& getFactorClusters (void) { return factorClusters_; }
private:
string getSignatureId (FgVarNode*) const;
string getSignatureId (Factor*) const;
Color getFreeColor (void) { return ++freeColor_ -1; }
Color getColor (FgVarNode* v) const { return varColors_[v->getIndex()]; }
Color getColor (Factor* f) const { return factorColors_[f->getIndex()]; }
void setColor (FgVarNode* v, Color c)
{
varColors_[v->getIndex()] = c;
}
void setColor (Factor* f, Color c)
{
factorColors_[f->getIndex()] = c;
}
VarCluster* getVariableCluster (Vid vid) const
{
return vid2VarCluster_.find (vid)->second;
}
Color freeColor_;
vector<Color> varColors_;
vector<Color> factorColors_;
VarClusterSet varClusters_;
FactorClusterSet factorClusters_;
Vid2VarCluster vid2VarCluster_;
const FactorGraph* groundFg_;
};
#endif // BP_LIFTED_FG_H

View File

@ -50,28 +50,33 @@ CWD=$(PWD)
HEADERS = \ HEADERS = \
$(srcdir)/GraphicalModel.h \ $(srcdir)/GraphicalModel.h \
$(srcdir)/Variable.h \ $(srcdir)/Variable.h \
$(srcdir)/Distribution.h \
$(srcdir)/BayesNet.h \ $(srcdir)/BayesNet.h \
$(srcdir)/BayesNode.h \ $(srcdir)/BayesNode.h \
$(srcdir)/Distribution.h \ $(srcdir)/LiftedFG.h \
$(srcdir)/CptEntry.h \ $(srcdir)/CptEntry.h \
$(srcdir)/FactorGraph.h \ $(srcdir)/FactorGraph.h \
$(srcdir)/FgVarNode.h \ $(srcdir)/FgVarNode.h \
$(srcdir)/Factor.h \ $(srcdir)/Factor.h \
$(srcdir)/Solver.h \ $(srcdir)/Solver.h \
$(srcdir)/BPSolver.h \ $(srcdir)/BPSolver.h \
$(srcdir)/BpNode.h \ $(srcdir)/BPNodeInfo.h \
$(srcdir)/SPSolver.h \ $(srcdir)/SPSolver.h \
$(srcdir)/CountingBP.h \
$(srcdir)/Shared.h \ $(srcdir)/Shared.h \
$(srcdir)/xmlParser/xmlParser.h $(srcdir)/xmlParser/xmlParser.h
CPP_SOURCES = \ CPP_SOURCES = \
$(srcdir)/BayesNet.cpp \ $(srcdir)/BayesNet.cpp \
$(srcdir)/BayesNode.cpp \ $(srcdir)/BayesNode.cpp \
$(srcdir)/FactorGraph.cpp \ $(srcdir)/FactorGraph.cpp \
$(srcdir)/Factor.cpp \ $(srcdir)/Factor.cpp \
$(srcdir)/LiftedFG.cpp \
$(srcdir)/BPSolver.cpp \ $(srcdir)/BPSolver.cpp \
$(srcdir)/BpNode.cpp \ $(srcdir)/BPNodeInfo.cpp \
$(srcdir)/SPSolver.cpp \ $(srcdir)/SPSolver.cpp \
$(srcdir)/CountingBP.cpp \
$(srcdir)/Util.cpp \
$(srcdir)/HorusYap.cpp \ $(srcdir)/HorusYap.cpp \
$(srcdir)/HorusCli.cpp \ $(srcdir)/HorusCli.cpp \
$(srcdir)/xmlParser/xmlParser.cpp $(srcdir)/xmlParser/xmlParser.cpp
@ -82,22 +87,38 @@ OBJS = \
FactorGraph.o \ FactorGraph.o \
Factor.o \ Factor.o \
BPSolver.o \ BPSolver.o \
BpNode.o \ BPNodeInfo.o \
SPSolver.o \ SPSolver.o \
HorusYap.o \ Util.o \
xmlParser.o LiftedFG.o \
CountingBP.o \
HorusYap.o
HCLI_OBJS = \
BayesNet.o \
BayesNode.o \
FactorGraph.o \
Factor.o \
BPSolver.o \
BPNodeInfo.o \
SPSolver.o \
Util.o \
LiftedFG.o \
CountingBP.o \
HorusCli.o \
xmlParser/xmlParser.o
SOBJS=horus.@SO@ SOBJS=horus.@SO@
all: $(SOBJS) all: $(SOBJS) hcli
# default rule # default rule
%.o : $(srcdir)/%.cpp %.o : $(srcdir)/%.cpp
$(CXX) -c $(CXXFLAGS) $< -o $@ $(CXX) -c $(CXXFLAGS) $< -o $@
xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp xmlParser/xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
$(CXX) -c $(CXXFLAGS) $< -o $@ $(CXX) -c $(CXXFLAGS) $< -o $@
@ -105,7 +126,7 @@ xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@ @DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
hcli: $(OBJS) hcli: $(HCLI_OBJS)
$(CXX) -o hcli $(HCLI_OBJS) $(CXX) -o hcli $(HCLI_OBJS)

View File

@ -1,38 +1,77 @@
#include <cassert> #include <cassert>
#include <algorithm> #include <limits>
#include <iostream> #include <iostream>
#include "SPSolver.h" #include "SPSolver.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "FgVarNode.h" #include "FgVarNode.h"
#include "Factor.h" #include "Factor.h"
#include "Shared.h"
SPSolver* Link::klass = 0;
SPSolver::SPSolver (const FactorGraph& fg) : Solver (&fg) SPSolver::SPSolver (FactorGraph& fg) : Solver (&fg)
{ {
fg_ = &fg; fg_ = &fg;
accuracy_ = 0.0001;
maxIter_ = 10000;
//schedule_ = S_SEQ_FIXED;
//schedule_ = S_SEQ_RANDOM;
//schedule_ = S_SEQ_PARALLEL;
schedule_ = S_MAX_RESIDUAL;
Link::klass = this;
FgVarSet vars = fg_->getFgVarNodes();
for (unsigned i = 0; i < vars.size(); i++) {
msgs_.push_back (new MessageBanket (vars[i]));
}
} }
SPSolver::~SPSolver (void) SPSolver::~SPSolver (void)
{ {
for (unsigned i = 0; i < msgs_.size(); i++) { for (unsigned i = 0; i < varsI_.size(); i++) {
delete msgs_[i]; delete varsI_[i];
} }
for (unsigned i = 0; i < factorsI_.size(); i++) {
delete factorsI_[i];
}
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
}
void
SPSolver::runTreeSolver (void)
{
CFactorSet factors = fg_->getFactors();
bool finish = false;
while (!finish) {
finish = true;
for (unsigned i = 0; i < factors.size(); i++) {
CLinkSet links = factorsI_[factors[i]->getIndex()]->getLinks();
for (unsigned j = 0; j < links.size(); j++) {
if (!links[j]->messageWasSended()) {
if (readyToSendMessage(links[j])) {
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
links[j]->updateMessage();
}
finish = false;
}
}
}
}
}
bool
SPSolver::readyToSendMessage (const Link* link) const
{
CFgVarSet factorVars = link->getFactor()->getFgVarNodes();
for (unsigned i = 0; i < factorVars.size(); i++) {
if (factorVars[i] != link->getVariable()) {
CLinkSet links = varsI_[factorVars[i]->getIndex()]->getLinks();
for (unsigned j = 0; j < links.size(); j++) {
if (links[j]->getFactor() != link->getFactor() &&
!links[j]->messageWasSended()) {
return false;
}
}
}
}
return true;
} }
@ -40,62 +79,54 @@ SPSolver::~SPSolver (void)
void void
SPSolver::runSolver (void) SPSolver::runSolver (void)
{ {
initializeSolver();
runTreeSolver();
return;
nIter_ = 0; nIter_ = 0;
vector<Factor*> factors = fg_->getFactors(); while (!converged() && nIter_ < SolverOptions::maxIter) {
for (unsigned i = 0; i < factors.size(); i++) {
FgVarSet neighbors = factors[i]->getFgVarNodes(); nIter_ ++;
for (unsigned j = 0; j < neighbors.size(); j++) { if (DL >= 2) {
updateOrder_.push_back (Link (factors[i], neighbors[j]));
}
}
while (!converged() && nIter_ < maxIter_) {
if (DL >= 1) {
cout << endl; cout << endl;
cout << "****************************************" ; cout << "****************************************" ;
cout << "****************************************" ; cout << "****************************************" ;
cout << endl; cout << endl;
cout << " Iteration " << nIter_ + 1 << endl; cout << " Iteration " << nIter_ << endl;
cout << "****************************************" ; cout << "****************************************" ;
cout << "****************************************" ; cout << "****************************************" ;
cout << endl; cout << endl;
} }
switch (schedule_) { switch (SolverOptions::schedule) {
case SolverOptions::S_SEQ_RANDOM:
case S_SEQ_RANDOM: random_shuffle (links_.begin(), links_.end());
random_shuffle (updateOrder_.begin(), updateOrder_.end());
// no break // no break
case S_SEQ_FIXED: case SolverOptions::S_SEQ_FIXED:
for (unsigned c = 0; c < updateOrder_.size(); c++) { for (unsigned i = 0; i < links_.size(); i++) {
Link& link = updateOrder_[c]; links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
calculateNextMessage (link.source, link.destination); links_[i]->updateMessage();
updateMessage (updateOrder_[c]);
} }
break; break;
case S_PARALLEL: case SolverOptions::S_PARALLEL:
for (unsigned c = 0; c < updateOrder_.size(); c++) { for (unsigned i = 0; i < links_.size(); i++) {
Link link = updateOrder_[c]; links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
calculateNextMessage (link.source, link.destination);
} }
for (unsigned c = 0; c < updateOrder_.size(); c++) { for (unsigned i = 0; i < links_.size(); i++) {
Link link = updateOrder_[c]; links_[i]->updateMessage();
updateMessage (updateOrder_[c]);
} }
break; break;
case S_MAX_RESIDUAL: case SolverOptions::S_MAX_RESIDUAL:
maxResidualSchedule(); maxResidualSchedule();
break; break;
} }
nIter_++;
} }
cout << endl;
if (DL >= 1) { if (DL >= 2) {
if (nIter_ < maxIter_) { cout << endl;
if (nIter_ < SolverOptions::maxIter) {
cout << "Loopy Sum-Product converged in " ; cout << "Loopy Sum-Product converged in " ;
cout << nIter_ << " iterations" << endl; cout << nIter_ << " iterations" << endl;
} else { } else {
@ -108,58 +139,168 @@ SPSolver::runSolver (void)
ParamSet ParamSet
SPSolver::getPosterioriOf (const Variable* var) const SPSolver::getPosterioriOf (Vid vid) const
{ {
assert (var); assert (fg_->getFgVarNode (vid));
assert (var == fg_->getVariableById (var->getVarId())); FgVarNode* var = fg_->getFgVarNode (vid);
assert (var->getIndex() < msgs_.size()); ParamSet probs;
ParamSet probs (var->getDomainSize(), 1);
if (var->hasEvidence()) { if (var->hasEvidence()) {
for (unsigned i = 0; i < probs.size(); i++) { probs.resize (var->getDomainSize(), 0.0);
if ((int)i != var->getEvidence()) { probs[var->getEvidence()] = 1.0;
probs[i] = 0;
}
}
} else { } else {
probs.resize (var->getDomainSize(), 1.0);
MessageBanket* mb = msgs_[var->getIndex()]; CLinkSet links = varsI_[var->getIndex()]->getLinks();
const FgVarNode* varNode = fg_->getFgVarNodes()[var->getIndex()]; for (unsigned i = 0; i < links.size(); i++) {
vector<Factor*> neighbors = varNode->getFactors(); CParamSet msg = links[i]->getMessage();
for (unsigned i = 0; i < neighbors.size(); i++) {
const Message& msg = mb->getMessage (neighbors[i]);
for (unsigned j = 0; j < msg.size(); j++) { for (unsigned j = 0; j < msg.size(); j++) {
probs[j] *= msg[j]; probs[j] *= msg[j];
} }
} }
Util::normalize (probs); Util::normalize (probs);
} }
return probs; return probs;
} }
ParamSet
SPSolver::getJointDistributionOf (const VidSet& jointVids)
{
FgVarSet jointVars;
unsigned dsize = 1;
for (unsigned i = 0; i < jointVids.size(); i++) {
FgVarNode* varNode = fg_->getFgVarNode (jointVids[i]);
dsize *= varNode->getDomainSize();
jointVars.push_back (varNode);
}
unsigned maxVid = std::numeric_limits<unsigned>::max();
FgVarNode* junctionVar = new FgVarNode (maxVid, dsize);
FgVarSet factorVars = { junctionVar };
for (unsigned i = 0; i < jointVars.size(); i++) {
factorVars.push_back (jointVars[i]);
}
unsigned nParams = dsize * dsize;
ParamSet params (nParams);
for (unsigned i = 0; i < nParams; i++) {
unsigned row = i / dsize;
unsigned col = i % dsize;
if (row == col) {
params[i] = 1;
} else {
params[i] = 0;
}
}
Distribution* dist = new Distribution (params, maxVid);
Factor* newFactor = new Factor (factorVars, dist);
fg_->addVariable (junctionVar);
fg_->addFactor (newFactor);
runSolver();
ParamSet results = getPosterioriOf (maxVid);
deleteJunction (newFactor, junctionVar);
return results;
}
void
SPSolver::initializeSolver (void)
{
fg_->setIndexes();
CFgVarSet vars = fg_->getFgVarNodes();
for (unsigned i = 0; i < varsI_.size(); i++) {
delete varsI_[i];
}
varsI_.reserve (vars.size());
for (unsigned i = 0; i < vars.size(); i++) {
varsI_.push_back (new SPNodeInfo());
}
CFactorSet factors = fg_->getFactors();
for (unsigned i = 0; i < factorsI_.size(); i++) {
delete factorsI_[i];
}
factorsI_.reserve (factors.size());
for (unsigned i = 0; i < factors.size(); i++) {
factorsI_.push_back (new SPNodeInfo());
}
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
createLinks();
for (unsigned i = 0; i < links_.size(); i++) {
Factor* source = links_[i]->getFactor();
FgVarNode* dest = links_[i]->getVariable();
varsI_[dest->getIndex()]->addLink (links_[i]);
factorsI_[source->getIndex()]->addLink (links_[i]);
}
}
void
SPSolver::createLinks (void)
{
CFactorSet factors = fg_->getFactors();
for (unsigned i = 0; i < factors.size(); i++) {
CFgVarSet neighbors = factors[i]->getFgVarNodes();
for (unsigned j = 0; j < neighbors.size(); j++) {
links_.push_back (new Link (factors[i], neighbors[j]));
}
}
}
void
SPSolver::deleteJunction (Factor* f, FgVarNode* v)
{
fg_->removeFactor (f);
f->freeDistribution();
delete f;
fg_->removeVariable (v);
delete v;
}
bool bool
SPSolver::converged (void) SPSolver::converged (void)
{ {
// this can happen if the graph is fully disconnected
if (links_.size() == 0) {
return true;
}
if (nIter_ == 0 || nIter_ == 1) { if (nIter_ == 0 || nIter_ == 1) {
return false; return false;
} }
bool converged = true; bool converged = true;
for (unsigned i = 0; i < updateOrder_.size(); i++) { if (SolverOptions::schedule == SolverOptions::S_MAX_RESIDUAL) {
double residual = getResidual (updateOrder_[i]); Param maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (DL >= 1) { if (maxResidual < SolverOptions::accuracy) {
cout << updateOrder_[i].toString(); converged = true;
cout << " residual = " << residual << endl; } else {
}
if (residual > accuracy_) {
converged = false; converged = false;
if (DL == 0) { }
break; } else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (DL >= 2) {
cout << links_[i]->toString() + " residual = " << residual << endl;
} }
} if (residual > SolverOptions::accuracy) {
converged = false;
if (DL == 0) break;
}
}
} }
return converged; return converged;
} }
@ -169,127 +310,161 @@ SPSolver::converged (void)
void void
SPSolver::maxResidualSchedule (void) SPSolver::maxResidualSchedule (void)
{ {
if (nIter_ == 0) { if (nIter_ == 1) {
for (unsigned c = 0; c < updateOrder_.size(); c++) { for (unsigned i = 0; i < links_.size(); i++) {
Link& l = updateOrder_[c]; links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
calculateNextMessage (l.source, l.destination); SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
if (DL >= 1) { linkMap_.insert (make_pair (links_[i], it));
cout << updateOrder_[c].toString() << " residual = " ; if (DL >= 2 && DL < 5) {
cout << getResidual (updateOrder_[c]) << endl; cout << "calculating " << links_[i]->toString() << endl;
} }
} }
sort (updateOrder_.begin(), updateOrder_.end(), compareResidual); return;
} else { }
for (unsigned c = 0; c < updateOrder_.size(); c++) { for (unsigned c = 0; c < links_.size(); c++) {
Link& link = updateOrder_.front(); if (DL >= 2) {
updateMessage (link); cout << endl << "current residuals:" << endl;
resetResidual (link); for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
cout << " " << setw (30) << left << (*it)->toString();
cout << "residual = " << (*it)->getResidual() << endl;
}
}
// update the messages that depend on message source --> destination SortedOrder::iterator it = sortedOrder_.begin();
vector<Factor*> fstLevelNeighbors = link.destination->getFactors(); Link* link = *it;
for (unsigned i = 0; i < fstLevelNeighbors.size(); i++) { if (DL >= 2) {
if (fstLevelNeighbors[i] != link.source) { cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
FgVarSet sndLevelNeighbors; }
sndLevelNeighbors = fstLevelNeighbors[i]->getFgVarNodes(); if (link->getResidual() < SolverOptions::accuracy) {
for (unsigned j = 0; j < sndLevelNeighbors.size(); j++) { return;
if (sndLevelNeighbors[j] != link.destination) { }
calculateNextMessage (fstLevelNeighbors[i], sndLevelNeighbors[j]); link->updateMessage();
link->clearResidual();
sortedOrder_.erase (it);
linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin
CFactorSet factorNeighbors = link->getVariable()->getFactors();
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
if (factorNeighbors[i] != link->getFactor()) {
CLinkSet links = factorsI_[factorNeighbors[i]->getIndex()]->getLinks();
for (unsigned j = 0; j < links.size(); j++) {
if (links[j]->getVariable() != link->getVariable()) {
if (DL >= 2 && DL < 5) {
cout << " calculating " << links[j]->toString() << endl;
} }
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
LinkMap::iterator iter = linkMap_.find (links[j]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[j]);
} }
} }
} }
sort (updateOrder_.begin(), updateOrder_.end(), compareResidual);
} }
} }
} }
void ParamSet
SPSolver::updateMessage (const Link& link) SPSolver::getFactor2VarMsg (const Link* link) const
{ {
updateMessage (link.source, link.destination); const Factor* src = link->getFactor();
} const FgVarNode* dest = link->getVariable();
CFgVarSet neighbors = src->getFgVarNodes();
CLinkSet links = factorsI_[src->getIndex()]->getLinks();
// calculate the product of messages that were sent
void
SPSolver::updateMessage (const Factor* src, const FgVarNode* dest)
{
msgs_[dest->getIndex()]->updateMessage (src);
/* cout << src->getLabel() << " --> " << dest->getLabel() << endl;
cout << " m: " ;
Message msg = msgs_[dest->getIndex()]->getMessage (src);
for (unsigned i = 0; i < msg.size(); i++) {
if (i != 0) cout << ", " ;
cout << msg[i];
}
cout << endl;
*/
}
void
SPSolver::calculateNextMessage (const Link& link)
{
calculateNextMessage (link.source, link.destination);
}
void
SPSolver::calculateNextMessage (const Factor* src, const FgVarNode* dest)
{
FgVarSet neighbors = src->getFgVarNodes();
// calculate the product of MessageBankets sended
// to factor `src', except from var `dest' // to factor `src', except from var `dest'
Factor result = *src; Factor result (*src);
for (unsigned i = 0; i < neighbors.size(); i++) { Factor temp;
if (neighbors[i] != dest) { if (DL >= 5) {
Message msg (neighbors[i]->getDomainSize(), 1); cout << "calculating " ;
calculateVarFactorMessage (neighbors[i], src, msg); cout << src->getLabel() << " --> " << dest->getLabel();
result *= Factor (neighbors[i], msg); cout << endl;
}
} }
// marginalize all vars except `dest'
for (unsigned i = 0; i < neighbors.size(); i++) { for (unsigned i = 0; i < neighbors.size(); i++) {
if (neighbors[i] != dest) { if (links[i]->getVariable() != dest) {
result.marginalizeVariable (neighbors[i]); if (DL >= 5) {
} cout << " message from " << links[i]->getVariable()->getLabel();
} cout << ": " ;
msgs_[dest->getIndex()]->setNextMessage (src, result.getParameters()); ParamSet p = getVar2FactorMsg (links[i]);
} cout << endl;
Factor temp2 (links[i]->getVariable(), p);
temp.multiplyByFactor (temp2);
temp2.freeDistribution();
void
SPSolver::calculateVarFactorMessage (const FgVarNode* src,
const Factor* dest,
Message& placeholder) const
{
assert (src->getDomainSize() == (int)placeholder.size());
if (src->hasEvidence()) {
for (unsigned i = 0; i < placeholder.size(); i++) {
if ((int)i != src->getEvidence()) {
placeholder[i] = 0.0;
} else { } else {
placeholder[i] = 1.0; Factor temp2 (links[i]->getVariable(), getVar2FactorMsg (links[i]));
} temp.multiplyByFactor (temp2);
} temp2.freeDistribution();
} else {
MessageBanket* mb = msgs_[src->getIndex()];
vector<Factor*> neighbors = src->getFactors();
for (unsigned i = 0; i < neighbors.size(); i++) {
if (neighbors[i] != dest) {
const Message& fromFactor = mb->getMessage (neighbors[i]);
for (unsigned j = 0; j < fromFactor.size(); j++) {
placeholder[j] *= fromFactor[j];
}
} }
} }
} }
if (links.size() >= 2) {
result.multiplyByFactor (temp, &(src->getCptEntries()));
if (DL >= 5) {
cout << " message product: " ;
cout << Util::parametersToString (temp.getParameters()) << endl;
cout << " factor product: " ;
cout << Util::parametersToString (src->getParameters());
cout << " x " ;
cout << Util::parametersToString (temp.getParameters());
cout << " = " ;
cout << Util::parametersToString (result.getParameters()) << endl;
}
temp.freeDistribution();
}
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getVariable() != dest) {
result.removeVariable (links[i]->getVariable());
}
}
if (DL >= 5) {
cout << " final message: " ;
cout << Util::parametersToString (result.getParameters()) << endl << endl;
}
ParamSet msg = result.getParameters();
result.freeDistribution();
return msg;
}
ParamSet
SPSolver::getVar2FactorMsg (const Link* link) const
{
const FgVarNode* src = link->getVariable();
const Factor* dest = link->getFactor();
ParamSet msg;
if (src->hasEvidence()) {
msg.resize (src->getDomainSize(), 0.0);
msg[src->getEvidence()] = 1.0;
if (DL >= 5) {
cout << Util::parametersToString (msg);
}
} else {
msg.resize (src->getDomainSize(), 1.0);
}
if (DL >= 5) {
cout << Util::parametersToString (msg);
}
CLinkSet links = varsI_[src->getIndex()]->getLinks();
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dest) {
CParamSet msgFromFactor = links[i]->getMessage();
for (unsigned j = 0; j < msgFromFactor.size(); j++) {
msg[j] *= msgFromFactor[j];
}
if (DL >= 5) {
cout << " x " << Util::parametersToString (msgFromFactor);
}
}
}
if (DL >= 5) {
cout << " = " << Util::parametersToString (msg);
}
return msg;
} }

View File

@ -1,10 +1,8 @@
#ifndef BP_SPSOLVER_H #ifndef BP_SP_SOLVER_H
#define BP_SPSOLVER_H #define BP_SP_SOLVER_H
#include <cmath>
#include <map>
#include <vector> #include <vector>
#include <string> #include <set>
#include "Solver.h" #include "Solver.h"
#include "FgVarNode.h" #include "FgVarNode.h"
@ -15,157 +13,118 @@ using namespace std;
class FactorGraph; class FactorGraph;
class SPSolver; class SPSolver;
struct Link
{
Link (Factor* s, FgVarNode* d)
{
source = s;
destination = d;
}
string toString (void) const
{
stringstream ss;
ss << source->getLabel() << " --> " ;
ss << destination->getLabel();
return ss.str();
}
Factor* source;
FgVarNode* destination;
static SPSolver* klass;
};
class Link
class MessageBanket
{ {
public: public:
MessageBanket (const FgVarNode* var) Link (Factor* f, FgVarNode* v)
{
factor_ = f;
var_ = v;
currMsg_.resize (v->getDomainSize(), 1);
nextMsg_.resize (v->getDomainSize(), 1);
msgSended_ = false;
residual_ = 0.0;
}
void setMessage (ParamSet msg)
{ {
vector<Factor*> sources = var->getFactors(); Util::normalize (msg);
for (unsigned i = 0; i < sources.size(); i++) { residual_ = Util::getMaxNorm (currMsg_, msg);
indexMap_.insert (make_pair (sources[i], i)); currMsg_ = msg;
currMsgs_.push_back (Message(var->getDomainSize(), 1));
nextMsgs_.push_back (Message(var->getDomainSize(), -10));
residuals_.push_back (0.0);
}
} }
void updateMessage (const Factor* source) void setNextMessage (CParamSet msg)
{ {
unsigned idx = getIndex(source); nextMsg_ = msg;
currMsgs_[idx] = nextMsgs_[idx]; Util::normalize (nextMsg_);
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
} }
void setNextMessage (const Factor* source, const Message& msg) void updateMessage (void)
{ {
unsigned idx = getIndex(source); currMsg_ = nextMsg_;
nextMsgs_[idx] = msg; msgSended_ = true;
residuals_[idx] = computeResidual (source);
} }
const Message& getMessage (const Factor* source) const string toString (void) const
{ {
return currMsgs_[getIndex(source)]; stringstream ss;
} ss << factor_->getLabel();
ss << " -- " ;
double getResidual (const Factor* source) const ss << var_->getLabel();
{ return ss.str();
return residuals_[getIndex(source)];
}
void resetResidual (const Factor* source)
{
residuals_[getIndex(source)] = 0.0;
} }
Factor* getFactor (void) const { return factor_; }
FgVarNode* getVariable (void) const { return var_; }
CParamSet getMessage (void) const { return currMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; }
private: private:
double computeResidual (const Factor* source) Factor* factor_;
{ FgVarNode* var_;
double change = 0.0; ParamSet currMsg_;
unsigned idx = getIndex (source); ParamSet nextMsg_;
const Message& currMessage = currMsgs_[idx]; bool msgSended_;
const Message& nextMessage = nextMsgs_[idx]; double residual_;
for (unsigned i = 0; i < currMessage.size(); i++) {
change += abs (currMessage[i] - nextMessage[i]);
}
return change;
}
unsigned getIndex (const Factor* factor) const
{
assert (factor);
assert (indexMap_.find(factor) != indexMap_.end());
return indexMap_.find(factor)->second;
}
typedef map<const Factor*, unsigned> IndexMap;
IndexMap indexMap_;
vector<Message> currMsgs_;
vector<Message> nextMsgs_;
vector<double> residuals_;
}; };
class SPNodeInfo
{
public:
void addLink (Link* link) { links_.push_back (link); }
CLinkSet getLinks (void) { return links_; }
private:
LinkSet links_;
};
class SPSolver : public Solver class SPSolver : public Solver
{ {
public: public:
SPSolver (const FactorGraph&); SPSolver (FactorGraph&);
~SPSolver (void); virtual ~SPSolver (void);
void runSolver (void); void runSolver (void);
ParamSet getPosterioriOf (const Variable* var) const; virtual ParamSet getPosterioriOf (Vid) const;
ParamSet getJointDistributionOf (CVidSet);
protected:
virtual void initializeSolver (void);
void runTreeSolver (void);
bool readyToSendMessage (const Link*) const;
virtual void createLinks (void);
virtual void deleteJunction (Factor*, FgVarNode*);
bool converged (void);
virtual void maxResidualSchedule (void);
virtual ParamSet getFactor2VarMsg (const Link*) const;
virtual ParamSet getVar2FactorMsg (const Link*) const;
private: struct CompareResidual {
bool converged (void); inline bool operator() (const Link* link1, const Link* link2)
void maxResidualSchedule (void); {
void updateMessage (const Link&); return link1->getResidual() > link2->getResidual();
void updateMessage (const Factor*, const FgVarNode*); }
void calculateNextMessage (const Link&); };
void calculateNextMessage (const Factor*, const FgVarNode*);
void calculateVarFactorMessage ( FactorGraph* fg_;
const FgVarNode*, const Factor*, Message&) const; LinkSet links_;
double getResidual (const Link&) const; vector<SPNodeInfo*> varsI_;
void resetResidual (const Link&) const; vector<SPNodeInfo*> factorsI_;
friend bool compareResidual (const Link&, const Link&); unsigned nIter_;
typedef multiset<Link*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_;
typedef map<Link*, SortedOrder::iterator> LinkMap;
LinkMap linkMap_;
const FactorGraph* fg_;
vector<MessageBanket*> msgs_;
Schedule schedule_;
int nIter_;
double accuracy_;
int maxIter_;
vector<Link> updateOrder_;
}; };
#endif // BP_SP_SOLVER_H
inline double
SPSolver::getResidual (const Link& link) const
{
MessageBanket* mb = Link::klass->msgs_[link.destination->getIndex()];
return mb->getResidual (link.source);
}
inline void
SPSolver::resetResidual (const Link& link) const
{
MessageBanket* mb = Link::klass->msgs_[link.destination->getIndex()];
mb->resetResidual (link.source);
}
inline bool
compareResidual (const Link& link1, const Link& link2)
{
MessageBanket* mb1 = Link::klass->msgs_[link1.destination->getIndex()];
MessageBanket* mb2 = Link::klass->msgs_[link2.destination->getIndex()];
return mb1->getResidual(link1.source) > mb2->getResidual(link2.source);
}
#endif

View File

@ -2,14 +2,15 @@
#define BP_SHARED_H #define BP_SHARED_H
#include <cmath> #include <cmath>
#include <iostream>
#include <fstream>
#include <cassert> #include <cassert>
#include <vector> #include <vector>
#include <map> #include <map>
#include <unordered_map> #include <unordered_map>
// Macro to disallow the copy constructor and operator= functions #include <iostream>
#include <fstream>
#include <iomanip>
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \ #define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&); \ TypeName(const TypeName&); \
void operator=(const TypeName&) void operator=(const TypeName&)
@ -19,61 +20,162 @@ using namespace std;
class Variable; class Variable;
class BayesNode; class BayesNode;
class FgVarNode; class FgVarNode;
class Factor;
class Link;
class Edge;
typedef double Param; typedef double Param;
typedef vector<Param> ParamSet; typedef vector<Param> ParamSet;
typedef vector<Param> Message; typedef const ParamSet& CParamSet;
typedef unsigned Vid;
typedef vector<Vid> VidSet;
typedef const VidSet& CVidSet;
typedef vector<Variable*> VarSet; typedef vector<Variable*> VarSet;
typedef vector<BayesNode*> NodeSet; typedef vector<BayesNode*> BnNodeSet;
typedef const BnNodeSet& CBnNodeSet;
typedef vector<FgVarNode*> FgVarSet; typedef vector<FgVarNode*> FgVarSet;
typedef const FgVarSet& CFgVarSet;
typedef vector<Factor*> FactorSet;
typedef const FactorSet& CFactorSet;
typedef vector<Link*> LinkSet;
typedef const LinkSet& CLinkSet;
typedef vector<Edge*> EdgeSet;
typedef const EdgeSet& CEdgeSet;
typedef vector<string> Domain; typedef vector<string> Domain;
typedef vector<unsigned> DomainConf; typedef vector<unsigned> DConf;
typedef pair<unsigned, unsigned> DomainConstr; typedef pair<unsigned, unsigned> DConstraint;
typedef unordered_map<unsigned, unsigned> IndexMap; typedef map<unsigned, unsigned> IndexMap;
// level of debug information
//extern unsigned DL;
static const unsigned DL = 0; static const unsigned DL = 0;
// number of digits to show when printing a parameter static const int NO_EVIDENCE = -1;
static const unsigned PRECISION = 10;
// shared by bp and sp solver // number of digits to show when printing a parameter
enum Schedule static const unsigned PRECISION = 5;
static const bool EXPORT_TO_DOT = false;
static const unsigned EXPORT_MIN_SIZE = 30;
namespace SolverOptions
{ {
S_SEQ_FIXED, enum Schedule
S_SEQ_RANDOM, {
S_PARALLEL, S_SEQ_FIXED,
S_MAX_RESIDUAL S_SEQ_RANDOM,
S_PARALLEL,
S_MAX_RESIDUAL
};
extern bool runBayesBall;
extern bool convertBn2Fg;
extern bool compressFactorGraph;
extern Schedule schedule;
extern double accuracy;
extern unsigned maxIter;
}
namespace Util
{
void normalize (ParamSet&);
void pow (ParamSet&, unsigned);
double getL1dist (CParamSet, CParamSet);
double getMaxNorm (CParamSet, CParamSet);
bool isInteger (const string&);
string parametersToString (CParamSet);
vector<DConf> getDomainConfigurations (const VarSet&);
vector<string> getInstantiations (const VarSet&);
}; };
struct NetInfo struct NetInfo
{ {
NetInfo (unsigned c, double t) NetInfo (void)
{ {
counting = c; counting = 0;
solvingTime = t; nIters = 0;
solvingTime = 0.0;
} }
unsigned counting; unsigned counting;
double solvingTime; double solvingTime;
unsigned nIters;
}; };
struct CompressInfo
{
CompressInfo (unsigned a, unsigned b, unsigned c,
unsigned d, unsigned e) {
nUncVars = a;
nUncFactors = b;
nCompVars = c;
nCompFactors = d;
nNeighborlessVars = e;
}
unsigned nUncVars;
unsigned nUncFactors;
unsigned nCompVars;
unsigned nCompFactors;
unsigned nNeighborlessVars;
};
typedef map<unsigned, NetInfo> StatisticMap; typedef map<unsigned, NetInfo> StatisticMap;
class Statistics class Statistics
{ {
public: public:
static void updateStats (unsigned size, double time) static void updateStats (unsigned size, unsigned nIters, double time)
{ {
StatisticMap::iterator it = stats_.find(size); StatisticMap::iterator it = stats_.find (size);
if (it == stats_.end()) { if (it == stats_.end()) {
stats_.insert (make_pair (size, NetInfo (1, 0.0))); it = (stats_.insert (make_pair (size, NetInfo()))).first;
} else { } else {
it->second.counting ++; it->second.counting ++;
it->second.nIters += nIters;
it->second.solvingTime += time; it->second.solvingTime += time;
totalOfIterations += nIters;
if (nIters > maxIterations) {
maxIterations = nIters;
}
}
}
static void updateCompressingStats (unsigned nUncVars,
unsigned nUncFactors,
unsigned nCompVars,
unsigned nCompFactors,
unsigned nNeighborlessVars) {
compressInfo_.push_back (CompressInfo (
nUncVars, nUncFactors, nCompVars, nCompFactors, nNeighborlessVars));
}
static void printCompressingStats (const char* fileName)
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "BayesNet::printCompressingStats()" << endl;
abort();
}
out << "--------------------------------------" ;
out << "--------------------------------------" << endl;
out << " Compression Stats" << endl;
out << "--------------------------------------" ;
out << "--------------------------------------" << endl;
out << left;
out << "Uncompress Compressed Uncompress Compressed Neighborless";
out << endl;
out << "Vars Vars Factors Factors Vars" ;
out << endl;
for (unsigned i = 0; i < compressInfo_.size(); i++) {
out << setw (13) << compressInfo_[i].nUncVars;
out << setw (13) << compressInfo_[i].nCompVars;
out << setw (13) << compressInfo_[i].nUncFactors;
out << setw (13) << compressInfo_[i].nCompFactors;
out << setw (13) << compressInfo_[i].nNeighborlessVars;
out << endl;
} }
} }
@ -84,20 +186,12 @@ class Statistics
return it->second.counting; return it->second.counting;
} }
static void updateIterations (unsigned nIters)
{
totalOfIterations += nIters;
if (nIters > maxIterations) {
maxIterations = nIters;
}
}
static void writeStats (void) static void writeStats (void)
{ {
ofstream out ("../../stats.txt"); ofstream out ("../../stats.txt");
if (!out.is_open()) { if (!out.is_open()) {
cerr << "error: cannot open file to write at " ; cerr << "error: cannot open file to write at " ;
cerr << "Statistics:::updateStats()" << endl; cerr << "Statistics::updateStats()" << endl;
abort(); abort();
} }
unsigned avgIterations = 0; unsigned avgIterations = 0;
@ -117,17 +211,24 @@ class Statistics
out << " average iterations: " << avgIterations << endl; out << " average iterations: " << avgIterations << endl;
out << "total solving time " << totalSolvingTime << endl; out << "total solving time " << totalSolvingTime << endl;
out << endl; out << endl;
out << "Network Size\tCounting\tSolving Time\tAverage Time" << endl; out << left << endl;
out << setw (15) << "Network Size" ;
out << setw (15) << "Counting" ;
out << setw (15) << "Solving Time" ;
out << setw (15) << "Average Time" ;
out << setw (15) << "#Iterations" ;
out << endl;
for (StatisticMap::iterator it = stats_.begin(); for (StatisticMap::iterator it = stats_.begin();
it != stats_.end(); it++) { it != stats_.end(); it++) {
out << it->first; out << setw (15) << it->first;
out << "\t\t" << it->second.counting; out << setw (15) << it->second.counting;
out << "\t\t" << it->second.solvingTime; out << setw (15) << it->second.solvingTime;
if (it->second.counting > 0) { if (it->second.counting > 0) {
out << "\t\t" << it->second.solvingTime / it->second.counting; out << setw (15) << it->second.solvingTime / it->second.counting;
} else { } else {
out << "\t\t0.0" ; out << setw (15) << "0.0" ;
} }
out << setw (15) << it->second.nIters;
out << endl; out << endl;
} }
out.close(); out.close();
@ -142,62 +243,8 @@ class Statistics
static StatisticMap stats_; static StatisticMap stats_;
static unsigned maxIterations; static unsigned maxIterations;
static unsigned totalOfIterations; static unsigned totalOfIterations;
static vector<CompressInfo> compressInfo_;
}; };
#endif //BP_SHARED_H
class Util
{
public:
static void normalize (ParamSet& v)
{
double sum = 0.0;
for (unsigned i = 0; i < v.size(); i++) {
sum += v[i];
}
assert (sum != 0.0);
for (unsigned i = 0; i < v.size(); i++) {
v[i] /= sum;
}
}
static double getL1dist (const ParamSet& v1, const ParamSet& v2)
{
assert (v1.size() == v2.size());
double dist = 0.0;
for (unsigned i = 0; i < v1.size(); i++) {
dist += abs (v1[i] - v2[i]);
}
return dist;
}
static double getMaxNorm (const ParamSet& v1, const ParamSet& v2)
{
assert (v1.size() == v2.size());
double max = 0.0;
for (unsigned i = 0; i < v1.size(); i++) {
double diff = abs (v1[i] - v2[i]);
if (diff > max) {
max = diff;
}
}
return max;
}
static bool isInteger (const string& s)
{
stringstream ss1 (s);
stringstream ss2;
int integer;
ss1 >> integer;
ss2 << integer;
return (ss1.str() == ss2.str());
}
};
//unsigned Statistics::totalOfIterations = 0;
#endif

View File

@ -15,19 +15,30 @@ class Solver
{ {
gm_ = gm; gm_ = gm;
} }
virtual ~Solver() {} // to call subclass destructor
virtual void runSolver (void) = 0; virtual void runSolver (void) = 0;
virtual ParamSet getPosterioriOf (const Variable*) const = 0; virtual ParamSet getPosterioriOf (Vid) const = 0;
virtual ParamSet getJointDistributionOf (const VidSet&) = 0;
void printPosterioriOf (const Variable* var) const void printAllPosterioris (void) const
{ {
VarSet vars = gm_->getVariables();
for (unsigned i = 0; i < vars.size(); i++) {
printPosterioriOf (vars[i]->getVarId());
}
}
void printPosterioriOf (Vid vid) const
{
Variable* var = gm_->getVariable (vid);
cout << endl; cout << endl;
cout << setw (20) << left << var->getLabel() << "posteriori" ; cout << setw (20) << left << var->getLabel() << "posteriori" ;
cout << endl; cout << endl;
cout << "------------------------------" ; cout << "------------------------------" ;
cout << endl; cout << endl;
const Domain& domain = var->getDomain(); const Domain& domain = var->getDomain();
ParamSet results = getPosterioriOf (var); ParamSet results = getPosterioriOf (vid);
for (int xi = 0; xi < var->getDomainSize(); xi++) { for (unsigned xi = 0; xi < var->getDomainSize(); xi++) {
cout << setw (20) << domain[xi]; cout << setw (20) << domain[xi];
cout << setprecision (PRECISION) << results[xi]; cout << setprecision (PRECISION) << results[xi];
cout << endl; cout << endl;
@ -35,16 +46,35 @@ class Solver
cout << endl; cout << endl;
} }
void printAllPosterioris (void) const void printJointDistributionOf (const VidSet& vids)
{ {
VarSet vars = gm_->getVariables(); const ParamSet& jointDist = getJointDistributionOf (vids);
for (unsigned i = 0; i < vars.size(); i++) { cout << endl;
printPosterioriOf (vars[i]); cout << "joint distribution of " ;
VarSet vars;
for (unsigned i = 0; i < vids.size() - 1; i++) {
Variable* var = gm_->getVariable (vids[i]);
cout << var->getLabel() << ", " ;
vars.push_back (var);
} }
Variable* var = gm_->getVariable (vids[vids.size() - 1]);
cout << var->getLabel() ;
vars.push_back (var);
cout << endl;
cout << "------------------------------" ;
cout << endl;
const vector<string>& domainConfs = Util::getInstantiations (vars);
for (unsigned i = 0; i < jointDist.size(); i++) {
cout << left << setw (20) << domainConfs[i];
cout << setprecision (PRECISION) << jointDist[i];
cout << endl;
}
cout << endl;
} }
private: private:
const GraphicalModel* gm_; const GraphicalModel* gm_;
}; };
#endif #endif //BP_SOLVER_H

View File

@ -0,0 +1,191 @@
#include <sstream>
#include "Variable.h"
#include "Shared.h"
namespace SolverOptions {
bool runBayesBall = false;
bool convertBn2Fg = true;
bool compressFactorGraph = true;
Schedule schedule = S_SEQ_FIXED;
//Schedule schedule = S_SEQ_RANDOM;
//Schedule schedule = S_PARALLEL;
//Schedule schedule = S_MAX_RESIDUAL;
double accuracy = 0.0001;
unsigned maxIter = 1000; //FIXME
}
unsigned Statistics::numCreatedNets = 0;
unsigned Statistics::numSolvedPolyTrees = 0;
unsigned Statistics::numSolvedLoopyNets = 0;
unsigned Statistics::numUnconvergedRuns = 0;
unsigned Statistics::maxIterations = 0;
unsigned Statistics::totalOfIterations = 0;
vector<CompressInfo> Statistics::compressInfo_;
StatisticMap Statistics::stats_;
namespace Util {
void
normalize (ParamSet& v)
{
double sum = 0.0;
for (unsigned i = 0; i < v.size(); i++) {
sum += v[i];
}
assert (sum != 0.0);
for (unsigned i = 0; i < v.size(); i++) {
v[i] /= sum;
}
}
void
pow (ParamSet& v, unsigned expoent)
{
for (unsigned i = 0; i < v.size(); i++) {
double value = 1;
for (unsigned j = 0; j < expoent; j++) {
value *= v[i];
}
v[i] = value;
}
}
double
getL1dist (const ParamSet& v1, const ParamSet& v2)
{
assert (v1.size() == v2.size());
double dist = 0.0;
for (unsigned i = 0; i < v1.size(); i++) {
dist += abs (v1[i] - v2[i]);
}
return dist;
}
double
getMaxNorm (const ParamSet& v1, const ParamSet& v2)
{
assert (v1.size() == v2.size());
double max = 0.0;
for (unsigned i = 0; i < v1.size(); i++) {
double diff = abs (v1[i] - v2[i]);
if (diff > max) {
max = diff;
}
}
return max;
}
bool
isInteger (const string& s)
{
stringstream ss1 (s);
stringstream ss2;
int integer;
ss1 >> integer;
ss2 << integer;
return (ss1.str() == ss2.str());
}
string
parametersToString (CParamSet v)
{
stringstream ss;
ss << "[" ;
for (unsigned i = 0; i < v.size() - 1; i++) {
ss << v[i] << ", " ;
}
if (v.size() != 0) {
ss << v[v.size() - 1];
}
ss << "]" ;
return ss.str();
}
vector<DConf>
getDomainConfigurations (const VarSet& vars)
{
unsigned nConfs = 1;
for (unsigned i = 0; i < vars.size(); i++) {
nConfs *= vars[i]->getDomainSize();
}
vector<DConf> confs (nConfs);
for (unsigned i = 0; i < nConfs; i++) {
confs[i].resize (vars.size());
}
unsigned nReps = 1;
for (int i = vars.size() - 1; i >= 0; i--) {
unsigned index = 0;
while (index < nConfs) {
for (unsigned j = 0; j < vars[i]->getDomainSize(); j++) {
for (unsigned r = 0; r < nReps; r++) {
confs[index][i] = j;
index++;
}
}
}
nReps *= vars[i]->getDomainSize();
}
return confs;
}
vector<string>
getInstantiations (const VarSet& vars)
{
//FIXME handle variables without domain
/*
char c = 'a' ;
const DConf& conf = entries[i].getDomainConfiguration();
for (unsigned j = 0; j < conf.size(); j++) {
if (j != 0) ss << "," ;
ss << c << conf[j] + 1;
c ++;
}
*/
unsigned rowSize = 1;
for (unsigned i = 0; i < vars.size(); i++) {
rowSize *= vars[i]->getDomainSize();
}
vector<string> headers (rowSize);
unsigned nReps = 1;
for (int i = vars.size() - 1; i >= 0; i--) {
Domain domain = vars[i]->getDomain();
unsigned index = 0;
while (index < rowSize) {
for (unsigned j = 0; j < vars[i]->getDomainSize(); j++) {
for (unsigned r = 0; r < nReps; r++) {
if (headers[index] != "") {
headers[index] = domain[j] + ", " + headers[index];
} else {
headers[index] = domain[j];
}
index++;
}
}
}
nReps *= vars[i]->getDomainSize();
}
return headers;
}
}

View File

@ -1,9 +1,10 @@
#ifndef BP_GENERIC_VARIABLE_H #ifndef BP_VARIABLE_H
#define BP_GENERIC_VARIABLE_H #define BP_VARIABLE_H
#include <algorithm>
#include <sstream> #include <sstream>
#include <algorithm>
#include "Shared.h" #include "Shared.h"
using namespace std; using namespace std;
@ -12,33 +13,61 @@ class Variable
{ {
public: public:
Variable (unsigned varId) Variable (const Variable* v)
{ {
this->varId_ = varId; vid_ = v->getVarId();
this->dsize_ = 0; dsize_ = v->getDomainSize();
this->evidence_ = -1; if (v->hasDomain()) {
this->label_ = 0; domain_ = v->getDomain();
dsize_ = domain_.size();
} else {
dsize_ = v->getDomainSize();
}
evidence_ = v->getEvidence();
if (v->hasLabel()) {
label_ = new string (v->getLabel());
} else {
label_ = 0;
}
} }
Variable (unsigned varId, unsigned dsize, int evidence = -1) Variable (Vid vid)
{
this->vid_ = vid;
this->dsize_ = 0;
this->evidence_ = NO_EVIDENCE;
this->label_ = 0;
}
Variable (Vid vid, unsigned dsize, int evidence = NO_EVIDENCE,
const string& lbl = string())
{ {
assert (dsize != 0); assert (dsize != 0);
assert (evidence < (int)dsize); assert (evidence < (int)dsize);
this->varId_ = varId; this->vid_ = vid;
this->dsize_ = dsize; this->dsize_ = dsize;
this->evidence_ = evidence; this->evidence_ = evidence;
this->label_ = 0; if (!lbl.empty()) {
this->label_ = new string (lbl);
} else {
this->label_ = 0;
}
} }
Variable (unsigned varId, const Domain& domain, int evidence = -1) Variable (Vid vid, const Domain& domain, int evidence = NO_EVIDENCE,
const string& lbl = string())
{ {
assert (!domain.empty()); assert (!domain.empty());
assert (evidence < (int)domain.size()); assert (evidence < (int)domain.size());
this->varId_ = varId; this->vid_ = vid;
this->dsize_ = domain.size(); this->dsize_ = domain.size();
this->domain_ = domain; this->domain_ = domain;
this->evidence_ = evidence; this->evidence_ = evidence;
this->label_ = 0; if (!lbl.empty()) {
this->label_ = new string (lbl);
} else {
this->label_ = 0;
}
} }
~Variable (void) ~Variable (void)
@ -46,19 +75,19 @@ class Variable
delete label_; delete label_;
} }
unsigned getVarId (void) const { return varId_; } unsigned getVarId (void) const { return vid_; }
unsigned getIndex (void) const { return index_; } unsigned getIndex (void) const { return index_; }
void setIndex (unsigned idx) { index_ = idx; } void setIndex (unsigned idx) { index_ = idx; }
int getDomainSize (void) const { return dsize_; } unsigned getDomainSize (void) const { return dsize_; }
bool hasEvidence (void) const { return evidence_ != -1; } bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; }
int getEvidence (void) const { return evidence_; } int getEvidence (void) const { return evidence_; }
bool hasDomain (void) { return !domain_.empty(); } bool hasDomain (void) const { return !domain_.empty(); }
bool hasLabel (void) { return label_ != 0; } bool hasLabel (void) const { return label_ != 0; }
bool isValidStateIndex (int index) bool isValidStateIndex (int index)
{ {
return index >= 0 && index < dsize_; return index >= 0 && index < (int)dsize_;
} }
bool isValidState (const string& state) bool isValidState (const string& state)
{ {
@ -70,7 +99,7 @@ class Variable
assert (dsize_ != 0); assert (dsize_ != 0);
if (domain_.size() == 0) { if (domain_.size() == 0) {
Domain d; Domain d;
for (int i = 0; i < dsize_; i++) { for (unsigned i = 0; i < dsize_; i++) {
stringstream ss; stringstream ss;
ss << "x" << i ; ss << "x" << i ;
d.push_back (ss.str()); d.push_back (ss.str());
@ -110,7 +139,7 @@ class Variable
} }
} }
void setLabel (string label) void setLabel (const string& label)
{ {
label_ = new string (label); label_ = new string (label);
} }
@ -119,25 +148,25 @@ class Variable
{ {
if (label_ == 0) { if (label_ == 0) {
stringstream ss; stringstream ss;
ss << "v" << varId_; ss << "v" << vid_;
return ss.str(); return ss.str();
} else { } else {
return *label_; return *label_;
} }
} }
protected:
unsigned varId_;
string* label_;
unsigned index_;
int evidence_;
private: private:
DISALLOW_COPY_AND_ASSIGN (Variable); DISALLOW_COPY_AND_ASSIGN (Variable);
Domain domain_;
int dsize_; Vid vid_;
unsigned dsize_;
int evidence_;
Domain domain_;
string* label_;
unsigned index_;
}; };
#endif // BP_GENERIC_VARIABLE_H #endif // BP_VARIABLE_H

View File

@ -0,0 +1,34 @@
:- use_module(library(clpbn)).
:- set_clpbn_flag(solver, bp).
%
% R
% / | \
% / | \
% A B C
%
r(R) :-
{ R = r with p([t, f], [0.35, 0.65]) }.
a(A) :-
r(R),
child_dist(R,Dist),
{ A = a with Dist }.
b(B) :-
r(R),
child_dist(R,Dist),
{ B = b with Dist }.
c(C) :-
r(R),
child_dist(R,Dist),
{ C = c with Dist }.
child_dist(R, p([t, f], [0.3, 0.4, 0.25, 0.05], [R])).

View File

@ -0,0 +1,53 @@
<?xml version="1.0" encoding="US-ASCII"?>
<!--
A B
\ /
\ /
C
-->
<BIF VERSION="0.3">
<NETWORK>
<NAME>Neapolitan</NAME>
<VARIABLE TYPE="nature">
<NAME>A</NAME>
<OUTCOME>a1</OUTCOME>
<OUTCOME>a2</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>B</NAME>
<OUTCOME>b1</OUTCOME>
<OUTCOME>b2</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>C</NAME>
<OUTCOME>c1</OUTCOME>
<OUTCOME>c2</OUTCOME>
</VARIABLE>
<DEFINITION>
<FOR>A</FOR>
<TABLE> .695 .305 </TABLE>
</DEFINITION>
<DEFINITION>
<FOR>B</FOR>
<TABLE> .25 .75 </TABLE>
</DEFINITION>
<DEFINITION>
<FOR>C</FOR>
<GIVEN>A</GIVEN>
<GIVEN>B</GIVEN>
<TABLE> .2 .8 .45 .55 .32 .68 .7 .3 </TABLE>
</DEFINITION>
</NETWORK>
</BIF>

View File

@ -9,10 +9,10 @@ MARKOV
2 4 2 2 4 2
2 2
.001 .009 .001 .999
2 2
.002 .008 .002 .998
8 8
.95 .94 .29 .001 .95 .94 .29 .001

View File

@ -49,12 +49,12 @@
<DEFINITION> <DEFINITION>
<FOR>B</FOR> <FOR>B</FOR>
<TABLE> .001 .009 </TABLE> <TABLE> .001 .999 </TABLE>
</DEFINITION> </DEFINITION>
<DEFINITION> <DEFINITION>
<FOR>E</FOR> <FOR>E</FOR>
<TABLE> .002 .008 </TABLE> <TABLE> .002 .998 </TABLE>
</DEFINITION> </DEFINITION>
<DEFINITION> <DEFINITION>

View File

@ -1,54 +1,29 @@
:- use_module(library(clpbn)). :- use_module(library(clpbn)).
:- set_clpbn_flag(solver, vel). :- set_clpbn_flag(solver, bp).
% r(R) :- r_cpt(RCpt),
% B E { R = r with p([r1, r2], RCpt) }.
% \ /
% \ /
% A
% / \
% / \
% J M
%
t(T) :- t_cpt(TCpt),
{ T = t with p([t1, t2], TCpt) }.
b(B) :- a(A) :- r(R), t(T), a_cpt(ACpt),
b_table(BDist), { A = a with p([a1, a2], ACpt, [R, T]) }.
{ B = b with p([b1, b2], BDist) }.
e(E) :- j(J) :- a(A), j_cpt(JCpt),
e_table(EDist), { J = j with p([j1, j2], JCpt, [A]) }.
{ E = e with p([e1, e2], EDist) }.
a(A) :- m(M) :- a(A), m_cpt(MCpt),
b(B), { M = m with p([m1, m2], MCpt, [A]) }.
e(E),
a_table(ADist),
{ A = a with p([a1, a2], ADist, [B, E]) }.
j(J):-
a(A),
j_table(JDist),
{ J = j with p([j1, j2], JDist, [A]) }.
m(M):-
a(A),
m_table(MDist),
{ M = m with p([m1, m2], MDist, [A]) }.
r_cpt([0.001, 0.999]).
b_table([0.001, 0.009]). t_cpt([0.002, 0.998]).
a_cpt([0.95, 0.94, 0.29, 0.001,
e_table([0.002, 0.008]). 0.05, 0.06, 0.71, 0.999]).
j_cpt([0.9, 0.05,
a_table([0.95, 0.94, 0.29, 0.001, 0.1, 0.95]).
0.05, 0.06, 0.71, 0.999]). m_cpt([0.7, 0.01,
0.3, 0.99]).
j_table([0.9, 0.05,
0.1, 0.95]).
m_table([0.7, 0.01,
0.3, 0.99]).

View File

@ -16,34 +16,37 @@
<VARIABLE TYPE="nature"> <VARIABLE TYPE="nature">
<NAME>A</NAME> <NAME>A</NAME>
<OUTCOME></OUTCOME> <OUTCOME>a1</OUTCOME>
<OUTCOME>a2</OUTCOME>
</VARIABLE> </VARIABLE>
<VARIABLE TYPE="nature"> <VARIABLE TYPE="nature">
<NAME>B</NAME> <NAME>B</NAME>
<OUTCOME></OUTCOME> <OUTCOME>b1</OUTCOME>
<OUTCOME>b2</OUTCOME>
</VARIABLE> </VARIABLE>
<VARIABLE TYPE="nature"> <VARIABLE TYPE="nature">
<NAME>C</NAME> <NAME>C</NAME>
<OUTCOME></OUTCOME> <OUTCOME>c1</OUTCOME>
<OUTCOME>c2</OUTCOME>
</VARIABLE> </VARIABLE>
<DEFINITION> <DEFINITION>
<FOR>A</FOR> <FOR>A</FOR>
<TABLE>1</TABLE> <TABLE>.695 .305</TABLE>
</DEFINITION> </DEFINITION>
<DEFINITION> <DEFINITION>
<FOR>B</FOR> <FOR>B</FOR>
<TABLE>1</TABLE> <TABLE>0.25 0.75</TABLE>
</DEFINITION> </DEFINITION>
<DEFINITION> <DEFINITION>
<FOR>C</FOR> <FOR>C</FOR>
<GIVEN>A</GIVEN> <GIVEN>A</GIVEN>
<GIVEN>B</GIVEN> <GIVEN>B</GIVEN>
<TABLE>1</TABLE> <TABLE>0.2 0.8 0.45 0.55 0.32 0.68 0.7 0.3</TABLE>
</DEFINITION> </DEFINITION>
</NETWORK> </NETWORK>

View File

@ -0,0 +1,67 @@
<?xml version="1.0" encoding="US-ASCII"?>
<!--
P1 P2 P3
\ | /
\ | /
-
C
-->
<BIF VERSION="0.3">
<NETWORK>
<NAME>Simple Convergence</NAME>
<VARIABLE TYPE="nature">
<NAME>P1</NAME>
<OUTCOME>p1</OUTCOME>
<OUTCOME>p2</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>P2</NAME>
<OUTCOME>p1</OUTCOME>
<OUTCOME>p2</OUTCOME>
<OUTCOME>p3</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>P3</NAME>
<OUTCOME>p1</OUTCOME>
<OUTCOME>p2</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>C</NAME>
<OUTCOME>c1</OUTCOME>
<OUTCOME>c2</OUTCOME>
</VARIABLE>
<DEFINITION>
<FOR>P1</FOR>
<TABLE>.695 .305</TABLE>
</DEFINITION>
<DEFINITION>
<FOR>P2</FOR>
<TABLE>0.2 0.3 0.5</TABLE>
</DEFINITION>
<DEFINITION>
<FOR>P3</FOR>
<TABLE>0.25 0.75</TABLE>
</DEFINITION>
<DEFINITION>
<FOR>C</FOR>
<GIVEN>P1</GIVEN>
<GIVEN>P2</GIVEN>
<GIVEN>P3</GIVEN>
<TABLE>0.2 0.8 0.45 0.55 0.32 0.68 0.7 0.3 0.3 0.7 0.55 0.45 0.22 0.78 0.25 0.75 0.11 0.89 0.34 0.66 0.1 0.9 0.6 0.4</TABLE>
</DEFINITION>
</NETWORK>
</BIF>

View File

@ -2,6 +2,7 @@
:- use_module(library(clpbn)). :- use_module(library(clpbn)).
:- set_clpbn_flag(solver, bp). :- set_clpbn_flag(solver, bp).
%:- set_clpbn_flag(solver, jt).
% %
% B F % B F

View File

@ -0,0 +1,17 @@
MARKOV
3
2 2 2
3
1 0
1 1
3 2 0 1
2
.695 .305
2
.25 .75
8
0.2 0.45 0.32 0.7
0.8 0.55 0.68 0.3

View File

@ -0,0 +1,128 @@
<?xml version="1.0" encoding="US-ASCII"?>
<!--
A B C
\ | /
\ | /
D
/ | \
/ | \
E F G
-->
<BIF VERSION="0.3">
<NETWORK>
<NAME>Node with several parents and childs</NAME>
<VARIABLE TYPE="nature">
<NAME>A</NAME>
<OUTCOME>a1</OUTCOME>
<OUTCOME>a2</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>B</NAME>
<OUTCOME>b1</OUTCOME>
<OUTCOME>b2</OUTCOME>
<OUTCOME>b3</OUTCOME>
<OUTCOME>b4</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>C</NAME>
<OUTCOME>c1</OUTCOME>
<OUTCOME>c2</OUTCOME>
<OUTCOME>c3</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>D</NAME>
<OUTCOME>d1</OUTCOME>
<OUTCOME>d2</OUTCOME>
<OUTCOME>d3</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>E</NAME>
<OUTCOME>e1</OUTCOME>
<OUTCOME>e2</OUTCOME>
<OUTCOME>e3</OUTCOME>
<OUTCOME>e4</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>F</NAME>
<OUTCOME>f1</OUTCOME>
<OUTCOME>f2</OUTCOME>
<OUTCOME>f3</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>G</NAME>
<OUTCOME>g1</OUTCOME>
<OUTCOME>g2</OUTCOME>
</VARIABLE>
<DEFINITION>
<FOR>A</FOR>
<TABLE> .1 .2 </TABLE>
</DEFINITION>
<DEFINITION>
<FOR>B</FOR>
<TABLE> .01 .02 .03 .04 </TABLE>
</DEFINITION>
<DEFINITION>
<FOR>C</FOR>
<TABLE> .11 .22 .33 </TABLE>
</DEFINITION>
<DEFINITION>
<FOR>D</FOR>
<GIVEN>A</GIVEN>
<GIVEN>B</GIVEN>
<GIVEN>C</GIVEN>
<TABLE>
.522 .008 .99 .01 .2 .8 .003 .457 .423 .007 .92 .04 .5 .232 .033 .227 .112 .048 .91 .21 .24 .18 .005 .227
.212 .04 .59 .21 .6 .1 .023 .215 .913 .017 .96 .01 .55 .422 .013 .417 .272 .068 .61 .11 .26 .28 .205 .322
.142 .028 .19 .11 .5 .67 .013 .437 .163 .067 .12 .06 .1 .262 .063 .167 .512 .028 .11 .41 .14 .68 .015 .92
</TABLE>
</DEFINITION>
<DEFINITION>
<FOR>E</FOR>
<GIVEN>D</GIVEN>
<TABLE>
.111 .11 .1
.222 .22 .2
.333 .33 .3
.444 .44 .4
</TABLE>
</DEFINITION>
<DEFINITION>
<FOR>F</FOR>
<GIVEN>D</GIVEN>
<TABLE>
.112 .111 .110
.223 .222 .221
.334 .333 .332
</TABLE>
</DEFINITION>
<DEFINITION>
<FOR>G</FOR>
<GIVEN>D</GIVEN>
<TABLE>
.101 .102 .103
.201 .202 .203
</TABLE>
</DEFINITION>
</NETWORK>
</BIF>

View File

@ -0,0 +1,36 @@
MARKOV
5
4 2 3 2 3
7
1 0
1 1
1 2
1 3
1 4
2 0 1
4 1 2 3 4
4
0.1 0.7 0.43 0.22
2
0.2 0.6
3
0.3 0.5 0.2
2
0.15 0.75
3
0.25 0.45 0.15
8
0.210 0.333 0.457 0.4
0.811 0.000 0.189 0.89
36
0.1 0.15 0.2 0.25 0.3 0.45 0.5 0.55 0.65 0.7 0.75 0.9
0.11 0.22 0.33 0.44 0.55 0.66 0.77 0.88 0.91 0.93 0.95 0.97
0.42 0.22 0.33 0.44 0.15 0.36 0.27 0.28 0.21 0.13 0.25 0.17

View File

@ -0,0 +1,69 @@
<?xml version="1.0" encoding="US-ASCII"?>
<!--
A B
\ /
\ /
C
|
|
D
-->
<BIF VERSION="0.3">
<NETWORK>
<NAME>Simple Loop</NAME>
<VARIABLE TYPE="nature">
<NAME>A</NAME>
<OUTCOME>a1</OUTCOME>
<OUTCOME>a2</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>B</NAME>
<OUTCOME>b1</OUTCOME>
<OUTCOME>b2</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>C</NAME>
<OUTCOME>c1</OUTCOME>
<OUTCOME>c2</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>D</NAME>
<OUTCOME>d1</OUTCOME>
<OUTCOME>d2</OUTCOME>
</VARIABLE>
<DEFINITION>
<FOR>A</FOR>
<TABLE> .001 .009 </TABLE>
</DEFINITION>
<DEFINITION>
<FOR>B</FOR>
<TABLE> .002 .008 </TABLE>
</DEFINITION>
<DEFINITION>
<FOR>C</FOR>
<GIVEN>A</GIVEN>
<GIVEN>B</GIVEN>
<TABLE> .95 .05 .94 .06 .29 .71 .001 .999 </TABLE>
</DEFINITION>
<DEFINITION>
<FOR>D</FOR>
<GIVEN>C</GIVEN>
<TABLE> .9 .1 .05 .95 </TABLE>
</DEFINITION>
</NETWORK>
</BIF>

View File

@ -884,6 +884,15 @@ writePrimitive(term_t t, write_options *options)
return writeString(t, options); return writeString(t, options);
#endif /* O_STRING */ #endif /* O_STRING */
#if __YAP_PROLOG__
{
number n;
n.type = V_INTEGER;
n.value.i = 0;
return WriteNumber(&n, options);
}
#endif
assert(0); assert(0);
fail; fail;
} }

View File

@ -1121,6 +1121,9 @@ Yap_StreamPosition(IOSTREAM *st)
return StreamPosition(st); return StreamPosition(st);
} }
IOSTREAM *STD_PROTO(Yap_Scurin, (void));
int STD_PROTO(Yap_dowrite, (Term, IOSTREAM *, int, int));
IOSTREAM * IOSTREAM *
Yap_Scurin(void) Yap_Scurin(void)
{ {
@ -1128,6 +1131,32 @@ Yap_Scurin(void)
return Scurin; return Scurin;
} }
int
Yap_dowrite(Term t, IOSTREAM *stream, int flags, int priority)
/* term to be written */
/* consumer */
/* write options */
{
CACHE_REGS
int swi_flags;
int res;
Int slot = Yap_InitSlot(t PASS_REGS);
swi_flags = 0;
if (flags & Quote_illegal_f)
swi_flags |= PL_WRT_QUOTED;
if (flags & Handle_vars_f)
swi_flags |= PL_WRT_NUMBERVARS;
if (flags & Use_portray_f)
swi_flags |= PL_WRT_PORTRAY;
if (flags & Ignore_ops_f)
swi_flags |= PL_WRT_IGNOREOPS;
res = PL_write_term(stream, slot, priority, swi_flags);
Yap_RecoverSlots(1 PASS_REGS);
return res;
}
#if THREADS #if THREADS
@ -1178,6 +1207,7 @@ error:
return rc; return rc;
} }
int int
recursiveMutexInit(recursiveMutex *m) recursiveMutexInit(recursiveMutex *m)
{ {

@ -1 +1 @@
Subproject commit cee6c346ba77e046ef1873b9d8c88c52c3baae2d Subproject commit a5d2a3755f86ebffd8eb6e2a6921790d299eab30

@ -1 +1 @@
Subproject commit f6fce313722d2f69e5ce6074f304ce05065bbf40 Subproject commit a6f0f4ec7d5fd51ca8b268b8392da9b20bfd1b44

@ -1 +1 @@
Subproject commit 3f7eee803d46071f10804010f7266b3864bab7e6 Subproject commit 2a0843683e8790d8129fa4bad99c82a5fcaf441b

@ -1 +1 @@
Subproject commit aa90b8f8a67e605e82566b719a8f20d125598bd2 Subproject commit 2daa5b9942a9fc22005ec48b6560f8596f33d995

@ -1 +1 @@
Subproject commit 2229eb3807b21497ffd680686ceeeddeac9f57fb Subproject commit e3c7eb9a9a54d6a9069396ecd1c3b537a8f6165a

@ -1 +1 @@
Subproject commit babcbfe9cc5ab269cd5fd4f024e9d57bd3d0a8db Subproject commit f1c3ef54f4d9431ba5b4188cb72ca3056d20b202

View File

@ -1,59 +0,0 @@
#
# default base directory for YAP installation
# (EROOT for architecture-dependent files)
#
prefix = @prefix@
exec_prefix = @exec_prefix@
ROOTDIR = $(prefix)
EROOTDIR = @exec_prefix@
abs_top_builddir = @abs_top_builddir@
#
# where the binary should be
#
BINDIR = $(EROOTDIR)/bin
#
# where YAP should look for libraries
#
LIBDIR=@libdir@
YAPLIBDIR=@libdir@/Yap
#
#
DEFS=@DEFS@ -D_YAP_NOT_INSTALLED_=1
CC=@CC@
CFLAGS= @SHLIB_CFLAGS@ $(YAP_EXTRAS) $(DEFS) -I$(srcdir) -I../.. -I$(srcdir)/../../include -I$(srcdir)/../PLStream -I$(srcdir)/../PLStream/windows -I$(srcdir)/../../H
#
#
# You shouldn't need to change what follows.
#
INSTALL=@INSTALL@
INSTALL_DATA=@INSTALL_DATA@
INSTALL_PROGRAM=@INSTALL_PROGRAM@
SHELL=/bin/sh
RANLIB=@RANLIB@
srcdir=@srcdir@
SO=@SO@
#4.1VPATH=@srcdir@:@srcdir@/OPTYap
CWD=$(PWD)
#
OBJS=pl-tai.o
SOBJS=pl-tai.@SO@
#in some systems we just create a single object, in others we need to
# create a libray
all: $(SOBJS)
pl-tai.o: $(srcdir)/pl-tai.c
(cd libtai ; $(MAKE))
$(CC) -c $(CFLAGS) $(srcdir)/pl-tai.c -o pl-tai.o
@DO_SECOND_LD@pl-tai.@SO@: pl-tai.o
@DO_SECOND_LD@ @SHLIB_LD@ $(LDFLAGS) -o pl-tai.@SO@ pl-tai.o libtai/libtai.a @EXTRA_LIBS_FOR_SWIDLLS@
install: all
$(INSTALL_PROGRAM) $(SOBJS) $(DESTDIR)$(YAPLIBDIR)
clean:
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK
-(cd libtai && $(MAKE) clean)

@ -1 +1 @@
Subproject commit deea4bfdf7041387e91eca37978a5c8db9287eda Subproject commit 109bf1d224009dc049ab12a0fbbb50511ef8e1fb

View File

@ -1050,6 +1050,8 @@ make :-
fail. fail.
make. make.
make_library_index(_Directory).
'$file_name'(Stream,F) :- '$file_name'(Stream,F) :-
stream_property(Stream, file_name(F)), !. stream_property(Stream, file_name(F)), !.
'$file_name'(user_input,user_output). '$file_name'(user_input,user_output).

View File

@ -495,7 +495,7 @@ debugging :-
'$continue_debugging'(no, '$execute_nonstop'(G,M)). '$continue_debugging'(no, '$execute_nonstop'(G,M)).
'$spycall'(G, M, CalledFromDebugger, InRedo) :- '$spycall'(G, M, CalledFromDebugger, InRedo) :-
'$flags'(G,M,F,F), '$flags'(G,M,F,F),
F /\ 0x18402000 =\= 0, !, % dynamic procedure, logical semantics, user-C, or source F /\ 0x08402000 =\= 0, !, % dynamic procedure, logical semantics, or source
% use the interpreter % use the interpreter
CP is '$last_choice_pt', CP is '$last_choice_pt',
'$clause'(G, M, Cl, _), '$clause'(G, M, Cl, _),

View File

@ -245,22 +245,15 @@ print_message(Severity, Msg) :-
'$notrace'(user:portray_message(Severity, Msg)), !. '$notrace'(user:portray_message(Severity, Msg)), !.
% This predicate has more hooks than a pirate ship! % This predicate has more hooks than a pirate ship!
print_message(Severity, Term) :- print_message(Severity, Term) :-
( % first step at hook processing
( '$message_to_lines'(Term, Lines),
'$oncenotrace'(user:generate_message_hook(Term, [], Lines)) -> ( nonvar(Term),
true '$oncenotrace'(user:message_hook(Term, Severity, Lines))
; ->
'$oncenotrace'(prolog:message(Term, Lines, [])) -> true
true ;
; '$print_system_message'(Term, Severity, Lines)
'$messages':generate_message(Term, Lines, []) ), !.
)
-> ( nonvar(Term),
'$oncenotrace'(user:message_hook(Term, Severity, Lines))
-> !
; !, '$print_system_message'(Term, Severity, Lines)
)
).
print_message(silent, _) :- !. print_message(silent, _) :- !.
print_message(_, error(syntax_error(syntax_error(_,between(_,L,_),_,_,_,_,StreamName)),_)) :- !, print_message(_, error(syntax_error(syntax_error(_,between(_,L,_),_,_,_,_,StreamName)),_)) :- !,
format(user_error,'SYNTAX ERROR at ~a, close to ~d~n',[StreamName,L]). format(user_error,'SYNTAX ERROR at ~a, close to ~d~n',[StreamName,L]).
@ -271,6 +264,14 @@ print_message(_, loaded(A, F, _, Time, Space)) :- !,
print_message(_, Term) :- print_message(_, Term) :-
format(user_error,'~q~n',[Term]). format(user_error,'~q~n',[Term]).
'$message_to_lines'(Term, Lines) :-
'$oncenotrace'(user:generate_message_hook(Term, [], Lines)), !.
'$message_to_lines'(Term, Lines) :-
'$oncenotrace'(prolog:message(Term, Lines, [])), !.
'$message_to_lines'(Term, Lines) :-
'$messages':generate_message(Term, Lines, []), !.
% print_system_message(+Term, +Level, +Lines) % print_system_message(+Term, +Level, +Lines)
% %
% Print the message if the user did not intecept the message. % Print the message if the user did not intecept the message.

View File

@ -69,6 +69,10 @@ generate_message(debug) --> !,
[ debug ]. [ debug ].
generate_message(trace) --> !, generate_message(trace) --> !,
[ trace ]. [ trace ].
generate_message(error(Error,Context)) -->
{ Error = existence_error(procedure,_) }, !,
system_message(error(Error,Context)),
stack_dump(error(Error,Context)).
generate_message(error(Error,context(Cause,Extra))) --> generate_message(error(Error,context(Cause,Extra))) -->
system_message(error(Error,Cause)), system_message(error(Error,Cause)),
stack_dump(error(Error,context(Cause,Extra))). stack_dump(error(Error,context(Cause,Extra))).
@ -130,8 +134,6 @@ system_message(no_match(P)) -->
[ 'No matching predicate for ~w.' - [P] ]. [ 'No matching predicate for ~w.' - [P] ].
system_message(leash([A|B])) --> system_message(leash([A|B])) -->
[ 'Leashing set to ~w.' - [[A|B]] ]. [ 'Leashing set to ~w.' - [[A|B]] ].
system_message(existence_error(prolog_flag,F)) -->
[ 'Prolog Flag ~w: new Prolog flags must be created using create_prolog_flag/3.' - [F] ].
system_message(singletons([SV],P)) --> system_message(singletons([SV],P)) -->
[ 'Singleton variable ~s in ~q.' - [SV,P] ]. [ 'Singleton variable ~s in ~q.' - [SV,P] ].
system_message(singletons(SVs,P)) --> system_message(singletons(SVs,P)) -->
@ -159,14 +161,18 @@ system_message(error(context_error(Goal,Who),Where)) -->
system_message(error(domain_error(DomainType,Opt), Where)) --> system_message(error(domain_error(DomainType,Opt), Where)) -->
[ 'DOMAIN ERROR- ~w: ' - Where], [ 'DOMAIN ERROR- ~w: ' - Where],
domain_error(DomainType, Opt). domain_error(DomainType, Opt).
system_message(error(existence_error(directory,Key), Where)) -->
[ 'EXISTENCE ERROR- ~w: ~w not an existing directory' - [Where,Key] ].
system_message(error(existence_error(key,Key), Where)) -->
[ 'EXISTENCE ERROR- ~w: ~w not an existing key' - [Where,Key] ].
system_message(existence_error(prolog_flag,F)) -->
[ 'Prolog Flag ~w: new Prolog flags must be created using create_prolog_flag/3.' - [F] ].
system_message(error(existence_error(prolog_flag,P), Where)) --> !, system_message(error(existence_error(prolog_flag,P), Where)) --> !,
[ 'EXISTENCE ERROR- ~w: prolog flag ~w is undefined' - [Where,P] ]. [ 'EXISTENCE ERROR- ~w: prolog flag ~w is undefined' - [Where,P] ].
system_message(error(existence_error(procedure,P), context(Call,Parent))) --> !, system_message(error(existence_error(procedure,P), context(Call,Parent))) --> !,
[ 'EXISTENCE ERROR- procedure ~w is undefined, called from context ~w~n Goal was ~w' - [P,Parent,Call] ]. [ 'EXISTENCE ERROR- procedure ~w is undefined, called from context ~w~n Goal was ~w' - [P,Parent,Call] ].
system_message(error(existence_error(stream,Stream), Where)) --> system_message(error(existence_error(stream,Stream), Where)) -->
[ 'EXISTENCE ERROR- ~w: ~w not an open stream' - [Where,Stream] ]. [ 'EXISTENCE ERROR- ~w: ~w not an open stream' - [Where,Stream] ].
system_message(error(existence_error(key,Key), Where)) -->
[ 'EXISTENCE ERROR- ~w: ~w not an existing key' - [Where,Key] ].
system_message(error(existence_error(thread,Thread), Where)) --> system_message(error(existence_error(thread,Thread), Where)) -->
[ 'EXISTENCE ERROR- ~w: ~w not a running thread' - [Where,Thread] ]. [ 'EXISTENCE ERROR- ~w: ~w not a running thread' - [Where,Thread] ].
system_message(error(existence_error(variable,Var), Where)) --> system_message(error(existence_error(variable,Var), Where)) -->