Merge branch 'master' of /cygdrive/z/vitor/Yap/yap-6.3
This commit is contained in:
commit
4fe1833ece
@ -2081,8 +2081,13 @@ Yap_absmi(int inp)
|
||||
goto failloop;
|
||||
} else
|
||||
#endif /* FROZEN_STACKS */
|
||||
if (IN_BETWEEN(H0,pt1,H) && IsAttVar(pt1))
|
||||
goto failloop;
|
||||
if (IN_BETWEEN(H0,pt1,H)) {
|
||||
if (IsAttVar(pt1)) {
|
||||
goto failloop;
|
||||
} else if (*pt1 == (CELL)FunctorBigInt) {
|
||||
Yap_CleanOpaqueVariable(pt1);
|
||||
}
|
||||
}
|
||||
#ifdef FROZEN_STACKS /* TRAIL */
|
||||
/* don't reset frozen variables */
|
||||
if (pt0 < TR_FZ)
|
||||
|
14
C/amasm.c
14
C/amasm.c
@ -252,7 +252,7 @@ STATIC_PROTO(yamop *a_if, (op_numbers, union clause_obj *, int, yamop *, int, st
|
||||
STATIC_PROTO(yamop *a_cut, (clause_info *,yamop *, int, struct intermediates *));
|
||||
#ifdef YAPOR
|
||||
STATIC_PROTO(yamop *a_try, (op_numbers, CELL, CELL, int, int, yamop *, int, struct intermediates *));
|
||||
STATIC_PROTO(yamop *a_either, (op_numbers, CELL, CELL, int, int, yamop *, int, struct intermediates *));
|
||||
STATIC_PROTO(yamop *a_either, (op_numbers, CELL, CELL, int, yamop *, int, struct intermediates *));
|
||||
#else
|
||||
STATIC_PROTO(yamop *a_try, (op_numbers, CELL, CELL, yamop *, int, struct intermediates *));
|
||||
STATIC_PROTO(yamop *a_either, (op_numbers, CELL, CELL, yamop *, int, struct intermediates *));
|
||||
@ -2155,7 +2155,7 @@ a_try(op_numbers opcode, CELL lab, CELL opr, int nofalts, int hascut, yamop *cod
|
||||
#endif
|
||||
#ifdef YAPOR
|
||||
INIT_YAMOP_LTT(code_p, nofalts);
|
||||
if (hascut)
|
||||
if (cip->clause_has_cut)
|
||||
PUT_YAMOP_CUT(code_p);
|
||||
if (ap->PredFlags & SequentialPredFlag)
|
||||
PUT_YAMOP_SEQ(code_p);
|
||||
@ -2167,7 +2167,7 @@ a_try(op_numbers opcode, CELL lab, CELL opr, int nofalts, int hascut, yamop *cod
|
||||
|
||||
static yamop *
|
||||
#ifdef YAPOR
|
||||
a_either(op_numbers opcode, CELL opr, CELL lab, int nofalts, int hascut, yamop *code_p, int pass_no, struct intermediates *cip)
|
||||
a_either(op_numbers opcode, CELL opr, CELL lab, int nofalts, yamop *code_p, int pass_no, struct intermediates *cip)
|
||||
#else
|
||||
a_either(op_numbers opcode, CELL opr, CELL lab, yamop *code_p, int pass_no, struct intermediates *cip)
|
||||
#endif /* YAPOR */
|
||||
@ -2179,7 +2179,7 @@ a_either(op_numbers opcode, CELL opr, CELL lab, yamop *code_p, int pass_no, stru
|
||||
code_p->u.Osblp.p0 = cip->CurrentPred;
|
||||
#ifdef YAPOR
|
||||
INIT_YAMOP_LTT(code_p, nofalts);
|
||||
if (hascut)
|
||||
if (cip->clause_has_cut)
|
||||
PUT_YAMOP_CUT(code_p);
|
||||
if (cip->CurrentPred->PredFlags & SequentialPredFlag)
|
||||
PUT_YAMOP_SEQ(code_p);
|
||||
@ -3583,7 +3583,7 @@ do_pass(int pass_no, yamop **entry_codep, int assembling, int *clause_has_blobsp
|
||||
}
|
||||
code_p = a_either(_either,
|
||||
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
|
||||
Unsigned(cip->code_addr) + cip->label_offset[cip->cpc->rnd1], 0, 0, code_p, pass_no, cip);
|
||||
Unsigned(cip->code_addr) + cip->label_offset[cip->cpc->rnd1], 0, code_p, pass_no, cip);
|
||||
#else
|
||||
code_p = a_either(_either,
|
||||
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
|
||||
@ -3596,7 +3596,7 @@ do_pass(int pass_no, yamop **entry_codep, int assembling, int *clause_has_blobsp
|
||||
either_inst[either_cont++] = code_p;
|
||||
code_p = a_either(_or_else,
|
||||
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
|
||||
Unsigned(cip->code_addr) + cip->label_offset[cip->cpc->rnd1], 0, 0, code_p, pass_no, cip);
|
||||
Unsigned(cip->code_addr) + cip->label_offset[cip->cpc->rnd1], 0, code_p, pass_no, cip);
|
||||
#else
|
||||
code_p = a_either(_or_else,
|
||||
-Signed(RealEnvSize) - CELLSIZE * cip->cpc->rnd2,
|
||||
@ -3608,7 +3608,7 @@ do_pass(int pass_no, yamop **entry_codep, int assembling, int *clause_has_blobsp
|
||||
#ifdef YAPOR
|
||||
if (pass_no)
|
||||
either_inst[either_cont++] = code_p;
|
||||
code_p = a_either(_or_last, 0, 0, 0, 0, code_p, pass_no, cip);
|
||||
code_p = a_either(_or_last, 0, 0, 0, code_p, pass_no, cip);
|
||||
if (pass_no) {
|
||||
int cont = 1;
|
||||
do {
|
||||
|
57
C/bignum.c
57
C/bignum.c
@ -25,9 +25,10 @@ static char SccsId[] = "%W% %G%";
|
||||
#include <string.h>
|
||||
#endif
|
||||
|
||||
#include "YapHeap.h"
|
||||
|
||||
#ifdef USE_GMP
|
||||
|
||||
#include "YapHeap.h"
|
||||
#include "eval.h"
|
||||
#include "alloc.h"
|
||||
|
||||
@ -59,6 +60,7 @@ Yap_MkBigIntTerm(MP_INT *big)
|
||||
return AbsAppl(ret);
|
||||
}
|
||||
|
||||
|
||||
MP_INT *
|
||||
Yap_BigIntOfTerm(Term t)
|
||||
{
|
||||
@ -127,9 +129,60 @@ Yap_RatTermToApplTerm(Term t)
|
||||
return Yap_MkApplTerm(FunctorRDiv,2,ts);
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
Term
|
||||
Yap_AllocExternalDataInStack(CELL tag, size_t bytes)
|
||||
{
|
||||
CACHE_REGS
|
||||
Int nlimbs;
|
||||
MP_INT *dst = (MP_INT *)(H+2);
|
||||
CELL *ret = H;
|
||||
|
||||
nlimbs = ALIGN_YAPTYPE(bytes,CELL)/CellSize;
|
||||
if (nlimbs > (ASP-ret)-1024) {
|
||||
return TermNil;
|
||||
}
|
||||
H[0] = (CELL)FunctorBigInt;
|
||||
H[1] = tag;
|
||||
dst->_mp_size = 0;
|
||||
dst->_mp_alloc = nlimbs;
|
||||
H = (CELL *)(dst+1)+nlimbs;
|
||||
H[0] = EndSpecials;
|
||||
H++;
|
||||
if (tag != EXTERNAL_BLOB) {
|
||||
TrailTerm(TR) = AbsPair(ret);
|
||||
TR++;
|
||||
}
|
||||
return AbsAppl(ret);
|
||||
}
|
||||
|
||||
int Yap_CleanOpaqueVariable(CELL *pt)
|
||||
{
|
||||
CELL blob_info, blob_tag;
|
||||
MP_INT *blobp;
|
||||
#ifdef DEBUG
|
||||
/* sanity checking */
|
||||
if (pt[0] != (CELL)FunctorBigInt) {
|
||||
Yap_Error(SYSTEM_ERROR, TermNil, "CleanOpaqueVariable bad call");
|
||||
return FALSE;
|
||||
}
|
||||
#endif
|
||||
blob_tag = pt[1];
|
||||
if (blob_tag < USER_BLOB_START ||
|
||||
blob_tag >= USER_BLOB_END) {
|
||||
Yap_Error(SYSTEM_ERROR, AbsAppl(pt), "clean opaque: bad blob with tag " UInt_FORMAT ,blob_tag);
|
||||
return FALSE;
|
||||
}
|
||||
blob_info = blob_tag - USER_BLOB_START;
|
||||
if (!GLOBAL_OpaqueHandlers)
|
||||
return FALSE;
|
||||
blobp = (MP_INT *)(pt+2);
|
||||
if (!GLOBAL_OpaqueHandlers[blob_info].fail_handler)
|
||||
return TRUE;
|
||||
return (GLOBAL_OpaqueHandlers[blob_info].fail_handler)((void *)(blobp+1));
|
||||
}
|
||||
|
||||
Term
|
||||
Yap_MkULLIntTerm(YAP_ULONG_LONG n)
|
||||
{
|
||||
|
@ -390,6 +390,8 @@ X_API Bool STD_PROTO(YAP_IsDbRefTerm,(Term));
|
||||
X_API Bool STD_PROTO(YAP_IsAtomTerm,(Term));
|
||||
X_API Bool STD_PROTO(YAP_IsPairTerm,(Term));
|
||||
X_API Bool STD_PROTO(YAP_IsApplTerm,(Term));
|
||||
X_API Bool STD_PROTO(YAP_IsExternalDataInStackTerm,(Term));
|
||||
X_API Bool STD_PROTO(YAP_IsOpaqueObjectTerm,(Term, int));
|
||||
X_API Term STD_PROTO(YAP_MkIntTerm,(Int));
|
||||
X_API Term STD_PROTO(YAP_MkBigNumTerm,(void *));
|
||||
X_API Term STD_PROTO(YAP_MkRationalTerm,(void *));
|
||||
@ -463,7 +465,7 @@ X_API IOSTREAM *STD_PROTO(YAP_TermToStream,(Term));
|
||||
X_API IOSTREAM *STD_PROTO(YAP_InitConsult,(int, char *));
|
||||
X_API void STD_PROTO(YAP_EndConsult,(IOSTREAM *));
|
||||
X_API Term STD_PROTO(YAP_Read, (IOSTREAM *));
|
||||
X_API void STD_PROTO(YAP_Write, (Term, int (*)(wchar_t), int));
|
||||
X_API void STD_PROTO(YAP_Write, (Term, IOSTREAM *, int));
|
||||
X_API Term STD_PROTO(YAP_CopyTerm, (Term));
|
||||
X_API Term STD_PROTO(YAP_WriteBuffer, (Term, char *, unsigned int, int));
|
||||
X_API char *STD_PROTO(YAP_CompileClause, (Term));
|
||||
@ -536,13 +538,11 @@ X_API Term STD_PROTO(YAP_ModuleUser,(void));
|
||||
X_API Int STD_PROTO(YAP_NumberOfClausesForPredicate,(PredEntry *));
|
||||
X_API int STD_PROTO(YAP_MaxOpPriority,(Atom, Term));
|
||||
X_API int STD_PROTO(YAP_OpInfo,(Atom, Term, int, int *, int *));
|
||||
|
||||
static int (*do_putcf)(wchar_t);
|
||||
|
||||
static int do_yap_putc(int streamno,wchar_t ch) {
|
||||
do_putcf(ch);
|
||||
return(ch);
|
||||
}
|
||||
X_API Term STD_PROTO(YAP_AllocExternalDataInStack,(size_t));
|
||||
X_API void *STD_PROTO(YAP_ExternalDataInStackFromTerm,(Term));
|
||||
X_API int STD_PROTO(YAP_NewOpaqueType,(void *));
|
||||
X_API Term STD_PROTO(YAP_NewOpaqueObject,(int, size_t));
|
||||
X_API void *STD_PROTO(YAP_OpaqueObjectFromTerm,(Term));
|
||||
|
||||
static int
|
||||
dogc(void)
|
||||
@ -1677,7 +1677,6 @@ YAP_ExecuteOnCut(PredEntry *pe, CPredicate exec_code, struct cut_c_str *top)
|
||||
if (pe->PredFlags & CArgsPredFlag) {
|
||||
val = execute_cargs_back(pe, exec_code, ctx PASS_REGS);
|
||||
} else {
|
||||
fprintf(stderr,"ctx=%p\n",ctx);
|
||||
val = ((codev)(args-LCL0,0,ctx));
|
||||
}
|
||||
/* make sure we clean up the frames left by the user */
|
||||
@ -2314,6 +2313,65 @@ YAP_RunGoal(Term t)
|
||||
return(out);
|
||||
}
|
||||
|
||||
X_API Term
|
||||
YAP_AllocExternalDataInStack(size_t bytes)
|
||||
{
|
||||
Term t = Yap_AllocExternalDataInStack(EXTERNAL_BLOB, bytes);
|
||||
if (t == TermNil)
|
||||
return 0L;
|
||||
return t;
|
||||
}
|
||||
|
||||
X_API Bool
|
||||
YAP_IsExternalDataInStackTerm(Term t)
|
||||
{
|
||||
return IsExternalBlobTerm(t, EXTERNAL_BLOB);
|
||||
}
|
||||
|
||||
X_API void *
|
||||
YAP_ExternalDataInStackFromTerm(Term t)
|
||||
{
|
||||
return ExternalBlobFromTerm (t);
|
||||
}
|
||||
|
||||
int YAP_NewOpaqueType(void *f)
|
||||
{
|
||||
int i;
|
||||
if (!GLOBAL_OpaqueHandlers) {
|
||||
GLOBAL_OpaqueHandlers = malloc(sizeof(opaque_handler_t)*(USER_BLOB_END-USER_BLOB_START));
|
||||
if (!GLOBAL_OpaqueHandlers) {
|
||||
/* no room */
|
||||
return -1;
|
||||
}
|
||||
} else if (GLOBAL_OpaqueHandlersCount == USER_BLOB_END-USER_BLOB_START) {
|
||||
/* all types used */
|
||||
return -1;
|
||||
}
|
||||
i = GLOBAL_OpaqueHandlersCount++;
|
||||
memcpy(GLOBAL_OpaqueHandlers+i,f,sizeof(opaque_handler_t));
|
||||
return i+USER_BLOB_START;
|
||||
}
|
||||
|
||||
Term YAP_NewOpaqueObject(int tag, size_t bytes)
|
||||
{
|
||||
Term t = Yap_AllocExternalDataInStack((CELL)tag, bytes);
|
||||
if (t == TermNil)
|
||||
return 0L;
|
||||
return t;
|
||||
}
|
||||
|
||||
X_API Bool
|
||||
YAP_IsOpaqueObjectTerm(Term t, int tag)
|
||||
{
|
||||
return IsExternalBlobTerm(t, (CELL)tag);
|
||||
}
|
||||
|
||||
X_API void *
|
||||
YAP_OpaqueObjectFromTerm(Term t)
|
||||
{
|
||||
return ExternalBlobFromTerm (t);
|
||||
}
|
||||
|
||||
X_API Term
|
||||
YAP_RunGoalOnce(Term t)
|
||||
{
|
||||
@ -2369,7 +2427,6 @@ YAP_RestartGoal(void)
|
||||
BACKUP_MACHINE_REGS();
|
||||
if (LOCAL_AllowRestart) {
|
||||
P = (yamop *)FAILCODE;
|
||||
do_putcf = myputc;
|
||||
LOCAL_PrologMode = UserMode;
|
||||
out = Yap_exec_absmi(TRUE);
|
||||
LOCAL_PrologMode = UserCCallMode;
|
||||
@ -2589,12 +2646,11 @@ YAP_Read(IOSTREAM *inp)
|
||||
}
|
||||
|
||||
X_API void
|
||||
YAP_Write(Term t, int (*myputc)(wchar_t), int flags)
|
||||
YAP_Write(Term t, IOSTREAM *stream, int flags)
|
||||
{
|
||||
BACKUP_MACHINE_REGS();
|
||||
|
||||
do_putcf = myputc; /* */
|
||||
Yap_plwrite (t, do_yap_putc, flags, 1200);
|
||||
Yap_dowrite (t, stream, flags, 1200);
|
||||
|
||||
RECOVER_MACHINE_REGS();
|
||||
}
|
||||
@ -2774,6 +2830,7 @@ YAP_Init(YAP_init_args *yap_init)
|
||||
Yap_init_yapor_global_local_memory();
|
||||
LOCAL = REMOTE(0);
|
||||
#endif /* YAPOR_COPY || YAPOR_COW || YAPOR_SBA */
|
||||
GLOBAL_PrologShouldHandleInterrupts = yap_init->PrologShouldHandleInterrupts;
|
||||
Yap_InitSysbits(); /* init signal handling and time, required by later functions */
|
||||
GLOBAL_argv = yap_init->Argv;
|
||||
GLOBAL_argc = yap_init->Argc;
|
||||
@ -2821,7 +2878,6 @@ YAP_Init(YAP_init_args *yap_init)
|
||||
} else {
|
||||
Heap = yap_init->HeapSize;
|
||||
}
|
||||
GLOBAL_PrologShouldHandleInterrupts = yap_init->PrologShouldHandleInterrupts;
|
||||
Yap_InitWorkspace(Heap, Stack, Trail, Atts,
|
||||
yap_init->MaxTableSpaceSize,
|
||||
yap_init->NumberWorkers,
|
||||
|
@ -735,7 +735,7 @@ c_arg(Int argno, Term t, unsigned int arity, unsigned int level, compiler_struct
|
||||
} else if (IsPairTerm(t)) {
|
||||
cglobs->space_used += 2;
|
||||
if (optimizer_on && level < 6) {
|
||||
#if !defined(THREADS)
|
||||
#if !defined(THREADS) && !defined(YAPOR)
|
||||
/* discard code sharing because we cannot write on shared stuff */
|
||||
if (!(cglobs->cint.CurrentPred->PredFlags & (DynamicPredFlag|LogUpdatePredFlag))) {
|
||||
if (try_store_as_dbterm(t, argno, arity, level, cglobs))
|
||||
|
4
C/exec.c
4
C/exec.c
@ -961,7 +961,7 @@ exec_absmi(int top USES_REGS)
|
||||
restore_H();
|
||||
/* set stack */
|
||||
ASP = (CELL *)PROTECT_FROZEN_B(B);
|
||||
Yap_StartSlots( PASS_REGS1 );
|
||||
Yap_PopSlots();
|
||||
LOCK(LOCAL_SignalLock);
|
||||
/* forget any signals active, we're reborne */
|
||||
LOCAL_ActiveSignals = 0;
|
||||
@ -991,9 +991,9 @@ exec_absmi(int top USES_REGS)
|
||||
LOCAL_PrologMode = UserMode;
|
||||
}
|
||||
} else {
|
||||
Yap_CloseSlots( PASS_REGS1 );
|
||||
LOCAL_PrologMode = UserMode;
|
||||
}
|
||||
Yap_CloseSlots( PASS_REGS1 );
|
||||
YENV = ASP;
|
||||
YENV[E_CB] = Unsigned (B);
|
||||
out = Yap_absmi(0);
|
||||
|
4
C/grow.c
4
C/grow.c
@ -399,7 +399,7 @@ AdjustTrail(int adjusting_heap, int thread_copying USES_REGS)
|
||||
#if defined(YAPOR_THREADS)
|
||||
}
|
||||
#endif
|
||||
/* moving the trail is simple */
|
||||
/* moving the trail is simple, yeaahhh! */
|
||||
while (ptt != tr_base) {
|
||||
register CELL reg = TrailTerm(ptt-1);
|
||||
#ifdef FROZEN_STACKS
|
||||
@ -420,8 +420,6 @@ AdjustTrail(int adjusting_heap, int thread_copying USES_REGS)
|
||||
} else if (IsPairTerm(reg)) {
|
||||
TrailTerm(ptt) = AdjustPair(reg PASS_REGS);
|
||||
#ifdef MULTI_ASSIGNMENT_VARIABLES /* does not work with new structures */
|
||||
/* check it whether we are protecting a
|
||||
multi-assignment */
|
||||
} else if (IsApplTerm(reg)) {
|
||||
TrailTerm(ptt) = AdjustAppl(reg PASS_REGS);
|
||||
#endif
|
||||
|
49
C/heapgc.c
49
C/heapgc.c
@ -1325,7 +1325,7 @@ mark_variable(CELL_PTR current USES_REGS)
|
||||
sz++;
|
||||
#if DEBUG
|
||||
if (next[sz] != EndSpecials) {
|
||||
fprintf(stderr,"[ Error: could not find EndSpecials at blob %p type %lx ]\n", next, next[1]);
|
||||
fprintf(stderr,"[ Error: could not find EndSpecials at blob %p type " UInt_FORMAT " ]\n", next, next[1]);
|
||||
}
|
||||
#endif
|
||||
MARK(next+sz);
|
||||
@ -1658,14 +1658,23 @@ mark_trail(tr_fr_ptr trail_ptr, tr_fr_ptr trail_base, CELL *gc_H, choiceptr gc_B
|
||||
#endif
|
||||
}
|
||||
} else if (IsPairTerm(trail_cell)) {
|
||||
/* can safely ignore this */
|
||||
/* cannot safely ignore this */
|
||||
CELL *cptr = RepPair(trail_cell);
|
||||
if (IN_BETWEEN(LOCAL_GlobalBase,cptr,H) &&
|
||||
GlobalIsAttVar(cptr)) {
|
||||
TrailTerm(trail_base) = (CELL)cptr;
|
||||
mark_external_reference(&TrailTerm(trail_base) PASS_REGS);
|
||||
TrailTerm(trail_base) = trail_cell;
|
||||
}
|
||||
if (IN_BETWEEN(LOCAL_GlobalBase,cptr,H)) {
|
||||
if (GlobalIsAttVar(cptr)) {
|
||||
TrailTerm(trail_base) = (CELL)cptr;
|
||||
mark_external_reference(&TrailTerm(trail_base) PASS_REGS);
|
||||
TrailTerm(trail_base) = trail_cell;
|
||||
} else if (*cptr == (CELL)FunctorBigInt) {
|
||||
TrailTerm(trail_base) = AbsAppl(cptr);
|
||||
mark_external_reference(&TrailTerm(trail_base) PASS_REGS);
|
||||
TrailTerm(trail_base) = trail_cell;
|
||||
}
|
||||
#ifdef DEBUG
|
||||
else
|
||||
fprintf(GLOBAL_stderr,"OOPS in GC: weird trail entry at %p:" UInt_FORMAT "\n", &TrailTerm(trail_base), (CELL)cptr);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#if MULTI_ASSIGNMENT_VARIABLES
|
||||
else {
|
||||
@ -1904,11 +1913,11 @@ mark_choicepoints(register choiceptr gc_B, tr_fr_ptr saved_TR, int very_verbose
|
||||
PredEntry *pe = Yap_PredForChoicePt(gc_B);
|
||||
#if defined(ANALYST) || defined(DEBUG)
|
||||
if (pe == NULL) {
|
||||
fprintf(GLOBAL_stderr,"%% marked %ld (%s)\n", LOCAL_total_marked, Yap_op_names[opnum]);
|
||||
fprintf(GLOBAL_stderr,"%% marked " UInt_FORMAT " (%s)\n", LOCAL_total_marked, Yap_op_names[opnum]);
|
||||
} else if (pe->ArityOfPE) {
|
||||
fprintf(GLOBAL_stderr,"%% %s/%d marked %ld (%s)\n", RepAtom(NameOfFunctor(pe->FunctorOfPred))->StrOfAE, pe->ArityOfPE, LOCAL_total_marked, Yap_op_names[opnum]);
|
||||
fprintf(GLOBAL_stderr,"%% %s/%d marked " UInt_FORMAT " (%s)\n", RepAtom(NameOfFunctor(pe->FunctorOfPred))->StrOfAE, pe->ArityOfPE, LOCAL_total_marked, Yap_op_names[opnum]);
|
||||
} else {
|
||||
fprintf(GLOBAL_stderr,"%% %s marked %ld (%s)\n", RepAtom((Atom)(pe->FunctorOfPred))->StrOfAE, LOCAL_total_marked, Yap_op_names[opnum]);
|
||||
fprintf(GLOBAL_stderr,"%% %s marked " UInt_FORMAT " (%s)\n", RepAtom((Atom)(pe->FunctorOfPred))->StrOfAE, LOCAL_total_marked, Yap_op_names[opnum]);
|
||||
}
|
||||
#else
|
||||
if (pe == NULL) {
|
||||
@ -2450,11 +2459,19 @@ sweep_trail(choiceptr gc_B, tr_fr_ptr old_TR USES_REGS)
|
||||
CELL *pt0 = RepPair(trail_cell);
|
||||
CELL flags;
|
||||
|
||||
if (IN_BETWEEN(LOCAL_GlobalBase, pt0, H) && GlobalIsAttVar(pt0)) {
|
||||
TrailTerm(dest) = trail_cell;
|
||||
/* be careful with partial gc */
|
||||
if (HEAP_PTR(TrailTerm(dest))) {
|
||||
into_relocation_chain(&TrailTerm(dest), GET_NEXT(trail_cell) PASS_REGS);
|
||||
if (IN_BETWEEN(LOCAL_GlobalBase, pt0, H)) {
|
||||
if (GlobalIsAttVar(pt0)) {
|
||||
TrailTerm(dest) = trail_cell;
|
||||
/* be careful with partial gc */
|
||||
if (HEAP_PTR(TrailTerm(dest))) {
|
||||
into_relocation_chain(&TrailTerm(dest), GET_NEXT(trail_cell) PASS_REGS);
|
||||
}
|
||||
} else if (*pt0 == (CELL)FunctorBigInt) {
|
||||
TrailTerm(dest) = trail_cell;
|
||||
/* be careful with partial gc */
|
||||
if (HEAP_PTR(TrailTerm(dest))) {
|
||||
into_relocation_chain(&TrailTerm(dest), GET_NEXT(trail_cell) PASS_REGS);
|
||||
}
|
||||
}
|
||||
dest++;
|
||||
trail_ptr++;
|
||||
|
23
H/TermExt.h
23
H/TermExt.h
@ -86,7 +86,9 @@ typedef enum
|
||||
CLAUSE_LIST = 0x40,
|
||||
BLOB_STRING = 0x80, /* SWI style strings */
|
||||
BLOB_WIDE_STRING = 0x81, /* SWI style strings */
|
||||
EXTERNAL_BLOB = 0x100 /* for SWI emulation */
|
||||
EXTERNAL_BLOB = 0x100, /* generic data */
|
||||
USER_BLOB_START = 0x1000, /* user defined blob */
|
||||
USER_BLOB_END = 0x1100 /* end of user defined blob */
|
||||
}
|
||||
big_blob_type;
|
||||
|
||||
@ -438,6 +440,25 @@ IsLargeNumTerm (Term t)
|
||||
&& (FunctorOfTerm (t) >= FunctorLongInt)));
|
||||
}
|
||||
|
||||
inline EXTERN int IsExternalBlobTerm (Term, CELL);
|
||||
|
||||
inline EXTERN int
|
||||
IsExternalBlobTerm (Term t, CELL tag)
|
||||
{
|
||||
return (int) (IsApplTerm (t) &&
|
||||
FunctorOfTerm (t) == FunctorBigInt &&
|
||||
RepAppl(t)[1] == tag);
|
||||
}
|
||||
|
||||
inline EXTERN void *ExternalBlobFromTerm (Term);
|
||||
|
||||
inline EXTERN void *
|
||||
ExternalBlobFromTerm (Term t)
|
||||
{
|
||||
MP_INT *base = (MP_INT *)(RepAppl(t)+2);
|
||||
return (void *) (base+1);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -33,6 +33,12 @@ typedef int (*SWI_PLGetStreamPositionFunction)(void *);
|
||||
|
||||
#include "../include/dswiatoms.h"
|
||||
|
||||
typedef int (*Opaque_CallOnFail)(void *);
|
||||
|
||||
typedef struct opaque_handler_struct {
|
||||
Opaque_CallOnFail fail_handler;
|
||||
} opaque_handler_t;
|
||||
|
||||
#ifndef INT_KEYS_DEFAULT_SIZE
|
||||
#define INT_KEYS_DEFAULT_SIZE 256
|
||||
#endif
|
||||
|
@ -120,6 +120,8 @@ int STD_PROTO(Yap_IsStringTerm, (Term));
|
||||
int STD_PROTO(Yap_IsWideStringTerm, (Term));
|
||||
Term STD_PROTO(Yap_RatTermToApplTerm, (Term));
|
||||
void STD_PROTO(Yap_InitBigNums, (void));
|
||||
Term STD_PROTO(Yap_AllocExternalDataInStack, (CELL, size_t));
|
||||
int STD_PROTO(Yap_CleanOpaqueVariable, (CELL *));
|
||||
|
||||
/* c_interface.c */
|
||||
Int STD_PROTO(YAP_Execute,(struct pred_entry *, CPredicate));
|
||||
|
@ -102,6 +102,8 @@
|
||||
|
||||
#define GLOBAL_Executable Yap_global->Executable_
|
||||
#endif
|
||||
#define GLOBAL_OpaqueHandlersCount Yap_global->OpaqueHandlersCount_
|
||||
#define GLOBAL_OpaqueHandlers Yap_global->OpaqueHandlers_
|
||||
#if __simplescalar__
|
||||
#define GLOBAL_pwd Yap_global->pwd_
|
||||
#endif
|
||||
|
@ -102,6 +102,8 @@ typedef struct global_data {
|
||||
|
||||
char Executable_[YAP_FILENAME_MAX];
|
||||
#endif
|
||||
int OpaqueHandlersCount_;
|
||||
struct opaque_handler_struct* OpaqueHandlers_;
|
||||
#if __simplescalar__
|
||||
char pwd_[YAP_FILENAME_MAX];
|
||||
#endif
|
||||
|
@ -102,6 +102,8 @@ static void InitGlobal(void) {
|
||||
|
||||
|
||||
#endif
|
||||
GLOBAL_OpaqueHandlersCount = 0;
|
||||
GLOBAL_OpaqueHandlers = NULL;
|
||||
#if __simplescalar__
|
||||
|
||||
#endif
|
||||
|
@ -102,6 +102,8 @@ static void RestoreGlobal(void) {
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
#if __simplescalar__
|
||||
|
||||
#endif
|
||||
|
@ -67,6 +67,7 @@
|
||||
#define YP_FILE FILE
|
||||
|
||||
int STD_PROTO(YP_putc,(int, int));
|
||||
void STD_PROTO(Yap_dowrite, (Term, IOSTREAM *, int, int));
|
||||
|
||||
#else
|
||||
|
||||
@ -96,8 +97,6 @@ int STD_PROTO(YP_putc,(int, int));
|
||||
#define fclose ERR_fclose
|
||||
#define fflush ERR_fflush
|
||||
|
||||
|
||||
|
||||
/* flags for files in IOSTREAM struct */
|
||||
#define _YP_IO_WRITE 1
|
||||
#define _YP_IO_READ 2
|
||||
|
@ -154,18 +154,6 @@ static int SQLBINDCOL(SQLHSTMT sthandle,
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static int SQLFREESTMT(SQLHSTMT sthandle,
|
||||
SQLUSMALLINT opt,
|
||||
char * print)
|
||||
{
|
||||
SQLRETURN retcode;
|
||||
retcode = SQLFreeStmt(sthandle,opt);
|
||||
|
||||
if (retcode != SQL_SUCCESS && retcode != SQL_SUCCESS_WITH_INFO)
|
||||
return odbc_error(SQL_HANDLE_STMT, sthandle, "SQLFreeStmt", print);
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static int SQLNUMRESULTCOLS(SQLHSTMT sthandle,
|
||||
SQLSMALLINT * ncols,
|
||||
char * print)
|
||||
@ -421,8 +409,9 @@ c_db_odbc_query( USES_REGS1 ) {
|
||||
/* +1 because of '\0' */
|
||||
bind_space = malloc(sizeof(char)*(ColumnSizePtr+1));
|
||||
data_info = malloc(sizeof(SQLINTEGER));
|
||||
if (!SQLBINDCOL(hstmt,i,SQL_C_CHAR,bind_space,(ColumnSizePtr+1),data_info,"db_query"))
|
||||
if (!SQLBINDCOL(hstmt,i,SQL_C_CHAR,bind_space,(ColumnSizePtr+1),data_info,"db_query")) {
|
||||
return FALSE;
|
||||
}
|
||||
|
||||
properties[0] = MkIntegerTerm((Int)bind_space);
|
||||
properties[2] = MkIntegerTerm((Int)data_info);
|
||||
@ -444,7 +433,7 @@ c_db_odbc_query( USES_REGS1 ) {
|
||||
{
|
||||
if (!SQLCLOSECURSOR(hstmt,"db_query"))
|
||||
return FALSE;
|
||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_query"))
|
||||
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_query"))
|
||||
return FALSE;
|
||||
return FALSE;
|
||||
}
|
||||
@ -466,7 +455,7 @@ c_db_odbc_number_of_fields( USES_REGS1 ) {
|
||||
char sql[256];
|
||||
SQLSMALLINT number_fields;
|
||||
|
||||
sprintf(sql,"DESCRIBE %s",relation);
|
||||
sprintf(sql,"SELECT column_name from INFORMATION_SCHEMA.COLUMNS where table_name = \'%s\'",relation);
|
||||
|
||||
if (!SQLALLOCHANDLE(SQL_HANDLE_STMT, hdbc, &hstmt, "db_number_of_fields"))
|
||||
return FALSE;
|
||||
@ -482,7 +471,7 @@ c_db_odbc_number_of_fields( USES_REGS1 ) {
|
||||
|
||||
if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields"))
|
||||
return FALSE;
|
||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_number_of_fields"))
|
||||
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_number_of_fields"))
|
||||
return FALSE;
|
||||
|
||||
if (!Yap_unify(arg_fields, MkIntegerTerm(number_fields)))
|
||||
@ -506,7 +495,7 @@ c_db_odbc_get_attributes_types( USES_REGS1 ) {
|
||||
Term head, list;
|
||||
list = arg_types_list;
|
||||
|
||||
sprintf(sql,"DESCRIBE %s",relation);
|
||||
sprintf(sql,"SELECT column_name,data_type FROM INFORMATION_SCHEMA.COLUMNS where table_name = \'%s\'",relation);
|
||||
|
||||
if (!SQLALLOCHANDLE(SQL_HANDLE_STMT, hdbc, &hstmt, "db_get_attributes_types"))
|
||||
return FALSE;
|
||||
@ -547,7 +536,7 @@ c_db_odbc_get_attributes_types( USES_REGS1 ) {
|
||||
|
||||
if (!SQLCLOSECURSOR(hstmt,"db_get_attributes_types"))
|
||||
return FALSE;
|
||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE, "db_get_attributes_types"))
|
||||
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_get_attributes_types"))
|
||||
return FALSE;
|
||||
return TRUE;
|
||||
}
|
||||
@ -585,12 +574,31 @@ c_db_odbc_row_cut( USES_REGS1 ) {
|
||||
|
||||
if (!SQLCLOSECURSOR(hstmt,"db_row_cut"))
|
||||
return FALSE;
|
||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_row_cut"))
|
||||
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_row_cut"))
|
||||
return FALSE;
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static int
|
||||
release_list_args(Term arg_list_args, Term arg_bind_list, const char *error_msg)
|
||||
{
|
||||
Term list = arg_list_args;
|
||||
Term list_bind = arg_bind_list;
|
||||
|
||||
while (IsPairTerm(list_bind))
|
||||
{
|
||||
Term head_bind = HeadOfTerm(list_bind);
|
||||
|
||||
list = TailOfTerm(list);
|
||||
list_bind = TailOfTerm(list_bind);
|
||||
|
||||
free((char *)IntegerOfTerm(ArgOfTerm(1,head_bind)));
|
||||
free((SQLINTEGER *)IntegerOfTerm(ArgOfTerm(3,head_bind)));
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
/* db_row: ResultSet x BindList x ListOfArgs -> */
|
||||
static Int
|
||||
c_db_odbc_row( USES_REGS1 ) {
|
||||
@ -611,9 +619,12 @@ c_db_odbc_row( USES_REGS1 ) {
|
||||
{
|
||||
if (!SQLCLOSECURSOR(hstmt,"db_row"))
|
||||
return FALSE;
|
||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_row"))
|
||||
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_row"))
|
||||
return FALSE;
|
||||
|
||||
if (!release_list_args(arg_list_args, arg_bind_list, "db_row")) {
|
||||
return FALSE;
|
||||
}
|
||||
|
||||
cut_fail();
|
||||
return FALSE;
|
||||
}
|
||||
@ -699,7 +710,7 @@ c_db_odbc_number_of_fields_in_query( USES_REGS1 ) {
|
||||
if (!Yap_unify(arg_fields, MkIntegerTerm(number_cols))){
|
||||
if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields_in_query"))
|
||||
return FALSE;
|
||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE, "db_number_of_fields_in_query"))
|
||||
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_number_of_fields_in_query"))
|
||||
return FALSE;
|
||||
|
||||
return FALSE;
|
||||
@ -707,7 +718,7 @@ c_db_odbc_number_of_fields_in_query( USES_REGS1 ) {
|
||||
|
||||
if (!SQLCLOSECURSOR(hstmt,"db_number_of_fields_in_query"))
|
||||
return FALSE;
|
||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE, "db_number_of_fields_in_query"))
|
||||
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt, "db_number_of_fields_in_query"))
|
||||
return FALSE;
|
||||
|
||||
return TRUE;
|
||||
@ -775,7 +786,7 @@ c_db_odbc_get_fields_properties( USES_REGS1 ) {
|
||||
|
||||
if (!SQLCLOSECURSOR(hstmt2,"db_get_fields_properties"))
|
||||
return FALSE;
|
||||
if (!SQLFREESTMT(hstmt2,SQL_CLOSE,"db_get_fields_properties"))
|
||||
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt2, "db_get_fields_properties"))
|
||||
return FALSE;
|
||||
|
||||
for (i=1;i<=num_fields;i++)
|
||||
@ -816,9 +827,8 @@ c_db_odbc_get_fields_properties( USES_REGS1 ) {
|
||||
|
||||
if (!SQLCLOSECURSOR(hstmt,"db_get_fields_properties"))
|
||||
return FALSE;
|
||||
if (!SQLFREESTMT(hstmt,SQL_CLOSE,"db_get_fields_properties"))
|
||||
if (!SQLFREEHANDLE(SQL_HANDLE_STMT, hstmt2, "db_get_fields_properties"))
|
||||
return FALSE;
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
7
configure
vendored
7
configure
vendored
@ -9927,6 +9927,9 @@ mkdir -p packages/clib/maildrop
|
||||
mkdir -p packages/clib/maildrop/rfc822
|
||||
mkdir -p packages/clib/maildrop/rfc2045
|
||||
mkdir -p packages/CLPBN
|
||||
mkdir -p packages/CLPBN/clpbn
|
||||
mkdir -p packages/CLPBN/clpbn/bp
|
||||
mkdir -p packages/CLPBN/clpbn/bp/xmlParser
|
||||
mkdir -p packages/clpqr
|
||||
mkdir -p packages/cplint
|
||||
mkdir -p packages/cplint/approx
|
||||
@ -10619,8 +10622,8 @@ esac
|
||||
|
||||
cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1
|
||||
# Files that config.status was made for.
|
||||
config_files="`echo $ac_config_files`"
|
||||
config_headers="`echo $ac_config_headers`"
|
||||
config_files="$ac_config_files"
|
||||
config_headers="$ac_config_headers"
|
||||
|
||||
_ACEOF
|
||||
|
||||
|
@ -2103,6 +2103,9 @@ mkdir -p packages/clib/maildrop
|
||||
mkdir -p packages/clib/maildrop/rfc822
|
||||
mkdir -p packages/clib/maildrop/rfc2045
|
||||
mkdir -p packages/CLPBN
|
||||
mkdir -p packages/CLPBN/clpbn
|
||||
mkdir -p packages/CLPBN/clpbn/bp
|
||||
mkdir -p packages/CLPBN/clpbn/bp/xmlParser
|
||||
mkdir -p packages/clpqr
|
||||
mkdir -p packages/cplint
|
||||
mkdir -p packages/cplint/approx
|
||||
|
80
docs/yap.tex
80
docs/yap.tex
@ -16811,13 +16811,29 @@ only two boolean flags are accepted: @code{YAPC_ENABLE_GC} and
|
||||
@code{YAPC_ENABLE_AGC}. The first enables/disables the standard garbage
|
||||
collector, the second does the same for the atom garbage collector.`
|
||||
|
||||
@item @code{YAP_TERM} YAP_AllocExternalDataInStack(@code{size_t bytes})
|
||||
@item @code{void *} YAP_ExternalDataInStackFromTerm(@code{YAP_Term t})
|
||||
@item @code{YAP_Bool} YAP_IsExternalDataInStackTerm(@code{YAP_Term t})
|
||||
@findex YAP_AllocExternalDataInStack (C-Interface function)
|
||||
|
||||
The next routines allow one to store external data in the Prolog
|
||||
execution stack. The first routine reserves space for @var{sz} bytes
|
||||
and returns an opaque handle. The second routines receives the handle
|
||||
and returns a pointer to the data. The last routine checks if a term
|
||||
is an opaque handle.
|
||||
|
||||
Data will be automatically reclaimed during
|
||||
backtracking. Also, this storage is opaque to the Prolog garbage compiler,
|
||||
so it should not be used to store Prolog terms. On the other hand, it
|
||||
may be useful to store arrays in a compact way, or pointers to external objects.
|
||||
|
||||
@item @code{int} YAP_HaltRegisterHook(@code{YAP_halt_hook f, void *closure})
|
||||
@findex YAP_HaltRegisterHook (C-Interface function)
|
||||
|
||||
Register the function @var{f} to be called if YAP is halted. The
|
||||
function is called with two arguments: the exit code of the process (@code{0}
|
||||
if this cannot be determined on your operating system) and the closure
|
||||
argument @var{closure}.
|
||||
function is called with two arguments: the exit code of the process
|
||||
(@code{0} if this cannot be determined on your operating system) and
|
||||
the closure argument @var{closure}.
|
||||
@c See also @code{at_halt/1}.
|
||||
|
||||
@end table
|
||||
@ -16850,6 +16866,7 @@ implementing the predicate and @var{arity} is its arity.
|
||||
@findex YAP_UserBackCutCPredicate (C-Interface function)
|
||||
@findex YAP_PRESERVE_DATA (C-Interface function)
|
||||
@findex YAP_PRESERVED_DATA (C-Interface function)
|
||||
@findex YAP_PRESERVED_DATA_CUT (C-Interface function)
|
||||
@findex YAP_cutsucceed (C-Interface function)
|
||||
@findex YAP_cutfail (C-Interface function)
|
||||
For the second kind of predicates we need three C functions. The first one
|
||||
@ -16915,11 +16932,15 @@ static int start_n100(void)
|
||||
@end example
|
||||
|
||||
The routine starts by getting the dereference value of the argument.
|
||||
The call to @code{YAP_PRESERVE_DATA} is used to initialize the memory which will
|
||||
hold the information to be preserved across backtracking. The first
|
||||
argument is the variable we shall use, and the second its type. Note
|
||||
that we can only use @code{YAP_PRESERVE_DATA} once, so often we will
|
||||
want the variable to be a structure.
|
||||
The call to @code{YAP_PRESERVE_DATA} is used to initialize the memory
|
||||
which will hold the information to be preserved across
|
||||
backtracking. The first argument is the variable we shall use, and the
|
||||
second its type. Note that we can only use @code{YAP_PRESERVE_DATA}
|
||||
once, so often we will want the variable to be a structure. This data
|
||||
is visible to the garbage collector, so it should consist of Prolog
|
||||
terms, as in the example. It is also correct to store pointers to
|
||||
objects external to YAP stacks, as the garbage collector will ignore
|
||||
such references.
|
||||
|
||||
If the argument of the predicate is a variable, the routine initializes the
|
||||
structure to be preserved across backtracking with the information
|
||||
@ -16988,6 +17009,34 @@ when pruning the execution of the predicate, @var{arity} is the
|
||||
predicate arity, and @var{sizeof} is the size of the data to be
|
||||
preserved in the stack. In this example, we would have something like
|
||||
|
||||
@example
|
||||
void
|
||||
init_n100(void)
|
||||
@{
|
||||
YAP_UserBackCutCPredicate("n100", start_n100, continue_n100, cut_n100, 1, 1);
|
||||
@}
|
||||
@end example
|
||||
The argument before last is the predicate's arity. Notice again the
|
||||
last argument to the call. function argument gives the extra space we
|
||||
want to use for @code{PRESERVED_DATA}. Space is given in cells, where
|
||||
a cell is the same size as a pointer. The garbage collector has access
|
||||
to this space, hence users should use it either to store terms or to
|
||||
store pointers to objects outside the stacks.
|
||||
|
||||
The code for @code{cut_n100} could be:
|
||||
@example
|
||||
static int cut_n100(void)
|
||||
@{
|
||||
YAP_PRESERVED_DATA_CUT(n100_data,n100_data_type*);
|
||||
|
||||
fprintf("n100 cut with counter %ld\n", YAP_IntOfTerm(n100_data->next_solution));
|
||||
return TRUE;
|
||||
@}
|
||||
@end example
|
||||
Notice that we have to use @code{YAP_PRESERVED_DATA_CUT}: this is
|
||||
because the Prolog engine is at a different state during cut.
|
||||
|
||||
If no work is required at cut, we can use:
|
||||
@example
|
||||
void
|
||||
init_n100(void)
|
||||
@ -16995,8 +17044,7 @@ init_n100(void)
|
||||
YAP_UserBackCutCPredicate("n100", start_n100, continue_n100, NULL, 1, 1);
|
||||
@}
|
||||
@end example
|
||||
Notice that we do not actually need to do anything on receiving a cut in
|
||||
this case.
|
||||
in this case no code is executed at cut time.
|
||||
|
||||
@node Loading Objects, Save&Rest, Writing C, C-Interface
|
||||
@section Loading Object Files
|
||||
@ -17272,9 +17320,9 @@ Associate the term @var{value} with the atom @var{at}. The term
|
||||
@var{value} must be a constant. This functionality is used by YAP as a
|
||||
simple way for controlling and communicating with the Prolog run-time.
|
||||
|
||||
@item @code{YAP_Term} YAP_Read(@code{int (*)(void)} @var{GetC})
|
||||
@item @code{YAP_Term} YAP_Read(@code{IOSTREAM *Stream})
|
||||
@findex YAP_Read/1
|
||||
Parse a Term using the function @var{GetC} to input characters.
|
||||
Parse a @var{Term} from the stream @var{Stream}.
|
||||
|
||||
@item @code{YAP_Term} YAP_Write(@code{YAP_Term} @var{t})
|
||||
@findex YAP_CopyTerm/1
|
||||
@ -17282,13 +17330,13 @@ Copy a Term @var{t} and all associated constraints. May call the garbage
|
||||
collector and returns @code{0L} on error (such as no space being
|
||||
available).
|
||||
|
||||
@item @code{void} YAP_Write(@code{YAP_Term} @var{t}, @code{void (*)(int)}
|
||||
@var{PutC}, @code{int} @var{flags})
|
||||
@item @code{void} YAP_Write(@code{YAP_Term} @var{t}, @code{IOSTREAM}
|
||||
@var{stream}, @code{int} @var{flags})
|
||||
@findex YAP_Write/3
|
||||
Write a Term @var{t} using the function @var{PutC} to output
|
||||
Write a Term @var{t} using the stream @var{stream} to output
|
||||
characters. The term is written according to a mask of the following
|
||||
flags in the @code{flag} argument: @code{YAP_WRITE_QUOTED},
|
||||
@code{YAP_WRITE_HANDLE_VARS}, and @code{YAP_WRITE_IGNORE_OPS}.
|
||||
@code{YAP_WRITE_HANDLE_VARS}, @code{YAP_WRITE_USE_PORTRAY}, and @code{YAP_WRITE_IGNORE_OPS}.
|
||||
|
||||
@item @code{void} YAP_WriteBuffer(@code{YAP_Term} @var{t}, @code{char *}
|
||||
@var{buff}, @code{unsigned int}
|
||||
|
@ -237,7 +237,7 @@ extern X_API void PROTO(YAP_UserBackCPredicate,(CONST char *, YAP_Bool (*)(void)
|
||||
|
||||
/* void UserBackCPredicate(char *name, int *init(), int *cont(), int *cut(), int
|
||||
arity, int extra) */
|
||||
extern X_API void PROTO(YAP_UserBackCutCPredicate,(char *, YAP_Bool (*)(void), YAP_Bool (*)(void), YAP_Bool (*)(void), YAP_Arity, unsigned int));
|
||||
extern X_API void PROTO(YAP_UserBackCutCPredicate,(CONST char *, YAP_Bool (*)(void), YAP_Bool (*)(void), YAP_Bool (*)(void), YAP_Arity, unsigned int));
|
||||
|
||||
/* void CallProlog(YAP_Term t) */
|
||||
extern X_API YAP_Bool PROTO(YAP_CallProlog,(YAP_Term t));
|
||||
@ -245,9 +245,9 @@ extern X_API YAP_Bool PROTO(YAP_CallProlog,(YAP_Term t));
|
||||
/* void cut_fail(void) */
|
||||
extern X_API void PROTO(YAP_cut_up,(void));
|
||||
|
||||
#define YAP_cut_succeed() { YAP_cut_up(); return TRUE; }
|
||||
#define YAP_cut_succeed() do { YAP_cut_up(); return TRUE; } while(0)
|
||||
|
||||
#define YAP_cut_fail() { YAP_cut_up(); return FALSE; }
|
||||
#define YAP_cut_fail() do { YAP_cut_up(); return FALSE; } while(0)
|
||||
|
||||
/* void *AllocSpaceFromYAP_(int) */
|
||||
extern X_API void *PROTO(YAP_AllocSpaceFromYap,(unsigned int));
|
||||
@ -555,6 +555,17 @@ extern X_API int PROTO(YAP_MaxOpPriority,(YAP_Atom, YAP_Term));
|
||||
/* int YAP_OpInfo(Atom, Term, int, int *, int *) */
|
||||
extern X_API int PROTO(YAP_OpInfo,(YAP_Atom, YAP_Term, int, int *, int *));
|
||||
|
||||
/* YAP_Bool YAP_IsExternalDataInStackTerm(YAP_Term) */
|
||||
extern X_API YAP_Bool PROTO(YAP_IsExternalDataInStackTerm,(YAP_Term));
|
||||
|
||||
extern X_API YAP_opaque_tag_t PROTO(YAP_NewOpaqueType,(struct YAP_opaque_handler_struct *));
|
||||
|
||||
extern X_API YAP_Bool PROTO(YAP_IsOpaqueObjectTerm,(YAP_Term, YAP_opaque_tag_t));
|
||||
|
||||
extern X_API YAP_Term PROTO(YAP_NewOpaqueObject,(YAP_opaque_tag_t, size_t));
|
||||
|
||||
extern X_API void *PROTO(YAP_OpaqueObjectFromTerm,(YAP_Term));
|
||||
|
||||
#define YAP_InitCPred(N,A,F) YAP_UserCPredicate(N,F,A)
|
||||
|
||||
__END_DECLS
|
||||
|
@ -101,9 +101,10 @@ typedef double YAP_Float;
|
||||
#define YAP_FULL_BOOT_FROM_PROLOG 4
|
||||
#define YAP_BOOT_ERROR -1
|
||||
|
||||
#define YAP_WRITE_QUOTED 0
|
||||
#define YAP_WRITE_HANDLE_VARS 1
|
||||
#define YAP_WRITE_QUOTED 1
|
||||
#define YAP_WRITE_IGNORE_OPS 2
|
||||
#define YAP_WRITE_HANDLE_VARS 2
|
||||
#define YAP_WRITE_USE_PORTRAY 8
|
||||
|
||||
#define YAP_CONSULT_MODE 0
|
||||
#define YAP_RECONSULT_MODE 1
|
||||
@ -204,6 +205,14 @@ typedef int (*YAP_agc_hook)(void *_Atom);
|
||||
|
||||
typedef void (*YAP_halt_hook)(int exit_code, void *closure);
|
||||
|
||||
typedef int YAP_opaque_tag_t;
|
||||
|
||||
typedef int (*YAP_Opaque_CallOnFail)(void *);
|
||||
|
||||
typedef struct YAP_opaque_handler_struct {
|
||||
YAP_Opaque_CallOnFail fail_handler;
|
||||
} YAP_opaque_handler_t;
|
||||
|
||||
/********* execution mode ***********************/
|
||||
|
||||
typedef enum
|
||||
|
@ -241,6 +241,8 @@ tokenize_arguments([FirstArg|RestArgs],[TokFirstArg|TokRestArgs]):-
|
||||
%
|
||||
% --------------------------------------------------------------------------------------
|
||||
|
||||
:- dynamic attribute/4.
|
||||
|
||||
query_generation([],_,[]).
|
||||
|
||||
query_generation([Conjunction|Conjunctions],ProjectionTerm,[Query|Queries]):-
|
||||
@ -1157,9 +1159,9 @@ column_atom(att(RangeVar,Attribute),QueryList,Diff):-
|
||||
column_atom(Attribute,X2,Diff).
|
||||
|
||||
column_atom(rel(Relation,RangeVar),QueryList,Diff):-
|
||||
column_atom('`',QueryList,X0),
|
||||
column_atom('',QueryList,X0),
|
||||
column_atom(Relation,X0,X1),
|
||||
column_atom('` ',X1,X2),
|
||||
column_atom(' ',X1,X2),
|
||||
column_atom(RangeVar,X2,Diff).
|
||||
|
||||
column_atom('$const$'(String),QueryList,Diff):-
|
||||
|
@ -142,12 +142,6 @@ p2c_putc(const int c) {
|
||||
/*
|
||||
* Function used by YAP to read a char from a string
|
||||
*/
|
||||
static int
|
||||
p2c_getc(void) {
|
||||
if( BUFFER_POS < BUFFER_LEN )
|
||||
return BUFFER_PTR[BUFFER_POS++];
|
||||
return -1;
|
||||
}
|
||||
/*
|
||||
* Writes a term to a stream.
|
||||
*/
|
||||
@ -177,7 +171,7 @@ read_term_from_stream(const int fd) {
|
||||
if ( size> BUFFER_SIZE)
|
||||
expand_buffer(size-BUFFER_SIZE);
|
||||
read(fd,BUFFER_PTR,size); // read term from stream
|
||||
return YAP_Read( p2c_getc );
|
||||
return YAP_ReadBuffer( BUFFER_PTR , NULL);
|
||||
}
|
||||
/*********************************************************************************************
|
||||
* Conversion: Prolog Term->char[] and char[]->Prolog Term
|
||||
@ -229,7 +223,7 @@ string2term(char *const ptr,const size_t *size) {
|
||||
}
|
||||
BUFFER_POS=0;
|
||||
LOCAL_ErrorMessage=NULL;
|
||||
t = YAP_Read(p2c_getc);
|
||||
t = YAP_ReadBuffer( BUFFER_PTR , NULL );
|
||||
if ( t==FALSE ) {
|
||||
write_msg(__FUNCTION__,__FILE__,__LINE__,"FAILED string2term>>>>size:%d %d %s\n",BUFFER_SIZE,strlen(BUFFER_PTR),LOCAL_ErrorMessage);
|
||||
exit(1);
|
||||
|
@ -30,7 +30,6 @@ static char *rcsid = "$Header: /Users/vitor/Yap/yap-cvsbackup/library/mpi/mpi.c,
|
||||
#include <string.h>
|
||||
#include <mpi.h>
|
||||
|
||||
Term STD_PROTO(YAP_Read, (int (*)(void)));
|
||||
void STD_PROTO(YAP_Write, (Term, void (*)(int), int));
|
||||
|
||||
STATIC_PROTO (Int p_mpi_open, (void));
|
||||
@ -105,15 +104,6 @@ mpi_putc(Int ch)
|
||||
}
|
||||
}
|
||||
|
||||
static Int
|
||||
mpi_getc(void)
|
||||
{
|
||||
if( bufptr < bufsize ) return buf[bufptr++];
|
||||
else return -1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/*
|
||||
* C Predicates
|
||||
@ -301,7 +291,7 @@ p_mpi_receive() /* mpi_receive(-data, ?orig, ?tag) */
|
||||
/* parse received string into a Prolog term */
|
||||
|
||||
bufptr = 0;
|
||||
t = YAP_Read( mpi_getc );
|
||||
t = YAP_ReadBuffer( buf, NULL );
|
||||
|
||||
if( t == TermNil ) {
|
||||
retv = FALSE;
|
||||
@ -384,7 +374,7 @@ p_mpi_bcast3() /* mpi_bcast( ?data, +root, +max_size ) */
|
||||
bufptr = 0;
|
||||
|
||||
/* parse received string into a Prolog term */
|
||||
return Yap_unify( YAP_Read(mpi_getc), ARG1 );
|
||||
return Yap_unify( YAP_ReadBuffer( buf, NULL ), ARG1 );
|
||||
}
|
||||
}
|
||||
|
||||
@ -464,7 +454,7 @@ p_mpi_bcast2() /* mpi_bcast( ?data, +root ) */
|
||||
bufstrlen = strlen(buf);
|
||||
bufptr = 0;
|
||||
|
||||
return Yap_unify(YAP_Read( mpi_getc ), ARG1);
|
||||
return Yap_unify(YAP_ReadBuffer( buf, NULL ), ARG1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -120,6 +120,8 @@ char* DIRNAME =NULL
|
||||
char Executable[YAP_FILENAME_MAX] void
|
||||
#endif
|
||||
|
||||
int OpaqueHandlersCount =0
|
||||
struct opaque_handler_struct* OpaqueHandlers =NULL
|
||||
|
||||
#if __simplescalar__
|
||||
char pwd[YAP_FILENAME_MAX] void
|
||||
|
149
packages/CLPBN/clpbn/bp/BPNodeInfo.cpp
Executable file
149
packages/CLPBN/clpbn/bp/BPNodeInfo.cpp
Executable file
@ -0,0 +1,149 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "BPNodeInfo.h"
|
||||
#include "BPSolver.h"
|
||||
|
||||
BPNodeInfo::BPNodeInfo (BayesNode* node)
|
||||
{
|
||||
node_ = node;
|
||||
ds_ = node->getDomainSize();
|
||||
piValsCalc_ = false;
|
||||
ldValsCalc_ = false;
|
||||
nPiMsgsRcv_ = 0;
|
||||
nLdMsgsRcv_ = 0;
|
||||
piVals_.resize (ds_, 1);
|
||||
ldVals_.resize (ds_, 1);
|
||||
const BnNodeSet& childs = node->getChilds();
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
cmsgs_.insert (make_pair (childs[i], false));
|
||||
}
|
||||
const BnNodeSet& parents = node->getParents();
|
||||
for (unsigned i = 0; i < parents.size(); i++) {
|
||||
pmsgs_.insert (make_pair (parents[i], false));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BPNodeInfo::getBeliefs (void) const
|
||||
{
|
||||
double sum = 0.0;
|
||||
ParamSet beliefs (ds_);
|
||||
for (unsigned xi = 0; xi < ds_; xi++) {
|
||||
double prod = piVals_[xi] * ldVals_[xi];
|
||||
beliefs[xi] = prod;
|
||||
sum += prod;
|
||||
}
|
||||
assert (sum);
|
||||
//normalize the beliefs
|
||||
for (unsigned xi = 0; xi < ds_; xi++) {
|
||||
beliefs[xi] /= sum;
|
||||
}
|
||||
return beliefs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BPNodeInfo::readyToSendPiMsgTo (const BayesNode* child) const
|
||||
{
|
||||
for (unsigned i = 0; i < inChildLinks_.size(); i++) {
|
||||
if (inChildLinks_[i]->getSource() != child
|
||||
&& !inChildLinks_[i]->messageWasSended()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BPNodeInfo::readyToSendLambdaMsgTo (const BayesNode* parent) const
|
||||
{
|
||||
for (unsigned i = 0; i < inParentLinks_.size(); i++) {
|
||||
if (inParentLinks_[i]->getSource() != parent
|
||||
&& !inParentLinks_[i]->messageWasSended()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
BPNodeInfo::getPiValue (unsigned idx) const
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
return piVals_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BPNodeInfo::setPiValue (unsigned idx, Param value)
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
piVals_[idx] = value;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
BPNodeInfo::getLambdaValue (unsigned idx) const
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
return ldVals_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BPNodeInfo::setLambdaValue (unsigned idx, Param value)
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
ldVals_[idx] = value;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
BPNodeInfo::getBeliefChange (void)
|
||||
{
|
||||
double change = 0.0;
|
||||
if (oldBeliefs_.size() == 0) {
|
||||
oldBeliefs_ = getBeliefs();
|
||||
change = 9999999999.0;
|
||||
} else {
|
||||
ParamSet currentBeliefs = getBeliefs();
|
||||
for (unsigned xi = 0; xi < ds_; xi++) {
|
||||
change += abs (currentBeliefs[xi] - oldBeliefs_[xi]);
|
||||
}
|
||||
oldBeliefs_ = currentBeliefs;
|
||||
}
|
||||
return change;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BPNodeInfo::receivedBottomInfluence (void) const
|
||||
{
|
||||
// if all lambda values are equal, then neither
|
||||
// this node neither its descendents have evidence,
|
||||
// we can use this to don't send lambda messages his parents
|
||||
bool childInfluenced = false;
|
||||
for (unsigned xi = 1; xi < ds_; xi++) {
|
||||
if (ldVals_[xi] != ldVals_[0]) {
|
||||
childInfluenced = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return childInfluenced;
|
||||
}
|
||||
|
82
packages/CLPBN/clpbn/bp/BPNodeInfo.h
Executable file
82
packages/CLPBN/clpbn/bp/BPNodeInfo.h
Executable file
@ -0,0 +1,82 @@
|
||||
#ifndef BP_BP_NODE_H
|
||||
#define BP_BP_NODE_H
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "BPSolver.h"
|
||||
#include "BayesNode.h"
|
||||
#include "Shared.h"
|
||||
|
||||
//class Edge;
|
||||
|
||||
using namespace std;
|
||||
|
||||
class BPNodeInfo
|
||||
{
|
||||
public:
|
||||
BPNodeInfo (int);
|
||||
BPNodeInfo (BayesNode*);
|
||||
|
||||
ParamSet getBeliefs (void) const;
|
||||
double getPiValue (unsigned) const;
|
||||
void setPiValue (unsigned, Param);
|
||||
double getLambdaValue (unsigned) const;
|
||||
void setLambdaValue (unsigned, Param);
|
||||
double getBeliefChange (void);
|
||||
bool receivedBottomInfluence (void) const;
|
||||
|
||||
ParamSet& getPiValues (void) { return piVals_; }
|
||||
ParamSet& getLambdaValues (void) { return ldVals_; }
|
||||
bool arePiValuesCalculated (void) { return piValsCalc_; }
|
||||
bool areLambdaValuesCalculated (void) { return ldValsCalc_; }
|
||||
void markPiValuesAsCalculated (void) { piValsCalc_ = true; }
|
||||
void markLambdaValuesAsCalculated (void) { ldValsCalc_ = true; }
|
||||
void incNumPiMsgsRcv (void) { nPiMsgsRcv_ ++; }
|
||||
void incNumLambdaMsgsRcv (void) { nLdMsgsRcv_ ++; }
|
||||
|
||||
bool receivedAllPiMessages (void)
|
||||
{
|
||||
return node_->getParents().size() == nPiMsgsRcv_;
|
||||
}
|
||||
|
||||
bool receivedAllLambdaMessages (void)
|
||||
{
|
||||
return node_->getChilds().size() == nLdMsgsRcv_;
|
||||
}
|
||||
|
||||
bool readyToSendPiMsgTo (const BayesNode*) const ;
|
||||
bool readyToSendLambdaMsgTo (const BayesNode*) const;
|
||||
|
||||
CEdgeSet getIncomingParentLinks (void) { return inParentLinks_; }
|
||||
CEdgeSet getIncomingChildLinks (void) { return inChildLinks_; }
|
||||
CEdgeSet getOutcomingParentLinks (void) { return outParentLinks_; }
|
||||
CEdgeSet getOutcomingChildLinks (void) { return outChildLinks_; }
|
||||
|
||||
void addIncomingParentLink (Edge* l) { inParentLinks_.push_back (l); }
|
||||
void addIncomingChildLink (Edge* l) { inChildLinks_.push_back (l); }
|
||||
void addOutcomingParentLink (Edge* l) { outParentLinks_.push_back (l); }
|
||||
void addOutcomingChildLink (Edge* l) { outChildLinks_.push_back (l); }
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BPNodeInfo);
|
||||
|
||||
ParamSet piVals_; // pi values
|
||||
ParamSet ldVals_; // lambda values
|
||||
ParamSet oldBeliefs_;
|
||||
unsigned nPiMsgsRcv_;
|
||||
unsigned nLdMsgsRcv_;
|
||||
bool piValsCalc_;
|
||||
bool ldValsCalc_;
|
||||
EdgeSet inParentLinks_;
|
||||
EdgeSet inChildLinks_;
|
||||
EdgeSet outParentLinks_;
|
||||
EdgeSet outChildLinks_;
|
||||
unsigned ds_;
|
||||
const BayesNode* node_;
|
||||
map<const BayesNode*, bool> pmsgs_;
|
||||
map<const BayesNode*, bool> cmsgs_;
|
||||
};
|
||||
|
||||
#endif //BP_BP_NODE_H
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,259 +1,106 @@
|
||||
#ifndef BP_BPSOLVER_H
|
||||
#define BP_BPSOLVER_H
|
||||
#ifndef BP_BP_SOLVER_H
|
||||
#define BP_BP_SOLVER_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "BayesNet.h"
|
||||
#include "BpNode.h"
|
||||
#include "BPNodeInfo.h"
|
||||
#include "Shared.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class BPSolver;
|
||||
class BPNodeInfo;
|
||||
|
||||
static const string PI = "pi" ;
|
||||
static const string LD = "ld" ;
|
||||
|
||||
|
||||
enum MessageType {PI_MSG, LAMBDA_MSG};
|
||||
enum JointCalcType {CHAIN_RULE, JUNCTION_NODE};
|
||||
|
||||
class BPSolver;
|
||||
struct Edge
|
||||
class Edge
|
||||
{
|
||||
Edge (BayesNode* s, BayesNode* d, MessageType t)
|
||||
{
|
||||
source = s;
|
||||
destination = d;
|
||||
type = t;
|
||||
}
|
||||
string getId (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
type == PI_MSG ? ss << PI : ss << LD;
|
||||
ss << source->getVarId() << "." << destination->getVarId();
|
||||
return ss.str();
|
||||
}
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
type == PI_MSG ? ss << PI << "(" : ss << LD << "(" ;
|
||||
ss << source->getLabel() << " --> " ;
|
||||
ss << destination->getLabel();
|
||||
ss << ")" ;
|
||||
return ss.str();
|
||||
}
|
||||
BayesNode* source;
|
||||
BayesNode* destination;
|
||||
MessageType type;
|
||||
static BPSolver* klass;
|
||||
};
|
||||
|
||||
|
||||
|
||||
/*
|
||||
class BPMessage
|
||||
{
|
||||
BPMessage (BayesNode* parent, BayesNode* child)
|
||||
{
|
||||
parent_ = parent;
|
||||
child_ = child;
|
||||
currPiMsg_.resize (child->getDomainSize(), 1);
|
||||
currLdMsg_.resize (parent->getDomainSize(), 1);
|
||||
nextLdMsg_.resize (parent->getDomainSize(), 1);
|
||||
nextPiMsg_.resize (child->getDomainSize(), 1);
|
||||
piResidual_ = 1.0;
|
||||
ldResidual_ = 1.0;
|
||||
}
|
||||
|
||||
Param getPiMessageValue (int idx) const
|
||||
{
|
||||
assert (idx >=0 && idx < child->getDomainSize());
|
||||
return currPiMsg_[idx];
|
||||
}
|
||||
|
||||
Param getLambdaMessageValue (int idx) const
|
||||
{
|
||||
assert (idx >=0 && idx < parent->getDomainSize());
|
||||
return currLdMsg_[idx];
|
||||
}
|
||||
|
||||
const ParamSet& getPiMessage (void) const
|
||||
{
|
||||
return currPiMsg_;
|
||||
}
|
||||
|
||||
const ParamSet& getLambdaMessage (void) const
|
||||
{
|
||||
return currLdMsg_;
|
||||
}
|
||||
|
||||
ParamSet& piNextMessageReference (void)
|
||||
{
|
||||
return nextPiMsg_;
|
||||
}
|
||||
|
||||
ParamSet& lambdaNextMessageReference (const BayesNode* source)
|
||||
{
|
||||
return nextLdMsg_;
|
||||
}
|
||||
|
||||
void updatePiMessage (void)
|
||||
{
|
||||
currPiMsg_ = nextPiMsg_;
|
||||
Util::normalize (currPiMsg_);
|
||||
}
|
||||
|
||||
void updateLambdaMessage (void)
|
||||
{
|
||||
currLdMsg_ = nextLdMsg_;
|
||||
Util::normalize (currLdMsg_);
|
||||
}
|
||||
|
||||
double getPiResidual (void)
|
||||
{
|
||||
return piResidual_;
|
||||
}
|
||||
|
||||
double getLambdaResidual (void)
|
||||
{
|
||||
return ldResidual_;
|
||||
}
|
||||
|
||||
void updatePiResidual (void)
|
||||
{
|
||||
piResidual_ = Util::getL1dist (currPiMsg_, nextPiMsg_);
|
||||
}
|
||||
|
||||
void updateLambdaResidual (void)
|
||||
{
|
||||
ldResidual_ = Util::getL1dist (currLdMsg_, nextLdMsg_);
|
||||
}
|
||||
|
||||
void clearPiResidual (void)
|
||||
{
|
||||
piResidual_ = 0.0;
|
||||
}
|
||||
|
||||
void clearLambdaResidual (void)
|
||||
{
|
||||
ldResidual_ = 0.0;
|
||||
}
|
||||
|
||||
BayesNode* parent_;
|
||||
BayesNode* child_;
|
||||
ParamSet currPiMsg_; // current pi messages
|
||||
ParamSet currLdMsg_; // current lambda messages
|
||||
ParamSet nextPiMsg_;
|
||||
ParamSet nextLdMsg_;
|
||||
Param piResidual_;
|
||||
Param ldResidual_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class NodeInfo
|
||||
{
|
||||
NodeInfo (BayesNode* node)
|
||||
{
|
||||
node_ = node;
|
||||
piVals_.resize (node->getDomainSize(), 1);
|
||||
ldVals_.resize (node->getDomainSize(), 1);
|
||||
}
|
||||
|
||||
ParamSet getBeliefs (void) const
|
||||
{
|
||||
double sum = 0.0;
|
||||
ParamSet beliefs (node_->getDomainSize());
|
||||
for (int xi = 0; xi < node_->getDomainSize(); xi++) {
|
||||
double prod = piVals_[xi] * ldVals_[xi];
|
||||
beliefs[xi] = prod;
|
||||
sum += prod;
|
||||
}
|
||||
assert (sum);
|
||||
//normalize the beliefs
|
||||
for (int xi = 0; xi < node_->getDomainSize(); xi++) {
|
||||
beliefs[xi] /= sum;
|
||||
}
|
||||
return beliefs;
|
||||
}
|
||||
|
||||
double getPiValue (int idx) const
|
||||
{
|
||||
assert (idx >=0 && idx < node_->getDomainSize());
|
||||
return piVals_[idx];
|
||||
}
|
||||
|
||||
void setPiValue (int idx, double value)
|
||||
{
|
||||
assert (idx >=0 && idx < node_->getDomainSize());
|
||||
piVals_[idx] = value;
|
||||
}
|
||||
|
||||
double getLambdaValue (int idx) const
|
||||
{
|
||||
assert (idx >=0 && idx < node_->getDomainSize());
|
||||
return ldVals_[idx];
|
||||
}
|
||||
|
||||
void setLambdaValue (int idx, double value)
|
||||
{
|
||||
assert (idx >=0 && idx < node_->getDomainSize());
|
||||
ldVals_[idx] = value;
|
||||
}
|
||||
|
||||
ParamSet& getPiValues (void)
|
||||
{
|
||||
return piVals_;
|
||||
}
|
||||
|
||||
ParamSet& getLambdaValues (void)
|
||||
{
|
||||
return ldVals_;
|
||||
}
|
||||
|
||||
double getBeliefChange (void)
|
||||
{
|
||||
double change = 0.0;
|
||||
if (oldBeliefs_.size() == 0) {
|
||||
oldBeliefs_ = getBeliefs();
|
||||
change = MAX_CHANGE_;
|
||||
} else {
|
||||
ParamSet currentBeliefs = getBeliefs();
|
||||
for (int xi = 0; xi < node_->getDomainSize(); xi++) {
|
||||
change += abs (currentBeliefs[xi] - oldBeliefs_[xi]);
|
||||
public:
|
||||
Edge (BayesNode* s, BayesNode* d, MessageType t)
|
||||
{
|
||||
source_ = s;
|
||||
destin_ = d;
|
||||
type_ = t;
|
||||
if (type_ == PI_MSG) {
|
||||
currMsg_.resize (s->getDomainSize(), 1);
|
||||
nextMsg_.resize (s->getDomainSize(), 1);
|
||||
} else {
|
||||
currMsg_.resize (d->getDomainSize(), 1);
|
||||
nextMsg_.resize (d->getDomainSize(), 1);
|
||||
}
|
||||
oldBeliefs_ = currentBeliefs;
|
||||
msgSended_ = false;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
return change;
|
||||
}
|
||||
|
||||
//void setMessage (ParamSet msg)
|
||||
//{
|
||||
// Util::normalize (msg);
|
||||
// residual_ = Util::getMaxNorm (currMsg_, msg);
|
||||
// currMsg_ = msg;
|
||||
//}
|
||||
|
||||
bool hasReceivedChildInfluence (void) const
|
||||
{
|
||||
// if all lambda values are equal, then neither
|
||||
// this node neither its descendents have evidence,
|
||||
// we can use this to don't send lambda messages his parents
|
||||
bool childInfluenced = false;
|
||||
for (int xi = 1; xi < node_->getDomainSize(); xi++) {
|
||||
if (ldVals_[xi] != ldVals_[0]) {
|
||||
childInfluenced = true;
|
||||
break;
|
||||
void setNextMessage (CParamSet msg)
|
||||
{
|
||||
nextMsg_ = msg;
|
||||
Util::normalize (nextMsg_);
|
||||
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
|
||||
}
|
||||
|
||||
void updateMessage (void)
|
||||
{
|
||||
currMsg_ = nextMsg_;
|
||||
if (DL >= 3) {
|
||||
cout << "updating " << toString() << endl;
|
||||
}
|
||||
msgSended_ = true;
|
||||
}
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
|
||||
}
|
||||
return childInfluenced;
|
||||
}
|
||||
|
||||
BayesNode* node_;
|
||||
ParamSet piVals_; // pi values
|
||||
ParamSet ldVals_; // lambda values
|
||||
ParamSet oldBeliefs_;
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
if (type_ == PI_MSG) {
|
||||
ss << PI;
|
||||
} else if (type_ == LAMBDA_MSG) {
|
||||
ss << LD;
|
||||
} else {
|
||||
abort();
|
||||
}
|
||||
ss << "(" << source_->getLabel();
|
||||
ss << " --> " << destin_->getLabel() << ")" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
BayesNode* getSource (void) const { return source_; }
|
||||
BayesNode* getDestination (void) const { return destin_; }
|
||||
MessageType getMessageType (void) const { return type_; }
|
||||
CParamSet getMessage (void) const { return currMsg_; }
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
double getResidual (void) const { return residual_; }
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
private:
|
||||
BayesNode* source_;
|
||||
BayesNode* destin_;
|
||||
MessageType type_;
|
||||
ParamSet currMsg_;
|
||||
ParamSet nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
*/
|
||||
|
||||
|
||||
bool compareResidual (const Edge&, const Edge&);
|
||||
|
||||
class BPSolver : public Solver
|
||||
{
|
||||
public:
|
||||
@ -261,190 +108,85 @@ class BPSolver : public Solver
|
||||
~BPSolver (void);
|
||||
|
||||
void runSolver (void);
|
||||
ParamSet getPosterioriOf (const Variable* var) const;
|
||||
ParamSet getJointDistribution (const NodeSet&) const;
|
||||
ParamSet getPosterioriOf (Vid) const;
|
||||
ParamSet getJointDistributionOf (const VidSet&);
|
||||
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BPSolver);
|
||||
|
||||
void initializeSolver (void);
|
||||
void incorporateEvidence (BayesNode*);
|
||||
void runPolyTreeSolver (void);
|
||||
void polyTreePiMessage (BayesNode*, BayesNode*);
|
||||
void polyTreeLambdaMessage (BayesNode*, BayesNode*);
|
||||
void runGenericSolver (void);
|
||||
void runLoopySolver (void);
|
||||
void maxResidualSchedule (void);
|
||||
bool converged (void) const;
|
||||
void updatePiValues (BayesNode*);
|
||||
void updateLambdaValues (BayesNode*);
|
||||
void calculateNextPiMessage (BayesNode*, BayesNode*);
|
||||
void calculateNextLambdaMessage (BayesNode*, BayesNode*);
|
||||
ParamSet calculateNextLambdaMessage (Edge* edge);
|
||||
ParamSet calculateNextPiMessage (Edge* edge);
|
||||
ParamSet getJointByJunctionNode (const VidSet&) const;
|
||||
ParamSet getJointByChainRule (const VidSet&) const;
|
||||
void printMessageStatusOf (const BayesNode*) const;
|
||||
void printAllMessageStatus (void) const;
|
||||
// inlines
|
||||
void updatePiMessage (BayesNode*, BayesNode*);
|
||||
void updateLambdaMessage (BayesNode*, BayesNode*);
|
||||
void calculateNextMessage (const Edge&);
|
||||
void updateMessage (const Edge&);
|
||||
void updateValues (const Edge&);
|
||||
double getResidual (const Edge&) const;
|
||||
void updateResidual (const Edge&);
|
||||
void clearResidual (const Edge&);
|
||||
BpNode* M (const BayesNode*) const;
|
||||
friend bool compareResidual (const Edge&, const Edge&);
|
||||
|
||||
ParamSet getMessage (Edge* edge)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << " calculating " << edge->toString() << endl;
|
||||
}
|
||||
if (edge->getMessageType() == PI_MSG) {
|
||||
return calculateNextPiMessage (edge);
|
||||
} else if (edge->getMessageType() == LAMBDA_MSG) {
|
||||
return calculateNextLambdaMessage (edge);
|
||||
} else {
|
||||
abort();
|
||||
}
|
||||
return ParamSet();
|
||||
}
|
||||
|
||||
void updateValues (Edge* edge)
|
||||
{
|
||||
if (!edge->getDestination()->hasEvidence()) {
|
||||
if (edge->getMessageType() == PI_MSG) {
|
||||
updatePiValues (edge->getDestination());
|
||||
} else if (edge->getMessageType() == LAMBDA_MSG) {
|
||||
updateLambdaValues (edge->getDestination());
|
||||
} else {
|
||||
abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BPNodeInfo* M (const BayesNode* node) const
|
||||
{
|
||||
assert (node);
|
||||
assert (node == bn_->getBayesNode (node->getVarId()));
|
||||
assert (node->getIndex() < nodesI_.size());
|
||||
return nodesI_[node->getIndex()];
|
||||
}
|
||||
|
||||
const BayesNet* bn_;
|
||||
vector<BpNode*> msgs_;
|
||||
Schedule schedule_;
|
||||
int nIter_;
|
||||
int maxIter_;
|
||||
double accuracy_;
|
||||
vector<Edge> updateOrder_;
|
||||
bool forceGenericSolver_;
|
||||
vector<BPNodeInfo*> nodesI_;
|
||||
unsigned nIter_;
|
||||
vector<Edge*> links_;
|
||||
bool useAlwaysLoopySolver_;
|
||||
JointCalcType jointCalcType_;
|
||||
|
||||
struct compare
|
||||
{
|
||||
inline bool operator() (const Edge& e1, const Edge& e2)
|
||||
inline bool operator() (const Edge* e1, const Edge* e2)
|
||||
{
|
||||
return compareResidual (e1, e2);
|
||||
return e1->getResidual() > e2->getResidual();
|
||||
}
|
||||
};
|
||||
|
||||
typedef multiset<Edge, compare> SortedOrder;
|
||||
typedef multiset<Edge*, compare> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef unordered_map<string, SortedOrder::iterator> EdgeMap;
|
||||
typedef map<Edge*, SortedOrder::iterator> EdgeMap;
|
||||
EdgeMap edgeMap_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BPSolver::updatePiMessage (BayesNode* source, BayesNode* destination)
|
||||
{
|
||||
M(source)->updatePiMessage(destination);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BPSolver::updateLambdaMessage (BayesNode* source, BayesNode* destination)
|
||||
{
|
||||
M(destination)->updateLambdaMessage(source);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BPSolver::calculateNextMessage (const Edge& e)
|
||||
{
|
||||
if (DL >= 1) {
|
||||
cout << "calculating " << e.toString() << endl;
|
||||
}
|
||||
if (e.type == PI_MSG) {
|
||||
calculateNextPiMessage (e.source, e.destination);
|
||||
} else {
|
||||
calculateNextLambdaMessage (e.source, e.destination);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BPSolver::updateMessage (const Edge& e)
|
||||
{
|
||||
if (DL >= 1) {
|
||||
cout << "updating " << e.toString() << endl;
|
||||
}
|
||||
if (e.type == PI_MSG) {
|
||||
M(e.source)->updatePiMessage(e.destination);
|
||||
} else {
|
||||
M(e.destination)->updateLambdaMessage(e.source);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BPSolver::updateValues (const Edge& e)
|
||||
{
|
||||
if (!e.destination->hasEvidence()) {
|
||||
if (e.type == PI_MSG) {
|
||||
updatePiValues (e.destination);
|
||||
} else {
|
||||
updateLambdaValues (e.destination);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline double
|
||||
BPSolver::getResidual (const Edge& e) const
|
||||
{
|
||||
if (e.type == PI_MSG) {
|
||||
return M(e.source)->getPiResidual(e.destination);
|
||||
} else {
|
||||
return M(e.destination)->getLambdaResidual(e.source);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BPSolver::updateResidual (const Edge& e)
|
||||
{
|
||||
if (e.type == PI_MSG) {
|
||||
M(e.source)->updatePiResidual(e.destination);
|
||||
} else {
|
||||
M(e.destination)->updateLambdaResidual(e.source);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BPSolver::clearResidual (const Edge& e)
|
||||
{
|
||||
if (e.type == PI_MSG) {
|
||||
M(e.source)->clearPiResidual(e.destination);
|
||||
} else {
|
||||
M(e.destination)->clearLambdaResidual(e.source);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
compareResidual (const Edge& e1, const Edge& e2)
|
||||
{
|
||||
double residual1;
|
||||
double residual2;
|
||||
if (e1.type == PI_MSG) {
|
||||
residual1 = Edge::klass->M(e1.source)->getPiResidual(e1.destination);
|
||||
} else {
|
||||
residual1 = Edge::klass->M(e1.destination)->getLambdaResidual(e1.source);
|
||||
}
|
||||
if (e2.type == PI_MSG) {
|
||||
residual2 = Edge::klass->M(e2.source)->getPiResidual(e2.destination);
|
||||
} else {
|
||||
residual2 = Edge::klass->M(e2.destination)->getLambdaResidual(e2.source);
|
||||
}
|
||||
return residual1 > residual2;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline BpNode*
|
||||
BPSolver::M (const BayesNode* node) const
|
||||
{
|
||||
assert (node);
|
||||
assert (node == bn_->getNode (node->getVarId()));
|
||||
assert (node->getIndex() < msgs_.size());
|
||||
return msgs_[node->getIndex()];
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
#endif //BP_BP_SOLVER_H
|
||||
|
||||
|
@ -1,30 +1,24 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <map>
|
||||
|
||||
#include "xmlParser/xmlParser.h"
|
||||
|
||||
#include "BayesNet.h"
|
||||
|
||||
|
||||
BayesNet::BayesNet (void)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNet::BayesNet (const char* fileName)
|
||||
{
|
||||
map<string, Domain> domains;
|
||||
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
|
||||
// only the first network is parsed, others are ignored
|
||||
XMLNode xNode = xMainNode.getChildNode ("NETWORK");
|
||||
int nVars = xNode.nChildNode ("VARIABLE");
|
||||
for (int i = 0; i < nVars; i++) {
|
||||
unsigned nVars = xNode.nChildNode ("VARIABLE");
|
||||
for (unsigned i = 0; i < nVars; i++) {
|
||||
XMLNode var = xNode.getChildNode ("VARIABLE", i);
|
||||
string type = var.getAttribute ("TYPE");
|
||||
if (type != "nature") {
|
||||
@ -32,9 +26,9 @@ BayesNet::BayesNet (const char* fileName)
|
||||
abort();
|
||||
}
|
||||
Domain domain;
|
||||
string label = var.getChildNode("NAME").getText();
|
||||
int domainSize = var.nChildNode ("OUTCOME");
|
||||
for (int j = 0; j < domainSize; j++) {
|
||||
string varLabel = var.getChildNode("NAME").getText();
|
||||
unsigned dsize = var.nChildNode ("OUTCOME");
|
||||
for (unsigned j = 0; j < dsize; j++) {
|
||||
if (var.getChildNode("OUTCOME", j).getText() == 0) {
|
||||
stringstream ss;
|
||||
ss << j + 1;
|
||||
@ -43,37 +37,37 @@ BayesNet::BayesNet (const char* fileName)
|
||||
domain.push_back (var.getChildNode("OUTCOME", j).getText());
|
||||
}
|
||||
}
|
||||
domains.insert (make_pair (label, domain));
|
||||
domains.insert (make_pair (varLabel, domain));
|
||||
}
|
||||
|
||||
int nDefs = xNode.nChildNode ("DEFINITION");
|
||||
unsigned nDefs = xNode.nChildNode ("DEFINITION");
|
||||
if (nVars != nDefs) {
|
||||
cerr << "error: different number of variables and definitions";
|
||||
cerr << endl;
|
||||
cerr << "error: different number of variables and definitions" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
queue<int> indexes;
|
||||
for (int i = 0; i < nDefs; i++) {
|
||||
queue<unsigned> indexes;
|
||||
for (unsigned i = 0; i < nDefs; i++) {
|
||||
indexes.push (i);
|
||||
}
|
||||
|
||||
while (!indexes.empty()) {
|
||||
int index = indexes.front();
|
||||
unsigned index = indexes.front();
|
||||
indexes.pop();
|
||||
XMLNode def = xNode.getChildNode ("DEFINITION", index);
|
||||
string label = def.getChildNode("FOR").getText();
|
||||
string varLabel = def.getChildNode("FOR").getText();
|
||||
map<string, Domain>::const_iterator iter;
|
||||
iter = domains.find (label);
|
||||
iter = domains.find (varLabel);
|
||||
if (iter == domains.end()) {
|
||||
cerr << "error: unknow variable `" << label << "'" << endl;
|
||||
cerr << "error: unknow variable `" << varLabel << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
bool processItLatter = false;
|
||||
NodeSet parents;
|
||||
int nParams = iter->second.size();
|
||||
BnNodeSet parents;
|
||||
unsigned nParams = iter->second.size();
|
||||
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
|
||||
string parentLabel = def.getChildNode("GIVEN", j).getText();
|
||||
BayesNode* parentNode = getNode (parentLabel);
|
||||
BayesNode* parentNode = getBayesNode (parentLabel);
|
||||
if (parentNode) {
|
||||
nParams *= parentNode->getDomainSize();
|
||||
parents.push_back (parentNode);
|
||||
@ -95,7 +89,7 @@ BayesNet::BayesNet (const char* fileName)
|
||||
}
|
||||
|
||||
if (!processItLatter) {
|
||||
int count = 0;
|
||||
unsigned count = 0;
|
||||
ParamSet params (nParams);
|
||||
stringstream s (def.getChildNode("TABLE").getText());
|
||||
while (!s.eof() && count < nParams) {
|
||||
@ -104,11 +98,11 @@ BayesNet::BayesNet (const char* fileName)
|
||||
}
|
||||
if (count != nParams) {
|
||||
cerr << "error: invalid number of parameters " ;
|
||||
cerr << "for variable `" << label << "'" << endl;
|
||||
cerr << "for variable `" << varLabel << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
params = reorderParameters (params, iter->second.size());
|
||||
addNode (label, iter->second, parents, params);
|
||||
addNode (varLabel, iter->second, parents, params);
|
||||
}
|
||||
}
|
||||
setIndexes();
|
||||
@ -118,7 +112,6 @@ BayesNet::BayesNet (const char* fileName)
|
||||
|
||||
BayesNet::~BayesNet (void)
|
||||
{
|
||||
Statistics::writeStats();
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
delete nodes_[i];
|
||||
}
|
||||
@ -127,25 +120,25 @@ BayesNet::~BayesNet (void)
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::addNode (unsigned varId)
|
||||
BayesNet::addNode (Vid vid)
|
||||
{
|
||||
indexMap_.insert (make_pair (varId, nodes_.size()));
|
||||
nodes_.push_back (new BayesNode (varId));
|
||||
indexMap_.insert (make_pair (vid, nodes_.size()));
|
||||
nodes_.push_back (new BayesNode (vid));
|
||||
return nodes_.back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::addNode (unsigned varId,
|
||||
BayesNet::addNode (Vid vid,
|
||||
unsigned dsize,
|
||||
int evidence,
|
||||
NodeSet& parents,
|
||||
BnNodeSet& parents,
|
||||
Distribution* dist)
|
||||
{
|
||||
indexMap_.insert (make_pair (varId, nodes_.size()));
|
||||
indexMap_.insert (make_pair (vid, nodes_.size()));
|
||||
nodes_.push_back (new BayesNode (
|
||||
varId, dsize, evidence, parents, dist));
|
||||
vid, dsize, evidence, parents, dist));
|
||||
return nodes_.back();
|
||||
}
|
||||
|
||||
@ -154,7 +147,7 @@ BayesNet::addNode (unsigned varId,
|
||||
BayesNode*
|
||||
BayesNet::addNode (string label,
|
||||
Domain domain,
|
||||
NodeSet& parents,
|
||||
BnNodeSet& parents,
|
||||
ParamSet& params)
|
||||
{
|
||||
indexMap_.insert (make_pair (nodes_.size(), nodes_.size()));
|
||||
@ -169,9 +162,9 @@ BayesNet::addNode (string label,
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::getNode (unsigned varId) const
|
||||
BayesNet::getBayesNode (Vid vid) const
|
||||
{
|
||||
IndexMap::const_iterator it = indexMap_.find(varId);
|
||||
IndexMap::const_iterator it = indexMap_.find (vid);
|
||||
if (it == indexMap_.end()) {
|
||||
return 0;
|
||||
} else {
|
||||
@ -182,7 +175,7 @@ BayesNet::getNode (unsigned varId) const
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::getNode (string label) const
|
||||
BayesNet::getBayesNode (string label) const
|
||||
{
|
||||
BayesNode* node = 0;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
@ -196,6 +189,15 @@ BayesNet::getNode (string label) const
|
||||
|
||||
|
||||
|
||||
|
||||
Variable*
|
||||
BayesNet::getVariable (Vid vid) const
|
||||
{
|
||||
return getBayesNode (vid);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::addDistribution (Distribution* dist)
|
||||
{
|
||||
@ -219,15 +221,15 @@ BayesNet::getDistribution (unsigned distId) const
|
||||
|
||||
|
||||
|
||||
const NodeSet&
|
||||
BayesNet::getNodes (void) const
|
||||
const BnNodeSet&
|
||||
BayesNet::getBayesNodes (void) const
|
||||
{
|
||||
return nodes_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
unsigned
|
||||
BayesNet::getNumberOfNodes (void) const
|
||||
{
|
||||
return nodes_.size();
|
||||
@ -235,10 +237,10 @@ BayesNet::getNumberOfNodes (void) const
|
||||
|
||||
|
||||
|
||||
NodeSet
|
||||
BnNodeSet
|
||||
BayesNet::getRootNodes (void) const
|
||||
{
|
||||
NodeSet roots;
|
||||
BnNodeSet roots;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (nodes_[i]->isRoot()) {
|
||||
roots.push_back (nodes_[i]);
|
||||
@ -249,10 +251,10 @@ BayesNet::getRootNodes (void) const
|
||||
|
||||
|
||||
|
||||
NodeSet
|
||||
BnNodeSet
|
||||
BayesNet::getLeafNodes (void) const
|
||||
{
|
||||
NodeSet leafs;
|
||||
BnNodeSet leafs;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (nodes_[i]->isLeaf()) {
|
||||
leafs.push_back (nodes_[i]);
|
||||
@ -276,30 +278,32 @@ BayesNet::getVariables (void) const
|
||||
|
||||
|
||||
BayesNet*
|
||||
BayesNet::pruneNetwork (BayesNode* queryNode) const
|
||||
BayesNet::getMinimalRequesiteNetwork (Vid vid) const
|
||||
{
|
||||
NodeSet queryNodes;
|
||||
queryNodes.push_back (queryNode);
|
||||
return pruneNetwork (queryNodes);
|
||||
return getMinimalRequesiteNetwork (VidSet() = {vid});
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNet*
|
||||
BayesNet::pruneNetwork (const NodeSet& interestedVars) const
|
||||
BayesNet::getMinimalRequesiteNetwork (const VidSet& queryVids) const
|
||||
{
|
||||
/*
|
||||
cout << "interested vars: " ;
|
||||
for (unsigned i = 0; i < interestedVars.size(); i++) {
|
||||
cout << interestedVars[i]->getLabel() << " " ;
|
||||
BnNodeSet queryVars;
|
||||
for (unsigned i = 0; i < queryVids.size(); i++) {
|
||||
assert (getBayesNode (queryVids[i]));
|
||||
queryVars.push_back (getBayesNode (queryVids[i]));
|
||||
}
|
||||
cout << endl;
|
||||
*/
|
||||
// cout << "query vars: " ;
|
||||
// for (unsigned i = 0; i < queryVars.size(); i++) {
|
||||
// cout << queryVars[i]->getLabel() << " " ;
|
||||
// }
|
||||
// cout << endl;
|
||||
|
||||
vector<StateInfo*> states (nodes_.size(), 0);
|
||||
|
||||
Scheduling scheduling;
|
||||
for (NodeSet::const_iterator it = interestedVars.begin();
|
||||
it != interestedVars.end(); it++) {
|
||||
for (BnNodeSet::const_iterator it = queryVars.begin();
|
||||
it != queryVars.end(); it++) {
|
||||
scheduling.push (ScheduleInfo (*it, false, true));
|
||||
}
|
||||
|
||||
@ -378,18 +382,18 @@ BayesNet::constructGraph (BayesNet* bn,
|
||||
states[i]->markedOnTop;
|
||||
}
|
||||
if (isRequired) {
|
||||
NodeSet parents;
|
||||
BnNodeSet parents;
|
||||
if (states[i]->markedOnTop) {
|
||||
const NodeSet& ps = nodes_[i]->getParents();
|
||||
const BnNodeSet& ps = nodes_[i]->getParents();
|
||||
for (unsigned j = 0; j < ps.size(); j++) {
|
||||
BayesNode* parent = bn->getNode (ps[j]->getVarId());
|
||||
BayesNode* parent = bn->getBayesNode (ps[j]->getVarId());
|
||||
if (!parent) {
|
||||
parent = bn->addNode (ps[j]->getVarId());
|
||||
}
|
||||
parents.push_back (parent);
|
||||
}
|
||||
}
|
||||
BayesNode* node = bn->getNode (nodes_[i]->getVarId());
|
||||
BayesNode* node = bn->getBayesNode (nodes_[i]->getVarId());
|
||||
if (node) {
|
||||
node->setData (nodes_[i]->getDomainSize(),
|
||||
nodes_[i]->getEvidence(), parents,
|
||||
@ -411,65 +415,6 @@ BayesNet::constructGraph (BayesNet* bn,
|
||||
bn->setIndexes();
|
||||
}
|
||||
|
||||
/*
|
||||
void
|
||||
BayesNet::constructGraph (BayesNet* bn,
|
||||
const vector<StateInfo*>& states) const
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (states[i]) {
|
||||
if (nodes_[i]->hasEvidence() && states[i]->visited) {
|
||||
NodeSet parents;
|
||||
if (states[i]->markedOnTop) {
|
||||
const NodeSet& ps = nodes_[i]->getParents();
|
||||
for (unsigned j = 0; j < ps.size(); j++) {
|
||||
BayesNode* parent = bn->getNode (ps[j]->getVarId());
|
||||
if (parent == 0) {
|
||||
parent = bn->addNode (ps[j]->getVarId());
|
||||
}
|
||||
parents.push_back (parent);
|
||||
}
|
||||
}
|
||||
|
||||
BayesNode* n = bn->getNode (nodes_[i]->getVarId());
|
||||
if (n) {
|
||||
n->setData (nodes_[i]->getDomainSize(),
|
||||
nodes_[i]->getEvidence(), parents,
|
||||
nodes_[i]->getDistribution());
|
||||
} else {
|
||||
bn->addNode (nodes_[i]->getVarId(),
|
||||
nodes_[i]->getDomainSize(),
|
||||
nodes_[i]->getEvidence(), parents,
|
||||
nodes_[i]->getDistribution());
|
||||
}
|
||||
|
||||
} else if (states[i]->markedOnTop) {
|
||||
NodeSet parents;
|
||||
const NodeSet& ps = nodes_[i]->getParents();
|
||||
for (unsigned j = 0; j < ps.size(); j++) {
|
||||
BayesNode* parent = bn->getNode (ps[j]->getVarId());
|
||||
if (parent == 0) {
|
||||
parent = bn->addNode (ps[j]->getVarId());
|
||||
}
|
||||
parents.push_back (parent);
|
||||
}
|
||||
|
||||
BayesNode* n = bn->getNode (nodes_[i]->getVarId());
|
||||
if (n) {
|
||||
n->setData (nodes_[i]->getDomainSize(),
|
||||
nodes_[i]->getEvidence(), parents,
|
||||
nodes_[i]->getDistribution());
|
||||
} else {
|
||||
bn->addNode (nodes_[i]->getVarId(),
|
||||
nodes_[i]->getDomainSize(),
|
||||
nodes_[i]->getEvidence(), parents,
|
||||
nodes_[i]->getDistribution());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}*/
|
||||
|
||||
|
||||
|
||||
bool
|
||||
@ -480,70 +425,6 @@ BayesNet::isSingleConnected (void) const
|
||||
|
||||
|
||||
|
||||
vector<DomainConf>
|
||||
BayesNet::getDomainConfigurationsOf (const NodeSet& nodes)
|
||||
{
|
||||
int nConfs = 1;
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
nConfs *= nodes[i]->getDomainSize();
|
||||
}
|
||||
|
||||
vector<DomainConf> confs (nConfs);
|
||||
for (int i = 0; i < nConfs; i++) {
|
||||
confs[i].resize (nodes.size());
|
||||
}
|
||||
|
||||
int nReps = 1;
|
||||
for (int i = nodes.size() - 1; i >= 0; i--) {
|
||||
int index = 0;
|
||||
while (index < nConfs) {
|
||||
for (int j = 0; j < nodes[i]->getDomainSize(); j++) {
|
||||
for (int r = 0; r < nReps; r++) {
|
||||
confs[index][i] = j;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
nReps *= nodes[i]->getDomainSize();
|
||||
}
|
||||
|
||||
return confs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<string>
|
||||
BayesNet::getInstantiations (const NodeSet& parents_)
|
||||
{
|
||||
int nParents = parents_.size();
|
||||
int rowSize = 1;
|
||||
for (unsigned i = 0; i < parents_.size(); i++) {
|
||||
rowSize *= parents_[i]->getDomainSize();
|
||||
}
|
||||
int nReps = 1;
|
||||
vector<string> headers (rowSize);
|
||||
for (int i = nParents - 1; i >= 0; i--) {
|
||||
Domain domain = parents_[i]->getDomain();
|
||||
int index = 0;
|
||||
while (index < rowSize) {
|
||||
for (int j = 0; j < parents_[i]->getDomainSize(); j++) {
|
||||
for (int r = 0; r < nReps; r++) {
|
||||
if (headers[index] != "") {
|
||||
headers[index] = domain[j] + "," + headers[index];
|
||||
} else {
|
||||
headers[index] = domain[j];
|
||||
}
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
nReps *= parents_[i]->getDomainSize();
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::setIndexes (void)
|
||||
{
|
||||
@ -565,7 +446,7 @@ BayesNet::freeDistributions (void)
|
||||
|
||||
|
||||
void
|
||||
BayesNet::printNetwork (void) const
|
||||
BayesNet::printGraphicalModel (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
cout << *nodes_[i];
|
||||
@ -575,32 +456,11 @@ BayesNet::printNetwork (void) const
|
||||
|
||||
|
||||
void
|
||||
BayesNet::printNetworkToFile (const char* fileName) const
|
||||
BayesNet::exportToDotFormat (const char* fileName,
|
||||
bool showNeighborless,
|
||||
CVidSet& highlightVids) const
|
||||
{
|
||||
string s = "../../" ;
|
||||
s += fileName;
|
||||
ofstream out (s.c_str());
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "BayesNet::printToFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
out << *nodes_[i];
|
||||
}
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::exportToDotFile (const char* fileName,
|
||||
bool showNeighborless,
|
||||
const NodeSet& highlightNodes) const
|
||||
{
|
||||
string s = "../../" ;
|
||||
s+= fileName;
|
||||
ofstream out (s.c_str());
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "BayesNet::exportToDotFile()" << endl;
|
||||
@ -608,13 +468,6 @@ BayesNet::exportToDotFile (const char* fileName,
|
||||
}
|
||||
|
||||
out << "digraph \"" << fileName << "\" {" << endl;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
const NodeSet& childs = nodes_[i]->getChilds();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
out << '"' << nodes_[i]->getLabel() << '"' << " -> " ;
|
||||
out << '"' << childs[j]->getLabel() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (showNeighborless || nodes_[i]->hasNeighbors()) {
|
||||
@ -627,9 +480,24 @@ BayesNet::exportToDotFile (const char* fileName,
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < highlightNodes.size(); i++) {
|
||||
out << '"' << highlightNodes[i]->getLabel() << '"' ;
|
||||
out << " [shape=box]" << endl;
|
||||
for (unsigned i = 0; i < highlightVids.size(); i++) {
|
||||
BayesNode* node = getBayesNode (highlightVids[i]);
|
||||
if (node) {
|
||||
out << '"' << node->getLabel() << '"' ;
|
||||
// out << " [shape=polygon, sides=6]" << endl;
|
||||
out << " [shape=box3d]" << endl;
|
||||
} else {
|
||||
cout << "error: invalid variable id: " << highlightVids[i] << endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
const BnNodeSet& childs = nodes_[i]->getChilds();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
out << '"' << nodes_[i]->getLabel() << '"' << " -> " ;
|
||||
out << '"' << childs[j]->getLabel() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
@ -639,11 +507,9 @@ BayesNet::exportToDotFile (const char* fileName,
|
||||
|
||||
|
||||
void
|
||||
BayesNet::exportToBifFile (const char* fileName) const
|
||||
BayesNet::exportToBifFormat (const char* fileName) const
|
||||
{
|
||||
string s = "../../" ;
|
||||
s += fileName;
|
||||
ofstream out (s.c_str());
|
||||
ofstream out (fileName);
|
||||
if(!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "BayesNet::exportToBifFile()" << endl;
|
||||
@ -666,7 +532,7 @@ BayesNet::exportToBifFile (const char* fileName) const
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
out << "<DEFINITION>" << endl;
|
||||
out << "\t<FOR>" << nodes_[i]->getLabel() << "</FOR>" << endl;
|
||||
const NodeSet& parents = nodes_[i]->getParents();
|
||||
const BnNodeSet& parents = nodes_[i]->getParents();
|
||||
for (unsigned j = 0; j < parents.size(); j++) {
|
||||
out << "\t<GIVEN>" << parents[j]->getLabel();
|
||||
out << "</GIVEN>" << endl;
|
||||
@ -682,7 +548,7 @@ BayesNet::exportToBifFile (const char* fileName) const
|
||||
}
|
||||
out << "</NETWORK>" << endl;
|
||||
out << "</BIF>" << endl << endl;
|
||||
out.close();
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
@ -731,8 +597,8 @@ vector<int>
|
||||
BayesNet::getAdjacentNodes (int v) const
|
||||
{
|
||||
vector<int> adjacencies;
|
||||
const NodeSet& parents = nodes_[v]->getParents();
|
||||
const NodeSet& childs = nodes_[v]->getChilds();
|
||||
const BnNodeSet& parents = nodes_[v]->getParents();
|
||||
const BnNodeSet& childs = nodes_[v]->getChilds();
|
||||
for (unsigned i = 0; i < parents.size(); i++) {
|
||||
adjacencies.push_back (parents[i]->getIndex());
|
||||
}
|
||||
@ -745,8 +611,8 @@ BayesNet::getAdjacentNodes (int v) const
|
||||
|
||||
|
||||
ParamSet
|
||||
BayesNet::reorderParameters (const ParamSet& params,
|
||||
int domainSize) const
|
||||
BayesNet::reorderParameters (CParamSet params,
|
||||
unsigned domainSize) const
|
||||
{
|
||||
// the interchange format for bayesian networks keeps the probabilities
|
||||
// in the following order:
|
||||
@ -773,15 +639,15 @@ BayesNet::reorderParameters (const ParamSet& params,
|
||||
|
||||
|
||||
ParamSet
|
||||
BayesNet::revertParameterReorder (const ParamSet& params,
|
||||
int domainSize) const
|
||||
BayesNet::revertParameterReorder (CParamSet params,
|
||||
unsigned domainSize) const
|
||||
{
|
||||
unsigned count = 0;
|
||||
unsigned rowSize = params.size() / domainSize;
|
||||
ParamSet reordered;
|
||||
while (reordered.size() < params.size()) {
|
||||
unsigned idx = count;
|
||||
for (int i = 0; i < domainSize; i++) {
|
||||
for (unsigned i = 0; i < domainSize; i++) {
|
||||
reordered.push_back (params[idx]);
|
||||
idx += rowSize;
|
||||
}
|
||||
|
@ -4,8 +4,6 @@
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <list>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <map>
|
||||
|
||||
#include "GraphicalModel.h"
|
||||
@ -46,42 +44,42 @@ struct StateInfo
|
||||
|
||||
typedef vector<Distribution*> DistSet;
|
||||
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
|
||||
typedef unordered_map<unsigned, unsigned> Histogram;
|
||||
typedef unordered_map<unsigned, double> Times;
|
||||
typedef map<unsigned, unsigned> Histogram;
|
||||
typedef map<unsigned, double> Times;
|
||||
|
||||
|
||||
class BayesNet : public GraphicalModel
|
||||
{
|
||||
public:
|
||||
BayesNet (void);
|
||||
BayesNet (void) {};
|
||||
BayesNet (const char*);
|
||||
~BayesNet (void);
|
||||
|
||||
BayesNode* addNode (unsigned);
|
||||
BayesNode* addNode (unsigned, unsigned, int, NodeSet&, Distribution*);
|
||||
BayesNode* addNode (string, Domain, NodeSet&, ParamSet&);
|
||||
BayesNode* getNode (unsigned) const;
|
||||
BayesNode* getNode (string) const;
|
||||
BayesNode* addNode (unsigned, unsigned, int, BnNodeSet&,
|
||||
Distribution*);
|
||||
BayesNode* addNode (string, Domain, BnNodeSet&, ParamSet&);
|
||||
BayesNode* getBayesNode (Vid) const;
|
||||
BayesNode* getBayesNode (string) const;
|
||||
Variable* getVariable (Vid) const;
|
||||
void addDistribution (Distribution*);
|
||||
Distribution* getDistribution (unsigned) const;
|
||||
const NodeSet& getNodes (void) const;
|
||||
int getNumberOfNodes (void) const;
|
||||
NodeSet getRootNodes (void) const;
|
||||
NodeSet getLeafNodes (void) const;
|
||||
const BnNodeSet& getBayesNodes (void) const;
|
||||
unsigned getNumberOfNodes (void) const;
|
||||
BnNodeSet getRootNodes (void) const;
|
||||
BnNodeSet getLeafNodes (void) const;
|
||||
VarSet getVariables (void) const;
|
||||
BayesNet* pruneNetwork (BayesNode*) const;
|
||||
BayesNet* pruneNetwork (const NodeSet& queryNodes) const;
|
||||
void constructGraph (BayesNet*, const vector<StateInfo*>&) const;
|
||||
BayesNet* getMinimalRequesiteNetwork (Vid) const;
|
||||
BayesNet* getMinimalRequesiteNetwork (const VidSet&) const;
|
||||
void constructGraph (BayesNet*,
|
||||
const vector<StateInfo*>&) const;
|
||||
bool isSingleConnected (void) const;
|
||||
static vector<DomainConf> getDomainConfigurationsOf (const NodeSet&);
|
||||
static vector<string> getInstantiations (const NodeSet& nodes);
|
||||
void setIndexes (void);
|
||||
void freeDistributions (void);
|
||||
void printNetwork (void) const;
|
||||
void printNetworkToFile (const char*) const;
|
||||
void exportToDotFile (const char*, bool = true,
|
||||
const NodeSet& = NodeSet()) const;
|
||||
void exportToBifFile (const char*) const;
|
||||
void printGraphicalModel (void) const;
|
||||
void exportToDotFormat (const char*, bool = true,
|
||||
CVidSet = VidSet()) const;
|
||||
void exportToBifFormat (const char*) const;
|
||||
|
||||
static Histogram histogram_;
|
||||
static Times times_;
|
||||
@ -93,12 +91,12 @@ class BayesNet : public GraphicalModel
|
||||
bool containsUndirectedCycle (int, int,
|
||||
vector<bool>&)const;
|
||||
vector<int> getAdjacentNodes (int) const ;
|
||||
ParamSet reorderParameters (const ParamSet&, int) const;
|
||||
ParamSet revertParameterReorder (const ParamSet&, int) const;
|
||||
ParamSet reorderParameters (CParamSet, unsigned) const;
|
||||
ParamSet revertParameterReorder (CParamSet, unsigned) const;
|
||||
void scheduleParents (const BayesNode*, Scheduling&) const;
|
||||
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
||||
|
||||
NodeSet nodes_;
|
||||
BnNodeSet nodes_;
|
||||
DistSet dists_;
|
||||
IndexMap indexMap_;
|
||||
};
|
||||
@ -108,8 +106,8 @@ class BayesNet : public GraphicalModel
|
||||
inline void
|
||||
BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
|
||||
{
|
||||
const NodeSet& ps = n->getParents();
|
||||
for (NodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) {
|
||||
const BnNodeSet& ps = n->getParents();
|
||||
for (BnNodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) {
|
||||
sch.push (ScheduleInfo (*it, false, true));
|
||||
}
|
||||
}
|
||||
@ -119,11 +117,11 @@ BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
|
||||
inline void
|
||||
BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
|
||||
{
|
||||
const NodeSet& cs = n->getChilds();
|
||||
for (NodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) {
|
||||
const BnNodeSet& cs = n->getChilds();
|
||||
for (BnNodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) {
|
||||
sch.push (ScheduleInfo (*it, true, false));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif //BP_BAYES_NET_H
|
||||
|
||||
|
@ -1,26 +1,21 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "BayesNode.h"
|
||||
|
||||
|
||||
BayesNode::BayesNode (unsigned varId) : Variable (varId)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode::BayesNode (unsigned varId,
|
||||
BayesNode::BayesNode (Vid vid,
|
||||
unsigned dsize,
|
||||
int evidence,
|
||||
const NodeSet& parents,
|
||||
Distribution* dist) : Variable(varId, dsize, evidence)
|
||||
const BnNodeSet& parents,
|
||||
Distribution* dist) : Variable (vid, dsize, evidence)
|
||||
{
|
||||
parents_ = parents;
|
||||
dist_ = dist;
|
||||
parents_ = parents;
|
||||
dist_ = dist;
|
||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||
parents[i]->addChild (this);
|
||||
}
|
||||
@ -28,15 +23,15 @@ BayesNode::BayesNode (unsigned varId,
|
||||
|
||||
|
||||
|
||||
BayesNode::BayesNode (unsigned varId,
|
||||
BayesNode::BayesNode (Vid vid,
|
||||
string label,
|
||||
const Domain& domain,
|
||||
const NodeSet& parents,
|
||||
Distribution* dist) : Variable(varId, domain)
|
||||
const BnNodeSet& parents,
|
||||
Distribution* dist) : Variable (vid, domain,
|
||||
NO_EVIDENCE, label)
|
||||
{
|
||||
label_ = new string (label);
|
||||
parents_ = parents;
|
||||
dist_ = dist;
|
||||
parents_ = parents;
|
||||
dist_ = dist;
|
||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||
parents[i]->addChild (this);
|
||||
}
|
||||
@ -47,11 +42,11 @@ BayesNode::BayesNode (unsigned varId,
|
||||
void
|
||||
BayesNode::setData (unsigned dsize,
|
||||
int evidence,
|
||||
const NodeSet& parents,
|
||||
const BnNodeSet& parents,
|
||||
Distribution* dist)
|
||||
{
|
||||
setDomainSize (dsize);
|
||||
evidence_ = evidence;
|
||||
setEvidence (evidence);
|
||||
parents_ = parents;
|
||||
dist_ = dist;
|
||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||
@ -135,19 +130,18 @@ BayesNode::getCptEntries (void)
|
||||
{
|
||||
if (dist_->entries.size() == 0) {
|
||||
unsigned rowSize = getRowSize();
|
||||
unsigned nParents = parents_.size();
|
||||
vector<DomainConf> confs (rowSize);
|
||||
vector<DConf> confs (rowSize);
|
||||
|
||||
for (unsigned i = 0; i < rowSize; i++) {
|
||||
confs[i].resize (nParents);
|
||||
confs[i].resize (parents_.size());
|
||||
}
|
||||
|
||||
int nReps = 1;
|
||||
for (int i = nParents - 1; i >= 0; i--) {
|
||||
unsigned nReps = 1;
|
||||
for (int i = parents_.size() - 1; i >= 0; i--) {
|
||||
unsigned index = 0;
|
||||
while (index < rowSize) {
|
||||
for (int j = 0; j < parents_[i]->getDomainSize(); j++) {
|
||||
for (int r = 0; r < nReps; r++) {
|
||||
for (unsigned j = 0; j < parents_[i]->getDomainSize(); j++) {
|
||||
for (unsigned r = 0; r < nReps; r++) {
|
||||
confs[index][i] = j;
|
||||
index++;
|
||||
}
|
||||
@ -184,7 +178,7 @@ BayesNode::cptEntryToString (const CptEntry& entry) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "p(" ;
|
||||
const DomainConf& conf = entry.getParentConfigurations();
|
||||
const DConf& conf = entry.getDomainConfiguration();
|
||||
int row = entry.getParameterIndex() / getRowSize();
|
||||
ss << getDomain()[row];
|
||||
if (parents_.size() > 0) {
|
||||
@ -207,7 +201,7 @@ BayesNode::cptEntryToString (int row, const CptEntry& entry) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "p(" ;
|
||||
const DomainConf& conf = entry.getParentConfigurations();
|
||||
const DConf& conf = entry.getDomainConfiguration();
|
||||
ss << getDomain()[row];
|
||||
if (parents_.size() > 0) {
|
||||
ss << "|" ;
|
||||
@ -227,16 +221,16 @@ BayesNode::cptEntryToString (int row, const CptEntry& entry) const
|
||||
vector<string>
|
||||
BayesNode::getDomainHeaders (void) const
|
||||
{
|
||||
int nParents = parents_.size();
|
||||
int rowSize = getRowSize();
|
||||
int nReps = 1;
|
||||
unsigned nParents = parents_.size();
|
||||
unsigned rowSize = getRowSize();
|
||||
unsigned nReps = 1;
|
||||
vector<string> headers (rowSize);
|
||||
for (int i = nParents - 1; i >= 0; i--) {
|
||||
Domain domain = parents_[i]->getDomain();
|
||||
int index = 0;
|
||||
unsigned index = 0;
|
||||
while (index < rowSize) {
|
||||
for (int j = 0; j < parents_[i]->getDomainSize(); j++) {
|
||||
for (int r = 0; r < nReps; r++) {
|
||||
for (unsigned j = 0; j < parents_[i]->getDomainSize(); j++) {
|
||||
for (unsigned r = 0; r < nReps; r++) {
|
||||
if (headers[index] != "") {
|
||||
headers[index] = domain[j] + "," + headers[index];
|
||||
} else {
|
||||
@ -270,7 +264,7 @@ operator << (ostream& o, const BayesNode& node)
|
||||
o << endl;
|
||||
|
||||
o << "Parents: " ;
|
||||
const NodeSet& parents = node.getParents();
|
||||
const BnNodeSet& parents = node.getParents();
|
||||
if (parents.size() != 0) {
|
||||
for (unsigned int i = 0; i < parents.size() - 1; i++) {
|
||||
o << parents[i]->getLabel() << ", " ;
|
||||
@ -280,7 +274,7 @@ operator << (ostream& o, const BayesNode& node)
|
||||
o << endl;
|
||||
|
||||
o << "Childs: " ;
|
||||
const NodeSet& childs = node.getChilds();
|
||||
const BnNodeSet& childs = node.getChilds();
|
||||
if (childs.size() != 0) {
|
||||
for (unsigned int i = 0; i < childs.size() - 1; i++) {
|
||||
o << childs[i]->getLabel() << ", " ;
|
||||
|
@ -1,9 +1,7 @@
|
||||
#ifndef BP_BAYESNODE_H
|
||||
#define BP_BAYESNODE_H
|
||||
#ifndef BP_BAYES_NODE_H
|
||||
#define BP_BAYES_NODE_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "Variable.h"
|
||||
#include "CptEntry.h"
|
||||
@ -16,11 +14,12 @@ using namespace std;
|
||||
class BayesNode : public Variable
|
||||
{
|
||||
public:
|
||||
BayesNode (unsigned);
|
||||
BayesNode (unsigned, unsigned, int, const NodeSet&, Distribution*);
|
||||
BayesNode (unsigned, string, const Domain&, const NodeSet&, Distribution*);
|
||||
BayesNode (Vid vid) : Variable (vid) {}
|
||||
BayesNode (Vid, unsigned, int, const BnNodeSet&, Distribution*);
|
||||
BayesNode (Vid, string, const Domain&, const BnNodeSet&, Distribution*);
|
||||
|
||||
void setData (unsigned, int, const NodeSet&, Distribution*);
|
||||
void setData (unsigned, int, const BnNodeSet&,
|
||||
Distribution*);
|
||||
void addChild (BayesNode*);
|
||||
Distribution* getDistribution (void);
|
||||
const ParamSet& getParameters (void);
|
||||
@ -34,11 +33,21 @@ class BayesNode : public Variable
|
||||
int getIndexOfParent (const BayesNode*) const;
|
||||
string cptEntryToString (const CptEntry&) const;
|
||||
string cptEntryToString (int, const CptEntry&) const;
|
||||
// inlines
|
||||
const NodeSet& getParents (void) const;
|
||||
const NodeSet& getChilds (void) const;
|
||||
double getProbability (int, const CptEntry& entry);
|
||||
unsigned getRowSize (void) const;
|
||||
|
||||
const BnNodeSet& getParents (void) const { return parents_; }
|
||||
const BnNodeSet& getChilds (void) const { return childs_; }
|
||||
|
||||
unsigned getRowSize (void) const
|
||||
{
|
||||
return dist_->params.size() / getDomainSize();
|
||||
}
|
||||
|
||||
double getProbability (int row, const CptEntry& entry)
|
||||
{
|
||||
int col = entry.getParameterIndex();
|
||||
int idx = (row * getRowSize()) + col;
|
||||
return dist_->params[idx];
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BayesNode);
|
||||
@ -46,46 +55,12 @@ class BayesNode : public Variable
|
||||
Domain getDomainHeaders (void) const;
|
||||
friend ostream& operator << (ostream&, const BayesNode&);
|
||||
|
||||
NodeSet parents_;
|
||||
NodeSet childs_;
|
||||
BnNodeSet parents_;
|
||||
BnNodeSet childs_;
|
||||
Distribution* dist_;
|
||||
};
|
||||
|
||||
ostream& operator << (ostream&, const BayesNode&);
|
||||
|
||||
|
||||
|
||||
inline const NodeSet&
|
||||
BayesNode::getParents (void) const
|
||||
{
|
||||
return parents_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline const NodeSet&
|
||||
BayesNode::getChilds (void) const
|
||||
{
|
||||
return childs_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline double
|
||||
BayesNode::getProbability (int row, const CptEntry& entry)
|
||||
{
|
||||
int col = entry.getParameterIndex();
|
||||
int idx = (row * getRowSize()) + col;
|
||||
return dist_->params[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
BayesNode::getRowSize (void) const
|
||||
{
|
||||
return dist_->params.size() / getDomainSize();
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif //BP_BAYES_NODE_H
|
||||
|
||||
|
198
packages/CLPBN/clpbn/bp/CountingBP.cpp
Normal file
198
packages/CLPBN/clpbn/bp/CountingBP.cpp
Normal file
@ -0,0 +1,198 @@
|
||||
#include "CountingBP.h"
|
||||
|
||||
|
||||
CountingBP::~CountingBP (void)
|
||||
{
|
||||
delete lfg_;
|
||||
delete fg_;
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
links_.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
CountingBP::getPosterioriOf (Vid vid) const
|
||||
{
|
||||
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
||||
ParamSet probs;
|
||||
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->getDomainSize(), 0.0);
|
||||
probs[var->getEvidence()] = 1.0;
|
||||
} else {
|
||||
probs.resize (var->getDomainSize(), 1.0);
|
||||
CLinkSet links = varsI_[var->getIndex()]->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
ParamSet msg = links[i]->getMessage();
|
||||
CountingBPLink* l = static_cast<CountingBPLink*> (links[i]);
|
||||
Util::pow (msg, l->getNumberOfEdges());
|
||||
for (unsigned j = 0; j < msg.size(); j++) {
|
||||
probs[j] *= msg[j];
|
||||
}
|
||||
}
|
||||
Util::normalize (probs);
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CountingBP::initializeSolver (void)
|
||||
{
|
||||
lfg_ = new LiftedFG (*fg_);
|
||||
unsigned nUncVars = fg_->getFgVarNodes().size();
|
||||
unsigned nUncFactors = fg_->getFactors().size();
|
||||
CFgVarSet vars = fg_->getFgVarNodes();
|
||||
unsigned nNeighborLessVars = 0;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
CFactorSet factors = vars[i]->getFactors();
|
||||
if (factors.size() == 1 && factors[0]->getFgVarNodes().size() == 1) {
|
||||
nNeighborLessVars ++;
|
||||
}
|
||||
}
|
||||
// cout << "UNCOMPRESSED FACTOR GRAPH" << endl;
|
||||
// fg_->printGraphicalModel();
|
||||
fg_->exportToDotFormat ("uncompress.dot");
|
||||
|
||||
FactorGraph *temp;
|
||||
temp = fg_;
|
||||
fg_ = lfg_->getCompressedFactorGraph();
|
||||
unsigned nCompVars = fg_->getFgVarNodes().size();
|
||||
unsigned nCompFactors = fg_->getFactors().size();
|
||||
|
||||
Statistics::updateCompressingStats (nUncVars,
|
||||
nUncFactors,
|
||||
nCompVars,
|
||||
nCompFactors,
|
||||
nNeighborLessVars);
|
||||
|
||||
cout << "COMPRESSED FACTOR GRAPH" << endl;
|
||||
fg_->printGraphicalModel();
|
||||
//fg_->exportToDotFormat ("compress.dot");
|
||||
|
||||
SPSolver::initializeSolver();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CountingBP::createLinks (void)
|
||||
{
|
||||
const FactorClusterSet fcs = lfg_->getFactorClusters();
|
||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||
const VarClusterSet vcs = fcs[i]->getVarClusters();
|
||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
|
||||
links_.push_back (
|
||||
new CountingBPLink (fcs[i]->getRepresentativeFactor(),
|
||||
vcs[j]->getRepresentativeVariable(), c));
|
||||
//cout << (links_.back())->toString() << " edge count =" << c << endl;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CountingBP::deleteJunction (Factor* f, FgVarNode*)
|
||||
{
|
||||
f->freeDistribution();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CountingBP::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIter_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
if (DL >= 2 && DL < 5) {
|
||||
cout << "calculating " << links_[i]->toString() << endl;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
cout << endl << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
Link* link = *it;
|
||||
if (DL >= 2) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
}
|
||||
if (link->getResidual() < SolverOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
link->updateMessage();
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
CFactorSet factorNeighbors = link->getVariable()->getFactors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
CLinkSet links = factorsI_[factorNeighbors[i]->getIndex()]->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) { //FIXMEFIXME
|
||||
if (DL >= 2 && DL < 5) {
|
||||
cout << " calculating " << links[j]->toString() << endl;
|
||||
}
|
||||
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||
LinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
CountingBP::getVar2FactorMsg (const Link* link) const
|
||||
{
|
||||
const FgVarNode* src = link->getVariable();
|
||||
const Factor* dest = link->getFactor();
|
||||
ParamSet msg;
|
||||
if (src->hasEvidence()) {
|
||||
cout << "has evidence" << endl;
|
||||
msg.resize (src->getDomainSize(), 0.0);
|
||||
msg[src->getEvidence()] = link->getMessage()[src->getEvidence()];
|
||||
cout << "-> " << link->getVariable()->getLabel() << " " << link->getFactor()->getLabel() << endl;
|
||||
cout << "-> p2s " << Util::parametersToString (msg) << endl;
|
||||
} else {
|
||||
msg = link->getMessage();
|
||||
}
|
||||
const CountingBPLink* l = static_cast<const CountingBPLink*> (link);
|
||||
Util::pow (msg, l->getNumberOfEdges() - 1);
|
||||
CLinkSet links = varsI_[src->getIndex()]->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dest) {
|
||||
ParamSet msgFromFactor = links[i]->getMessage();
|
||||
CountingBPLink* l = static_cast<CountingBPLink*> (links[i]);
|
||||
Util::pow (msgFromFactor, l->getNumberOfEdges());
|
||||
for (unsigned j = 0; j < msgFromFactor.size(); j++) {
|
||||
msg[j] *= msgFromFactor[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
45
packages/CLPBN/clpbn/bp/CountingBP.h
Normal file
45
packages/CLPBN/clpbn/bp/CountingBP.h
Normal file
@ -0,0 +1,45 @@
|
||||
#ifndef BP_COUNTING_BP_H
|
||||
#define BP_COUNTING_BP_H
|
||||
|
||||
#include "SPSolver.h"
|
||||
#include "LiftedFG.h"
|
||||
|
||||
class Factor;
|
||||
class FgVarNode;
|
||||
|
||||
class CountingBPLink : public Link
|
||||
{
|
||||
public:
|
||||
CountingBPLink (Factor* f, FgVarNode* v, unsigned c) : Link (f, v)
|
||||
{
|
||||
edgeCount_ = c;
|
||||
}
|
||||
|
||||
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
||||
|
||||
private:
|
||||
unsigned edgeCount_;
|
||||
};
|
||||
|
||||
|
||||
class CountingBP : public SPSolver
|
||||
{
|
||||
public:
|
||||
CountingBP (FactorGraph& fg) : SPSolver (fg) { }
|
||||
~CountingBP (void);
|
||||
|
||||
ParamSet getPosterioriOf (Vid) const;
|
||||
|
||||
private:
|
||||
void initializeSolver (void);
|
||||
void createLinks (void);
|
||||
void deleteJunction (Factor*, FgVarNode*);
|
||||
|
||||
void maxResidualSchedule (void);
|
||||
ParamSet getVar2FactorMsg (const Link*) const;
|
||||
|
||||
LiftedFG* lfg_;
|
||||
};
|
||||
|
||||
#endif // BP_COUNTING_BP_H
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef BP_CPTENTRY_H
|
||||
#define BP_CPTENTRY_H
|
||||
#ifndef BP_CPT_ENTRY_H
|
||||
#define BP_CPT_ENTRY_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
@ -10,62 +10,34 @@ using namespace std;
|
||||
class CptEntry
|
||||
{
|
||||
public:
|
||||
CptEntry (unsigned, const vector<unsigned>&);
|
||||
CptEntry (unsigned index, const DConf& conf)
|
||||
{
|
||||
index_ = index;
|
||||
conf_ = conf;
|
||||
}
|
||||
|
||||
unsigned getParameterIndex (void) const;
|
||||
const vector<unsigned>& getParentConfigurations (void) const;
|
||||
bool matchConstraints (const DomainConstr&) const;
|
||||
bool matchConstraints (const vector<DomainConstr>&) const;
|
||||
unsigned getParameterIndex (void) const { return index_; }
|
||||
const DConf& getDomainConfiguration (void) const { return conf_; }
|
||||
|
||||
bool matchConstraints (const DConstraint& constr) const
|
||||
{
|
||||
return conf_[constr.first] == constr.second;
|
||||
}
|
||||
|
||||
bool matchConstraints (const vector<DConstraint>& constrs) const
|
||||
{
|
||||
for (unsigned j = 0; j < constrs.size(); j++) {
|
||||
if (conf_[constrs[j].first] != constrs[j].second) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned index_;
|
||||
vector<unsigned> confs_;
|
||||
unsigned index_;
|
||||
DConf conf_;
|
||||
};
|
||||
|
||||
#endif //BP_CPT_ENTRY_H
|
||||
|
||||
|
||||
inline
|
||||
CptEntry::CptEntry (unsigned index, const vector<unsigned>& confs)
|
||||
{
|
||||
index_ = index;
|
||||
confs_ = confs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
CptEntry::getParameterIndex (void) const
|
||||
{
|
||||
return index_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline const vector<unsigned>&
|
||||
CptEntry::getParentConfigurations (void) const
|
||||
{
|
||||
return confs_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
CptEntry::matchConstraints (const DomainConstr& constr) const
|
||||
{
|
||||
return confs_[constr.first] == constr.second;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
CptEntry::matchConstraints (const vector<DomainConstr>& constrs) const
|
||||
{
|
||||
for (unsigned j = 0; j < constrs.size(); j++) {
|
||||
if (confs_[constrs[j].first] != constrs[j].second) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -2,8 +2,8 @@
|
||||
#define BP_DISTRIBUTION_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "CptEntry.h"
|
||||
#include "Shared.h"
|
||||
|
||||
using namespace std;
|
||||
@ -11,16 +11,18 @@ using namespace std;
|
||||
struct Distribution
|
||||
{
|
||||
public:
|
||||
Distribution (unsigned id)
|
||||
Distribution (unsigned id, bool shared = false)
|
||||
{
|
||||
this->id = id;
|
||||
this->params = params;
|
||||
this->shared = shared;
|
||||
}
|
||||
|
||||
Distribution (const ParamSet& params)
|
||||
Distribution (const ParamSet& params, bool shared = false)
|
||||
{
|
||||
this->id = -1;
|
||||
this->params = params;
|
||||
this->shared = shared;
|
||||
}
|
||||
|
||||
void updateParameters (const ParamSet& params)
|
||||
@ -31,10 +33,11 @@ struct Distribution
|
||||
unsigned id;
|
||||
ParamSet params;
|
||||
vector<CptEntry> entries;
|
||||
bool shared;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (Distribution);
|
||||
};
|
||||
|
||||
#endif
|
||||
#endif //BP_DISTRIBUTION_H
|
||||
|
||||
|
@ -1,37 +1,37 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "Factor.h"
|
||||
#include "FgVarNode.h"
|
||||
|
||||
|
||||
int Factor::indexCount_ = 0;
|
||||
|
||||
Factor::Factor (FgVarNode* var) {
|
||||
vs_.push_back (var);
|
||||
int nParams = var->getDomainSize();
|
||||
// create a uniform distribution
|
||||
double val = 1.0 / nParams;
|
||||
ps_ = ParamSet (nParams, val);
|
||||
id_ = indexCount_;
|
||||
indexCount_ ++;
|
||||
Factor::Factor (const Factor& g)
|
||||
{
|
||||
copyFactor (g);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const FgVarSet& vars) {
|
||||
vs_ = vars;
|
||||
Factor::Factor (FgVarNode* var)
|
||||
{
|
||||
Factor (FgVarSet() = {var});
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const FgVarSet& vars)
|
||||
{
|
||||
vars_ = vars;
|
||||
int nParams = 1;
|
||||
for (unsigned i = 0; i < vs_.size(); i++) {
|
||||
nParams *= vs_[i]->getDomainSize();
|
||||
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||
nParams *= vars_[i]->getDomainSize();
|
||||
}
|
||||
// create a uniform distribution
|
||||
double val = 1.0 / nParams;
|
||||
ps_ = ParamSet (nParams, val);
|
||||
id_ = indexCount_;
|
||||
indexCount_ ++;
|
||||
dist_ = new Distribution (ParamSet (nParams, val));
|
||||
}
|
||||
|
||||
|
||||
@ -39,10 +39,17 @@ Factor::Factor (const FgVarSet& vars) {
|
||||
Factor::Factor (FgVarNode* var,
|
||||
const ParamSet& params)
|
||||
{
|
||||
vs_.push_back (var);
|
||||
ps_ = params;
|
||||
id_ = indexCount_;
|
||||
indexCount_ ++;
|
||||
vars_.push_back (var);
|
||||
dist_ = new Distribution (params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (FgVarSet& vars,
|
||||
Distribution* dist)
|
||||
{
|
||||
vars_ = vars;
|
||||
dist_ = dist;
|
||||
}
|
||||
|
||||
|
||||
@ -50,42 +57,8 @@ Factor::Factor (FgVarNode* var,
|
||||
Factor::Factor (const FgVarSet& vars,
|
||||
const ParamSet& params)
|
||||
{
|
||||
vs_ = vars;
|
||||
ps_ = params;
|
||||
id_ = indexCount_;
|
||||
indexCount_ ++;
|
||||
}
|
||||
|
||||
|
||||
|
||||
const FgVarSet&
|
||||
Factor::getFgVarNodes (void) const
|
||||
{
|
||||
return vs_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FgVarSet&
|
||||
Factor::getFgVarNodes (void)
|
||||
{
|
||||
return vs_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
const ParamSet&
|
||||
Factor::getParameters (void) const
|
||||
{
|
||||
return ps_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet&
|
||||
Factor::getParameters (void)
|
||||
{
|
||||
return ps_;
|
||||
vars_ = vars;
|
||||
dist_ = new Distribution (params);
|
||||
}
|
||||
|
||||
|
||||
@ -93,75 +66,95 @@ Factor::getParameters (void)
|
||||
void
|
||||
Factor::setParameters (const ParamSet& params)
|
||||
{
|
||||
//cout << "ps size: " << ps_.size() << endl;
|
||||
//cout << "params size: " << params.size() << endl;
|
||||
assert (ps_.size() == params.size());
|
||||
ps_ = params;
|
||||
assert (dist_->params.size() == params.size());
|
||||
dist_->updateParameters (params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor&
|
||||
Factor::operator= (const Factor& g)
|
||||
void
|
||||
Factor::copyFactor (const Factor& g)
|
||||
{
|
||||
FgVarSet vars = g.getFgVarNodes();
|
||||
ParamSet params = g.getParameters();
|
||||
return *this;
|
||||
vars_ = g.getFgVarNodes();
|
||||
dist_ = new Distribution (g.getDistribution()->params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor&
|
||||
Factor::operator*= (const Factor& g)
|
||||
void
|
||||
Factor::multiplyByFactor (const Factor& g, const vector<CptEntry>* entries)
|
||||
{
|
||||
FgVarSet gVs = g.getFgVarNodes();
|
||||
if (vars_.size() == 0) {
|
||||
copyFactor (g);
|
||||
return;
|
||||
}
|
||||
|
||||
const FgVarSet& gVs = g.getFgVarNodes();
|
||||
const ParamSet& gPs = g.getParameters();
|
||||
|
||||
bool hasCommonVars = false;
|
||||
vector<int> varIndexes;
|
||||
for (unsigned i = 0; i < gVs.size(); i++) {
|
||||
int idx = getIndexOf (gVs[i]);
|
||||
if (idx == -1) {
|
||||
insertVariable (gVs[i]);
|
||||
varIndexes.push_back (vs_.size() - 1);
|
||||
} else {
|
||||
hasCommonVars = true;
|
||||
varIndexes.push_back (idx);
|
||||
}
|
||||
}
|
||||
|
||||
if (hasCommonVars) {
|
||||
vector<int> offsets (gVs.size());
|
||||
offsets[gVs.size() - 1] = 1;
|
||||
for (int i = gVs.size() - 2; i >= 0; i--) {
|
||||
offsets[i] = offsets[i + 1] * gVs[i + 1]->getDomainSize();
|
||||
}
|
||||
vector<CptEntry> entries = getCptEntries();
|
||||
for (unsigned i = 0; i < entries.size(); i++) {
|
||||
int idx = 0;
|
||||
const DomainConf conf = entries[i].getParentConfigurations();
|
||||
for (unsigned j = 0; j < varIndexes.size(); j++) {
|
||||
idx += offsets[j] * conf[varIndexes[j]];
|
||||
bool factorsAreEqual = true;
|
||||
if (gVs.size() == vars_.size()) {
|
||||
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||
if (gVs[i] != vars_[i]) {
|
||||
factorsAreEqual = false;
|
||||
break;
|
||||
}
|
||||
//cout << "ps_[" << i << "] = " << ps_[i] << " * " ;
|
||||
//cout << gPs[idx] << " , idx = " << idx << endl;
|
||||
ps_[i] = ps_[i] * gPs[idx];
|
||||
}
|
||||
} else {
|
||||
// if the originally factors doesn't have common factors.
|
||||
// we don't have to make domain comparations
|
||||
unsigned idx = 0;
|
||||
for (unsigned i = 0; i < ps_.size(); i++) {
|
||||
//cout << "ps_[" << i << "] = " << ps_[i] << " * " ;
|
||||
//cout << gPs[idx] << " , idx = " << idx << endl;
|
||||
ps_[i] = ps_[i] * gPs[idx];
|
||||
idx ++;
|
||||
if (idx >= gPs.size()) {
|
||||
idx = 0;
|
||||
factorsAreEqual = false;
|
||||
}
|
||||
|
||||
if (factorsAreEqual) {
|
||||
// optimization: if the factors contain the same set of variables,
|
||||
// we can do 1 to 1 operations on the parameteres
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
dist_->params[i] *= gPs[i];
|
||||
}
|
||||
} else {
|
||||
bool hasCommonVars = false;
|
||||
vector<unsigned> gVsIndexes;
|
||||
for (unsigned i = 0; i < gVs.size(); i++) {
|
||||
int idx = getIndexOf (gVs[i]);
|
||||
if (idx == -1) {
|
||||
insertVariable (gVs[i]);
|
||||
gVsIndexes.push_back (vars_.size() - 1);
|
||||
} else {
|
||||
hasCommonVars = true;
|
||||
gVsIndexes.push_back (idx);
|
||||
}
|
||||
}
|
||||
if (hasCommonVars) {
|
||||
vector<unsigned> gVsOffsets (gVs.size());
|
||||
gVsOffsets[gVs.size() - 1] = 1;
|
||||
for (int i = gVs.size() - 2; i >= 0; i--) {
|
||||
gVsOffsets[i] = gVsOffsets[i + 1] * gVs[i + 1]->getDomainSize();
|
||||
}
|
||||
|
||||
if (entries == 0) {
|
||||
entries = &getCptEntries();
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < entries->size(); i++) {
|
||||
unsigned idx = 0;
|
||||
const DConf& conf = (*entries)[i].getDomainConfiguration();
|
||||
for (unsigned j = 0; j < gVsIndexes.size(); j++) {
|
||||
idx += gVsOffsets[j] * conf[ gVsIndexes[j] ];
|
||||
}
|
||||
dist_->params[i] = dist_->params[i] * gPs[idx];
|
||||
}
|
||||
} else {
|
||||
// optimization: if the original factors doesn't have common variables,
|
||||
// we don't need to marry the states of the common variables
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
dist_->params[i] *= gPs[count];
|
||||
count ++;
|
||||
if (count >= gPs.size()) {
|
||||
count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
@ -169,81 +162,109 @@ Factor::operator*= (const Factor& g)
|
||||
void
|
||||
Factor::insertVariable (FgVarNode* var)
|
||||
{
|
||||
int c = 0;
|
||||
ParamSet newPs (ps_.size() * var->getDomainSize());
|
||||
for (unsigned i = 0; i < ps_.size(); i++) {
|
||||
for (int j = 0; j < var->getDomainSize(); j++) {
|
||||
newPs[c] = ps_[i];
|
||||
c ++;
|
||||
assert (getIndexOf (var) == -1);
|
||||
ParamSet newPs;
|
||||
newPs.reserve (dist_->params.size() * var->getDomainSize());
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
for (unsigned j = 0; j < var->getDomainSize(); j++) {
|
||||
newPs.push_back (dist_->params[i]);
|
||||
}
|
||||
}
|
||||
vs_.push_back (var);
|
||||
ps_ = newPs;
|
||||
vars_.push_back (var);
|
||||
dist_->updateParameters (newPs);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::marginalizeVariable (const FgVarNode* var) {
|
||||
int varIndex = getIndexOf (var);
|
||||
marginalizeVariable (varIndex);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::marginalizeVariable (unsigned varIndex)
|
||||
Factor::removeVariable (const FgVarNode* var)
|
||||
{
|
||||
assert (varIndex >= 0 && varIndex < vs_.size());
|
||||
int distOffset = 1;
|
||||
int leftVarOffset = 1;
|
||||
for (unsigned i = vs_.size() - 1; i > varIndex; i--) {
|
||||
distOffset *= vs_[i]->getDomainSize();
|
||||
leftVarOffset *= vs_[i]->getDomainSize();
|
||||
}
|
||||
leftVarOffset *= vs_[varIndex]->getDomainSize();
|
||||
int varIndex = getIndexOf (var);
|
||||
assert (varIndex >= 0 && varIndex < (int)vars_.size());
|
||||
|
||||
// number of parameters separating a different state of `var',
|
||||
// with the states of the remaining variables fixed
|
||||
unsigned varOffset = 1;
|
||||
|
||||
// number of parameters separating a different state of the variable
|
||||
// on the left of `var', with the states of the remaining vars fixed
|
||||
unsigned leftVarOffset = 1;
|
||||
|
||||
for (int i = vars_.size() - 1; i > varIndex; i--) {
|
||||
varOffset *= vars_[i]->getDomainSize();
|
||||
leftVarOffset *= vars_[i]->getDomainSize();
|
||||
}
|
||||
leftVarOffset *= vars_[varIndex]->getDomainSize();
|
||||
|
||||
unsigned offset = 0;
|
||||
unsigned count1 = 0;
|
||||
unsigned count2 = 0;
|
||||
unsigned newPsSize = dist_->params.size() / vars_[varIndex]->getDomainSize();
|
||||
|
||||
int ds = vs_[varIndex]->getDomainSize();
|
||||
int count = 0;
|
||||
int offset = 0;
|
||||
int startIndex = 0;
|
||||
int currDomainIdx = 0;
|
||||
unsigned newPsSize = ps_.size() / ds;
|
||||
ParamSet newPs;
|
||||
newPs.reserve (newPsSize);
|
||||
|
||||
stringstream ss;
|
||||
ss << "marginalizing " << vs_[varIndex]->getLabel();
|
||||
ss << " from factor " << getLabel() << endl;
|
||||
// stringstream ss;
|
||||
// ss << "marginalizing " << vars_[varIndex]->getLabel();
|
||||
// ss << " from factor " << getLabel() << endl;
|
||||
while (newPs.size() < newPsSize) {
|
||||
ss << " sum = ";
|
||||
// ss << " sum = ";
|
||||
double sum = 0.0;
|
||||
for (int j = 0; j < ds; j++) {
|
||||
if (j != 0) ss << " + ";
|
||||
ss << ps_[offset];
|
||||
sum = sum + ps_[offset];
|
||||
offset = offset + distOffset;
|
||||
for (unsigned i = 0; i < vars_[varIndex]->getDomainSize(); i++) {
|
||||
// if (i != 0) ss << " + ";
|
||||
// ss << dist_->params[offset];
|
||||
sum += dist_->params[offset];
|
||||
offset += varOffset;
|
||||
}
|
||||
newPs.push_back (sum);
|
||||
count ++;
|
||||
if (varIndex == vs_.size() - 1) {
|
||||
offset = count * ds;
|
||||
count1 ++;
|
||||
if (varIndex == (int)vars_.size() - 1) {
|
||||
offset = count1 * vars_[varIndex]->getDomainSize();
|
||||
} else {
|
||||
offset = offset - distOffset + 1;
|
||||
if ((offset % leftVarOffset) == 0) {
|
||||
currDomainIdx ++;
|
||||
startIndex = leftVarOffset * currDomainIdx;
|
||||
offset = startIndex;
|
||||
count = 0;
|
||||
} else {
|
||||
offset = startIndex + count;
|
||||
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
||||
count1 = 0;
|
||||
count2 ++;
|
||||
}
|
||||
offset = (leftVarOffset * count2) + count1;
|
||||
}
|
||||
ss << " = " << sum << endl;
|
||||
// ss << " = " << sum << endl;
|
||||
}
|
||||
//cout << ss.str() << endl;
|
||||
ps_ = newPs;
|
||||
vs_.erase (vs_.begin() + varIndex);
|
||||
// cout << ss.str() << endl;
|
||||
vars_.erase (vars_.begin() + varIndex);
|
||||
dist_->updateParameters (newPs);
|
||||
}
|
||||
|
||||
|
||||
|
||||
const vector<CptEntry>&
|
||||
Factor::getCptEntries (void) const
|
||||
{
|
||||
if (dist_->entries.size() == 0) {
|
||||
vector<DConf> confs (dist_->params.size());
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
confs[i].resize (vars_.size());
|
||||
}
|
||||
|
||||
unsigned nReps = 1;
|
||||
for (int i = vars_.size() - 1; i >= 0; i--) {
|
||||
unsigned index = 0;
|
||||
while (index < dist_->params.size()) {
|
||||
for (unsigned j = 0; j < vars_[i]->getDomainSize(); j++) {
|
||||
for (unsigned r = 0; r < nReps; r++) {
|
||||
confs[index][i] = j;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
nReps *= vars_[i]->getDomainSize();
|
||||
}
|
||||
dist_->entries.clear();
|
||||
dist_->entries.reserve (dist_->params.size());
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
dist_->entries.push_back (CptEntry (i, confs[i]));
|
||||
}
|
||||
}
|
||||
return dist_->entries;
|
||||
}
|
||||
|
||||
|
||||
@ -252,11 +273,10 @@ string
|
||||
Factor::getLabel (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "f(" ;
|
||||
// ss << "Φ(" ;
|
||||
for (unsigned i = 0; i < vs_.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
ss << "v" << vs_[i]->getVarId();
|
||||
ss << "Φ(" ;
|
||||
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
ss << vars_[i]->getLabel();
|
||||
}
|
||||
ss << ")" ;
|
||||
return ss.str();
|
||||
@ -264,62 +284,24 @@ Factor::getLabel (void) const
|
||||
|
||||
|
||||
|
||||
string
|
||||
Factor::toString (void) const
|
||||
void
|
||||
Factor::printFactor (void)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "vars: " ;
|
||||
for (unsigned i = 0; i < vs_.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
ss << "v" << vs_[i]->getVarId();
|
||||
ss << getLabel() << endl;
|
||||
ss << "--------------------" << endl;
|
||||
VarSet vs;
|
||||
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||
vs.push_back (vars_[i]);
|
||||
}
|
||||
ss << endl;
|
||||
vector<CptEntry> entries = getCptEntries();
|
||||
vector<string> domainConfs = Util::getInstantiations (vs);
|
||||
const vector<CptEntry>& entries = getCptEntries();
|
||||
for (unsigned i = 0; i < entries.size(); i++) {
|
||||
ss << "Φ(" ;
|
||||
char s = 'a' ;
|
||||
const DomainConf& conf = entries[i].getParentConfigurations();
|
||||
for (unsigned j = 0; j < conf.size(); j++) {
|
||||
if (j != 0) ss << "," ;
|
||||
ss << s << conf[j] + 1;
|
||||
s++;
|
||||
}
|
||||
ss << ") = " << ps_[entries[i].getParameterIndex()] << endl;
|
||||
ss << "Φ(" << domainConfs[i] << ")" ;
|
||||
unsigned idx = entries[i].getParameterIndex();
|
||||
ss << " = " << dist_->params[idx] << endl;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<CptEntry>
|
||||
Factor::getCptEntries (void) const
|
||||
{
|
||||
vector<DomainConf> confs (ps_.size());
|
||||
for (unsigned i = 0; i < ps_.size(); i++) {
|
||||
confs[i].resize (vs_.size());
|
||||
}
|
||||
|
||||
int nReps = 1;
|
||||
for (int i = vs_.size() - 1; i >= 0; i--) {
|
||||
unsigned index = 0;
|
||||
while (index < ps_.size()) {
|
||||
for (int j = 0; j < vs_[i]->getDomainSize(); j++) {
|
||||
for (int r = 0; r < nReps; r++) {
|
||||
confs[index][i] = j;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
nReps *= vs_[i]->getDomainSize();
|
||||
}
|
||||
|
||||
vector<CptEntry> entries;
|
||||
for (unsigned i = 0; i < ps_.size(); i++) {
|
||||
for (unsigned j = 0; j < vs_.size(); j++) {
|
||||
}
|
||||
entries.push_back (CptEntry (i, confs[i]));
|
||||
}
|
||||
return entries;
|
||||
cout << ss.str();
|
||||
}
|
||||
|
||||
|
||||
@ -327,20 +309,11 @@ Factor::getCptEntries (void) const
|
||||
int
|
||||
Factor::getIndexOf (const FgVarNode* var) const
|
||||
{
|
||||
for (unsigned i = 0; i < vs_.size(); i++) {
|
||||
if (vs_[i] == var) {
|
||||
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||
if (vars_[i] == var) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor operator* (const Factor& f, const Factor& g)
|
||||
{
|
||||
Factor r = f;
|
||||
r *= g;
|
||||
return r;
|
||||
}
|
||||
|
||||
|
@ -3,43 +3,46 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Distribution.h"
|
||||
#include "CptEntry.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class FgVarNode;
|
||||
class Distribution;
|
||||
|
||||
class Factor
|
||||
{
|
||||
public:
|
||||
Factor (void) { }
|
||||
Factor (const Factor&);
|
||||
Factor (FgVarNode*);
|
||||
Factor (const FgVarSet&);
|
||||
Factor (CFgVarSet);
|
||||
Factor (FgVarNode*, const ParamSet&);
|
||||
Factor (const FgVarSet&, const ParamSet&);
|
||||
Factor (FgVarSet&, Distribution*);
|
||||
Factor (CFgVarSet, CParamSet);
|
||||
|
||||
const FgVarSet& getFgVarNodes (void) const;
|
||||
FgVarSet& getFgVarNodes (void);
|
||||
const ParamSet& getParameters (void) const;
|
||||
ParamSet& getParameters (void);
|
||||
void setParameters (const ParamSet&);
|
||||
Factor& operator= (const Factor& f);
|
||||
Factor& operator*= (const Factor& f);
|
||||
void insertVariable (FgVarNode* index);
|
||||
void marginalizeVariable (const FgVarNode* var);
|
||||
void marginalizeVariable (unsigned);
|
||||
string getLabel (void) const;
|
||||
string toString (void) const;
|
||||
void setParameters (CParamSet);
|
||||
void copyFactor (const Factor& f);
|
||||
void multiplyByFactor (const Factor& f, const vector<CptEntry>* = 0);
|
||||
void insertVariable (FgVarNode* index);
|
||||
void removeVariable (const FgVarNode* var);
|
||||
const vector<CptEntry>& getCptEntries (void) const;
|
||||
string getLabel (void) const;
|
||||
void printFactor (void);
|
||||
|
||||
CFgVarSet getFgVarNodes (void) const { return vars_; }
|
||||
CParamSet getParameters (void) const { return dist_->params; }
|
||||
Distribution* getDistribution (void) const { return dist_; }
|
||||
unsigned getIndex (void) const { return index_; }
|
||||
void setIndex (unsigned index) { index_ = index; }
|
||||
void freeDistribution (void) { delete dist_; dist_ = 0;}
|
||||
int getIndexOf (const FgVarNode*) const;
|
||||
|
||||
private:
|
||||
vector<CptEntry> getCptEntries() const;
|
||||
int getIndexOf (const FgVarNode*) const;
|
||||
|
||||
FgVarSet vs_;
|
||||
ParamSet ps_;
|
||||
int id_;
|
||||
static int indexCount_;
|
||||
FgVarSet vars_;
|
||||
Distribution* dist_;
|
||||
unsigned index_;
|
||||
};
|
||||
|
||||
Factor operator* (const Factor&, const Factor&);
|
||||
|
||||
#endif
|
||||
#endif //BP_FACTOR_H
|
||||
|
@ -1,23 +1,26 @@
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "FgVarNode.h"
|
||||
#include "Factor.h"
|
||||
#include "BayesNet.h"
|
||||
|
||||
|
||||
FactorGraph::FactorGraph (const char* fileName)
|
||||
{
|
||||
string line;
|
||||
ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
string line;
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
getline (is, line);
|
||||
if (line != "MARKOV") {
|
||||
@ -39,7 +42,7 @@ FactorGraph::FactorGraph (const char* fileName)
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
for (int i = 0; i < nVars; i++) {
|
||||
varNodes_.push_back (new FgVarNode (i, domainSizes[i]));
|
||||
addVariable (new FgVarNode (i, domainSizes[i]));
|
||||
}
|
||||
|
||||
int nFactors;
|
||||
@ -50,11 +53,11 @@ FactorGraph::FactorGraph (const char* fileName)
|
||||
is >> nFactorVars;
|
||||
FgVarSet factorVars;
|
||||
for (int j = 0; j < nFactorVars; j++) {
|
||||
int varId;
|
||||
is >> varId;
|
||||
FgVarNode* var = getVariableById (varId);
|
||||
if (var == 0) {
|
||||
cerr << "error: invalid variable identifier (" << varId << ")" << endl;
|
||||
int vid;
|
||||
is >> vid;
|
||||
FgVarNode* var = getFgVarNode (vid);
|
||||
if (!var) {
|
||||
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
|
||||
abort();
|
||||
}
|
||||
factorVars.push_back (var);
|
||||
@ -87,6 +90,33 @@ FactorGraph::FactorGraph (const char* fileName)
|
||||
|
||||
|
||||
|
||||
FactorGraph::FactorGraph (const BayesNet& bn)
|
||||
{
|
||||
const BnNodeSet& nodes = bn.getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
FgVarNode* varNode = new FgVarNode (nodes[i]);
|
||||
varNode->setIndex (i);
|
||||
addVariable (varNode);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
const BnNodeSet& parents = nodes[i]->getParents();
|
||||
if (!(nodes[i]->hasEvidence() && parents.size() == 0)) {
|
||||
FgVarSet factorVars = { varNodes_[nodes[i]->getIndex()] };
|
||||
for (unsigned j = 0; j < parents.size(); j++) {
|
||||
factorVars.push_back (varNodes_[parents[j]->getIndex()]);
|
||||
}
|
||||
Factor* f = new Factor (factorVars, nodes[i]->getDistribution());
|
||||
factors_.push_back (f);
|
||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||
factorVars[j]->addFactor (f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph::~FactorGraph (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
@ -99,18 +129,67 @@ FactorGraph::~FactorGraph (void)
|
||||
|
||||
|
||||
|
||||
FgVarSet
|
||||
FactorGraph::getFgVarNodes (void) const
|
||||
void
|
||||
FactorGraph::addVariable (FgVarNode* varNode)
|
||||
{
|
||||
return varNodes_;
|
||||
varNodes_.push_back (varNode);
|
||||
varNode->setIndex (varNodes_.size() - 1);
|
||||
indexMap_.insert (make_pair (varNode->getVarId(), varNodes_.size() - 1));
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<Factor*>
|
||||
FactorGraph::getFactors (void) const
|
||||
void
|
||||
FactorGraph::removeVariable (const FgVarNode* var)
|
||||
{
|
||||
return factors_;
|
||||
if (varNodes_[varNodes_.size() - 1] == var) {
|
||||
varNodes_.pop_back();
|
||||
} else {
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
if (varNodes_[i] == var) {
|
||||
varNodes_.erase (varNodes_.begin() + i);
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert (false);
|
||||
}
|
||||
indexMap_.erase (indexMap_.find (var->getVarId()));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (Factor* f)
|
||||
{
|
||||
factors_.push_back (f);
|
||||
const FgVarSet& factorVars = f->getFgVarNodes();
|
||||
for (unsigned i = 0; i < factorVars.size(); i++) {
|
||||
factorVars[i]->addFactor (f);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::removeFactor (const Factor* f)
|
||||
{
|
||||
const FgVarSet& factorVars = f->getFgVarNodes();
|
||||
for (unsigned i = 0; i < factorVars.size(); i++) {
|
||||
if (factorVars[i]) {
|
||||
factorVars[i]->removeFactor (f);
|
||||
}
|
||||
}
|
||||
if (factors_[factors_.size() - 1] == f) {
|
||||
factors_.pop_back();
|
||||
} else {
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
if (factors_[i] == f) {
|
||||
factors_.erase (factors_.begin() + i);
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert (false);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -127,47 +206,142 @@ FactorGraph::getVariables (void) const
|
||||
|
||||
|
||||
|
||||
FgVarNode*
|
||||
FactorGraph::getVariableById (unsigned id) const
|
||||
Variable*
|
||||
FactorGraph::getVariable (Vid vid) const
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
if (varNodes_[i]->getVarId() == id) {
|
||||
return varNodes_[i];
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FgVarNode*
|
||||
FactorGraph::getVariableByLabel (string label) const
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
stringstream ss;
|
||||
ss << "v" << varNodes_[i]->getVarId();
|
||||
if (ss.str() == label) {
|
||||
return varNodes_[i];
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
return getFgVarNode (vid);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::printFactorGraph (void) const
|
||||
FactorGraph::setIndexes (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
varNodes_[i]->setIndex (i);
|
||||
}
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
factors_[i]->setIndex (i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::freeDistributions (void)
|
||||
{
|
||||
set<Distribution*> dists;
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
dists.insert (factors_[i]->getDistribution());
|
||||
}
|
||||
for (set<Distribution*>::iterator it = dists.begin();
|
||||
it != dists.end(); it++) {
|
||||
delete *it;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::printGraphicalModel (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
cout << "variable number " << varNodes_[i]->getIndex() << endl;
|
||||
cout << "Id = " << varNodes_[i]->getVarId() << endl;
|
||||
cout << "Label = " << varNodes_[i]->getLabel() << endl;
|
||||
cout << "Domain size = " << varNodes_[i]->getDomainSize() << endl;
|
||||
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||
cout << endl;
|
||||
cout << "Factors = " ;
|
||||
for (unsigned j = 0; j < varNodes_[i]->getFactors().size(); j++) {
|
||||
cout << varNodes_[i]->getFactors()[j]->getLabel() << " " ;
|
||||
}
|
||||
cout << endl << endl;
|
||||
}
|
||||
cout << endl;
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
cout << factors_[i]->toString() << endl;
|
||||
factors_[i]->printFactor();
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::exportToDotFormat (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "FactorGraph::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "graph \"" << fileName << "\" {" << endl;
|
||||
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
if (varNodes_[i]->hasEvidence()) {
|
||||
out << '"' << varNodes_[i]->getLabel() << '"' ;
|
||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
out << '"' << factors_[i]->getLabel() << '"' ;
|
||||
out << " [label=\"" << factors_[i]->getLabel() << "\\n(";
|
||||
out << factors_[i]->getDistribution()->id << ")" << "\"" ;
|
||||
out << ", shape=box]" << endl;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
CFgVarSet myVars = factors_[i]->getFgVarNodes();
|
||||
for (unsigned j = 0; j < myVars.size(); j++) {
|
||||
out << '"' << factors_[i]->getLabel() << '"' ;
|
||||
out << " -- " ;
|
||||
out << '"' << myVars[j]->getLabel() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "FactorGraph::exportToUaiFormat()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "MARKOV" << endl;
|
||||
out << varNodes_.size() << endl;
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
out << varNodes_[i]->getDomainSize() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
|
||||
out << factors_.size() << endl;
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
CFgVarSet factorVars = factors_[i]->getFgVarNodes();
|
||||
out << factorVars.size();
|
||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||
out << " " << factorVars[j]->getIndex();
|
||||
}
|
||||
out << endl;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
CParamSet params = factors_[i]->getParameters();
|
||||
out << endl << params.size() << endl << " " ;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
out << params[j] << " " ;
|
||||
}
|
||||
out << endl;
|
||||
}
|
||||
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
#ifndef BP_FACTORGRAPH_H
|
||||
#define BP_FACTORGRAPH_H
|
||||
#ifndef BP_FACTOR_GRAPH_H
|
||||
#define BP_FACTOR_GRAPH_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "GraphicalModel.h"
|
||||
#include "Shared.h"
|
||||
@ -11,25 +10,48 @@ using namespace std;
|
||||
|
||||
class FgVarNode;
|
||||
class Factor;
|
||||
class BayesNet;
|
||||
|
||||
class FactorGraph : public GraphicalModel
|
||||
{
|
||||
public:
|
||||
FactorGraph (const char* fileName);
|
||||
FactorGraph (void) {};
|
||||
FactorGraph (const char*);
|
||||
FactorGraph (const BayesNet&);
|
||||
~FactorGraph (void);
|
||||
|
||||
FgVarSet getFgVarNodes (void) const;
|
||||
vector<Factor*> getFactors (void) const;
|
||||
|
||||
void addVariable (FgVarNode*);
|
||||
void removeVariable (const FgVarNode*);
|
||||
void addFactor (Factor*);
|
||||
void removeFactor (const Factor*);
|
||||
VarSet getVariables (void) const;
|
||||
FgVarNode* getVariableById (unsigned) const;
|
||||
FgVarNode* getVariableByLabel (string) const;
|
||||
void printFactorGraph (void) const;
|
||||
Variable* getVariable (unsigned) const;
|
||||
void setIndexes (void);
|
||||
void freeDistributions (void);
|
||||
void printGraphicalModel (void) const;
|
||||
void exportToDotFormat (const char*) const;
|
||||
void exportToUaiFormat (const char*) const;
|
||||
|
||||
const FgVarSet& getFgVarNodes (void) const { return varNodes_; }
|
||||
const FactorSet& getFactors (void) const { return factors_; }
|
||||
|
||||
FgVarNode* getFgVarNode (Vid vid) const
|
||||
{
|
||||
IndexMap::const_iterator it = indexMap_.find (vid);
|
||||
if (it == indexMap_.end()) {
|
||||
return 0;
|
||||
} else {
|
||||
return varNodes_[it->second];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||
|
||||
FgVarSet varNodes_;
|
||||
vector<Factor*> factors_;
|
||||
FgVarSet varNodes_;
|
||||
FactorSet factors_;
|
||||
IndexMap indexMap_;
|
||||
};
|
||||
|
||||
#endif
|
||||
#endif // BP_FACTOR_GRAPH_H
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
#ifndef BP_VARIABLE_H
|
||||
#define BP_VARIABLE_H
|
||||
#ifndef BP_FG_VAR_NODE_H
|
||||
#define BP_FG_VAR_NODE_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "Variable.h"
|
||||
#include "Shared.h"
|
||||
@ -14,15 +13,31 @@ class Factor;
|
||||
class FgVarNode : public Variable
|
||||
{
|
||||
public:
|
||||
FgVarNode (int varId, int dsize) : Variable (varId, dsize) { }
|
||||
FgVarNode (unsigned vid, unsigned dsize) : Variable (vid, dsize) { }
|
||||
FgVarNode (const Variable* v) : Variable (v) { }
|
||||
|
||||
void addFactor (Factor* f) { factors_.push_back (f); }
|
||||
vector<Factor*> getFactors (void) const { return factors_; }
|
||||
void addFactor (Factor* f) { factors_.push_back (f); }
|
||||
CFactorSet getFactors (void) const { return factors_; }
|
||||
|
||||
void removeFactor (const Factor* f)
|
||||
{
|
||||
if (factors_[factors_.size() -1] == f) {
|
||||
factors_.pop_back();
|
||||
} else {
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
if (factors_[i] == f) {
|
||||
factors_.erase (factors_.begin() + i);
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert (false);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
||||
// members
|
||||
vector<Factor*> factors_;
|
||||
FactorSet factors_;
|
||||
};
|
||||
|
||||
#endif // BP_VARIABLE_H
|
||||
#endif // BP_FG_VAR_NODE_H
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef BP_GRAPHICALMODEL_H
|
||||
#define BP_GRAPHICALMODEL_H
|
||||
#ifndef BP_GRAPHICAL_MODEL_H
|
||||
#define BP_GRAPHICAL_MODEL_H
|
||||
|
||||
#include "Variable.h"
|
||||
#include "Shared.h"
|
||||
@ -9,9 +9,10 @@ using namespace std;
|
||||
class GraphicalModel
|
||||
{
|
||||
public:
|
||||
virtual VarSet getVariables (void) const = 0;
|
||||
|
||||
private:
|
||||
virtual ~GraphicalModel (void) {};
|
||||
virtual Variable* getVariable (Vid) const = 0;
|
||||
virtual VarSet getVariables (void) const = 0;
|
||||
virtual void printGraphicalModel (void) const = 0;
|
||||
};
|
||||
|
||||
#endif
|
||||
#endif // BP_GRAPHICAL_MODEL_H
|
||||
|
@ -1,17 +1,19 @@
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "BayesNet.h"
|
||||
#include "BPSolver.h"
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "SPSolver.h"
|
||||
#include "BPSolver.h"
|
||||
#include "CountingBP.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
void BayesianNetwork (int, const char* []);
|
||||
void markovNetwork (int, const char* []);
|
||||
void runSolver (Solver*, const VarSet&);
|
||||
|
||||
const string USAGE = "usage: \
|
||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||
@ -19,14 +21,40 @@ const string USAGE = "usage: \
|
||||
|
||||
int
|
||||
main (int argc, const char* argv[])
|
||||
{
|
||||
{
|
||||
/*
|
||||
FactorGraph fg;
|
||||
FgVarNode* varNode1 = new FgVarNode (0, 2);
|
||||
FgVarNode* varNode2 = new FgVarNode (1, 2);
|
||||
FgVarNode* varNode3 = new FgVarNode (2, 2);
|
||||
fg.addVariable (varNode1);
|
||||
fg.addVariable (varNode2);
|
||||
fg.addVariable (varNode3);
|
||||
Distribution* dist = new Distribution (ParamSet() = {1.2, 1.4, 2.0, 0.4});
|
||||
fg.addFactor (new Factor (FgVarSet() = {varNode1, varNode2}, dist));
|
||||
fg.addFactor (new Factor (FgVarSet() = {varNode3, varNode2}, dist));
|
||||
//fg.printGraphicalModel();
|
||||
//SPSolver sp (fg);
|
||||
//sp.runSolver();
|
||||
//sp.printAllPosterioris();
|
||||
//ParamSet p = sp.getJointDistributionOf (VidSet() = {0, 1, 2});
|
||||
//cout << Util::parametersToString (p) << endl;
|
||||
CountingBP cbp (fg);
|
||||
//cbp.runSolver();
|
||||
//cbp.printAllPosterioris();
|
||||
ParamSet p2 = cbp.getJointDistributionOf (VidSet() = {0, 1, 2});
|
||||
cout << Util::parametersToString (p2) << endl;
|
||||
fg.freeDistributions();
|
||||
Statistics::printCompressingStats ("compressing.stats");
|
||||
return 0;
|
||||
*/
|
||||
if (!argv[1]) {
|
||||
cerr << "error: no graphical model specified" << endl;
|
||||
cerr << USAGE << endl;
|
||||
exit (0);
|
||||
}
|
||||
string fileName = argv[1];
|
||||
string extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||
const string& fileName = argv[1];
|
||||
const string& extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||
if (extension == "xml") {
|
||||
BayesianNetwork (argc, argv);
|
||||
} else if (extension == "uai") {
|
||||
@ -45,13 +73,13 @@ void
|
||||
BayesianNetwork (int argc, const char* argv[])
|
||||
{
|
||||
BayesNet bn (argv[1]);
|
||||
//bn.printNetwork();
|
||||
//bn.printGraphicalModel();
|
||||
|
||||
NodeSet queryVars;
|
||||
VarSet queryVars;
|
||||
for (int i = 2; i < argc; i++) {
|
||||
string arg = argv[i];
|
||||
const string& arg = argv[i];
|
||||
if (arg.find ('=') == std::string::npos) {
|
||||
BayesNode* queryVar = bn.getNode (arg);
|
||||
BayesNode* queryVar = bn.getBayesNode (arg);
|
||||
if (queryVar) {
|
||||
queryVars.push_back (queryVar);
|
||||
} else {
|
||||
@ -61,9 +89,9 @@ BayesianNetwork (int argc, const char* argv[])
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
size_t pos = arg.find ('=');
|
||||
string label = arg.substr (0, pos);
|
||||
string state = arg.substr (pos + 1);
|
||||
size_t pos = arg.find ('=');
|
||||
const string& label = arg.substr (0, pos);
|
||||
const string& state = arg.substr (pos + 1);
|
||||
if (label.empty()) {
|
||||
cerr << "error: missing left argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
@ -74,7 +102,7 @@ BayesianNetwork (int argc, const char* argv[])
|
||||
cerr << USAGE << endl;
|
||||
exit (0);
|
||||
}
|
||||
BayesNode* node = bn.getNode (label);
|
||||
BayesNode* node = bn.getBayesNode (label);
|
||||
if (node) {
|
||||
if (node->isValidState (state)) {
|
||||
node->setEvidence (state);
|
||||
@ -94,19 +122,16 @@ BayesianNetwork (int argc, const char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
BPSolver solver (bn);
|
||||
if (queryVars.size() == 0) {
|
||||
solver.runSolver();
|
||||
solver.printAllPosterioris();
|
||||
} else if (queryVars.size() == 1) {
|
||||
solver.runSolver();
|
||||
solver.printPosterioriOf (queryVars[0]);
|
||||
Solver* solver;
|
||||
if (SolverOptions::convertBn2Fg) {
|
||||
FactorGraph* fg = new FactorGraph (bn);
|
||||
fg->printGraphicalModel();
|
||||
solver = new SPSolver (*fg);
|
||||
runSolver (solver, queryVars);
|
||||
delete fg;
|
||||
} else {
|
||||
Domain domain = BayesNet::getInstantiations(queryVars);
|
||||
ParamSet params = solver.getJointDistribution (queryVars);
|
||||
for (unsigned i = 0; i < params.size(); i++) {
|
||||
cout << domain[i] << "\t" << params[i] << endl;
|
||||
}
|
||||
solver = new BPSolver (bn);
|
||||
runSolver (solver, queryVars);
|
||||
}
|
||||
bn.freeDistributions();
|
||||
}
|
||||
@ -117,11 +142,11 @@ void
|
||||
markovNetwork (int argc, const char* argv[])
|
||||
{
|
||||
FactorGraph fg (argv[1]);
|
||||
//fg.printFactorGraph();
|
||||
|
||||
//fg.printGraphicalModel();
|
||||
|
||||
VarSet queryVars;
|
||||
for (int i = 2; i < argc; i++) {
|
||||
string arg = argv[i];
|
||||
const string& arg = argv[i];
|
||||
if (arg.find ('=') == std::string::npos) {
|
||||
if (!Util::isInteger (arg)) {
|
||||
cerr << "error: `" << arg << "' " ;
|
||||
@ -129,16 +154,16 @@ markovNetwork (int argc, const char* argv[])
|
||||
cerr << endl;
|
||||
exit (0);
|
||||
}
|
||||
unsigned varId;
|
||||
Vid vid;
|
||||
stringstream ss;
|
||||
ss << arg;
|
||||
ss >> varId;
|
||||
Variable* queryVar = fg.getVariableById (varId);
|
||||
ss >> vid;
|
||||
Variable* queryVar = fg.getFgVarNode (vid);
|
||||
if (queryVar) {
|
||||
queryVars.push_back (queryVar);
|
||||
} else {
|
||||
cerr << "error: there isn't a variable with " ;
|
||||
cerr << "`" << varId << "' as id" ;
|
||||
cerr << "`" << vid << "' as id" ;
|
||||
cerr << endl;
|
||||
exit (0);
|
||||
}
|
||||
@ -160,11 +185,11 @@ markovNetwork (int argc, const char* argv[])
|
||||
cerr << endl;
|
||||
exit (0);
|
||||
}
|
||||
unsigned varId;
|
||||
Vid vid;
|
||||
stringstream ss;
|
||||
ss << arg.substr (0, pos);
|
||||
ss >> varId;
|
||||
Variable* var = fg.getVariableById (varId);
|
||||
ss >> vid;
|
||||
Variable* var = fg.getFgVarNode (vid);
|
||||
if (var) {
|
||||
if (!Util::isInteger (arg.substr (pos + 1))) {
|
||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
||||
@ -176,7 +201,6 @@ markovNetwork (int argc, const char* argv[])
|
||||
stringstream ss;
|
||||
ss << arg.substr (pos + 1);
|
||||
ss >> stateIndex;
|
||||
cout << "si: " << stateIndex << endl;
|
||||
if (var->isValidStateIndex (stateIndex)) {
|
||||
var->setEvidence (stateIndex);
|
||||
} else {
|
||||
@ -188,27 +212,35 @@ markovNetwork (int argc, const char* argv[])
|
||||
}
|
||||
} else {
|
||||
cerr << "error: there isn't a variable with " ;
|
||||
cerr << "`" << varId << "' as id" ;
|
||||
cerr << "`" << vid << "' as id" ;
|
||||
cerr << endl;
|
||||
exit (0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SPSolver solver (fg);
|
||||
if (queryVars.size() == 0) {
|
||||
solver.runSolver();
|
||||
solver.printAllPosterioris();
|
||||
} else if (queryVars.size() == 1) {
|
||||
solver.runSolver();
|
||||
solver.printPosterioriOf (queryVars[0]);
|
||||
} else {
|
||||
assert (false); //FIXME
|
||||
//Domain domain = BayesNet::getInstantiations(queryVars);
|
||||
//ParamSet params = solver.getJointDistribution (queryVars);
|
||||
//for (unsigned i = 0; i < params.size(); i++) {
|
||||
// cout << domain[i] << "\t" << params[i] << endl;
|
||||
//}
|
||||
}
|
||||
Solver* solver = new SPSolver (fg);
|
||||
runSolver (solver, queryVars);
|
||||
fg.freeDistributions();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
runSolver (Solver* solver, const VarSet& queryVars)
|
||||
{
|
||||
VidSet vids;
|
||||
for (unsigned i = 0; i < queryVars.size(); i++) {
|
||||
vids.push_back (queryVars[i]->getVarId());
|
||||
}
|
||||
if (queryVars.size() == 0) {
|
||||
solver->runSolver();
|
||||
solver->printAllPosterioris();
|
||||
} else if (queryVars.size() == 1) {
|
||||
solver->runSolver();
|
||||
solver->printPosterioriOf (vids[0]);
|
||||
} else {
|
||||
solver->printJointDistributionOf (vids);
|
||||
}
|
||||
delete solver;
|
||||
}
|
||||
|
||||
|
@ -1,41 +1,39 @@
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include <YapInterface.h>
|
||||
|
||||
#include "callgrind.h"
|
||||
|
||||
#include "BayesNet.h"
|
||||
#include "BayesNode.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "BPSolver.h"
|
||||
#include "SPSolver.h"
|
||||
#include "CountingBP.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
int
|
||||
createNetwork (void)
|
||||
{
|
||||
Statistics::numCreatedNets ++;
|
||||
cout << "creating network number " << Statistics::numCreatedNets << endl;
|
||||
if (Statistics::numCreatedNets == 1) {
|
||||
//CALLGRIND_START_INSTRUMENTATION;
|
||||
}
|
||||
BayesNet* bn = new BayesNet();
|
||||
//Statistics::numCreatedNets ++;
|
||||
//cout << "creating network number " << Statistics::numCreatedNets << endl;
|
||||
|
||||
BayesNet* bn = new BayesNet();
|
||||
YAP_Term varList = YAP_ARG1;
|
||||
while (varList != YAP_TermNil()) {
|
||||
YAP_Term var = YAP_HeadOfTerm (varList);
|
||||
unsigned varId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
|
||||
Vid vid = (Vid) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
|
||||
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var));
|
||||
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
|
||||
YAP_Term parentL = YAP_ArgOfTerm (4, var);
|
||||
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
|
||||
YAP_Term parentL = YAP_ArgOfTerm (4, var);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
|
||||
NodeSet parents;
|
||||
BnNodeSet parents;
|
||||
while (parentL != YAP_TermNil()) {
|
||||
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
|
||||
BayesNode* parent = bn->getNode (parentId);
|
||||
BayesNode* parent = bn->getBayesNode (parentId);
|
||||
if (!parent) {
|
||||
parent = bn->addNode (parentId);
|
||||
}
|
||||
@ -47,23 +45,20 @@ createNetwork (void)
|
||||
dist = new Distribution (distId);
|
||||
bn->addDistribution (dist);
|
||||
}
|
||||
BayesNode* node = bn->getNode (varId);
|
||||
BayesNode* node = bn->getBayesNode (vid);
|
||||
if (node) {
|
||||
node->setData (dsize, evidence, parents, dist);
|
||||
} else {
|
||||
bn->addNode (varId, dsize, evidence, parents, dist);
|
||||
bn->addNode (vid, dsize, evidence, parents, dist);
|
||||
}
|
||||
varList = YAP_TailOfTerm (varList);
|
||||
}
|
||||
bn->setIndexes();
|
||||
|
||||
if (Statistics::numCreatedNets == 1688) {
|
||||
Statistics::writeStats();
|
||||
//Statistics::writeStats();
|
||||
//CALLGRIND_STOP_INSTRUMENTATION;
|
||||
//CALLGRIND_DUMP_STATS;
|
||||
//exit (0);
|
||||
}
|
||||
// if (Statistics::numCreatedNets == 1688) {
|
||||
// Statistics::writeStats();
|
||||
// exit (0);
|
||||
// }
|
||||
YAP_Int p = (YAP_Int) (bn);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
|
||||
}
|
||||
@ -73,20 +68,20 @@ createNetwork (void)
|
||||
int
|
||||
setExtraVarsInfo (void)
|
||||
{
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term varsInfoL = YAP_ARG2;
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term varsInfoL = YAP_ARG2;
|
||||
while (varsInfoL != YAP_TermNil()) {
|
||||
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
||||
unsigned varId = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
|
||||
Vid vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
|
||||
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
|
||||
YAP_Term domainL = YAP_ArgOfTerm (3, head);
|
||||
YAP_Term domainL = YAP_ArgOfTerm (3, head);
|
||||
Domain domain;
|
||||
while (domainL != YAP_TermNil()) {
|
||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (domainL));
|
||||
domain.push_back ((char*) YAP_AtomName (atom));
|
||||
domainL = YAP_TailOfTerm (domainL);
|
||||
}
|
||||
BayesNode* node = bn->getNode (varId);
|
||||
BayesNode* node = bn->getBayesNode (vid);
|
||||
assert (node);
|
||||
node->setLabel ((char*) YAP_AtomName (label));
|
||||
node->setDomain (domain);
|
||||
@ -100,8 +95,8 @@ setExtraVarsInfo (void)
|
||||
int
|
||||
setParameters (void)
|
||||
{
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
while (distList != YAP_TermNil()) {
|
||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
@ -112,6 +107,11 @@ setParameters (void)
|
||||
paramL = YAP_TailOfTerm (paramL);
|
||||
}
|
||||
bn->getDistribution(distId)->updateParameters(params);
|
||||
if (Statistics::numCreatedNets == 4) {
|
||||
cout << "dist " << distId << " parameters:" ;
|
||||
cout << Util::parametersToString (params);
|
||||
cout << endl;
|
||||
}
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
}
|
||||
return TRUE;
|
||||
@ -122,84 +122,126 @@ setParameters (void)
|
||||
int
|
||||
runSolver (void)
|
||||
{
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
|
||||
vector<NodeSet> tasks;
|
||||
NodeSet marginalVars;
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
vector<VidSet> tasks;
|
||||
VidSet marginalVids;
|
||||
|
||||
while (taskList != YAP_TermNil()) {
|
||||
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) {
|
||||
NodeSet jointVars;
|
||||
VidSet jointVids;
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
while (jointList != YAP_TermNil()) {
|
||||
unsigned varId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
||||
assert (bn->getNode (varId));
|
||||
jointVars.push_back (bn->getNode (varId));
|
||||
Vid vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
||||
assert (bn->getBayesNode (vid));
|
||||
jointVids.push_back (vid);
|
||||
jointList = YAP_TailOfTerm (jointList);
|
||||
}
|
||||
tasks.push_back (jointVars);
|
||||
tasks.push_back (jointVids);
|
||||
} else {
|
||||
unsigned varId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
|
||||
BayesNode* node = bn->getNode (varId);
|
||||
assert (node);
|
||||
tasks.push_back (NodeSet() = {node});
|
||||
marginalVars.push_back (node);
|
||||
Vid vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
|
||||
assert (bn->getBayesNode (vid));
|
||||
tasks.push_back (VidSet() = {vid});
|
||||
marginalVids.push_back (vid);
|
||||
}
|
||||
taskList = YAP_TailOfTerm (taskList);
|
||||
}
|
||||
/*
|
||||
cout << "tasks to resolve:" << endl;
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
cout << "i" << ": " ;
|
||||
if (tasks[i].size() == 1) {
|
||||
cout << tasks[i][0]->getVarId() << endl;
|
||||
} else {
|
||||
for (unsigned j = 0; j < tasks[i].size(); j++) {
|
||||
cout << tasks[i][j]->getVarId() << " " ;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
cerr << "prunning now..." << endl;
|
||||
BayesNet* prunedNet = bn->pruneNetwork (marginalVars);
|
||||
bn->printNetworkToFile ("net.txt");
|
||||
BPSolver solver (*prunedNet);
|
||||
cerr << "solving marginals now..." << endl;
|
||||
solver.runSolver();
|
||||
cerr << "calculating joints now ..." << endl;
|
||||
// cout << "inference tasks:" << endl;
|
||||
// for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
// cout << "i" << ": " ;
|
||||
// if (tasks[i].size() == 1) {
|
||||
// cout << tasks[i][0] << endl;
|
||||
// } else {
|
||||
// for (unsigned j = 0; j < tasks[i].size(); j++) {
|
||||
// cout << tasks[i][j] << " " ;
|
||||
// }
|
||||
// cout << endl;
|
||||
// }
|
||||
// }
|
||||
|
||||
Solver* solver = 0;
|
||||
GraphicalModel* gm = 0;
|
||||
VidSet vids;
|
||||
const BnNodeSet& nodes = bn->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
vids.push_back (nodes[i]->getVarId());
|
||||
}
|
||||
if (marginalVids.size() != 0) {
|
||||
bn->exportToDotFormat ("bn unbayes.dot");
|
||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (marginalVids);
|
||||
mrn->exportToDotFormat ("bn bayes.dot");
|
||||
//BayesNet* mrn = bn->getMinimalRequesiteNetwork (vids);
|
||||
if (SolverOptions::convertBn2Fg) {
|
||||
gm = new FactorGraph (*mrn);
|
||||
if (SolverOptions::compressFactorGraph) {
|
||||
solver = new CountingBP (*static_cast<FactorGraph*> (gm));
|
||||
} else {
|
||||
solver = new SPSolver (*static_cast<FactorGraph*> (gm));
|
||||
}
|
||||
if (SolverOptions::runBayesBall) {
|
||||
delete mrn;
|
||||
}
|
||||
} else {
|
||||
gm = mrn;
|
||||
solver = new BPSolver (*static_cast<BayesNet*> (gm));
|
||||
}
|
||||
solver->runSolver();
|
||||
}
|
||||
|
||||
vector<ParamSet> results;
|
||||
results.reserve (tasks.size());
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
if (tasks[i].size() == 1) {
|
||||
BayesNode* node = prunedNet->getNode (tasks[i][0]->getVarId());
|
||||
results.push_back (solver.getPosterioriOf (node));
|
||||
results.push_back (solver->getPosterioriOf (tasks[i][0]));
|
||||
} else {
|
||||
BPSolver solver2 (*bn);
|
||||
cout << "calculating an join dist on: " ;
|
||||
for (unsigned j = 0; j < tasks[i].size(); j++) {
|
||||
cout << tasks[i][j]->getVarId() << " " ;
|
||||
static int count = 0;
|
||||
cout << "calculating joint... " << count ++ << endl;
|
||||
//if (count == 5225) {
|
||||
// Statistics::printCompressingStats ("compressing.stats");
|
||||
//}
|
||||
Solver* solver2 = 0;
|
||||
GraphicalModel* gm2 = 0;
|
||||
bn->exportToDotFormat ("joint.dot");
|
||||
BayesNet* mrn2;
|
||||
if (SolverOptions::runBayesBall) {
|
||||
mrn2 = bn->getMinimalRequesiteNetwork (tasks[i]);
|
||||
} else {
|
||||
mrn2 = bn;
|
||||
}
|
||||
cout << "..." << endl;
|
||||
results.push_back (solver2.getJointDistribution (tasks[i]));
|
||||
if (SolverOptions::convertBn2Fg) {
|
||||
gm2 = new FactorGraph (*mrn2);
|
||||
if (SolverOptions::compressFactorGraph) {
|
||||
solver2 = new CountingBP (*static_cast<FactorGraph*> (gm2));
|
||||
} else {
|
||||
solver2 = new SPSolver (*static_cast<FactorGraph*> (gm2));
|
||||
}
|
||||
if (SolverOptions::runBayesBall) {
|
||||
delete mrn2;
|
||||
}
|
||||
} else {
|
||||
gm2 = mrn2;
|
||||
solver2 = new BPSolver (*static_cast<BayesNet*> (gm2));
|
||||
}
|
||||
results.push_back (solver2->getJointDistributionOf (tasks[i]));
|
||||
delete solver2;
|
||||
delete gm2;
|
||||
}
|
||||
}
|
||||
|
||||
delete prunedNet;
|
||||
delete solver;
|
||||
delete gm;
|
||||
|
||||
YAP_Term list = YAP_TermNil();
|
||||
for (int i = results.size() - 1; i >= 0; i--) {
|
||||
const ParamSet& beliefs = results[i];
|
||||
YAP_Term queryBeliefsL = YAP_TermNil();
|
||||
for (int j = beliefs.size() - 1; j >= 0; j--) {
|
||||
YAP_Int sl1 = YAP_InitSlot(list);
|
||||
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
||||
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
||||
list = YAP_GetFromSlot(sl1);
|
||||
YAP_RecoverSlots(1);
|
||||
YAP_Int sl1 = YAP_InitSlot (list);
|
||||
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
||||
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
||||
list = YAP_GetFromSlot (sl1);
|
||||
YAP_RecoverSlots (1);
|
||||
}
|
||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||
}
|
||||
@ -210,8 +252,9 @@ runSolver (void)
|
||||
|
||||
|
||||
int
|
||||
deleteBayesNet (void)
|
||||
freeBayesNetwork (void)
|
||||
{
|
||||
//Statistics::printCompressingStats ("../../compressing.stats");
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
bn->freeDistributions();
|
||||
delete bn;
|
||||
@ -223,10 +266,10 @@ deleteBayesNet (void)
|
||||
extern "C" void
|
||||
init_predicates (void)
|
||||
{
|
||||
YAP_UserCPredicate ("create_network", createNetwork, 2);
|
||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
||||
YAP_UserCPredicate ("set_parameters", setParameters, 2);
|
||||
YAP_UserCPredicate ("run_solver", runSolver, 3);
|
||||
YAP_UserCPredicate ("delete_bayes_net", deleteBayesNet, 1);
|
||||
YAP_UserCPredicate ("create_network", createNetwork, 2);
|
||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
||||
YAP_UserCPredicate ("set_parameters", setParameters, 2);
|
||||
YAP_UserCPredicate ("run_solver", runSolver, 3);
|
||||
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
|
||||
}
|
||||
|
||||
|
278
packages/CLPBN/clpbn/bp/LiftedFG.cpp
Normal file
278
packages/CLPBN/clpbn/bp/LiftedFG.cpp
Normal file
@ -0,0 +1,278 @@
|
||||
|
||||
#include "LiftedFG.h"
|
||||
#include "FgVarNode.h"
|
||||
#include "Factor.h"
|
||||
#include "Distribution.h"
|
||||
|
||||
LiftedFG::LiftedFG (const FactorGraph& fg)
|
||||
{
|
||||
groundFg_ = &fg;
|
||||
freeColor_ = 0;
|
||||
|
||||
const FgVarSet& varNodes = fg.getFgVarNodes();
|
||||
const FactorSet& factors = fg.getFactors();
|
||||
varColors_.resize (varNodes.size());
|
||||
factorColors_.resize (factors.size());
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
factors[i]->setIndex (i);
|
||||
}
|
||||
|
||||
// create the initial variable colors
|
||||
VarColorMap colorMap;
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
unsigned dsize = varNodes[i]->getDomainSize();
|
||||
VarColorMap::iterator it = colorMap.find (dsize);
|
||||
if (it == colorMap.end()) {
|
||||
it = colorMap.insert (make_pair (
|
||||
dsize, vector<Color> (dsize + 1,-1))).first;
|
||||
}
|
||||
unsigned idx;
|
||||
if (varNodes[i]->hasEvidence()) {
|
||||
idx = varNodes[i]->getEvidence();
|
||||
} else {
|
||||
idx = dsize;
|
||||
}
|
||||
vector<Color>& stateColors = it->second;
|
||||
if (stateColors[idx] == -1) {
|
||||
stateColors[idx] = getFreeColor();
|
||||
}
|
||||
setColor (varNodes[i], stateColors[idx]);
|
||||
}
|
||||
|
||||
// create the initial factor colors
|
||||
DistColorMap distColors;
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
Distribution* dist = factors[i]->getDistribution();
|
||||
DistColorMap::iterator it = distColors.find (dist);
|
||||
if (it == distColors.end()) {
|
||||
it = distColors.insert (make_pair (dist, getFreeColor())).first;
|
||||
}
|
||||
setColor (factors[i], it->second);
|
||||
}
|
||||
|
||||
VarSignMap varGroups;
|
||||
FactorSignMap factorGroups;
|
||||
bool groupsHaveChanged = true;
|
||||
unsigned nIter = 0;
|
||||
while (groupsHaveChanged || nIter == 1) {
|
||||
nIter ++;
|
||||
if (Statistics::numCreatedNets == 4) {
|
||||
cout << "--------------------------------------------" << endl;
|
||||
cout << "Iteration " << nIter << endl;
|
||||
cout << "--------------------------------------------" << endl;
|
||||
}
|
||||
|
||||
unsigned prevFactorGroupsSize = factorGroups.size();
|
||||
factorGroups.clear();
|
||||
// set a new color to the factors with the same signature
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
const string& signatureId = getSignatureId (factors[i]);
|
||||
// cout << factors[i]->getLabel() << " signature: " ;
|
||||
// cout<< signatureId << endl;
|
||||
FactorSignMap::iterator it = factorGroups.find (signatureId);
|
||||
if (it == factorGroups.end()) {
|
||||
it = factorGroups.insert (make_pair (signatureId, FactorSet())).first;
|
||||
}
|
||||
it->second.push_back (factors[i]);
|
||||
}
|
||||
if (nIter > 0)
|
||||
for (FactorSignMap::iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FactorSet& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
}
|
||||
|
||||
// set a new color to the variables with the same signature
|
||||
unsigned prevVarGroupsSize = varGroups.size();
|
||||
varGroups.clear();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
const string& signatureId = getSignatureId (varNodes[i]);
|
||||
VarSignMap::iterator it = varGroups.find (signatureId);
|
||||
// cout << varNodes[i]->getLabel() << " signature: " ;
|
||||
// cout << signatureId << endl;
|
||||
if (it == varGroups.end()) {
|
||||
it = varGroups.insert (make_pair (signatureId, FgVarSet())).first;
|
||||
}
|
||||
it->second.push_back (varNodes[i]);
|
||||
}
|
||||
if (nIter > 0)
|
||||
for (VarSignMap::iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FgVarSet& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
}
|
||||
|
||||
//if (nIter >= 3) cout << "bigger than three: " << nIter << endl;
|
||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||
|| prevFactorGroupsSize != factorGroups.size();
|
||||
}
|
||||
|
||||
printGroups (varGroups, factorGroups);
|
||||
for (VarSignMap::iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
CFgVarSet vars = it->second;
|
||||
VarCluster* vc = new VarCluster (vars);
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
vid2VarCluster_.insert (make_pair (vars[i]->getVarId(), vc));
|
||||
}
|
||||
varClusters_.push_back (vc);
|
||||
}
|
||||
|
||||
for (FactorSignMap::iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
VarClusterSet varClusters;
|
||||
Factor* groundFactor = it->second[0];
|
||||
FgVarSet groundVars = groundFactor->getFgVarNodes();
|
||||
for (unsigned i = 0; i < groundVars.size(); i++) {
|
||||
Vid vid = groundVars[i]->getVarId();
|
||||
varClusters.push_back (vid2VarCluster_.find (vid)->second);
|
||||
}
|
||||
factorClusters_.push_back (new FactorCluster (it->second, varClusters));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedFG::~LiftedFG (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
delete varClusters_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < factorClusters_.size(); i++) {
|
||||
delete factorClusters_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
LiftedFG::getSignatureId (FgVarNode* var) const
|
||||
{
|
||||
stringstream ss;
|
||||
CFactorSet myFactors = var->getFactors();
|
||||
ss << myFactors.size();
|
||||
for (unsigned i = 0; i < myFactors.size(); i++) {
|
||||
ss << "." << getColor (myFactors[i]);
|
||||
ss << "." << myFactors[i]->getIndexOf(var);
|
||||
}
|
||||
ss << "." << getColor (var);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
LiftedFG::getSignatureId (Factor* factor) const
|
||||
{
|
||||
stringstream ss;
|
||||
CFgVarSet myVars = factor->getFgVarNodes();
|
||||
ss << myVars.size();
|
||||
for (unsigned i = 0; i < myVars.size(); i++) {
|
||||
ss << "." << getColor (myVars[i]);
|
||||
}
|
||||
ss << "." << getColor (factor);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
LiftedFG::getCompressedFactorGraph (void)
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0];
|
||||
FgVarNode* newVar = new FgVarNode (var);
|
||||
newVar->setIndex (i);
|
||||
varClusters_[i]->setRepresentativeVariable (newVar);
|
||||
fg->addVariable (newVar);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < factorClusters_.size(); i++) {
|
||||
FgVarSet myGroundVars;
|
||||
const VarClusterSet& myVarClusters = factorClusters_[i]->getVarClusters();
|
||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||
myGroundVars.push_back (myVarClusters[j]->getRepresentativeVariable());
|
||||
}
|
||||
Factor* newFactor = new Factor (myGroundVars,
|
||||
factorClusters_[i]->getGroundFactors()[0]->getDistribution());
|
||||
factorClusters_[i]->setRepresentativeFactor (newFactor);
|
||||
fg->addFactor (newFactor);
|
||||
}
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
LiftedFG::getGroundEdgeCount (FactorCluster* fc, VarCluster* vc) const
|
||||
{
|
||||
CFactorSet clusterGroundFactors = fc->getGroundFactors();
|
||||
FgVarNode* var = vc->getGroundFgVarNodes()[0];
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||
if (clusterGroundFactors[i]->getIndexOf (var) != -1) {
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
/*
|
||||
CFgVarSet vars = vc->getGroundFgVarNodes();
|
||||
for (unsigned i = 1; i < vars.size(); i++) {
|
||||
FgVarNode* var = vc->getGroundFgVarNodes()[i];
|
||||
unsigned count2 = 0;
|
||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||
if (clusterGroundFactors[i]->getIndexOf (var) != -1) {
|
||||
count2 ++;
|
||||
}
|
||||
}
|
||||
if (count != count2) { cout << "oops!" << endl; abort(); }
|
||||
}
|
||||
*/
|
||||
return count;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedFG::printGroups (const VarSignMap& varGroups,
|
||||
const FactorSignMap& factorGroups) const
|
||||
{
|
||||
cout << "variable groups:" << endl;
|
||||
unsigned count = 0;
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
const FgVarSet& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << ++count << ": " ;
|
||||
//if (groupMembers.size() > 1) {
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->getLabel() << " " ;
|
||||
}
|
||||
//}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
cout << endl;
|
||||
cout << "factor groups:" << endl;
|
||||
count = 0;
|
||||
for (FactorSignMap::const_iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
const FactorSet& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << ++count << ": " ;
|
||||
//if (groupMembers.size() > 1) {
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->getLabel() << " " ;
|
||||
}
|
||||
//}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
152
packages/CLPBN/clpbn/bp/LiftedFG.h
Normal file
152
packages/CLPBN/clpbn/bp/LiftedFG.h
Normal file
@ -0,0 +1,152 @@
|
||||
#ifndef BP_LIFTED_FG_H
|
||||
#define BP_LIFTED_FG_H
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "FgVarNode.h"
|
||||
#include "Factor.h"
|
||||
#include "Shared.h"
|
||||
|
||||
class VarCluster;
|
||||
class FactorCluster;
|
||||
class Distribution;
|
||||
|
||||
typedef long Color;
|
||||
typedef vector<Color> Signature;
|
||||
typedef vector<VarCluster*> VarClusterSet;
|
||||
typedef vector<FactorCluster*> FactorClusterSet;
|
||||
|
||||
typedef map<string, FgVarSet> VarSignMap;
|
||||
typedef map<string, FactorSet> FactorSignMap;
|
||||
|
||||
typedef map<unsigned, vector<Color> > VarColorMap;
|
||||
typedef map<Distribution*, Color> DistColorMap;
|
||||
|
||||
typedef map<Vid, VarCluster*> Vid2VarCluster;
|
||||
|
||||
|
||||
class VarCluster
|
||||
{
|
||||
public:
|
||||
VarCluster (CFgVarSet vs)
|
||||
{
|
||||
for (unsigned i = 0; i < vs.size(); i++) {
|
||||
groundVars_.push_back (vs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void addFactorCluster (FactorCluster* fc)
|
||||
{
|
||||
factorClusters_.push_back (fc);
|
||||
}
|
||||
|
||||
const FactorClusterSet& getFactorClusters (void) const
|
||||
{
|
||||
return factorClusters_;
|
||||
}
|
||||
|
||||
FgVarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
|
||||
CFgVarSet getGroundFgVarNodes (void) const { return groundVars_; }
|
||||
|
||||
private:
|
||||
FgVarSet groundVars_;
|
||||
FactorClusterSet factorClusters_;
|
||||
FgVarNode* representVar_;
|
||||
};
|
||||
|
||||
|
||||
class FactorCluster
|
||||
{
|
||||
public:
|
||||
FactorCluster (CFactorSet groundFactors, const VarClusterSet& vcs)
|
||||
{
|
||||
groundFactors_ = groundFactors;
|
||||
varClusters_ = vcs;
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
varClusters_[i]->addFactorCluster (this);
|
||||
}
|
||||
}
|
||||
|
||||
const VarClusterSet& getVarClusters (void) const
|
||||
{
|
||||
return varClusters_;
|
||||
}
|
||||
|
||||
bool containsGround (const Factor* f)
|
||||
{
|
||||
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
||||
if (groundFactors_[i] == f) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Factor* getRepresentativeFactor (void) const { return representFactor_; }
|
||||
void setRepresentativeFactor (Factor* f) { representFactor_ = f; }
|
||||
CFactorSet getGroundFactors (void) const { return groundFactors_; }
|
||||
|
||||
|
||||
private:
|
||||
FactorSet groundFactors_;
|
||||
VarClusterSet varClusters_;
|
||||
Factor* representFactor_;
|
||||
};
|
||||
|
||||
|
||||
class LiftedFG
|
||||
{
|
||||
public:
|
||||
LiftedFG (const FactorGraph&);
|
||||
~LiftedFG (void);
|
||||
|
||||
FactorGraph* getCompressedFactorGraph (void);
|
||||
unsigned getGroundEdgeCount (FactorCluster*, VarCluster*) const;
|
||||
void printGroups (const VarSignMap& varGroups,
|
||||
const FactorSignMap& factorGroups) const;
|
||||
|
||||
FgVarNode* getEquivalentVariable (Vid vid)
|
||||
{
|
||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||
return vc->getRepresentativeVariable();
|
||||
}
|
||||
|
||||
const VarClusterSet& getVariableClusters (void) { return varClusters_; }
|
||||
const FactorClusterSet& getFactorClusters (void) { return factorClusters_; }
|
||||
|
||||
private:
|
||||
string getSignatureId (FgVarNode*) const;
|
||||
string getSignatureId (Factor*) const;
|
||||
|
||||
Color getFreeColor (void) { return ++freeColor_ -1; }
|
||||
Color getColor (FgVarNode* v) const { return varColors_[v->getIndex()]; }
|
||||
Color getColor (Factor* f) const { return factorColors_[f->getIndex()]; }
|
||||
|
||||
void setColor (FgVarNode* v, Color c)
|
||||
{
|
||||
varColors_[v->getIndex()] = c;
|
||||
}
|
||||
|
||||
void setColor (Factor* f, Color c)
|
||||
{
|
||||
factorColors_[f->getIndex()] = c;
|
||||
}
|
||||
|
||||
VarCluster* getVariableCluster (Vid vid) const
|
||||
{
|
||||
return vid2VarCluster_.find (vid)->second;
|
||||
}
|
||||
|
||||
Color freeColor_;
|
||||
vector<Color> varColors_;
|
||||
vector<Color> factorColors_;
|
||||
VarClusterSet varClusters_;
|
||||
FactorClusterSet factorClusters_;
|
||||
Vid2VarCluster vid2VarCluster_;
|
||||
const FactorGraph* groundFg_;
|
||||
};
|
||||
|
||||
#endif // BP_LIFTED_FG_H
|
||||
|
@ -50,28 +50,33 @@ CWD=$(PWD)
|
||||
HEADERS = \
|
||||
$(srcdir)/GraphicalModel.h \
|
||||
$(srcdir)/Variable.h \
|
||||
$(srcdir)/Distribution.h \
|
||||
$(srcdir)/BayesNet.h \
|
||||
$(srcdir)/BayesNode.h \
|
||||
$(srcdir)/Distribution.h \
|
||||
$(srcdir)/LiftedFG.h \
|
||||
$(srcdir)/CptEntry.h \
|
||||
$(srcdir)/FactorGraph.h \
|
||||
$(srcdir)/FgVarNode.h \
|
||||
$(srcdir)/Factor.h \
|
||||
$(srcdir)/Solver.h \
|
||||
$(srcdir)/BPSolver.h \
|
||||
$(srcdir)/BpNode.h \
|
||||
$(srcdir)/BPNodeInfo.h \
|
||||
$(srcdir)/SPSolver.h \
|
||||
$(srcdir)/CountingBP.h \
|
||||
$(srcdir)/Shared.h \
|
||||
$(srcdir)/xmlParser/xmlParser.h
|
||||
|
||||
|
||||
CPP_SOURCES = \
|
||||
$(srcdir)/BayesNet.cpp \
|
||||
$(srcdir)/BayesNode.cpp \
|
||||
$(srcdir)/FactorGraph.cpp \
|
||||
$(srcdir)/Factor.cpp \
|
||||
$(srcdir)/LiftedFG.cpp \
|
||||
$(srcdir)/BPSolver.cpp \
|
||||
$(srcdir)/BpNode.cpp \
|
||||
$(srcdir)/BPNodeInfo.cpp \
|
||||
$(srcdir)/SPSolver.cpp \
|
||||
$(srcdir)/CountingBP.cpp \
|
||||
$(srcdir)/Util.cpp \
|
||||
$(srcdir)/HorusYap.cpp \
|
||||
$(srcdir)/HorusCli.cpp \
|
||||
$(srcdir)/xmlParser/xmlParser.cpp
|
||||
@ -82,22 +87,38 @@ OBJS = \
|
||||
FactorGraph.o \
|
||||
Factor.o \
|
||||
BPSolver.o \
|
||||
BpNode.o \
|
||||
BPNodeInfo.o \
|
||||
SPSolver.o \
|
||||
HorusYap.o \
|
||||
xmlParser.o
|
||||
Util.o \
|
||||
LiftedFG.o \
|
||||
CountingBP.o \
|
||||
HorusYap.o
|
||||
|
||||
HCLI_OBJS = \
|
||||
BayesNet.o \
|
||||
BayesNode.o \
|
||||
FactorGraph.o \
|
||||
Factor.o \
|
||||
BPSolver.o \
|
||||
BPNodeInfo.o \
|
||||
SPSolver.o \
|
||||
Util.o \
|
||||
LiftedFG.o \
|
||||
CountingBP.o \
|
||||
HorusCli.o \
|
||||
xmlParser/xmlParser.o
|
||||
|
||||
SOBJS=horus.@SO@
|
||||
|
||||
|
||||
all: $(SOBJS)
|
||||
all: $(SOBJS) hcli
|
||||
|
||||
# default rule
|
||||
%.o : $(srcdir)/%.cpp
|
||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
||||
|
||||
|
||||
xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
|
||||
xmlParser/xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
|
||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
||||
|
||||
|
||||
@ -105,7 +126,7 @@ xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
|
||||
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
||||
|
||||
|
||||
hcli: $(OBJS)
|
||||
hcli: $(HCLI_OBJS)
|
||||
$(CXX) -o hcli $(HCLI_OBJS)
|
||||
|
||||
|
||||
|
@ -1,38 +1,77 @@
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "SPSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "FgVarNode.h"
|
||||
#include "Factor.h"
|
||||
|
||||
SPSolver* Link::klass = 0;
|
||||
#include "Shared.h"
|
||||
|
||||
|
||||
SPSolver::SPSolver (const FactorGraph& fg) : Solver (&fg)
|
||||
SPSolver::SPSolver (FactorGraph& fg) : Solver (&fg)
|
||||
{
|
||||
fg_ = &fg;
|
||||
accuracy_ = 0.0001;
|
||||
maxIter_ = 10000;
|
||||
//schedule_ = S_SEQ_FIXED;
|
||||
//schedule_ = S_SEQ_RANDOM;
|
||||
//schedule_ = S_SEQ_PARALLEL;
|
||||
schedule_ = S_MAX_RESIDUAL;
|
||||
Link::klass = this;
|
||||
FgVarSet vars = fg_->getFgVarNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
msgs_.push_back (new MessageBanket (vars[i]));
|
||||
}
|
||||
fg_ = &fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
SPSolver::~SPSolver (void)
|
||||
{
|
||||
for (unsigned i = 0; i < msgs_.size(); i++) {
|
||||
delete msgs_[i];
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < factorsI_.size(); i++) {
|
||||
delete factorsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::runTreeSolver (void)
|
||||
{
|
||||
CFactorSet factors = fg_->getFactors();
|
||||
bool finish = false;
|
||||
while (!finish) {
|
||||
finish = true;
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
CLinkSet links = factorsI_[factors[i]->getIndex()]->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (!links[j]->messageWasSended()) {
|
||||
if (readyToSendMessage(links[j])) {
|
||||
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||
links[j]->updateMessage();
|
||||
}
|
||||
finish = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
SPSolver::readyToSendMessage (const Link* link) const
|
||||
{
|
||||
CFgVarSet factorVars = link->getFactor()->getFgVarNodes();
|
||||
for (unsigned i = 0; i < factorVars.size(); i++) {
|
||||
if (factorVars[i] != link->getVariable()) {
|
||||
CLinkSet links = varsI_[factorVars[i]->getIndex()]->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getFactor() != link->getFactor() &&
|
||||
!links[j]->messageWasSended()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -40,62 +79,54 @@ SPSolver::~SPSolver (void)
|
||||
void
|
||||
SPSolver::runSolver (void)
|
||||
{
|
||||
initializeSolver();
|
||||
runTreeSolver();
|
||||
return;
|
||||
nIter_ = 0;
|
||||
vector<Factor*> factors = fg_->getFactors();
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
FgVarSet neighbors = factors[i]->getFgVarNodes();
|
||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||
updateOrder_.push_back (Link (factors[i], neighbors[j]));
|
||||
}
|
||||
}
|
||||
|
||||
while (!converged() && nIter_ < maxIter_) {
|
||||
if (DL >= 1) {
|
||||
while (!converged() && nIter_ < SolverOptions::maxIter) {
|
||||
|
||||
nIter_ ++;
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
cout << " Iteration " << nIter_ + 1 << endl;
|
||||
cout << " Iteration " << nIter_ << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
switch (schedule_) {
|
||||
|
||||
case S_SEQ_RANDOM:
|
||||
random_shuffle (updateOrder_.begin(), updateOrder_.end());
|
||||
switch (SolverOptions::schedule) {
|
||||
case SolverOptions::S_SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
|
||||
case S_SEQ_FIXED:
|
||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
||||
Link& link = updateOrder_[c];
|
||||
calculateNextMessage (link.source, link.destination);
|
||||
updateMessage (updateOrder_[c]);
|
||||
case SolverOptions::S_SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||
links_[i]->updateMessage();
|
||||
}
|
||||
break;
|
||||
|
||||
case S_PARALLEL:
|
||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
||||
Link link = updateOrder_[c];
|
||||
calculateNextMessage (link.source, link.destination);
|
||||
case SolverOptions::S_PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||
}
|
||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
||||
Link link = updateOrder_[c];
|
||||
updateMessage (updateOrder_[c]);
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->updateMessage();
|
||||
}
|
||||
break;
|
||||
|
||||
case S_MAX_RESIDUAL:
|
||||
case SolverOptions::S_MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
|
||||
nIter_++;
|
||||
}
|
||||
cout << endl;
|
||||
if (DL >= 1) {
|
||||
if (nIter_ < maxIter_) {
|
||||
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
if (nIter_ < SolverOptions::maxIter) {
|
||||
cout << "Loopy Sum-Product converged in " ;
|
||||
cout << nIter_ << " iterations" << endl;
|
||||
} else {
|
||||
@ -108,58 +139,168 @@ SPSolver::runSolver (void)
|
||||
|
||||
|
||||
ParamSet
|
||||
SPSolver::getPosterioriOf (const Variable* var) const
|
||||
SPSolver::getPosterioriOf (Vid vid) const
|
||||
{
|
||||
assert (var);
|
||||
assert (var == fg_->getVariableById (var->getVarId()));
|
||||
assert (var->getIndex() < msgs_.size());
|
||||
assert (fg_->getFgVarNode (vid));
|
||||
FgVarNode* var = fg_->getFgVarNode (vid);
|
||||
ParamSet probs;
|
||||
|
||||
ParamSet probs (var->getDomainSize(), 1);
|
||||
if (var->hasEvidence()) {
|
||||
for (unsigned i = 0; i < probs.size(); i++) {
|
||||
if ((int)i != var->getEvidence()) {
|
||||
probs[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
probs.resize (var->getDomainSize(), 0.0);
|
||||
probs[var->getEvidence()] = 1.0;
|
||||
} else {
|
||||
|
||||
MessageBanket* mb = msgs_[var->getIndex()];
|
||||
const FgVarNode* varNode = fg_->getFgVarNodes()[var->getIndex()];
|
||||
vector<Factor*> neighbors = varNode->getFactors();
|
||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
||||
const Message& msg = mb->getMessage (neighbors[i]);
|
||||
probs.resize (var->getDomainSize(), 1.0);
|
||||
CLinkSet links = varsI_[var->getIndex()]->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CParamSet msg = links[i]->getMessage();
|
||||
for (unsigned j = 0; j < msg.size(); j++) {
|
||||
probs[j] *= msg[j];
|
||||
}
|
||||
}
|
||||
Util::normalize (probs);
|
||||
}
|
||||
|
||||
return probs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
SPSolver::getJointDistributionOf (const VidSet& jointVids)
|
||||
{
|
||||
FgVarSet jointVars;
|
||||
unsigned dsize = 1;
|
||||
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||
FgVarNode* varNode = fg_->getFgVarNode (jointVids[i]);
|
||||
dsize *= varNode->getDomainSize();
|
||||
jointVars.push_back (varNode);
|
||||
}
|
||||
|
||||
unsigned maxVid = std::numeric_limits<unsigned>::max();
|
||||
FgVarNode* junctionVar = new FgVarNode (maxVid, dsize);
|
||||
FgVarSet factorVars = { junctionVar };
|
||||
for (unsigned i = 0; i < jointVars.size(); i++) {
|
||||
factorVars.push_back (jointVars[i]);
|
||||
}
|
||||
|
||||
unsigned nParams = dsize * dsize;
|
||||
ParamSet params (nParams);
|
||||
for (unsigned i = 0; i < nParams; i++) {
|
||||
unsigned row = i / dsize;
|
||||
unsigned col = i % dsize;
|
||||
if (row == col) {
|
||||
params[i] = 1;
|
||||
} else {
|
||||
params[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Distribution* dist = new Distribution (params, maxVid);
|
||||
Factor* newFactor = new Factor (factorVars, dist);
|
||||
fg_->addVariable (junctionVar);
|
||||
fg_->addFactor (newFactor);
|
||||
|
||||
runSolver();
|
||||
ParamSet results = getPosterioriOf (maxVid);
|
||||
deleteJunction (newFactor, junctionVar);
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::initializeSolver (void)
|
||||
{
|
||||
fg_->setIndexes();
|
||||
|
||||
CFgVarSet vars = fg_->getFgVarNodes();
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
varsI_.reserve (vars.size());
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
CFactorSet factors = fg_->getFactors();
|
||||
for (unsigned i = 0; i < factorsI_.size(); i++) {
|
||||
delete factorsI_[i];
|
||||
}
|
||||
factorsI_.reserve (factors.size());
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
factorsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
createLinks();
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
Factor* source = links_[i]->getFactor();
|
||||
FgVarNode* dest = links_[i]->getVariable();
|
||||
varsI_[dest->getIndex()]->addLink (links_[i]);
|
||||
factorsI_[source->getIndex()]->addLink (links_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::createLinks (void)
|
||||
{
|
||||
CFactorSet factors = fg_->getFactors();
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
CFgVarSet neighbors = factors[i]->getFgVarNodes();
|
||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||
links_.push_back (new Link (factors[i], neighbors[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::deleteJunction (Factor* f, FgVarNode* v)
|
||||
{
|
||||
fg_->removeFactor (f);
|
||||
f->freeDistribution();
|
||||
delete f;
|
||||
fg_->removeVariable (v);
|
||||
delete v;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
SPSolver::converged (void)
|
||||
{
|
||||
// this can happen if the graph is fully disconnected
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIter_ == 0 || nIter_ == 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
for (unsigned i = 0; i < updateOrder_.size(); i++) {
|
||||
double residual = getResidual (updateOrder_[i]);
|
||||
if (DL >= 1) {
|
||||
cout << updateOrder_[i].toString();
|
||||
cout << " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > accuracy_) {
|
||||
if (SolverOptions::schedule == SolverOptions::S_MAX_RESIDUAL) {
|
||||
Param maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual < SolverOptions::accuracy) {
|
||||
converged = true;
|
||||
} else {
|
||||
converged = false;
|
||||
if (DL == 0) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (DL >= 2) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
}
|
||||
if (residual > SolverOptions::accuracy) {
|
||||
converged = false;
|
||||
if (DL == 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
@ -169,127 +310,161 @@ SPSolver::converged (void)
|
||||
void
|
||||
SPSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIter_ == 0) {
|
||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
||||
Link& l = updateOrder_[c];
|
||||
calculateNextMessage (l.source, l.destination);
|
||||
if (DL >= 1) {
|
||||
cout << updateOrder_[c].toString() << " residual = " ;
|
||||
cout << getResidual (updateOrder_[c]) << endl;
|
||||
if (nIter_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
if (DL >= 2 && DL < 5) {
|
||||
cout << "calculating " << links_[i]->toString() << endl;
|
||||
}
|
||||
}
|
||||
sort (updateOrder_.begin(), updateOrder_.end(), compareResidual);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
||||
Link& link = updateOrder_.front();
|
||||
updateMessage (link);
|
||||
resetResidual (link);
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
cout << endl << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
// update the messages that depend on message source --> destination
|
||||
vector<Factor*> fstLevelNeighbors = link.destination->getFactors();
|
||||
for (unsigned i = 0; i < fstLevelNeighbors.size(); i++) {
|
||||
if (fstLevelNeighbors[i] != link.source) {
|
||||
FgVarSet sndLevelNeighbors;
|
||||
sndLevelNeighbors = fstLevelNeighbors[i]->getFgVarNodes();
|
||||
for (unsigned j = 0; j < sndLevelNeighbors.size(); j++) {
|
||||
if (sndLevelNeighbors[j] != link.destination) {
|
||||
calculateNextMessage (fstLevelNeighbors[i], sndLevelNeighbors[j]);
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
Link* link = *it;
|
||||
if (DL >= 2) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
}
|
||||
if (link->getResidual() < SolverOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
link->updateMessage();
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
CFactorSet factorNeighbors = link->getVariable()->getFactors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
if (factorNeighbors[i] != link->getFactor()) {
|
||||
CLinkSet links = factorsI_[factorNeighbors[i]->getIndex()]->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) {
|
||||
if (DL >= 2 && DL < 5) {
|
||||
cout << " calculating " << links[j]->toString() << endl;
|
||||
}
|
||||
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||
LinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
sort (updateOrder_.begin(), updateOrder_.end(), compareResidual);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::updateMessage (const Link& link)
|
||||
ParamSet
|
||||
SPSolver::getFactor2VarMsg (const Link* link) const
|
||||
{
|
||||
updateMessage (link.source, link.destination);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::updateMessage (const Factor* src, const FgVarNode* dest)
|
||||
{
|
||||
msgs_[dest->getIndex()]->updateMessage (src);
|
||||
/* cout << src->getLabel() << " --> " << dest->getLabel() << endl;
|
||||
cout << " m: " ;
|
||||
Message msg = msgs_[dest->getIndex()]->getMessage (src);
|
||||
for (unsigned i = 0; i < msg.size(); i++) {
|
||||
if (i != 0) cout << ", " ;
|
||||
cout << msg[i];
|
||||
}
|
||||
cout << endl;
|
||||
*/
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::calculateNextMessage (const Link& link)
|
||||
{
|
||||
calculateNextMessage (link.source, link.destination);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
SPSolver::calculateNextMessage (const Factor* src, const FgVarNode* dest)
|
||||
{
|
||||
FgVarSet neighbors = src->getFgVarNodes();
|
||||
// calculate the product of MessageBankets sended
|
||||
const Factor* src = link->getFactor();
|
||||
const FgVarNode* dest = link->getVariable();
|
||||
CFgVarSet neighbors = src->getFgVarNodes();
|
||||
CLinkSet links = factorsI_[src->getIndex()]->getLinks();
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dest'
|
||||
Factor result = *src;
|
||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
||||
if (neighbors[i] != dest) {
|
||||
Message msg (neighbors[i]->getDomainSize(), 1);
|
||||
calculateVarFactorMessage (neighbors[i], src, msg);
|
||||
result *= Factor (neighbors[i], msg);
|
||||
}
|
||||
Factor result (*src);
|
||||
Factor temp;
|
||||
if (DL >= 5) {
|
||||
cout << "calculating " ;
|
||||
cout << src->getLabel() << " --> " << dest->getLabel();
|
||||
cout << endl;
|
||||
}
|
||||
// marginalize all vars except `dest'
|
||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
||||
if (neighbors[i] != dest) {
|
||||
result.marginalizeVariable (neighbors[i]);
|
||||
}
|
||||
}
|
||||
msgs_[dest->getIndex()]->setNextMessage (src, result.getParameters());
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::calculateVarFactorMessage (const FgVarNode* src,
|
||||
const Factor* dest,
|
||||
Message& placeholder) const
|
||||
{
|
||||
assert (src->getDomainSize() == (int)placeholder.size());
|
||||
if (src->hasEvidence()) {
|
||||
for (unsigned i = 0; i < placeholder.size(); i++) {
|
||||
if ((int)i != src->getEvidence()) {
|
||||
placeholder[i] = 0.0;
|
||||
if (links[i]->getVariable() != dest) {
|
||||
if (DL >= 5) {
|
||||
cout << " message from " << links[i]->getVariable()->getLabel();
|
||||
cout << ": " ;
|
||||
ParamSet p = getVar2FactorMsg (links[i]);
|
||||
cout << endl;
|
||||
Factor temp2 (links[i]->getVariable(), p);
|
||||
temp.multiplyByFactor (temp2);
|
||||
temp2.freeDistribution();
|
||||
} else {
|
||||
placeholder[i] = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
MessageBanket* mb = msgs_[src->getIndex()];
|
||||
vector<Factor*> neighbors = src->getFactors();
|
||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
||||
if (neighbors[i] != dest) {
|
||||
const Message& fromFactor = mb->getMessage (neighbors[i]);
|
||||
for (unsigned j = 0; j < fromFactor.size(); j++) {
|
||||
placeholder[j] *= fromFactor[j];
|
||||
}
|
||||
Factor temp2 (links[i]->getVariable(), getVar2FactorMsg (links[i]));
|
||||
temp.multiplyByFactor (temp2);
|
||||
temp2.freeDistribution();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (links.size() >= 2) {
|
||||
result.multiplyByFactor (temp, &(src->getCptEntries()));
|
||||
if (DL >= 5) {
|
||||
cout << " message product: " ;
|
||||
cout << Util::parametersToString (temp.getParameters()) << endl;
|
||||
cout << " factor product: " ;
|
||||
cout << Util::parametersToString (src->getParameters());
|
||||
cout << " x " ;
|
||||
cout << Util::parametersToString (temp.getParameters());
|
||||
cout << " = " ;
|
||||
cout << Util::parametersToString (result.getParameters()) << endl;
|
||||
}
|
||||
temp.freeDistribution();
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getVariable() != dest) {
|
||||
result.removeVariable (links[i]->getVariable());
|
||||
}
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " final message: " ;
|
||||
cout << Util::parametersToString (result.getParameters()) << endl << endl;
|
||||
}
|
||||
ParamSet msg = result.getParameters();
|
||||
result.freeDistribution();
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
SPSolver::getVar2FactorMsg (const Link* link) const
|
||||
{
|
||||
const FgVarNode* src = link->getVariable();
|
||||
const Factor* dest = link->getFactor();
|
||||
ParamSet msg;
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->getDomainSize(), 0.0);
|
||||
msg[src->getEvidence()] = 1.0;
|
||||
if (DL >= 5) {
|
||||
cout << Util::parametersToString (msg);
|
||||
}
|
||||
} else {
|
||||
msg.resize (src->getDomainSize(), 1.0);
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << Util::parametersToString (msg);
|
||||
}
|
||||
CLinkSet links = varsI_[src->getIndex()]->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dest) {
|
||||
CParamSet msgFromFactor = links[i]->getMessage();
|
||||
for (unsigned j = 0; j < msgFromFactor.size(); j++) {
|
||||
msg[j] *= msgFromFactor[j];
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " x " << Util::parametersToString (msgFromFactor);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " = " << Util::parametersToString (msg);
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,8 @@
|
||||
#ifndef BP_SPSOLVER_H
|
||||
#define BP_SPSOLVER_H
|
||||
#ifndef BP_SP_SOLVER_H
|
||||
#define BP_SP_SOLVER_H
|
||||
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "FgVarNode.h"
|
||||
@ -15,157 +13,118 @@ using namespace std;
|
||||
class FactorGraph;
|
||||
class SPSolver;
|
||||
|
||||
struct Link
|
||||
{
|
||||
Link (Factor* s, FgVarNode* d)
|
||||
{
|
||||
source = s;
|
||||
destination = d;
|
||||
}
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << source->getLabel() << " --> " ;
|
||||
ss << destination->getLabel();
|
||||
return ss.str();
|
||||
}
|
||||
Factor* source;
|
||||
FgVarNode* destination;
|
||||
static SPSolver* klass;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class MessageBanket
|
||||
class Link
|
||||
{
|
||||
public:
|
||||
MessageBanket (const FgVarNode* var)
|
||||
Link (Factor* f, FgVarNode* v)
|
||||
{
|
||||
factor_ = f;
|
||||
var_ = v;
|
||||
currMsg_.resize (v->getDomainSize(), 1);
|
||||
nextMsg_.resize (v->getDomainSize(), 1);
|
||||
msgSended_ = false;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
void setMessage (ParamSet msg)
|
||||
{
|
||||
vector<Factor*> sources = var->getFactors();
|
||||
for (unsigned i = 0; i < sources.size(); i++) {
|
||||
indexMap_.insert (make_pair (sources[i], i));
|
||||
currMsgs_.push_back (Message(var->getDomainSize(), 1));
|
||||
nextMsgs_.push_back (Message(var->getDomainSize(), -10));
|
||||
residuals_.push_back (0.0);
|
||||
}
|
||||
Util::normalize (msg);
|
||||
residual_ = Util::getMaxNorm (currMsg_, msg);
|
||||
currMsg_ = msg;
|
||||
}
|
||||
|
||||
void updateMessage (const Factor* source)
|
||||
void setNextMessage (CParamSet msg)
|
||||
{
|
||||
unsigned idx = getIndex(source);
|
||||
currMsgs_[idx] = nextMsgs_[idx];
|
||||
nextMsg_ = msg;
|
||||
Util::normalize (nextMsg_);
|
||||
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
|
||||
}
|
||||
|
||||
void setNextMessage (const Factor* source, const Message& msg)
|
||||
void updateMessage (void)
|
||||
{
|
||||
unsigned idx = getIndex(source);
|
||||
nextMsgs_[idx] = msg;
|
||||
residuals_[idx] = computeResidual (source);
|
||||
currMsg_ = nextMsg_;
|
||||
msgSended_ = true;
|
||||
}
|
||||
|
||||
const Message& getMessage (const Factor* source) const
|
||||
string toString (void) const
|
||||
{
|
||||
return currMsgs_[getIndex(source)];
|
||||
}
|
||||
|
||||
double getResidual (const Factor* source) const
|
||||
{
|
||||
return residuals_[getIndex(source)];
|
||||
}
|
||||
|
||||
void resetResidual (const Factor* source)
|
||||
{
|
||||
residuals_[getIndex(source)] = 0.0;
|
||||
stringstream ss;
|
||||
ss << factor_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->getLabel();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
Factor* getFactor (void) const { return factor_; }
|
||||
FgVarNode* getVariable (void) const { return var_; }
|
||||
CParamSet getMessage (void) const { return currMsg_; }
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
double getResidual (void) const { return residual_; }
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
private:
|
||||
double computeResidual (const Factor* source)
|
||||
{
|
||||
double change = 0.0;
|
||||
unsigned idx = getIndex (source);
|
||||
const Message& currMessage = currMsgs_[idx];
|
||||
const Message& nextMessage = nextMsgs_[idx];
|
||||
for (unsigned i = 0; i < currMessage.size(); i++) {
|
||||
change += abs (currMessage[i] - nextMessage[i]);
|
||||
}
|
||||
return change;
|
||||
}
|
||||
|
||||
unsigned getIndex (const Factor* factor) const
|
||||
{
|
||||
assert (factor);
|
||||
assert (indexMap_.find(factor) != indexMap_.end());
|
||||
return indexMap_.find(factor)->second;
|
||||
}
|
||||
|
||||
typedef map<const Factor*, unsigned> IndexMap;
|
||||
|
||||
IndexMap indexMap_;
|
||||
vector<Message> currMsgs_;
|
||||
vector<Message> nextMsgs_;
|
||||
vector<double> residuals_;
|
||||
Factor* factor_;
|
||||
FgVarNode* var_;
|
||||
ParamSet currMsg_;
|
||||
ParamSet nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
|
||||
|
||||
class SPNodeInfo
|
||||
{
|
||||
public:
|
||||
void addLink (Link* link) { links_.push_back (link); }
|
||||
CLinkSet getLinks (void) { return links_; }
|
||||
|
||||
private:
|
||||
LinkSet links_;
|
||||
};
|
||||
|
||||
|
||||
class SPSolver : public Solver
|
||||
{
|
||||
public:
|
||||
SPSolver (const FactorGraph&);
|
||||
~SPSolver (void);
|
||||
SPSolver (FactorGraph&);
|
||||
virtual ~SPSolver (void);
|
||||
|
||||
void runSolver (void);
|
||||
ParamSet getPosterioriOf (const Variable* var) const;
|
||||
void runSolver (void);
|
||||
virtual ParamSet getPosterioriOf (Vid) const;
|
||||
ParamSet getJointDistributionOf (CVidSet);
|
||||
|
||||
protected:
|
||||
virtual void initializeSolver (void);
|
||||
void runTreeSolver (void);
|
||||
bool readyToSendMessage (const Link*) const;
|
||||
virtual void createLinks (void);
|
||||
virtual void deleteJunction (Factor*, FgVarNode*);
|
||||
bool converged (void);
|
||||
virtual void maxResidualSchedule (void);
|
||||
virtual ParamSet getFactor2VarMsg (const Link*) const;
|
||||
virtual ParamSet getVar2FactorMsg (const Link*) const;
|
||||
|
||||
private:
|
||||
bool converged (void);
|
||||
void maxResidualSchedule (void);
|
||||
void updateMessage (const Link&);
|
||||
void updateMessage (const Factor*, const FgVarNode*);
|
||||
void calculateNextMessage (const Link&);
|
||||
void calculateNextMessage (const Factor*, const FgVarNode*);
|
||||
void calculateVarFactorMessage (
|
||||
const FgVarNode*, const Factor*, Message&) const;
|
||||
double getResidual (const Link&) const;
|
||||
void resetResidual (const Link&) const;
|
||||
friend bool compareResidual (const Link&, const Link&);
|
||||
struct CompareResidual {
|
||||
inline bool operator() (const Link* link1, const Link* link2)
|
||||
{
|
||||
return link1->getResidual() > link2->getResidual();
|
||||
}
|
||||
};
|
||||
|
||||
FactorGraph* fg_;
|
||||
LinkSet links_;
|
||||
vector<SPNodeInfo*> varsI_;
|
||||
vector<SPNodeInfo*> factorsI_;
|
||||
unsigned nIter_;
|
||||
|
||||
typedef multiset<Link*, CompareResidual> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef map<Link*, SortedOrder::iterator> LinkMap;
|
||||
LinkMap linkMap_;
|
||||
|
||||
const FactorGraph* fg_;
|
||||
vector<MessageBanket*> msgs_;
|
||||
Schedule schedule_;
|
||||
int nIter_;
|
||||
double accuracy_;
|
||||
int maxIter_;
|
||||
vector<Link> updateOrder_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline double
|
||||
SPSolver::getResidual (const Link& link) const
|
||||
{
|
||||
MessageBanket* mb = Link::klass->msgs_[link.destination->getIndex()];
|
||||
return mb->getResidual (link.source);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
SPSolver::resetResidual (const Link& link) const
|
||||
{
|
||||
MessageBanket* mb = Link::klass->msgs_[link.destination->getIndex()];
|
||||
mb->resetResidual (link.source);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
compareResidual (const Link& link1, const Link& link2)
|
||||
{
|
||||
MessageBanket* mb1 = Link::klass->msgs_[link1.destination->getIndex()];
|
||||
MessageBanket* mb2 = Link::klass->msgs_[link2.destination->getIndex()];
|
||||
return mb1->getResidual(link1.source) > mb2->getResidual(link2.source);
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif // BP_SP_SOLVER_H
|
||||
|
||||
|
@ -2,14 +2,15 @@
|
||||
#define BP_SHARED_H
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
|
||||
// Macro to disallow the copy constructor and operator= functions
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
|
||||
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||
TypeName(const TypeName&); \
|
||||
void operator=(const TypeName&)
|
||||
@ -19,61 +20,162 @@ using namespace std;
|
||||
class Variable;
|
||||
class BayesNode;
|
||||
class FgVarNode;
|
||||
class Factor;
|
||||
class Link;
|
||||
class Edge;
|
||||
|
||||
typedef double Param;
|
||||
typedef vector<Param> ParamSet;
|
||||
typedef vector<Param> Message;
|
||||
typedef const ParamSet& CParamSet;
|
||||
typedef unsigned Vid;
|
||||
typedef vector<Vid> VidSet;
|
||||
typedef const VidSet& CVidSet;
|
||||
typedef vector<Variable*> VarSet;
|
||||
typedef vector<BayesNode*> NodeSet;
|
||||
typedef vector<BayesNode*> BnNodeSet;
|
||||
typedef const BnNodeSet& CBnNodeSet;
|
||||
typedef vector<FgVarNode*> FgVarSet;
|
||||
typedef const FgVarSet& CFgVarSet;
|
||||
typedef vector<Factor*> FactorSet;
|
||||
typedef const FactorSet& CFactorSet;
|
||||
typedef vector<Link*> LinkSet;
|
||||
typedef const LinkSet& CLinkSet;
|
||||
typedef vector<Edge*> EdgeSet;
|
||||
typedef const EdgeSet& CEdgeSet;
|
||||
typedef vector<string> Domain;
|
||||
typedef vector<unsigned> DomainConf;
|
||||
typedef pair<unsigned, unsigned> DomainConstr;
|
||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
||||
typedef vector<unsigned> DConf;
|
||||
typedef pair<unsigned, unsigned> DConstraint;
|
||||
typedef map<unsigned, unsigned> IndexMap;
|
||||
|
||||
|
||||
//extern unsigned DL;
|
||||
// level of debug information
|
||||
static const unsigned DL = 0;
|
||||
|
||||
// number of digits to show when printing a parameter
|
||||
static const unsigned PRECISION = 10;
|
||||
static const int NO_EVIDENCE = -1;
|
||||
|
||||
// shared by bp and sp solver
|
||||
enum Schedule
|
||||
// number of digits to show when printing a parameter
|
||||
static const unsigned PRECISION = 5;
|
||||
|
||||
static const bool EXPORT_TO_DOT = false;
|
||||
static const unsigned EXPORT_MIN_SIZE = 30;
|
||||
|
||||
|
||||
namespace SolverOptions
|
||||
{
|
||||
S_SEQ_FIXED,
|
||||
S_SEQ_RANDOM,
|
||||
S_PARALLEL,
|
||||
S_MAX_RESIDUAL
|
||||
enum Schedule
|
||||
{
|
||||
S_SEQ_FIXED,
|
||||
S_SEQ_RANDOM,
|
||||
S_PARALLEL,
|
||||
S_MAX_RESIDUAL
|
||||
};
|
||||
extern bool runBayesBall;
|
||||
extern bool convertBn2Fg;
|
||||
extern bool compressFactorGraph;
|
||||
extern Schedule schedule;
|
||||
extern double accuracy;
|
||||
extern unsigned maxIter;
|
||||
}
|
||||
|
||||
|
||||
namespace Util
|
||||
{
|
||||
void normalize (ParamSet&);
|
||||
void pow (ParamSet&, unsigned);
|
||||
double getL1dist (CParamSet, CParamSet);
|
||||
double getMaxNorm (CParamSet, CParamSet);
|
||||
bool isInteger (const string&);
|
||||
string parametersToString (CParamSet);
|
||||
vector<DConf> getDomainConfigurations (const VarSet&);
|
||||
vector<string> getInstantiations (const VarSet&);
|
||||
};
|
||||
|
||||
|
||||
struct NetInfo
|
||||
{
|
||||
NetInfo (unsigned c, double t)
|
||||
NetInfo (void)
|
||||
{
|
||||
counting = c;
|
||||
solvingTime = t;
|
||||
counting = 0;
|
||||
nIters = 0;
|
||||
solvingTime = 0.0;
|
||||
}
|
||||
unsigned counting;
|
||||
double solvingTime;
|
||||
unsigned nIters;
|
||||
};
|
||||
|
||||
|
||||
struct CompressInfo
|
||||
{
|
||||
CompressInfo (unsigned a, unsigned b, unsigned c,
|
||||
unsigned d, unsigned e) {
|
||||
nUncVars = a;
|
||||
nUncFactors = b;
|
||||
nCompVars = c;
|
||||
nCompFactors = d;
|
||||
nNeighborlessVars = e;
|
||||
}
|
||||
unsigned nUncVars;
|
||||
unsigned nUncFactors;
|
||||
unsigned nCompVars;
|
||||
unsigned nCompFactors;
|
||||
unsigned nNeighborlessVars;
|
||||
};
|
||||
|
||||
|
||||
typedef map<unsigned, NetInfo> StatisticMap;
|
||||
|
||||
|
||||
class Statistics
|
||||
{
|
||||
public:
|
||||
|
||||
static void updateStats (unsigned size, double time)
|
||||
static void updateStats (unsigned size, unsigned nIters, double time)
|
||||
{
|
||||
StatisticMap::iterator it = stats_.find(size);
|
||||
StatisticMap::iterator it = stats_.find (size);
|
||||
if (it == stats_.end()) {
|
||||
stats_.insert (make_pair (size, NetInfo (1, 0.0)));
|
||||
it = (stats_.insert (make_pair (size, NetInfo()))).first;
|
||||
} else {
|
||||
it->second.counting ++;
|
||||
it->second.nIters += nIters;
|
||||
it->second.solvingTime += time;
|
||||
totalOfIterations += nIters;
|
||||
if (nIters > maxIterations) {
|
||||
maxIterations = nIters;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void updateCompressingStats (unsigned nUncVars,
|
||||
unsigned nUncFactors,
|
||||
unsigned nCompVars,
|
||||
unsigned nCompFactors,
|
||||
unsigned nNeighborlessVars) {
|
||||
compressInfo_.push_back (CompressInfo (
|
||||
nUncVars, nUncFactors, nCompVars, nCompFactors, nNeighborlessVars));
|
||||
}
|
||||
|
||||
static void printCompressingStats (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "BayesNet::printCompressingStats()" << endl;
|
||||
abort();
|
||||
}
|
||||
out << "--------------------------------------" ;
|
||||
out << "--------------------------------------" << endl;
|
||||
out << " Compression Stats" << endl;
|
||||
out << "--------------------------------------" ;
|
||||
out << "--------------------------------------" << endl;
|
||||
out << left;
|
||||
out << "Uncompress Compressed Uncompress Compressed Neighborless";
|
||||
out << endl;
|
||||
out << "Vars Vars Factors Factors Vars" ;
|
||||
out << endl;
|
||||
for (unsigned i = 0; i < compressInfo_.size(); i++) {
|
||||
out << setw (13) << compressInfo_[i].nUncVars;
|
||||
out << setw (13) << compressInfo_[i].nCompVars;
|
||||
out << setw (13) << compressInfo_[i].nUncFactors;
|
||||
out << setw (13) << compressInfo_[i].nCompFactors;
|
||||
out << setw (13) << compressInfo_[i].nNeighborlessVars;
|
||||
out << endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -84,20 +186,12 @@ class Statistics
|
||||
return it->second.counting;
|
||||
}
|
||||
|
||||
static void updateIterations (unsigned nIters)
|
||||
{
|
||||
totalOfIterations += nIters;
|
||||
if (nIters > maxIterations) {
|
||||
maxIterations = nIters;
|
||||
}
|
||||
}
|
||||
|
||||
static void writeStats (void)
|
||||
{
|
||||
ofstream out ("../../stats.txt");
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Statistics:::updateStats()" << endl;
|
||||
cerr << "Statistics::updateStats()" << endl;
|
||||
abort();
|
||||
}
|
||||
unsigned avgIterations = 0;
|
||||
@ -117,17 +211,24 @@ class Statistics
|
||||
out << " average iterations: " << avgIterations << endl;
|
||||
out << "total solving time " << totalSolvingTime << endl;
|
||||
out << endl;
|
||||
out << "Network Size\tCounting\tSolving Time\tAverage Time" << endl;
|
||||
out << left << endl;
|
||||
out << setw (15) << "Network Size" ;
|
||||
out << setw (15) << "Counting" ;
|
||||
out << setw (15) << "Solving Time" ;
|
||||
out << setw (15) << "Average Time" ;
|
||||
out << setw (15) << "#Iterations" ;
|
||||
out << endl;
|
||||
for (StatisticMap::iterator it = stats_.begin();
|
||||
it != stats_.end(); it++) {
|
||||
out << it->first;
|
||||
out << "\t\t" << it->second.counting;
|
||||
out << "\t\t" << it->second.solvingTime;
|
||||
out << setw (15) << it->first;
|
||||
out << setw (15) << it->second.counting;
|
||||
out << setw (15) << it->second.solvingTime;
|
||||
if (it->second.counting > 0) {
|
||||
out << "\t\t" << it->second.solvingTime / it->second.counting;
|
||||
out << setw (15) << it->second.solvingTime / it->second.counting;
|
||||
} else {
|
||||
out << "\t\t0.0" ;
|
||||
out << setw (15) << "0.0" ;
|
||||
}
|
||||
out << setw (15) << it->second.nIters;
|
||||
out << endl;
|
||||
}
|
||||
out.close();
|
||||
@ -142,62 +243,8 @@ class Statistics
|
||||
static StatisticMap stats_;
|
||||
static unsigned maxIterations;
|
||||
static unsigned totalOfIterations;
|
||||
|
||||
static vector<CompressInfo> compressInfo_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class Util
|
||||
{
|
||||
public:
|
||||
static void normalize (ParamSet& v)
|
||||
{
|
||||
double sum = 0.0;
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
sum += v[i];
|
||||
}
|
||||
assert (sum != 0.0);
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
static double getL1dist (const ParamSet& v1, const ParamSet& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
double dist = 0.0;
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
dist += abs (v1[i] - v2[i]);
|
||||
}
|
||||
return dist;
|
||||
}
|
||||
|
||||
static double getMaxNorm (const ParamSet& v1, const ParamSet& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
double max = 0.0;
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
double diff = abs (v1[i] - v2[i]);
|
||||
if (diff > max) {
|
||||
max = diff;
|
||||
}
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
static bool isInteger (const string& s)
|
||||
{
|
||||
stringstream ss1 (s);
|
||||
stringstream ss2;
|
||||
int integer;
|
||||
ss1 >> integer;
|
||||
ss2 << integer;
|
||||
return (ss1.str() == ss2.str());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
//unsigned Statistics::totalOfIterations = 0;
|
||||
|
||||
#endif
|
||||
#endif //BP_SHARED_H
|
||||
|
||||
|
@ -15,19 +15,30 @@ class Solver
|
||||
{
|
||||
gm_ = gm;
|
||||
}
|
||||
virtual ~Solver() {} // to call subclass destructor
|
||||
virtual void runSolver (void) = 0;
|
||||
virtual ParamSet getPosterioriOf (const Variable*) const = 0;
|
||||
virtual ParamSet getPosterioriOf (Vid) const = 0;
|
||||
virtual ParamSet getJointDistributionOf (const VidSet&) = 0;
|
||||
|
||||
void printPosterioriOf (const Variable* var) const
|
||||
void printAllPosterioris (void) const
|
||||
{
|
||||
VarSet vars = gm_->getVariables();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
printPosterioriOf (vars[i]->getVarId());
|
||||
}
|
||||
}
|
||||
|
||||
void printPosterioriOf (Vid vid) const
|
||||
{
|
||||
Variable* var = gm_->getVariable (vid);
|
||||
cout << endl;
|
||||
cout << setw (20) << left << var->getLabel() << "posteriori" ;
|
||||
cout << endl;
|
||||
cout << "------------------------------" ;
|
||||
cout << endl;
|
||||
const Domain& domain = var->getDomain();
|
||||
ParamSet results = getPosterioriOf (var);
|
||||
for (int xi = 0; xi < var->getDomainSize(); xi++) {
|
||||
ParamSet results = getPosterioriOf (vid);
|
||||
for (unsigned xi = 0; xi < var->getDomainSize(); xi++) {
|
||||
cout << setw (20) << domain[xi];
|
||||
cout << setprecision (PRECISION) << results[xi];
|
||||
cout << endl;
|
||||
@ -35,16 +46,35 @@ class Solver
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
void printAllPosterioris (void) const
|
||||
void printJointDistributionOf (const VidSet& vids)
|
||||
{
|
||||
VarSet vars = gm_->getVariables();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
printPosterioriOf (vars[i]);
|
||||
const ParamSet& jointDist = getJointDistributionOf (vids);
|
||||
cout << endl;
|
||||
cout << "joint distribution of " ;
|
||||
VarSet vars;
|
||||
for (unsigned i = 0; i < vids.size() - 1; i++) {
|
||||
Variable* var = gm_->getVariable (vids[i]);
|
||||
cout << var->getLabel() << ", " ;
|
||||
vars.push_back (var);
|
||||
}
|
||||
Variable* var = gm_->getVariable (vids[vids.size() - 1]);
|
||||
cout << var->getLabel() ;
|
||||
vars.push_back (var);
|
||||
cout << endl;
|
||||
cout << "------------------------------" ;
|
||||
cout << endl;
|
||||
const vector<string>& domainConfs = Util::getInstantiations (vars);
|
||||
for (unsigned i = 0; i < jointDist.size(); i++) {
|
||||
cout << left << setw (20) << domainConfs[i];
|
||||
cout << setprecision (PRECISION) << jointDist[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
private:
|
||||
const GraphicalModel* gm_;
|
||||
const GraphicalModel* gm_;
|
||||
};
|
||||
|
||||
#endif
|
||||
#endif //BP_SOLVER_H
|
||||
|
||||
|
191
packages/CLPBN/clpbn/bp/Util.cpp
Normal file
191
packages/CLPBN/clpbn/bp/Util.cpp
Normal file
@ -0,0 +1,191 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "Variable.h"
|
||||
#include "Shared.h"
|
||||
|
||||
namespace SolverOptions {
|
||||
|
||||
bool runBayesBall = false;
|
||||
bool convertBn2Fg = true;
|
||||
bool compressFactorGraph = true;
|
||||
Schedule schedule = S_SEQ_FIXED;
|
||||
//Schedule schedule = S_SEQ_RANDOM;
|
||||
//Schedule schedule = S_PARALLEL;
|
||||
//Schedule schedule = S_MAX_RESIDUAL;
|
||||
double accuracy = 0.0001;
|
||||
unsigned maxIter = 1000; //FIXME
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned Statistics::numCreatedNets = 0;
|
||||
unsigned Statistics::numSolvedPolyTrees = 0;
|
||||
unsigned Statistics::numSolvedLoopyNets = 0;
|
||||
unsigned Statistics::numUnconvergedRuns = 0;
|
||||
unsigned Statistics::maxIterations = 0;
|
||||
unsigned Statistics::totalOfIterations = 0;
|
||||
vector<CompressInfo> Statistics::compressInfo_;
|
||||
StatisticMap Statistics::stats_;
|
||||
|
||||
|
||||
|
||||
namespace Util {
|
||||
|
||||
void
|
||||
normalize (ParamSet& v)
|
||||
{
|
||||
double sum = 0.0;
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
sum += v[i];
|
||||
}
|
||||
assert (sum != 0.0);
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
pow (ParamSet& v, unsigned expoent)
|
||||
{
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
double value = 1;
|
||||
for (unsigned j = 0; j < expoent; j++) {
|
||||
value *= v[i];
|
||||
}
|
||||
v[i] = value;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
double
|
||||
getL1dist (const ParamSet& v1, const ParamSet& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
double dist = 0.0;
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
dist += abs (v1[i] - v2[i]);
|
||||
}
|
||||
return dist;
|
||||
}
|
||||
|
||||
|
||||
double
|
||||
getMaxNorm (const ParamSet& v1, const ParamSet& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
double max = 0.0;
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
double diff = abs (v1[i] - v2[i]);
|
||||
if (diff > max) {
|
||||
max = diff;
|
||||
}
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
|
||||
bool
|
||||
isInteger (const string& s)
|
||||
{
|
||||
stringstream ss1 (s);
|
||||
stringstream ss2;
|
||||
int integer;
|
||||
ss1 >> integer;
|
||||
ss2 << integer;
|
||||
return (ss1.str() == ss2.str());
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
parametersToString (CParamSet v)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "[" ;
|
||||
for (unsigned i = 0; i < v.size() - 1; i++) {
|
||||
ss << v[i] << ", " ;
|
||||
}
|
||||
if (v.size() != 0) {
|
||||
ss << v[v.size() - 1];
|
||||
}
|
||||
ss << "]" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<DConf>
|
||||
getDomainConfigurations (const VarSet& vars)
|
||||
{
|
||||
unsigned nConfs = 1;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
nConfs *= vars[i]->getDomainSize();
|
||||
}
|
||||
|
||||
vector<DConf> confs (nConfs);
|
||||
for (unsigned i = 0; i < nConfs; i++) {
|
||||
confs[i].resize (vars.size());
|
||||
}
|
||||
|
||||
unsigned nReps = 1;
|
||||
for (int i = vars.size() - 1; i >= 0; i--) {
|
||||
unsigned index = 0;
|
||||
while (index < nConfs) {
|
||||
for (unsigned j = 0; j < vars[i]->getDomainSize(); j++) {
|
||||
for (unsigned r = 0; r < nReps; r++) {
|
||||
confs[index][i] = j;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
nReps *= vars[i]->getDomainSize();
|
||||
}
|
||||
return confs;
|
||||
}
|
||||
|
||||
|
||||
vector<string>
|
||||
getInstantiations (const VarSet& vars)
|
||||
{
|
||||
//FIXME handle variables without domain
|
||||
/*
|
||||
char c = 'a' ;
|
||||
const DConf& conf = entries[i].getDomainConfiguration();
|
||||
for (unsigned j = 0; j < conf.size(); j++) {
|
||||
if (j != 0) ss << "," ;
|
||||
ss << c << conf[j] + 1;
|
||||
c ++;
|
||||
}
|
||||
*/
|
||||
unsigned rowSize = 1;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
rowSize *= vars[i]->getDomainSize();
|
||||
}
|
||||
|
||||
vector<string> headers (rowSize);
|
||||
|
||||
unsigned nReps = 1;
|
||||
for (int i = vars.size() - 1; i >= 0; i--) {
|
||||
Domain domain = vars[i]->getDomain();
|
||||
unsigned index = 0;
|
||||
while (index < rowSize) {
|
||||
for (unsigned j = 0; j < vars[i]->getDomainSize(); j++) {
|
||||
for (unsigned r = 0; r < nReps; r++) {
|
||||
if (headers[index] != "") {
|
||||
headers[index] = domain[j] + ", " + headers[index];
|
||||
} else {
|
||||
headers[index] = domain[j];
|
||||
}
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
nReps *= vars[i]->getDomainSize();
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,9 +1,10 @@
|
||||
#ifndef BP_GENERIC_VARIABLE_H
|
||||
#define BP_GENERIC_VARIABLE_H
|
||||
#ifndef BP_VARIABLE_H
|
||||
#define BP_VARIABLE_H
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include <algorithm>
|
||||
#include "Shared.h"
|
||||
|
||||
using namespace std;
|
||||
@ -12,33 +13,61 @@ class Variable
|
||||
{
|
||||
public:
|
||||
|
||||
Variable (unsigned varId)
|
||||
Variable (const Variable* v)
|
||||
{
|
||||
this->varId_ = varId;
|
||||
this->dsize_ = 0;
|
||||
this->evidence_ = -1;
|
||||
this->label_ = 0;
|
||||
vid_ = v->getVarId();
|
||||
dsize_ = v->getDomainSize();
|
||||
if (v->hasDomain()) {
|
||||
domain_ = v->getDomain();
|
||||
dsize_ = domain_.size();
|
||||
} else {
|
||||
dsize_ = v->getDomainSize();
|
||||
}
|
||||
evidence_ = v->getEvidence();
|
||||
if (v->hasLabel()) {
|
||||
label_ = new string (v->getLabel());
|
||||
} else {
|
||||
label_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Variable (unsigned varId, unsigned dsize, int evidence = -1)
|
||||
Variable (Vid vid)
|
||||
{
|
||||
this->vid_ = vid;
|
||||
this->dsize_ = 0;
|
||||
this->evidence_ = NO_EVIDENCE;
|
||||
this->label_ = 0;
|
||||
}
|
||||
|
||||
Variable (Vid vid, unsigned dsize, int evidence = NO_EVIDENCE,
|
||||
const string& lbl = string())
|
||||
{
|
||||
assert (dsize != 0);
|
||||
assert (evidence < (int)dsize);
|
||||
this->varId_ = varId;
|
||||
this->dsize_ = dsize;
|
||||
this->evidence_ = evidence;
|
||||
this->label_ = 0;
|
||||
this->vid_ = vid;
|
||||
this->dsize_ = dsize;
|
||||
this->evidence_ = evidence;
|
||||
if (!lbl.empty()) {
|
||||
this->label_ = new string (lbl);
|
||||
} else {
|
||||
this->label_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Variable (unsigned varId, const Domain& domain, int evidence = -1)
|
||||
Variable (Vid vid, const Domain& domain, int evidence = NO_EVIDENCE,
|
||||
const string& lbl = string())
|
||||
{
|
||||
assert (!domain.empty());
|
||||
assert (evidence < (int)domain.size());
|
||||
this->varId_ = varId;
|
||||
this->dsize_ = domain.size();
|
||||
this->domain_ = domain;
|
||||
this->evidence_ = evidence;
|
||||
this->label_ = 0;
|
||||
this->vid_ = vid;
|
||||
this->dsize_ = domain.size();
|
||||
this->domain_ = domain;
|
||||
this->evidence_ = evidence;
|
||||
if (!lbl.empty()) {
|
||||
this->label_ = new string (lbl);
|
||||
} else {
|
||||
this->label_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
~Variable (void)
|
||||
@ -46,19 +75,19 @@ class Variable
|
||||
delete label_;
|
||||
}
|
||||
|
||||
unsigned getVarId (void) const { return varId_; }
|
||||
unsigned getIndex (void) const { return index_; }
|
||||
void setIndex (unsigned idx) { index_ = idx; }
|
||||
int getDomainSize (void) const { return dsize_; }
|
||||
bool hasEvidence (void) const { return evidence_ != -1; }
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
bool hasDomain (void) { return !domain_.empty(); }
|
||||
bool hasLabel (void) { return label_ != 0; }
|
||||
unsigned getVarId (void) const { return vid_; }
|
||||
unsigned getIndex (void) const { return index_; }
|
||||
void setIndex (unsigned idx) { index_ = idx; }
|
||||
unsigned getDomainSize (void) const { return dsize_; }
|
||||
bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; }
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
bool hasDomain (void) const { return !domain_.empty(); }
|
||||
bool hasLabel (void) const { return label_ != 0; }
|
||||
|
||||
bool isValidStateIndex (int index)
|
||||
{
|
||||
return index >= 0 && index < dsize_;
|
||||
}
|
||||
bool isValidStateIndex (int index)
|
||||
{
|
||||
return index >= 0 && index < (int)dsize_;
|
||||
}
|
||||
|
||||
bool isValidState (const string& state)
|
||||
{
|
||||
@ -70,7 +99,7 @@ class Variable
|
||||
assert (dsize_ != 0);
|
||||
if (domain_.size() == 0) {
|
||||
Domain d;
|
||||
for (int i = 0; i < dsize_; i++) {
|
||||
for (unsigned i = 0; i < dsize_; i++) {
|
||||
stringstream ss;
|
||||
ss << "x" << i ;
|
||||
d.push_back (ss.str());
|
||||
@ -110,7 +139,7 @@ class Variable
|
||||
}
|
||||
}
|
||||
|
||||
void setLabel (string label)
|
||||
void setLabel (const string& label)
|
||||
{
|
||||
label_ = new string (label);
|
||||
}
|
||||
@ -119,25 +148,25 @@ class Variable
|
||||
{
|
||||
if (label_ == 0) {
|
||||
stringstream ss;
|
||||
ss << "v" << varId_;
|
||||
ss << "v" << vid_;
|
||||
return ss.str();
|
||||
} else {
|
||||
return *label_;
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
unsigned varId_;
|
||||
string* label_;
|
||||
unsigned index_;
|
||||
int evidence_;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (Variable);
|
||||
Domain domain_;
|
||||
int dsize_;
|
||||
|
||||
Vid vid_;
|
||||
unsigned dsize_;
|
||||
int evidence_;
|
||||
Domain domain_;
|
||||
string* label_;
|
||||
unsigned index_;
|
||||
|
||||
};
|
||||
|
||||
#endif // BP_GENERIC_VARIABLE_H
|
||||
#endif // BP_VARIABLE_H
|
||||
|
||||
|
34
packages/CLPBN/clpbn/bp/examples/1parentNchilds.yap
Executable file
34
packages/CLPBN/clpbn/bp/examples/1parentNchilds.yap
Executable file
@ -0,0 +1,34 @@
|
||||
|
||||
:- use_module(library(clpbn)).
|
||||
|
||||
:- set_clpbn_flag(solver, bp).
|
||||
|
||||
%
|
||||
% R
|
||||
% / | \
|
||||
% / | \
|
||||
% A B C
|
||||
%
|
||||
|
||||
|
||||
r(R) :-
|
||||
{ R = r with p([t, f], [0.35, 0.65]) }.
|
||||
|
||||
a(A) :-
|
||||
r(R),
|
||||
child_dist(R,Dist),
|
||||
{ A = a with Dist }.
|
||||
|
||||
b(B) :-
|
||||
r(R),
|
||||
child_dist(R,Dist),
|
||||
{ B = b with Dist }.
|
||||
|
||||
c(C) :-
|
||||
r(R),
|
||||
child_dist(R,Dist),
|
||||
{ C = c with Dist }.
|
||||
|
||||
|
||||
child_dist(R, p([t, f], [0.3, 0.4, 0.25, 0.05], [R])).
|
||||
|
53
packages/CLPBN/clpbn/bp/examples/bp-example.xml
Executable file
53
packages/CLPBN/clpbn/bp/examples/bp-example.xml
Executable file
@ -0,0 +1,53 @@
|
||||
<?xml version="1.0" encoding="US-ASCII"?>
|
||||
|
||||
<!--
|
||||
|
||||
A B
|
||||
\ /
|
||||
\ /
|
||||
C
|
||||
|
||||
-->
|
||||
|
||||
<BIF VERSION="0.3">
|
||||
<NETWORK>
|
||||
<NAME>Neapolitan</NAME>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>A</NAME>
|
||||
<OUTCOME>a1</OUTCOME>
|
||||
<OUTCOME>a2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>B</NAME>
|
||||
<OUTCOME>b1</OUTCOME>
|
||||
<OUTCOME>b2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>C</NAME>
|
||||
<OUTCOME>c1</OUTCOME>
|
||||
<OUTCOME>c2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>A</FOR>
|
||||
<TABLE> .695 .305 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>B</FOR>
|
||||
<TABLE> .25 .75 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>C</FOR>
|
||||
<GIVEN>A</GIVEN>
|
||||
<GIVEN>B</GIVEN>
|
||||
<TABLE> .2 .8 .45 .55 .32 .68 .7 .3 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
</NETWORK>
|
||||
</BIF>
|
||||
|
@ -9,10 +9,10 @@ MARKOV
|
||||
2 4 2
|
||||
|
||||
2
|
||||
.001 .009
|
||||
.001 .999
|
||||
|
||||
2
|
||||
.002 .008
|
||||
.002 .998
|
||||
|
||||
8
|
||||
.95 .94 .29 .001
|
||||
|
@ -49,12 +49,12 @@
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>B</FOR>
|
||||
<TABLE> .001 .009 </TABLE>
|
||||
<TABLE> .001 .999 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>E</FOR>
|
||||
<TABLE> .002 .008 </TABLE>
|
||||
<TABLE> .002 .998 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
|
@ -1,54 +1,29 @@
|
||||
|
||||
:- use_module(library(clpbn)).
|
||||
|
||||
:- set_clpbn_flag(solver, vel).
|
||||
:- set_clpbn_flag(solver, bp).
|
||||
|
||||
%
|
||||
% B E
|
||||
% \ /
|
||||
% \ /
|
||||
% A
|
||||
% / \
|
||||
% / \
|
||||
% J M
|
||||
%
|
||||
r(R) :- r_cpt(RCpt),
|
||||
{ R = r with p([r1, r2], RCpt) }.
|
||||
|
||||
t(T) :- t_cpt(TCpt),
|
||||
{ T = t with p([t1, t2], TCpt) }.
|
||||
|
||||
b(B) :-
|
||||
b_table(BDist),
|
||||
{ B = b with p([b1, b2], BDist) }.
|
||||
a(A) :- r(R), t(T), a_cpt(ACpt),
|
||||
{ A = a with p([a1, a2], ACpt, [R, T]) }.
|
||||
|
||||
e(E) :-
|
||||
e_table(EDist),
|
||||
{ E = e with p([e1, e2], EDist) }.
|
||||
j(J) :- a(A), j_cpt(JCpt),
|
||||
{ J = j with p([j1, j2], JCpt, [A]) }.
|
||||
|
||||
a(A) :-
|
||||
b(B),
|
||||
e(E),
|
||||
a_table(ADist),
|
||||
{ A = a with p([a1, a2], ADist, [B, E]) }.
|
||||
|
||||
j(J):-
|
||||
a(A),
|
||||
j_table(JDist),
|
||||
{ J = j with p([j1, j2], JDist, [A]) }.
|
||||
|
||||
m(M):-
|
||||
a(A),
|
||||
m_table(MDist),
|
||||
{ M = m with p([m1, m2], MDist, [A]) }.
|
||||
m(M) :- a(A), m_cpt(MCpt),
|
||||
{ M = m with p([m1, m2], MCpt, [A]) }.
|
||||
|
||||
|
||||
b_table([0.001, 0.009]).
|
||||
|
||||
e_table([0.002, 0.008]).
|
||||
|
||||
a_table([0.95, 0.94, 0.29, 0.001,
|
||||
0.05, 0.06, 0.71, 0.999]).
|
||||
|
||||
j_table([0.9, 0.05,
|
||||
0.1, 0.95]).
|
||||
|
||||
m_table([0.7, 0.01,
|
||||
0.3, 0.99]).
|
||||
r_cpt([0.001, 0.999]).
|
||||
t_cpt([0.002, 0.998]).
|
||||
a_cpt([0.95, 0.94, 0.29, 0.001,
|
||||
0.05, 0.06, 0.71, 0.999]).
|
||||
j_cpt([0.9, 0.05,
|
||||
0.1, 0.95]).
|
||||
m_cpt([0.7, 0.01,
|
||||
0.3, 0.99]).
|
||||
|
||||
|
@ -16,34 +16,37 @@
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>A</NAME>
|
||||
<OUTCOME></OUTCOME>
|
||||
<OUTCOME>a1</OUTCOME>
|
||||
<OUTCOME>a2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>B</NAME>
|
||||
<OUTCOME></OUTCOME>
|
||||
<OUTCOME>b1</OUTCOME>
|
||||
<OUTCOME>b2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>C</NAME>
|
||||
<OUTCOME></OUTCOME>
|
||||
<OUTCOME>c1</OUTCOME>
|
||||
<OUTCOME>c2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>A</FOR>
|
||||
<TABLE>1</TABLE>
|
||||
<TABLE>.695 .305</TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>B</FOR>
|
||||
<TABLE>1</TABLE>
|
||||
<TABLE>0.25 0.75</TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>C</FOR>
|
||||
<GIVEN>A</GIVEN>
|
||||
<GIVEN>B</GIVEN>
|
||||
<TABLE>1</TABLE>
|
||||
<TABLE>0.2 0.8 0.45 0.55 0.32 0.68 0.7 0.3</TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
</NETWORK>
|
||||
|
67
packages/CLPBN/clpbn/bp/examples/lambda fail.xml
Executable file
67
packages/CLPBN/clpbn/bp/examples/lambda fail.xml
Executable file
@ -0,0 +1,67 @@
|
||||
<?xml version="1.0" encoding="US-ASCII"?>
|
||||
|
||||
<!--
|
||||
|
||||
P1 P2 P3
|
||||
\ | /
|
||||
\ | /
|
||||
-
|
||||
C
|
||||
|
||||
-->
|
||||
|
||||
<BIF VERSION="0.3">
|
||||
<NETWORK>
|
||||
|
||||
<NAME>Simple Convergence</NAME>
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>P1</NAME>
|
||||
<OUTCOME>p1</OUTCOME>
|
||||
<OUTCOME>p2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>P2</NAME>
|
||||
<OUTCOME>p1</OUTCOME>
|
||||
<OUTCOME>p2</OUTCOME>
|
||||
<OUTCOME>p3</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>P3</NAME>
|
||||
<OUTCOME>p1</OUTCOME>
|
||||
<OUTCOME>p2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>C</NAME>
|
||||
<OUTCOME>c1</OUTCOME>
|
||||
<OUTCOME>c2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>P1</FOR>
|
||||
<TABLE>.695 .305</TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>P2</FOR>
|
||||
<TABLE>0.2 0.3 0.5</TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>P3</FOR>
|
||||
<TABLE>0.25 0.75</TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>C</FOR>
|
||||
<GIVEN>P1</GIVEN>
|
||||
<GIVEN>P2</GIVEN>
|
||||
<GIVEN>P3</GIVEN>
|
||||
<TABLE>0.2 0.8 0.45 0.55 0.32 0.68 0.7 0.3 0.3 0.7 0.55 0.45 0.22 0.78 0.25 0.75 0.11 0.89 0.34 0.66 0.1 0.9 0.6 0.4</TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
</NETWORK>
|
||||
</BIF>
|
||||
|
@ -2,6 +2,7 @@
|
||||
:- use_module(library(clpbn)).
|
||||
|
||||
:- set_clpbn_flag(solver, bp).
|
||||
%:- set_clpbn_flag(solver, jt).
|
||||
|
||||
%
|
||||
% B F
|
||||
|
17
packages/CLPBN/clpbn/bp/examples/sp-example.uai
Executable file
17
packages/CLPBN/clpbn/bp/examples/sp-example.uai
Executable file
@ -0,0 +1,17 @@
|
||||
MARKOV
|
||||
3
|
||||
2 2 2
|
||||
3
|
||||
1 0
|
||||
1 1
|
||||
3 2 0 1
|
||||
|
||||
2
|
||||
.695 .305
|
||||
|
||||
2
|
||||
.25 .75
|
||||
|
||||
8
|
||||
0.2 0.45 0.32 0.7
|
||||
0.8 0.55 0.68 0.3
|
128
packages/CLPBN/clpbn/bp/examples/test_bn.xml
Executable file
128
packages/CLPBN/clpbn/bp/examples/test_bn.xml
Executable file
@ -0,0 +1,128 @@
|
||||
<?xml version="1.0" encoding="US-ASCII"?>
|
||||
|
||||
<!--
|
||||
|
||||
A B C
|
||||
\ | /
|
||||
\ | /
|
||||
D
|
||||
/ | \
|
||||
/ | \
|
||||
E F G
|
||||
|
||||
-->
|
||||
|
||||
<BIF VERSION="0.3">
|
||||
<NETWORK>
|
||||
<NAME>Node with several parents and childs</NAME>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>A</NAME>
|
||||
<OUTCOME>a1</OUTCOME>
|
||||
<OUTCOME>a2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>B</NAME>
|
||||
<OUTCOME>b1</OUTCOME>
|
||||
<OUTCOME>b2</OUTCOME>
|
||||
<OUTCOME>b3</OUTCOME>
|
||||
<OUTCOME>b4</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>C</NAME>
|
||||
<OUTCOME>c1</OUTCOME>
|
||||
<OUTCOME>c2</OUTCOME>
|
||||
<OUTCOME>c3</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>D</NAME>
|
||||
<OUTCOME>d1</OUTCOME>
|
||||
<OUTCOME>d2</OUTCOME>
|
||||
<OUTCOME>d3</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>E</NAME>
|
||||
<OUTCOME>e1</OUTCOME>
|
||||
<OUTCOME>e2</OUTCOME>
|
||||
<OUTCOME>e3</OUTCOME>
|
||||
<OUTCOME>e4</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>F</NAME>
|
||||
<OUTCOME>f1</OUTCOME>
|
||||
<OUTCOME>f2</OUTCOME>
|
||||
<OUTCOME>f3</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>G</NAME>
|
||||
<OUTCOME>g1</OUTCOME>
|
||||
<OUTCOME>g2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>A</FOR>
|
||||
<TABLE> .1 .2 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>B</FOR>
|
||||
<TABLE> .01 .02 .03 .04 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>C</FOR>
|
||||
<TABLE> .11 .22 .33 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>D</FOR>
|
||||
<GIVEN>A</GIVEN>
|
||||
<GIVEN>B</GIVEN>
|
||||
<GIVEN>C</GIVEN>
|
||||
<TABLE>
|
||||
.522 .008 .99 .01 .2 .8 .003 .457 .423 .007 .92 .04 .5 .232 .033 .227 .112 .048 .91 .21 .24 .18 .005 .227
|
||||
.212 .04 .59 .21 .6 .1 .023 .215 .913 .017 .96 .01 .55 .422 .013 .417 .272 .068 .61 .11 .26 .28 .205 .322
|
||||
.142 .028 .19 .11 .5 .67 .013 .437 .163 .067 .12 .06 .1 .262 .063 .167 .512 .028 .11 .41 .14 .68 .015 .92
|
||||
</TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>E</FOR>
|
||||
<GIVEN>D</GIVEN>
|
||||
<TABLE>
|
||||
.111 .11 .1
|
||||
.222 .22 .2
|
||||
.333 .33 .3
|
||||
.444 .44 .4
|
||||
</TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>F</FOR>
|
||||
<GIVEN>D</GIVEN>
|
||||
<TABLE>
|
||||
.112 .111 .110
|
||||
.223 .222 .221
|
||||
.334 .333 .332
|
||||
</TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>G</FOR>
|
||||
<GIVEN>D</GIVEN>
|
||||
<TABLE>
|
||||
.101 .102 .103
|
||||
.201 .202 .203
|
||||
</TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
</NETWORK>
|
||||
</BIF>
|
||||
|
36
packages/CLPBN/clpbn/bp/examples/test_mk.uai
Executable file
36
packages/CLPBN/clpbn/bp/examples/test_mk.uai
Executable file
@ -0,0 +1,36 @@
|
||||
MARKOV
|
||||
5
|
||||
4 2 3 2 3
|
||||
7
|
||||
1 0
|
||||
1 1
|
||||
1 2
|
||||
1 3
|
||||
1 4
|
||||
2 0 1
|
||||
4 1 2 3 4
|
||||
|
||||
4
|
||||
0.1 0.7 0.43 0.22
|
||||
|
||||
2
|
||||
0.2 0.6
|
||||
|
||||
3
|
||||
0.3 0.5 0.2
|
||||
|
||||
2
|
||||
0.15 0.75
|
||||
|
||||
3
|
||||
0.25 0.45 0.15
|
||||
|
||||
8
|
||||
0.210 0.333 0.457 0.4
|
||||
0.811 0.000 0.189 0.89
|
||||
|
||||
36
|
||||
0.1 0.15 0.2 0.25 0.3 0.45 0.5 0.55 0.65 0.7 0.75 0.9
|
||||
0.11 0.22 0.33 0.44 0.55 0.66 0.77 0.88 0.91 0.93 0.95 0.97
|
||||
0.42 0.22 0.33 0.44 0.15 0.36 0.27 0.28 0.21 0.13 0.25 0.17
|
||||
|
69
packages/CLPBN/clpbn/bp/examples/ve_example.xml
Executable file
69
packages/CLPBN/clpbn/bp/examples/ve_example.xml
Executable file
@ -0,0 +1,69 @@
|
||||
<?xml version="1.0" encoding="US-ASCII"?>
|
||||
|
||||
<!--
|
||||
|
||||
A B
|
||||
\ /
|
||||
\ /
|
||||
C
|
||||
|
|
||||
|
|
||||
D
|
||||
|
||||
-->
|
||||
|
||||
|
||||
<BIF VERSION="0.3">
|
||||
<NETWORK>
|
||||
<NAME>Simple Loop</NAME>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>A</NAME>
|
||||
<OUTCOME>a1</OUTCOME>
|
||||
<OUTCOME>a2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>B</NAME>
|
||||
<OUTCOME>b1</OUTCOME>
|
||||
<OUTCOME>b2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>C</NAME>
|
||||
<OUTCOME>c1</OUTCOME>
|
||||
<OUTCOME>c2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<VARIABLE TYPE="nature">
|
||||
<NAME>D</NAME>
|
||||
<OUTCOME>d1</OUTCOME>
|
||||
<OUTCOME>d2</OUTCOME>
|
||||
</VARIABLE>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>A</FOR>
|
||||
<TABLE> .001 .009 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>B</FOR>
|
||||
<TABLE> .002 .008 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>C</FOR>
|
||||
<GIVEN>A</GIVEN>
|
||||
<GIVEN>B</GIVEN>
|
||||
<TABLE> .95 .05 .94 .06 .29 .71 .001 .999 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
<DEFINITION>
|
||||
<FOR>D</FOR>
|
||||
<GIVEN>C</GIVEN>
|
||||
<TABLE> .9 .1 .05 .95 </TABLE>
|
||||
</DEFINITION>
|
||||
|
||||
</NETWORK>
|
||||
</BIF>
|
||||
|
@ -884,6 +884,15 @@ writePrimitive(term_t t, write_options *options)
|
||||
return writeString(t, options);
|
||||
#endif /* O_STRING */
|
||||
|
||||
#if __YAP_PROLOG__
|
||||
{
|
||||
number n;
|
||||
n.type = V_INTEGER;
|
||||
n.value.i = 0;
|
||||
return WriteNumber(&n, options);
|
||||
}
|
||||
#endif
|
||||
|
||||
assert(0);
|
||||
fail;
|
||||
}
|
||||
|
@ -1121,6 +1121,9 @@ Yap_StreamPosition(IOSTREAM *st)
|
||||
return StreamPosition(st);
|
||||
}
|
||||
|
||||
IOSTREAM *STD_PROTO(Yap_Scurin, (void));
|
||||
int STD_PROTO(Yap_dowrite, (Term, IOSTREAM *, int, int));
|
||||
|
||||
IOSTREAM *
|
||||
Yap_Scurin(void)
|
||||
{
|
||||
@ -1128,6 +1131,32 @@ Yap_Scurin(void)
|
||||
return Scurin;
|
||||
}
|
||||
|
||||
int
|
||||
Yap_dowrite(Term t, IOSTREAM *stream, int flags, int priority)
|
||||
/* term to be written */
|
||||
/* consumer */
|
||||
/* write options */
|
||||
{
|
||||
CACHE_REGS
|
||||
int swi_flags;
|
||||
int res;
|
||||
Int slot = Yap_InitSlot(t PASS_REGS);
|
||||
|
||||
swi_flags = 0;
|
||||
if (flags & Quote_illegal_f)
|
||||
swi_flags |= PL_WRT_QUOTED;
|
||||
if (flags & Handle_vars_f)
|
||||
swi_flags |= PL_WRT_NUMBERVARS;
|
||||
if (flags & Use_portray_f)
|
||||
swi_flags |= PL_WRT_PORTRAY;
|
||||
if (flags & Ignore_ops_f)
|
||||
swi_flags |= PL_WRT_IGNOREOPS;
|
||||
|
||||
res = PL_write_term(stream, slot, priority, swi_flags);
|
||||
Yap_RecoverSlots(1 PASS_REGS);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
#if THREADS
|
||||
|
||||
@ -1178,6 +1207,7 @@ error:
|
||||
return rc;
|
||||
}
|
||||
|
||||
|
||||
int
|
||||
recursiveMutexInit(recursiveMutex *m)
|
||||
{
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit cee6c346ba77e046ef1873b9d8c88c52c3baae2d
|
||||
Subproject commit a5d2a3755f86ebffd8eb6e2a6921790d299eab30
|
@ -1 +1 @@
|
||||
Subproject commit f6fce313722d2f69e5ce6074f304ce05065bbf40
|
||||
Subproject commit a6f0f4ec7d5fd51ca8b268b8392da9b20bfd1b44
|
@ -1 +1 @@
|
||||
Subproject commit 3f7eee803d46071f10804010f7266b3864bab7e6
|
||||
Subproject commit 2a0843683e8790d8129fa4bad99c82a5fcaf441b
|
@ -1 +1 @@
|
||||
Subproject commit aa90b8f8a67e605e82566b719a8f20d125598bd2
|
||||
Subproject commit 2daa5b9942a9fc22005ec48b6560f8596f33d995
|
@ -1 +1 @@
|
||||
Subproject commit 2229eb3807b21497ffd680686ceeeddeac9f57fb
|
||||
Subproject commit e3c7eb9a9a54d6a9069396ecd1c3b537a8f6165a
|
@ -1 +1 @@
|
||||
Subproject commit babcbfe9cc5ab269cd5fd4f024e9d57bd3d0a8db
|
||||
Subproject commit f1c3ef54f4d9431ba5b4188cb72ca3056d20b202
|
@ -1,59 +0,0 @@
|
||||
#
|
||||
# default base directory for YAP installation
|
||||
# (EROOT for architecture-dependent files)
|
||||
#
|
||||
prefix = @prefix@
|
||||
exec_prefix = @exec_prefix@
|
||||
ROOTDIR = $(prefix)
|
||||
EROOTDIR = @exec_prefix@
|
||||
abs_top_builddir = @abs_top_builddir@
|
||||
#
|
||||
# where the binary should be
|
||||
#
|
||||
BINDIR = $(EROOTDIR)/bin
|
||||
#
|
||||
# where YAP should look for libraries
|
||||
#
|
||||
LIBDIR=@libdir@
|
||||
YAPLIBDIR=@libdir@/Yap
|
||||
#
|
||||
#
|
||||
DEFS=@DEFS@ -D_YAP_NOT_INSTALLED_=1
|
||||
CC=@CC@
|
||||
CFLAGS= @SHLIB_CFLAGS@ $(YAP_EXTRAS) $(DEFS) -I$(srcdir) -I../.. -I$(srcdir)/../../include -I$(srcdir)/../PLStream -I$(srcdir)/../PLStream/windows -I$(srcdir)/../../H
|
||||
#
|
||||
#
|
||||
# You shouldn't need to change what follows.
|
||||
#
|
||||
INSTALL=@INSTALL@
|
||||
INSTALL_DATA=@INSTALL_DATA@
|
||||
INSTALL_PROGRAM=@INSTALL_PROGRAM@
|
||||
SHELL=/bin/sh
|
||||
RANLIB=@RANLIB@
|
||||
srcdir=@srcdir@
|
||||
SO=@SO@
|
||||
#4.1VPATH=@srcdir@:@srcdir@/OPTYap
|
||||
CWD=$(PWD)
|
||||
#
|
||||
|
||||
OBJS=pl-tai.o
|
||||
SOBJS=pl-tai.@SO@
|
||||
|
||||
#in some systems we just create a single object, in others we need to
|
||||
# create a libray
|
||||
|
||||
all: $(SOBJS)
|
||||
|
||||
pl-tai.o: $(srcdir)/pl-tai.c
|
||||
(cd libtai ; $(MAKE))
|
||||
$(CC) -c $(CFLAGS) $(srcdir)/pl-tai.c -o pl-tai.o
|
||||
|
||||
@DO_SECOND_LD@pl-tai.@SO@: pl-tai.o
|
||||
@DO_SECOND_LD@ @SHLIB_LD@ $(LDFLAGS) -o pl-tai.@SO@ pl-tai.o libtai/libtai.a @EXTRA_LIBS_FOR_SWIDLLS@
|
||||
|
||||
install: all
|
||||
$(INSTALL_PROGRAM) $(SOBJS) $(DESTDIR)$(YAPLIBDIR)
|
||||
|
||||
clean:
|
||||
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK
|
||||
-(cd libtai && $(MAKE) clean)
|
@ -1 +1 @@
|
||||
Subproject commit deea4bfdf7041387e91eca37978a5c8db9287eda
|
||||
Subproject commit 109bf1d224009dc049ab12a0fbbb50511ef8e1fb
|
@ -1050,6 +1050,8 @@ make :-
|
||||
fail.
|
||||
make.
|
||||
|
||||
make_library_index(_Directory).
|
||||
|
||||
'$file_name'(Stream,F) :-
|
||||
stream_property(Stream, file_name(F)), !.
|
||||
'$file_name'(user_input,user_output).
|
||||
|
@ -495,7 +495,7 @@ debugging :-
|
||||
'$continue_debugging'(no, '$execute_nonstop'(G,M)).
|
||||
'$spycall'(G, M, CalledFromDebugger, InRedo) :-
|
||||
'$flags'(G,M,F,F),
|
||||
F /\ 0x18402000 =\= 0, !, % dynamic procedure, logical semantics, user-C, or source
|
||||
F /\ 0x08402000 =\= 0, !, % dynamic procedure, logical semantics, or source
|
||||
% use the interpreter
|
||||
CP is '$last_choice_pt',
|
||||
'$clause'(G, M, Cl, _),
|
||||
|
@ -245,22 +245,15 @@ print_message(Severity, Msg) :-
|
||||
'$notrace'(user:portray_message(Severity, Msg)), !.
|
||||
% This predicate has more hooks than a pirate ship!
|
||||
print_message(Severity, Term) :-
|
||||
(
|
||||
(
|
||||
'$oncenotrace'(user:generate_message_hook(Term, [], Lines)) ->
|
||||
true
|
||||
;
|
||||
'$oncenotrace'(prolog:message(Term, Lines, [])) ->
|
||||
true
|
||||
;
|
||||
'$messages':generate_message(Term, Lines, [])
|
||||
)
|
||||
-> ( nonvar(Term),
|
||||
'$oncenotrace'(user:message_hook(Term, Severity, Lines))
|
||||
-> !
|
||||
; !, '$print_system_message'(Term, Severity, Lines)
|
||||
)
|
||||
).
|
||||
% first step at hook processing
|
||||
'$message_to_lines'(Term, Lines),
|
||||
( nonvar(Term),
|
||||
'$oncenotrace'(user:message_hook(Term, Severity, Lines))
|
||||
->
|
||||
true
|
||||
;
|
||||
'$print_system_message'(Term, Severity, Lines)
|
||||
), !.
|
||||
print_message(silent, _) :- !.
|
||||
print_message(_, error(syntax_error(syntax_error(_,between(_,L,_),_,_,_,_,StreamName)),_)) :- !,
|
||||
format(user_error,'SYNTAX ERROR at ~a, close to ~d~n',[StreamName,L]).
|
||||
@ -271,6 +264,14 @@ print_message(_, loaded(A, F, _, Time, Space)) :- !,
|
||||
print_message(_, Term) :-
|
||||
format(user_error,'~q~n',[Term]).
|
||||
|
||||
'$message_to_lines'(Term, Lines) :-
|
||||
'$oncenotrace'(user:generate_message_hook(Term, [], Lines)), !.
|
||||
'$message_to_lines'(Term, Lines) :-
|
||||
'$oncenotrace'(prolog:message(Term, Lines, [])), !.
|
||||
'$message_to_lines'(Term, Lines) :-
|
||||
'$messages':generate_message(Term, Lines, []), !.
|
||||
|
||||
|
||||
% print_system_message(+Term, +Level, +Lines)
|
||||
%
|
||||
% Print the message if the user did not intecept the message.
|
||||
|
@ -69,6 +69,10 @@ generate_message(debug) --> !,
|
||||
[ debug ].
|
||||
generate_message(trace) --> !,
|
||||
[ trace ].
|
||||
generate_message(error(Error,Context)) -->
|
||||
{ Error = existence_error(procedure,_) }, !,
|
||||
system_message(error(Error,Context)),
|
||||
stack_dump(error(Error,Context)).
|
||||
generate_message(error(Error,context(Cause,Extra))) -->
|
||||
system_message(error(Error,Cause)),
|
||||
stack_dump(error(Error,context(Cause,Extra))).
|
||||
@ -130,8 +134,6 @@ system_message(no_match(P)) -->
|
||||
[ 'No matching predicate for ~w.' - [P] ].
|
||||
system_message(leash([A|B])) -->
|
||||
[ 'Leashing set to ~w.' - [[A|B]] ].
|
||||
system_message(existence_error(prolog_flag,F)) -->
|
||||
[ 'Prolog Flag ~w: new Prolog flags must be created using create_prolog_flag/3.' - [F] ].
|
||||
system_message(singletons([SV],P)) -->
|
||||
[ 'Singleton variable ~s in ~q.' - [SV,P] ].
|
||||
system_message(singletons(SVs,P)) -->
|
||||
@ -159,14 +161,18 @@ system_message(error(context_error(Goal,Who),Where)) -->
|
||||
system_message(error(domain_error(DomainType,Opt), Where)) -->
|
||||
[ 'DOMAIN ERROR- ~w: ' - Where],
|
||||
domain_error(DomainType, Opt).
|
||||
system_message(error(existence_error(directory,Key), Where)) -->
|
||||
[ 'EXISTENCE ERROR- ~w: ~w not an existing directory' - [Where,Key] ].
|
||||
system_message(error(existence_error(key,Key), Where)) -->
|
||||
[ 'EXISTENCE ERROR- ~w: ~w not an existing key' - [Where,Key] ].
|
||||
system_message(existence_error(prolog_flag,F)) -->
|
||||
[ 'Prolog Flag ~w: new Prolog flags must be created using create_prolog_flag/3.' - [F] ].
|
||||
system_message(error(existence_error(prolog_flag,P), Where)) --> !,
|
||||
[ 'EXISTENCE ERROR- ~w: prolog flag ~w is undefined' - [Where,P] ].
|
||||
system_message(error(existence_error(procedure,P), context(Call,Parent))) --> !,
|
||||
[ 'EXISTENCE ERROR- procedure ~w is undefined, called from context ~w~n Goal was ~w' - [P,Parent,Call] ].
|
||||
system_message(error(existence_error(stream,Stream), Where)) -->
|
||||
[ 'EXISTENCE ERROR- ~w: ~w not an open stream' - [Where,Stream] ].
|
||||
system_message(error(existence_error(key,Key), Where)) -->
|
||||
[ 'EXISTENCE ERROR- ~w: ~w not an existing key' - [Where,Key] ].
|
||||
system_message(error(existence_error(thread,Thread), Where)) -->
|
||||
[ 'EXISTENCE ERROR- ~w: ~w not a running thread' - [Where,Thread] ].
|
||||
system_message(error(existence_error(variable,Var), Where)) -->
|
||||
|
Reference in New Issue
Block a user