Merge branch 'master' of git://yap.git.sourceforge.net/gitroot/yap/yap-6.3

This commit is contained in:
Denys Duchier 2012-04-16 21:51:54 +02:00
commit b66b261972
169 changed files with 8751 additions and 1277262 deletions

View File

@ -490,7 +490,6 @@ Yap_HasOp(Atom a)
OpEntry * OpEntry *
Yap_OpPropForModule(Atom a, Term mod) Yap_OpPropForModule(Atom a, Term mod)
{ /* look property list of atom a for kind */ { /* look property list of atom a for kind */
CACHE_REGS
AtomEntry *ae = RepAtom(a); AtomEntry *ae = RepAtom(a);
PropEntry *pp; PropEntry *pp;
OpEntry *info = NULL; OpEntry *info = NULL;
@ -767,6 +766,7 @@ ExpandPredHash(void)
Prop Prop
Yap_NewPredPropByFunctor(FunctorEntry *fe, Term cur_mod) Yap_NewPredPropByFunctor(FunctorEntry *fe, Term cur_mod)
{ {
CACHE_REGS
PredEntry *p = (PredEntry *) Yap_AllocAtomSpace(sizeof(*p)); PredEntry *p = (PredEntry *) Yap_AllocAtomSpace(sizeof(*p));
if (p == NULL) { if (p == NULL) {
@ -902,6 +902,7 @@ Yap_NewThreadPred(PredEntry *ap USES_REGS)
Prop Prop
Yap_NewPredPropByAtom(AtomEntry *ae, Term cur_mod) Yap_NewPredPropByAtom(AtomEntry *ae, Term cur_mod)
{ {
CACHE_REGS
Prop p0; Prop p0;
PredEntry *p = (PredEntry *) Yap_AllocAtomSpace(sizeof(*p)); PredEntry *p = (PredEntry *) Yap_AllocAtomSpace(sizeof(*p));

View File

@ -2053,6 +2053,7 @@ a_try(op_numbers opcode, CELL lab, CELL opr, int nofalts, int hascut, yamop *cod
yamop *newcp; yamop *newcp;
/* emit a special instruction and then a label for backpatching */ /* emit a special instruction and then a label for backpatching */
if (pass_no) { if (pass_no) {
CACHE_REGS
UInt size = (UInt)NEXTOP((yamop *)NULL,OtaLl); UInt size = (UInt)NEXTOP((yamop *)NULL,OtaLl);
if ((newcp = (yamop *)Yap_AllocCodeSpace(size)) == NULL) { if ((newcp = (yamop *)Yap_AllocCodeSpace(size)) == NULL) {
/* OOOPS, got in trouble, must do a longjmp and recover space */ /* OOOPS, got in trouble, must do a longjmp and recover space */

View File

@ -561,6 +561,7 @@ X_API YAP_tag_t STD_PROTO(YAP_TagOfTerm,(Term));
X_API size_t STD_PROTO(YAP_ExportTerm,(Term, char *, size_t)); X_API size_t STD_PROTO(YAP_ExportTerm,(Term, char *, size_t));
X_API size_t STD_PROTO(YAP_SizeOfExportedTerm,(char *)); X_API size_t STD_PROTO(YAP_SizeOfExportedTerm,(char *));
X_API Term STD_PROTO(YAP_ImportTerm,(char *)); X_API Term STD_PROTO(YAP_ImportTerm,(char *));
X_API int STD_PROTO(YAP_RequiresExtraStack,(size_t));
static UInt static UInt
current_arity(void) current_arity(void)
@ -2705,7 +2706,6 @@ YAP_InitConsult(int mode, char *filename)
X_API IOSTREAM * X_API IOSTREAM *
YAP_TermToStream(Term t) YAP_TermToStream(Term t)
{ {
CACHE_REGS
IOSTREAM *s; IOSTREAM *s;
BACKUP_MACHINE_REGS(); BACKUP_MACHINE_REGS();
@ -2937,7 +2937,13 @@ YAP_Init(YAP_init_args *yap_init)
int restore_result; int restore_result;
int do_bootstrap = (yap_init->YapPrologBootFile != NULL); int do_bootstrap = (yap_init->YapPrologBootFile != NULL);
CELL Trail = 0, Stack = 0, Heap = 0, Atts = 0; CELL Trail = 0, Stack = 0, Heap = 0, Atts = 0;
static char boot_file[256]; char boot_file[256];
static int initialised = FALSE;
/* ignore repeated calls to YAP_Init */
if (initialised)
return YAP_BOOT_DONE_BEFOREHAND;
initialised = TRUE;
Yap_InitPageSize(); /* init memory page size, required by later functions */ Yap_InitPageSize(); /* init memory page size, required by later functions */
#if defined(YAPOR_COPY) || defined(YAPOR_COW) || defined(YAPOR_SBA) #if defined(YAPOR_COPY) || defined(YAPOR_COW) || defined(YAPOR_SBA)
@ -3612,9 +3618,22 @@ YAP_ListToFloats(Term t, double *dblp, size_t sz)
if (!IsPairTerm(t)) if (!IsPairTerm(t))
return -1; return -1;
hd = HeadOfTerm(t); hd = HeadOfTerm(t);
if (!IsFloatTerm(hd)) if (IsFloatTerm(hd)) {
return -1; dblp[i++] = FloatOfTerm(hd);
dblp[i++] = FloatOfTerm(hd); } else {
extern double Yap_gmp_to_float(Term hd);
if (IsIntTerm(hd))
dblp[i++] = IntOfTerm(hd);
else if (IsLongIntTerm(hd))
dblp[i++] = LongIntOfTerm(hd);
#if USE_GMP
else if (IsBigIntTerm(hd))
dblp[i++] = Yap_gmp_to_float(hd);
#endif
else
return -1;
}
if (i == sz) if (i == sz)
return sz; return sz;
t = TailOfTerm(t); t = TailOfTerm(t);
@ -4108,3 +4127,24 @@ YAP_ImportTerm(char * buf) {
return Yap_ImportTerm(buf); return Yap_ImportTerm(buf);
} }
X_API int
YAP_RequiresExtraStack(size_t sz) {
CACHE_REGS
if (sz < 16*1024)
sz = 16*1024;
if (H <= ASP-sz) {
return FALSE;
}
BACKUP_H();
while (H > ASP-sz) {
CACHE_REGS
RECOVER_H();
if (!dogc( 0, NULL PASS_REGS )) {
return -1;
}
BACKUP_H();
}
RECOVER_H();
return TRUE;
}

View File

@ -5107,6 +5107,8 @@ p_continue_static_clause( USES_REGS1 )
static void static void
add_code_in_lu_index(LogUpdIndex *cl, PredEntry *pp) add_code_in_lu_index(LogUpdIndex *cl, PredEntry *pp)
{ {
CACHE_REGS
char *code_end = (char *)cl + cl->ClSize; char *code_end = (char *)cl + cl->ClSize;
Yap_inform_profiler_of_clause(cl, code_end, pp, GPROF_LU_INDEX); Yap_inform_profiler_of_clause(cl, code_end, pp, GPROF_LU_INDEX);
cl = cl->ChildIndex; cl = cl->ChildIndex;
@ -5119,6 +5121,7 @@ add_code_in_lu_index(LogUpdIndex *cl, PredEntry *pp)
static void static void
add_code_in_static_index(StaticIndex *cl, PredEntry *pp) add_code_in_static_index(StaticIndex *cl, PredEntry *pp)
{ {
CACHE_REGS
char *code_end = (char *)cl + cl->ClSize; char *code_end = (char *)cl + cl->ClSize;
Yap_inform_profiler_of_clause(cl, code_end, pp, GPROF_STATIC_INDEX); Yap_inform_profiler_of_clause(cl, code_end, pp, GPROF_STATIC_INDEX);
cl = cl->ChildIndex; cl = cl->ChildIndex;
@ -5131,6 +5134,7 @@ add_code_in_static_index(StaticIndex *cl, PredEntry *pp)
static void static void
add_code_in_pred(PredEntry *pp) { add_code_in_pred(PredEntry *pp) {
CACHE_REGS
yamop *clcode; yamop *clcode;
PELOCK(49,pp); PELOCK(49,pp);
@ -5202,6 +5206,7 @@ add_code_in_pred(PredEntry *pp) {
void void
Yap_dump_code_area_for_profiler(void) { Yap_dump_code_area_for_profiler(void) {
CACHE_REGS
ModEntry *me = CurrentModules; ModEntry *me = CurrentModules;
while (me) { while (me) {

View File

@ -1887,6 +1887,7 @@ Yap_new_ludbe(Term t, PredEntry *pe, UInt nargs)
static LogUpdClause * static LogUpdClause *
record_lu(PredEntry *pe, Term t, int position) record_lu(PredEntry *pe, Term t, int position)
{ {
CACHE_REGS
LogUpdClause *cl; LogUpdClause *cl;
if ((cl = new_lu_db_entry(t, pe)) == NULL) { if ((cl = new_lu_db_entry(t, pe)) == NULL) {

View File

@ -1066,6 +1066,7 @@ do_goal(Term t, yamop *CodeAdr, int arity, CELL *pt, int top USES_REGS)
S = CellPtr (RepPredProp (PredPropByFunc (Yap_MkFunctor(AtomCall, 1),0))); /* A1 mishaps */ S = CellPtr (RepPredProp (PredPropByFunc (Yap_MkFunctor(AtomCall, 1),0))); /* A1 mishaps */
out = exec_absmi(top PASS_REGS); out = exec_absmi(top PASS_REGS);
Yap_flush();
// if (out) { // if (out) {
// out = Yap_GetFromSlot(sl); // out = Yap_GetFromSlot(sl);
// } // }

View File

@ -168,6 +168,7 @@ RBfree(rb_red_blk_node *ptr)
static rb_red_blk_node * static rb_red_blk_node *
RBTreeCreate(void) { RBTreeCreate(void) {
CACHE_REGS
rb_red_blk_node* temp; rb_red_blk_node* temp;
/* see the comment in the rb_red_blk_tree structure in red_black_tree.h */ /* see the comment in the rb_red_blk_tree structure in red_black_tree.h */
@ -210,6 +211,7 @@ RBTreeCreate(void) {
static void static void
LeftRotate(rb_red_blk_node* x) { LeftRotate(rb_red_blk_node* x) {
CACHE_REGS
rb_red_blk_node* y; rb_red_blk_node* y;
rb_red_blk_node* nil=LOCAL_ProfilerNil; rb_red_blk_node* nil=LOCAL_ProfilerNil;
@ -266,6 +268,7 @@ LeftRotate(rb_red_blk_node* x) {
static void static void
RightRotate(rb_red_blk_node* y) { RightRotate(rb_red_blk_node* y) {
CACHE_REGS
rb_red_blk_node* x; rb_red_blk_node* x;
rb_red_blk_node* nil=LOCAL_ProfilerNil; rb_red_blk_node* nil=LOCAL_ProfilerNil;
@ -318,6 +321,7 @@ RightRotate(rb_red_blk_node* y) {
static void static void
TreeInsertHelp(rb_red_blk_node* z) { TreeInsertHelp(rb_red_blk_node* z) {
CACHE_REGS
/* This function should only be called by InsertRBTree (see above) */ /* This function should only be called by InsertRBTree (see above) */
rb_red_blk_node* x; rb_red_blk_node* x;
rb_red_blk_node* y; rb_red_blk_node* y;
@ -369,6 +373,7 @@ TreeInsertHelp(rb_red_blk_node* z) {
static rb_red_blk_node * static rb_red_blk_node *
RBTreeInsert(yamop *key, yamop *lim) { RBTreeInsert(yamop *key, yamop *lim) {
CACHE_REGS
rb_red_blk_node * y; rb_red_blk_node * y;
rb_red_blk_node * x; rb_red_blk_node * x;
rb_red_blk_node * newNode; rb_red_blk_node * newNode;
@ -440,6 +445,7 @@ RBTreeInsert(yamop *key, yamop *lim) {
static rb_red_blk_node* static rb_red_blk_node*
RBExactQuery(yamop* q) { RBExactQuery(yamop* q) {
CACHE_REGS
rb_red_blk_node* x; rb_red_blk_node* x;
rb_red_blk_node* nil=LOCAL_ProfilerNil; rb_red_blk_node* nil=LOCAL_ProfilerNil;
@ -460,6 +466,7 @@ RBExactQuery(yamop* q) {
static rb_red_blk_node* static rb_red_blk_node*
RBLookup(yamop *entry) { RBLookup(yamop *entry) {
CACHE_REGS
rb_red_blk_node *current; rb_red_blk_node *current;
if (!LOCAL_ProfilerRoot) if (!LOCAL_ProfilerRoot)
@ -495,6 +502,7 @@ RBLookup(yamop *entry) {
/***********************************************************************/ /***********************************************************************/
static void RBDeleteFixUp(rb_red_blk_node* x) { static void RBDeleteFixUp(rb_red_blk_node* x) {
CACHE_REGS
rb_red_blk_node* root=LOCAL_ProfilerRoot->left; rb_red_blk_node* root=LOCAL_ProfilerRoot->left;
rb_red_blk_node *w; rb_red_blk_node *w;
@ -574,6 +582,7 @@ static void RBDeleteFixUp(rb_red_blk_node* x) {
static rb_red_blk_node* static rb_red_blk_node*
TreeSuccessor(rb_red_blk_node* x) { TreeSuccessor(rb_red_blk_node* x) {
CACHE_REGS
rb_red_blk_node* y; rb_red_blk_node* y;
rb_red_blk_node* nil=LOCAL_ProfilerNil; rb_red_blk_node* nil=LOCAL_ProfilerNil;
rb_red_blk_node* root=LOCAL_ProfilerRoot; rb_red_blk_node* root=LOCAL_ProfilerRoot;
@ -612,6 +621,7 @@ TreeSuccessor(rb_red_blk_node* x) {
static void static void
RBDelete(rb_red_blk_node* z){ RBDelete(rb_red_blk_node* z){
CACHE_REGS
rb_red_blk_node* y; rb_red_blk_node* y;
rb_red_blk_node* x; rb_red_blk_node* x;
rb_red_blk_node* nil=LOCAL_ProfilerNil; rb_red_blk_node* nil=LOCAL_ProfilerNil;
@ -664,7 +674,8 @@ RBDelete(rb_red_blk_node* z){
char *set_profile_dir(char *); char *set_profile_dir(char *);
char *set_profile_dir(char *name){ char *set_profile_dir(char *name){
int size=0; CACHE_REGS
int size=0;
if (name!=NULL) { if (name!=NULL) {
size=strlen(name)+1; size=strlen(name)+1;
@ -687,8 +698,9 @@ return LOCAL_DIRNAME;
char *profile_names(int); char *profile_names(int);
char *profile_names(int k) { char *profile_names(int k) {
static char *FNAME=NULL; CACHE_REGS
int size=200; static char *FNAME=NULL;
int size=200;
if (LOCAL_DIRNAME==NULL) set_profile_dir(NULL); if (LOCAL_DIRNAME==NULL) set_profile_dir(NULL);
size=strlen(LOCAL_DIRNAME)+40; size=strlen(LOCAL_DIRNAME)+40;
@ -709,6 +721,7 @@ int size=200;
void del_profile_files(void); void del_profile_files(void);
void del_profile_files() { void del_profile_files() {
CACHE_REGS
if (LOCAL_DIRNAME!=NULL) { if (LOCAL_DIRNAME!=NULL) {
remove(profile_names(PROFPREDS_FILE)); remove(profile_names(PROFPREDS_FILE));
remove(profile_names(PROFILING_FILE)); remove(profile_names(PROFILING_FILE));
@ -717,6 +730,7 @@ void del_profile_files() {
void void
Yap_inform_profiler_of_clause__(void *code_start, void *code_end, PredEntry *pe,gprof_info index_code) { Yap_inform_profiler_of_clause__(void *code_start, void *code_end, PredEntry *pe,gprof_info index_code) {
CACHE_REGS
buf_ptr b; buf_ptr b;
buf_extra e; buf_extra e;
LOCAL_ProfOn = TRUE; LOCAL_ProfOn = TRUE;
@ -742,6 +756,7 @@ static Int profend( USES_REGS1 );
static void static void
clean_tree(rb_red_blk_node* node) { clean_tree(rb_red_blk_node* node) {
CACHE_REGS
if (node == LOCAL_ProfilerNil) if (node == LOCAL_ProfilerNil)
return; return;
clean_tree(node->left); clean_tree(node->left);
@ -751,6 +766,7 @@ clean_tree(rb_red_blk_node* node) {
static void static void
reset_tree(void) { reset_tree(void) {
CACHE_REGS
clean_tree(LOCAL_ProfilerRoot); clean_tree(LOCAL_ProfilerRoot);
Yap_FreeCodeSpace((char *)LOCAL_ProfilerNil); Yap_FreeCodeSpace((char *)LOCAL_ProfilerNil);
LOCAL_ProfilerNil = LOCAL_ProfilerRoot = NULL; LOCAL_ProfilerNil = LOCAL_ProfilerRoot = NULL;
@ -760,6 +776,7 @@ reset_tree(void) {
static int static int
InitProfTree(void) InitProfTree(void)
{ {
CACHE_REGS
if (LOCAL_ProfilerRoot) if (LOCAL_ProfilerRoot)
reset_tree(); reset_tree();
while (!(LOCAL_ProfilerRoot = RBTreeCreate())) { while (!(LOCAL_ProfilerRoot = RBTreeCreate())) {
@ -773,6 +790,7 @@ InitProfTree(void)
static void RemoveCode(CODEADDR clau) static void RemoveCode(CODEADDR clau)
{ {
CACHE_REGS
rb_red_blk_node* x, *node; rb_red_blk_node* x, *node;
PredEntry *pp; PredEntry *pp;
UInt count; UInt count;
@ -958,6 +976,7 @@ prof_alrm(int signo, siginfo_t *si, void *scv)
void void
Yap_InformOfRemoval(void *clau) Yap_InformOfRemoval(void *clau)
{ {
CACHE_REGS
LOCAL_ProfOn = TRUE; LOCAL_ProfOn = TRUE;
if (LOCAL_FPreds != NULL) { if (LOCAL_FPreds != NULL) {
/* just store info about what is going on */ /* just store info about what is going on */
@ -1048,6 +1067,7 @@ static Int profinit( USES_REGS1 )
static Int start_profilers(int msec) static Int start_profilers(int msec)
{ {
CACHE_REGS
struct itimerval t; struct itimerval t;
struct sigaction sa; struct sigaction sa;
@ -1157,6 +1177,7 @@ static Int profres0( USES_REGS1 ) {
void void
Yap_InitLowProf(void) Yap_InitLowProf(void)
{ {
CACHE_REGS
#if LOW_PROF #if LOW_PROF
LOCAL_ProfCalls = 0; LOCAL_ProfCalls = 0;
LOCAL_ProfilerOn = FALSE; LOCAL_ProfilerOn = FALSE;

View File

@ -718,6 +718,11 @@ AdjustScannerStacks(TokEntry **tksp, VarEntry **vep USES_REGS)
TokEntry *tktmp; TokEntry *tktmp;
switch (tks->Tok) { switch (tks->Tok) {
case Number_tok:
if (IsApplTerm(tks->TokInfo)) {
tks->TokInfo = AdjustAppl(tks->TokInfo PASS_REGS);
}
break;
case Var_tok: case Var_tok:
case String_tok: case String_tok:
if (IsOldTrail(tks->TokInfo)) if (IsOldTrail(tks->TokInfo))

View File

@ -1888,6 +1888,7 @@ emit_single_switch_case(ClauseDef *min, struct intermediates *cint, int first, i
static UInt static UInt
suspend_indexing(ClauseDef *min, ClauseDef *max, PredEntry *ap, struct intermediates *cint) suspend_indexing(ClauseDef *min, ClauseDef *max, PredEntry *ap, struct intermediates *cint)
{ {
CACHE_REGS
UInt tcls = ap->cs.p_code.NOfClauses; UInt tcls = ap->cs.p_code.NOfClauses;
UInt cls = (max-min)+1; UInt cls = (max-min)+1;

View File

@ -993,6 +993,15 @@ p_read_module_preds( USES_REGS1 )
return TRUE; return TRUE;
} }
static void
ReInitCatch(void)
{
Term t = Yap_MkNewApplTerm(PredHandleThrow->FunctorOfPred, PredHandleThrow->ArityOfPE);
YAP_RunGoalOnce(t);
}
static Int static Int
p_read_program( USES_REGS1 ) p_read_program( USES_REGS1 )
{ {
@ -1016,7 +1025,7 @@ p_read_program( USES_REGS1 )
Sclose( stream ); Sclose( stream );
/* back to the top level we go */ /* back to the top level we go */
Yap_CloseSlots(PASS_REGS1); Yap_CloseSlots(PASS_REGS1);
ReInitCatch();
Yap_RestartYap( 3 ); Yap_RestartYap( 3 );
return TRUE; return TRUE;
} }
@ -1030,6 +1039,7 @@ Yap_Restore(char *s, char *lib_dir)
return -1; return -1;
read_module(stream); read_module(stream);
Sclose( stream ); Sclose( stream );
ReInitCatch();
return DO_ONLY_CODE; return DO_ONLY_CODE;
} }

View File

@ -1619,6 +1619,7 @@ InteractSIGINT(int ch) {
static int static int
ProcessSIGINT(void) ProcessSIGINT(void)
{ {
CACHE_REGS
int ch, out; int ch, out;
LOCAL_PrologMode |= AsyncIntMode; LOCAL_PrologMode |= AsyncIntMode;

View File

@ -52,17 +52,7 @@ send_tracer_message(char *start, char *name, Int arity, char *mname, CELL *args)
if (args) { if (args) {
for (i= 0; i < arity; i++) { for (i= 0; i < arity; i++) {
if (i > 0) fprintf(GLOBAL_stderr, ","); if (i > 0) fprintf(GLOBAL_stderr, ",");
#if DEBUG Yap_plwrite(args[i], NULL, 15, Handle_vars_f|AttVar_Portray_f, 1200);
#if COROUTINING
Yap_Portray_delays = TRUE;
#endif
#endif
Yap_plwrite(args[i], NULL, 15, Handle_vars_f, 1200);
#if DEBUG
#if COROUTINING
Yap_Portray_delays = FALSE;
#endif
#endif
} }
if (arity) { if (arity) {
fprintf(GLOBAL_stderr, ")"); fprintf(GLOBAL_stderr, ")");

View File

@ -4255,7 +4255,7 @@ p_is_list_or_partial_list( USES_REGS1 )
} }
static Term static Term
numbervar(Int id) numbervar(Int id USES_REGS)
{ {
Term ts[1]; Term ts[1];
ts[0] = MkIntegerTerm(id); ts[0] = MkIntegerTerm(id);
@ -4263,7 +4263,7 @@ numbervar(Int id)
} }
static Term static Term
numbervar_singleton(void) numbervar_singleton(USES_REGS1)
{ {
Term ts[1]; Term ts[1];
ts[0] = MkIntegerTerm(-1); ts[0] = MkIntegerTerm(-1);
@ -4356,9 +4356,9 @@ static Int numbervars_in_complex_term(register CELL *pt0, register CELL *pt0_end
derefa_body(d0, ptd0, vars_in_term_unk, vars_in_term_nvar); derefa_body(d0, ptd0, vars_in_term_unk, vars_in_term_nvar);
/* do or pt2 are unbound */ /* do or pt2 are unbound */
if (singles) if (singles)
*ptd0 = numbervar_singleton(); *ptd0 = numbervar_singleton( PASS_REGS1 );
else else
*ptd0 = numbervar(numbv++); *ptd0 = numbervar(numbv++ PASS_REGS);
/* leave an empty slot to fill in later */ /* leave an empty slot to fill in later */
if (H+1024 > ASP) { if (H+1024 > ASP) {
goto global_overflow; goto global_overflow;
@ -4450,10 +4450,10 @@ Yap_NumberVars( Term inp, Int numbv, int handle_singles ) /* numbervariables in
CELL *ptd0 = VarOfTerm(t); CELL *ptd0 = VarOfTerm(t);
TrailTerm(TR++) = (CELL)ptd0; TrailTerm(TR++) = (CELL)ptd0;
if (handle_singles) { if (handle_singles) {
*ptd0 = numbervar_singleton(); *ptd0 = numbervar_singleton( PASS_REGS1 );
return numbv; return numbv;
} else { } else {
*ptd0 = numbervar(numbv); *ptd0 = numbervar(numbv PASS_REGS);
return numbv+1; return numbv+1;
} }
} else if (IsPrimitiveTerm(t)) { } else if (IsPrimitiveTerm(t)) {

164
C/write.c
View File

@ -66,7 +66,7 @@ typedef struct rewind_term {
typedef struct write_globs { typedef struct write_globs {
void *stream; void *stream;
int Quote_illegal, Ignore_ops, Handle_vars, Use_portray; int Quote_illegal, Ignore_ops, Handle_vars, Use_portray, Portray_delays;
int Keep_terms; int Keep_terms;
int Write_Loops; int Write_Loops;
int Write_strings; int Write_strings;
@ -90,28 +90,50 @@ STATIC_PROTO(void writeTerm, (Term, int, int, int, struct write_globs *, struct
#define wrputc(X,WF) Sputcode(X,WF) /* writes a character */ #define wrputc(X,WF) Sputcode(X,WF) /* writes a character */
/*
protect bracket from merging with previoous character.
avoid stuff like not (2,3) -> not(2,3) or
*/
static void static void
protect_open_number(struct write_globs *wglb, int minus_required) wropen_bracket(struct write_globs *wglb, int protect)
{ {
wrf stream = wglb->stream; wrf stream = wglb->stream;
if (lastw == symbol && last_minus && !minus_required) { if (lastw != separator && protect)
if (!wglb->Ignore_ops) { wrputc(' ', stream);
/* protect against collating - with number, and getting - 1 ^2 as (-(1))^2 */ wrputc('(', stream);
wrputc(' ', wglb->stream); lastw = separator;
}
wrputc('(', wglb->stream);
} else if (lastw == alphanum) {
wrputc(' ', stream);
}
} }
static void static void
protect_close_number(struct write_globs *wglb, int minus_required) wrclose_bracket(struct write_globs *wglb, int protect)
{ {
if (lastw == symbol && last_minus && !minus_required) { wrf stream = wglb->stream;
wrputc(')', wglb->stream);
lastw = separator; wrputc(')', stream);
lastw = separator;
}
static int
protect_open_number(struct write_globs *wglb, int lm, int minus_required)
{
wrf stream = wglb->stream;
if (lastw == symbol && lm && !minus_required) {
wropen_bracket(wglb, TRUE);
return TRUE;
} else if (lastw == alphanum ||
(lastw == symbol && minus_required)) {
wrputc(' ', stream);
}
return FALSE;
}
static void
protect_close_number(struct write_globs *wglb, int used_bracket)
{
if (used_bracket) {
wrclose_bracket(wglb, TRUE);
} else { } else {
lastw = alphanum; lastw = alphanum;
} }
@ -125,8 +147,9 @@ wrputn(Int n, struct write_globs *wglb) /* writes an integer */
wrf stream = wglb->stream; wrf stream = wglb->stream;
char s[256], *s1=s; /* that should be enough for most integers */ char s[256], *s1=s; /* that should be enough for most integers */
int has_minus = (n < 0); int has_minus = (n < 0);
int ob;
protect_open_number(wglb, has_minus); ob = protect_open_number(wglb, last_minus, has_minus);
#if HAVE_SNPRINTF #if HAVE_SNPRINTF
snprintf(s, 256, Int_FORMAT, n); snprintf(s, 256, Int_FORMAT, n);
#else #else
@ -134,7 +157,7 @@ wrputn(Int n, struct write_globs *wglb) /* writes an integer */
#endif #endif
while (*s1) while (*s1)
wrputc(*s1++, stream); wrputc(*s1++, stream);
protect_close_number(wglb, has_minus); protect_close_number(wglb, ob);
} }
#define wrputs(s, stream) Sfputs(s, stream) #define wrputs(s, stream) Sfputs(s, stream)
@ -190,9 +213,10 @@ static void
write_mpint(MP_INT *big, struct write_globs *wglb) { write_mpint(MP_INT *big, struct write_globs *wglb) {
char *s; char *s;
int has_minus = mpz_sgn(big); int has_minus = mpz_sgn(big);
int ob;
s = ensure_space(3+mpz_sizeinbase(big, 10)); s = ensure_space(3+mpz_sizeinbase(big, 10));
protect_open_number(wglb, has_minus); ob = protect_open_number(wglb, last_minus, has_minus);
if (!s) { if (!s) {
s = mpz_get_str(NULL, 10, big); s = mpz_get_str(NULL, 10, big);
if (!s) if (!s)
@ -203,7 +227,7 @@ write_mpint(MP_INT *big, struct write_globs *wglb) {
mpz_get_str(s, 10, big); mpz_get_str(s, 10, big);
wrputs(s,wglb->stream); wrputs(s,wglb->stream);
} }
protect_close_number(wglb, has_minus); protect_close_number(wglb, ob);
} }
#endif #endif
@ -271,6 +295,8 @@ wrputf(Float f, struct write_globs *wglb) /* writes a float */
char s[256]; char s[256];
wrf stream = wglb->stream; wrf stream = wglb->stream;
int sgn; int sgn;
int ob;
#if HAVE_ISNAN || defined(__WIN32) #if HAVE_ISNAN || defined(__WIN32)
if (isnan(f)) { if (isnan(f)) {
@ -291,7 +317,7 @@ wrputf(Float f, struct write_globs *wglb) /* writes a float */
return; return;
} }
#endif #endif
protect_open_number(wglb, sgn); ob = protect_open_number(wglb, last_minus, sgn);
#if THREADS #if THREADS
/* old style writing */ /* old style writing */
int found_dot = FALSE, found_exp = FALSE; int found_dot = FALSE, found_exp = FALSE;
@ -343,7 +369,7 @@ wrputf(Float f, struct write_globs *wglb) /* writes a float */
if (!buf) return; if (!buf) return;
wrputs(buf, stream); wrputs(buf, stream);
#endif #endif
protect_close_number(wglb, sgn); protect_close_number(wglb, ob);
} }
/* writes a data base reference */ /* writes a data base reference */
@ -423,7 +449,7 @@ AtomIsSymbols(unsigned char *s) /* Is this atom just formed by symbols ? */
return(separator); return(separator);
while ((ch = *s++) != '\0') { while ((ch = *s++) != '\0') {
if (Yap_chtype[ch] != SY) if (Yap_chtype[ch] != SY)
return(alphanum); return alphanum;
} }
return symbol; return symbol;
} }
@ -669,15 +695,14 @@ write_var(CELL *t, struct write_globs *wglb, struct rewind_term *rwt)
/* make sure we don't get no creepy spaces where they shouldn't be */ /* make sure we don't get no creepy spaces where they shouldn't be */
lastw = separator; lastw = separator;
if (IsAttVar(t)) { if (IsAttVar(t)) {
#if defined(COROUTINING) && defined(DEBUG)
Int vcount = (t-H0); Int vcount = (t-H0);
if (Yap_Portray_delays) { if (wglb->Portray_delays) {
exts ext = ExtFromCell(t); exts ext = ExtFromCell(t);
struct rewind_term nrwt; struct rewind_term nrwt;
nrwt.parent = rwt; nrwt.parent = rwt;
nrwt.u.s.ptr = 0; nrwt.u.s.ptr = 0;
Yap_Portray_delays = FALSE; wglb->Portray_delays = FALSE;
if (ext == attvars_ext) { if (ext == attvars_ext) {
attvar_record *attv = RepAttVar(t); attvar_record *attv = RepAttVar(t);
CELL *l = &attv->Value; /* dirty low-level hack, check atts.h */ CELL *l = &attv->Value; /* dirty low-level hack, check atts.h */
@ -691,14 +716,13 @@ write_var(CELL *t, struct write_globs *wglb, struct rewind_term *rwt)
l += 2; l += 2;
writeTerm(from_pointer(l, &nrwt, wglb), 999, 1, FALSE, wglb, &nrwt); writeTerm(from_pointer(l, &nrwt, wglb), 999, 1, FALSE, wglb, &nrwt);
restore_from_write(&nrwt, wglb); restore_from_write(&nrwt, wglb);
wrputc(')', wglb->stream); wrclose_bracket(wglb, TRUE);
} }
Yap_Portray_delays = TRUE; wglb->Portray_delays = TRUE;
return; return;
} }
wrputc('D', wglb->stream); wrputc('D', wglb->stream);
wrputn(vcount,wglb); wrputn(vcount,wglb);
#endif
} else { } else {
wrputn(((Int) (t- H0)),wglb); wrputn(((Int) (t- H0)),wglb);
} }
@ -790,6 +814,7 @@ write_list(Term t, int direction, int depth, struct write_globs *wglb, struct re
} }
} }
static void static void
writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, struct rewind_term *rwt) writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, struct rewind_term *rwt)
/* term to write */ /* term to write */
@ -823,8 +848,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
wrputs(",",wglb->stream); wrputs(",",wglb->stream);
writeTerm(from_pointer(RepPair(t)+1, &nrwt, wglb), 999, depth + 1, FALSE, wglb, &nrwt); writeTerm(from_pointer(RepPair(t)+1, &nrwt, wglb), 999, depth + 1, FALSE, wglb, &nrwt);
restore_from_write(&nrwt, wglb); restore_from_write(&nrwt, wglb);
wrputc(')', wglb->stream); wrclose_bracket(wglb, TRUE);
lastw = separator;
return; return;
} }
if (wglb->Use_portray) { if (wglb->Use_portray) {
@ -886,7 +910,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
int argno = 1; int argno = 1;
CELL *p = ArgsOfSFTerm(t); CELL *p = ArgsOfSFTerm(t);
putAtom(atom, wglb->Quote_illegal, wglb); putAtom(atom, wglb->Quote_illegal, wglb);
wrputc('(', wglb->stream); wropen_bracket(wglb, FALSE);
lastw = separator; lastw = separator;
while (*p) { while (*p) {
Int sl = 0; Int sl = 0;
@ -904,8 +928,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
wrputc(',', wglb->stream); wrputc(',', wglb->stream);
argno++; argno++;
} }
wrputc(')', wglb->stream); wrclose_bracket(wglb, TRUE);
lastw = separator;
return; return;
} }
#endif #endif
@ -934,28 +957,22 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
!IsVarTerm(tright) && IsAtomTerm(tright) && !IsVarTerm(tright) && IsAtomTerm(tright) &&
Yap_IsOp(AtomOfTerm(tright)); Yap_IsOp(AtomOfTerm(tright));
if (op > p) { if (op > p) {
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */ wropen_bracket(wglb, TRUE);
if (lastw != separator && !rinfixarg)
wrputc(' ', wglb->stream);
wrputc('(', wglb->stream);
lastw = separator;
} }
putAtom(atom, wglb->Quote_illegal, wglb); putAtom(atom, wglb->Quote_illegal, wglb);
if (bracket_right) { if (bracket_right) {
wrputc('(', wglb->stream); /* avoid stuff such as \+ (a,b) being written as \+(a,b) */
lastw = separator; wropen_bracket(wglb, TRUE);
} else if (atom == AtomMinus) { } else if (atom == AtomMinus) {
last_minus = TRUE; last_minus = TRUE;
} }
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), rp, depth + 1, FALSE, wglb, &nrwt); writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), rp, depth + 1, TRUE, wglb, &nrwt);
restore_from_write(&nrwt, wglb); restore_from_write(&nrwt, wglb);
if (bracket_right) { if (bracket_right) {
wrputc(')', wglb->stream); wrclose_bracket(wglb, TRUE);
lastw = separator;
} }
if (op > p) { if (op > p) {
wrputc(')', wglb->stream); wrclose_bracket(wglb, TRUE);
lastw = separator;
} }
} else if (!wglb->Ignore_ops && } else if (!wglb->Ignore_ops &&
Arity == 1 && Arity == 1 &&
@ -963,29 +980,24 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
Term tleft = ArgOfTerm(1, t); Term tleft = ArgOfTerm(1, t);
int bracket_left = int bracket_left =
!IsVarTerm(tleft) && IsAtomTerm(tleft) && !IsVarTerm(tleft) &&
IsAtomTerm(tleft) &&
Yap_IsOp(AtomOfTerm(tleft)); Yap_IsOp(AtomOfTerm(tleft));
if (op > p) { if (op > p) {
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */ /* avoid stuff such as \+ (a,b) being written as \+(a,b) */
if (lastw != separator && !rinfixarg) wropen_bracket(wglb, TRUE);
wrputc(' ', wglb->stream);
wrputc('(', wglb->stream);
lastw = separator;
} }
if (bracket_left) { if (bracket_left) {
wrputc('(', wglb->stream); wropen_bracket(wglb, TRUE);
lastw = separator;
} }
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), lp, depth + 1, rinfixarg, wglb, &nrwt); writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), lp, depth + 1, rinfixarg, wglb, &nrwt);
restore_from_write(&nrwt, wglb); restore_from_write(&nrwt, wglb);
if (bracket_left) { if (bracket_left) {
wrputc(')', wglb->stream); wrclose_bracket(wglb, TRUE);
lastw = separator;
} }
putAtom(atom, wglb->Quote_illegal, wglb); putAtom(atom, wglb->Quote_illegal, wglb);
if (op > p) { if (op > p) {
wrputc(')', wglb->stream); wrclose_bracket(wglb, TRUE);
lastw = separator;
} }
} else if (!wglb->Ignore_ops && } else if (!wglb->Ignore_ops &&
Arity == 2 && Yap_IsInfixOp(atom, &op, &lp, Arity == 2 && Yap_IsInfixOp(atom, &op, &lp,
@ -1001,41 +1013,36 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
if (op > p) { if (op > p) {
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */ /* avoid stuff such as \+ (a,b) being written as \+(a,b) */
if (lastw != separator && !rinfixarg) wropen_bracket(wglb, TRUE);
wrputc(' ', wglb->stream);
wrputc('(', wglb->stream);
lastw = separator; lastw = separator;
} }
if (bracket_left) { if (bracket_left) {
wrputc('(', wglb->stream); wropen_bracket(wglb, TRUE);
lastw = separator;
} }
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), lp, depth + 1, rinfixarg, wglb, &nrwt); writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), lp, depth + 1, rinfixarg, wglb, &nrwt);
t = AbsAppl(restore_from_write(&nrwt, wglb)-1); t = AbsAppl(restore_from_write(&nrwt, wglb)-1);
if (bracket_left) { if (bracket_left) {
wrputc(')', wglb->stream); wrclose_bracket(wglb, TRUE);
lastw = separator;
} }
/* avoid quoting commas */ /* avoid quoting commas and bars */
if (strcmp(RepAtom(atom)->StrOfAE,",")) if (!strcmp(RepAtom(atom)->StrOfAE,",")) {
putAtom(atom, wglb->Quote_illegal, wglb);
else {
wrputc(',', wglb->stream); wrputc(',', wglb->stream);
lastw = separator; lastw = separator;
} } else if (!strcmp(RepAtom(atom)->StrOfAE,"|")) {
if (bracket_right) { wrputc('|', wglb->stream);
wrputc('(', wglb->stream);
lastw = separator; lastw = separator;
} else
putAtom(atom, wglb->Quote_illegal, wglb);
if (bracket_right) {
wropen_bracket(wglb, TRUE);
} }
writeTerm(from_pointer(RepAppl(t)+2, &nrwt, wglb), rp, depth + 1, TRUE, wglb, &nrwt); writeTerm(from_pointer(RepAppl(t)+2, &nrwt, wglb), rp, depth + 1, TRUE, wglb, &nrwt);
restore_from_write(&nrwt, wglb); restore_from_write(&nrwt, wglb);
if (bracket_right) { if (bracket_right) {
wrputc(')', wglb->stream); wrclose_bracket(wglb, TRUE);
lastw = separator;
} }
if (op > p) { if (op > p) {
wrputc(')', wglb->stream); wrclose_bracket(wglb, TRUE);
lastw = separator;
} }
} else if (wglb->Handle_vars && functor == LOCAL_FunctorVar) { } else if (wglb->Handle_vars && functor == LOCAL_FunctorVar) {
Term ti = ArgOfTerm(1, t); Term ti = ArgOfTerm(1, t);
@ -1068,8 +1075,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
lastw = separator; lastw = separator;
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), 999, depth + 1, FALSE, wglb, &nrwt); writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), 999, depth + 1, FALSE, wglb, &nrwt);
restore_from_write(&nrwt, wglb); restore_from_write(&nrwt, wglb);
wrputc(')', wglb->stream); wrclose_bracket(wglb, TRUE);
lastw = separator;
} }
} else if (!wglb->Ignore_ops && functor == FunctorBraces) { } else if (!wglb->Ignore_ops && functor == FunctorBraces) {
wrputc('{', wglb->stream); wrputc('{', wglb->stream);
@ -1098,7 +1104,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
} else { } else {
putAtom(atom, wglb->Quote_illegal, wglb); putAtom(atom, wglb->Quote_illegal, wglb);
lastw = separator; lastw = separator;
wrputc('(', wglb->stream); wropen_bracket(wglb, FALSE);
for (op = 1; op <= Arity; ++op) { for (op = 1; op <= Arity; ++op) {
if (op == wglb->MaxArgs) { if (op == wglb->MaxArgs) {
wrputc('.', wglb->stream); wrputc('.', wglb->stream);
@ -1113,8 +1119,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
lastw = separator; lastw = separator;
} }
} }
wrputc(')', wglb->stream); wrclose_bracket(wglb, TRUE);
lastw = separator;
} }
} }
} }
@ -1138,6 +1143,7 @@ Yap_plwrite(Term t, void *mywrite, int max_depth, int flags, int priority)
wglb.Quote_illegal = flags & Quote_illegal_f; wglb.Quote_illegal = flags & Quote_illegal_f;
wglb.Handle_vars = flags & Handle_vars_f; wglb.Handle_vars = flags & Handle_vars_f;
wglb.Use_portray = flags & Use_portray_f; wglb.Use_portray = flags & Use_portray_f;
wglb.Portray_delays = flags & AttVar_Portray_f;
wglb.MaxDepth = max_depth; wglb.MaxDepth = max_depth;
wglb.MaxArgs = max_depth; wglb.MaxArgs = max_depth;
/* notice: we must have ASP well set when using portray, otherwise /* notice: we must have ASP well set when using portray, otherwise

View File

@ -498,6 +498,7 @@ void STD_PROTO(Yap_init_optyap_preds,(void));
/* pl-file.c */ /* pl-file.c */
struct PL_local_data *Yap_InitThreadIO(int wid); struct PL_local_data *Yap_InitThreadIO(int wid);
void Yap_flush(void);
static inline static inline
yamop * yamop *

View File

@ -159,6 +159,9 @@ typedef enum
#ifdef HAVE_LOCALE_H #ifdef HAVE_LOCALE_H
#include <locale.h> #include <locale.h>
#endif #endif
#ifdef HAVE_LIMITS_H /* get MAXPATHLEN */
#include <limits.h>
#endif
#include <setjmp.h> #include <setjmp.h>
#include <assert.h> #include <assert.h>
#if HAVE_SYS_PARAM_H #if HAVE_SYS_PARAM_H

View File

@ -323,12 +323,6 @@ int STD_PROTO(Yap_growtrail_in_parser, (tr_fr_ptr *, TokEntry **, VarEntry **)
extern int errno; extern int errno;
#endif #endif
#ifdef DEBUG
#if COROUTINING
extern int Yap_Portray_delays;
#endif
#endif
EXTERN inline UInt STD_PROTO(HashFunction, (unsigned char *)); EXTERN inline UInt STD_PROTO(HashFunction, (unsigned char *));
EXTERN inline UInt STD_PROTO(WideHashFunction, (wchar_t *)); EXTERN inline UInt STD_PROTO(WideHashFunction, (wchar_t *));

View File

@ -710,6 +710,7 @@ all: startup.yss
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE)) @ENABLE_CPLINT@ (cd packages/cplint; $(MAKE))
@ENABLE_CPLINT@ (cd packages/cplint/slipcase; $(MAKE)) @ENABLE_CPLINT@ (cd packages/cplint/slipcase; $(MAKE))
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE)) @ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE))
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE))
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE)) @ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE))
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE)) @ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE))
@ENABLE_JPL@ @INSTALL_DLLS@ (cd packages/jpl; $(MAKE)) @ENABLE_JPL@ @INSTALL_DLLS@ (cd packages/jpl; $(MAKE))
@ -788,6 +789,7 @@ install_unix: startup.yss libYap.a
@ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE) install) @ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE) install)
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) install) @ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) install)
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) install) @ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) install)
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE) install)
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) install) @ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) install)
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) install) @ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) install)
@ -840,6 +842,7 @@ install_win32: startup.yss @ENABLE_WINCONSOLE@ pl-yap@EXEC_SUFFIX@
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) install) @ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) install)
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) install) @ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) install)
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) install) @ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) install)
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE) install)
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) install) @ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) install)
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) install) @ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) install)
@ -904,6 +907,7 @@ clean: clean_docs
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) clean) @ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) clean)
@ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE) clean) @ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE) clean)
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) clean) @ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) clean)
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE) clean)
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) clean) @ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) clean)
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) clean) @ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) clean)
@ENABLE_JPL@ @INSTALL_DLLS@ (cd packages/jpl; $(MAKE) clean) @ENABLE_JPL@ @INSTALL_DLLS@ (cd packages/jpl; $(MAKE) clean)

View File

@ -257,6 +257,10 @@
#undef HAVE_WAITPID #undef HAVE_WAITPID
#undef HAVE_MPZ_XOR #undef HAVE_MPZ_XOR
#if HAVE_GETHOSTNAME==1
#define HAS_GETHOSTNAME 1
#endif
#undef HAVE_SIGINFO #undef HAVE_SIGINFO
#undef HAVE_SIGSEGV #undef HAVE_SIGSEGV
#undef HAVE_SIGPROF #undef HAVE_SIGPROF

22
configure vendored
View File

@ -625,6 +625,7 @@ ENABLE_REAL
ENABLE_MINISAT ENABLE_MINISAT
CUDD_CPPFLAGS CUDD_CPPFLAGS
CUDD_LDFLAGS CUDD_LDFLAGS
ENABLE_BDDLIB
ENABLE_CUDD ENABLE_CUDD
EXTRA_INCLUDES_FOR_WIN32 EXTRA_INCLUDES_FOR_WIN32
ENABLE_WINCONSOLE ENABLE_WINCONSOLE
@ -788,6 +789,7 @@ enable_depth_limit
enable_wam_profile enable_wam_profile
enable_low_level_tracer enable_low_level_tracer
enable_threads enable_threads
enable_bddlib
enable_pthread_locking enable_pthread_locking
enable_max_performance enable_max_performance
enable_max_memory enable_max_memory
@ -1462,6 +1464,7 @@ Optional Features:
--enable-wam-profile support low level profiling of abstract machine --enable-wam-profile support low level profiling of abstract machine
--enable-low-level-tracer support support for procedure-call tracing --enable-low-level-tracer support support for procedure-call tracing
--enable-threads support system threads --enable-threads support system threads
--enable-bddlib dynamic bdd library
--enable-pthread-locking use pthread locking primitives for internal locking (requires threads) --enable-pthread-locking use pthread locking primitives for internal locking (requires threads)
--enable-max-performance try using the best flags for specific architecture --enable-max-performance try using the best flags for specific architecture
--enable-max-memory try using the best flags for using the memory to the most --enable-max-memory try using the best flags for using the memory to the most
@ -4486,6 +4489,13 @@ else
threads=no threads=no
fi fi
# Check whether --enable-bddlib was given.
if test "${enable_bddlib+set}" = set; then :
enableval=$enable_bddlib; dynamic_bdd="$enableval"
else
dynamic_bdd=no
fi
# Check whether --enable-pthread-locking was given. # Check whether --enable-pthread-locking was given.
if test "${enable_pthread_locking+set}" = set; then : if test "${enable_pthread_locking+set}" = set; then :
enableval=$enable_pthread_locking; pthreadlocking="$enableval" enableval=$enable_pthread_locking; pthreadlocking="$enableval"
@ -5000,7 +5010,14 @@ fi
if test "$yap_cv_cudd" = no if test "$yap_cv_cudd" = no
then then
ENABLE_CUDD="@# " ENABLE_CUDD="@# "
ENABLE_BDDLIB="@# "
else else
if test "$dynamic_bdd" = yes
then
ENABLE_BDDLIB=""
else
ENABLE_BDDLIB="@# "
fi
ENABLE_CUDD="" ENABLE_CUDD=""
fi fi
@ -9220,6 +9237,7 @@ fi
{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for gcc threaded code" >&5 { $as_echo "$as_me:${as_lineno-$LINENO}: checking for gcc threaded code" >&5
@ -10454,6 +10472,7 @@ mkdir -p LGPL/clp
mkdir -p LGPL/swi_console mkdir -p LGPL/swi_console
mkdir -p GPL mkdir -p GPL
mkdir -p packages/ mkdir -p packages/
mkdir -p packages/bdd
mkdir -p packages/clib mkdir -p packages/clib
mkdir -p packages/clib/sha1 mkdir -p packages/clib/sha1
mkdir -p packages/clib/maildrop mkdir -p packages/clib/maildrop
@ -10616,6 +10635,8 @@ ac_config_files="$ac_config_files packages/zlib/Makefile"
fi fi
if test "$ENABLE_CUDD" = ""; then if test "$ENABLE_CUDD" = ""; then
ac_config_files="$ac_config_files packages/bdd/Makefile"
ac_config_files="$ac_config_files packages/ProbLog/simplecudd/Makefile" ac_config_files="$ac_config_files packages/ProbLog/simplecudd/Makefile"
ac_config_files="$ac_config_files packages/ProbLog/simplecudd_lfi/Makefile" ac_config_files="$ac_config_files packages/ProbLog/simplecudd_lfi/Makefile"
@ -11398,6 +11419,7 @@ do
"packages/semweb/Makefile") CONFIG_FILES="$CONFIG_FILES packages/semweb/Makefile" ;; "packages/semweb/Makefile") CONFIG_FILES="$CONFIG_FILES packages/semweb/Makefile" ;;
"packages/sgml/Makefile") CONFIG_FILES="$CONFIG_FILES packages/sgml/Makefile" ;; "packages/sgml/Makefile") CONFIG_FILES="$CONFIG_FILES packages/sgml/Makefile" ;;
"packages/zlib/Makefile") CONFIG_FILES="$CONFIG_FILES packages/zlib/Makefile" ;; "packages/zlib/Makefile") CONFIG_FILES="$CONFIG_FILES packages/zlib/Makefile" ;;
"packages/bdd/Makefile") CONFIG_FILES="$CONFIG_FILES packages/bdd/Makefile" ;;
"packages/ProbLog/simplecudd/Makefile") CONFIG_FILES="$CONFIG_FILES packages/ProbLog/simplecudd/Makefile" ;; "packages/ProbLog/simplecudd/Makefile") CONFIG_FILES="$CONFIG_FILES packages/ProbLog/simplecudd/Makefile" ;;
"packages/ProbLog/simplecudd_lfi/Makefile") CONFIG_FILES="$CONFIG_FILES packages/ProbLog/simplecudd_lfi/Makefile" ;; "packages/ProbLog/simplecudd_lfi/Makefile") CONFIG_FILES="$CONFIG_FILES packages/ProbLog/simplecudd_lfi/Makefile" ;;
"packages/swi-minisat2/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/Makefile" ;; "packages/swi-minisat2/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/Makefile" ;;

View File

@ -158,6 +158,9 @@ AC_ARG_ENABLE(low-level-tracer,
AC_ARG_ENABLE(threads, AC_ARG_ENABLE(threads,
[ --enable-threads support system threads ], [ --enable-threads support system threads ],
threads="$enableval", threads=no) threads="$enableval", threads=no)
AC_ARG_ENABLE(bddlib,
[ --enable-bddlib dynamic bdd library ],
dynamic_bdd="$enableval", dynamic_bdd=no)
AC_ARG_ENABLE(pthread-locking, AC_ARG_ENABLE(pthread-locking,
[ --enable-pthread-locking use pthread locking primitives for internal locking (requires threads) ], [ --enable-pthread-locking use pthread locking primitives for internal locking (requires threads) ],
pthreadlocking="$enableval", pthreadlocking=no) pthreadlocking="$enableval", pthreadlocking=no)
@ -510,7 +513,14 @@ fi
if test "$yap_cv_cudd" = no if test "$yap_cv_cudd" = no
then then
ENABLE_CUDD="@# " ENABLE_CUDD="@# "
ENABLE_BDDLIB="@# "
else else
if test "$dynamic_bdd" = yes
then
ENABLE_BDDLIB=""
else
ENABLE_BDDLIB="@# "
fi
ENABLE_CUDD="" ENABLE_CUDD=""
fi fi
@ -1789,6 +1799,7 @@ AC_SUBST(ENABLE_WINCONSOLE)
AC_SUBST(EXTRA_INCLUDES_FOR_WIN32) AC_SUBST(EXTRA_INCLUDES_FOR_WIN32)
AC_SUBST(ENABLE_CUDD) AC_SUBST(ENABLE_CUDD)
AC_SUBST(ENABLE_BDDLIB)
AC_SUBST(CUDD_LDFLAGS) AC_SUBST(CUDD_LDFLAGS)
AC_SUBST(CUDD_CPPFLAGS) AC_SUBST(CUDD_CPPFLAGS)
AC_SUBST(ENABLE_MINISAT) AC_SUBST(ENABLE_MINISAT)
@ -2269,6 +2280,7 @@ mkdir -p LGPL/clp
mkdir -p LGPL/swi_console mkdir -p LGPL/swi_console
mkdir -p GPL mkdir -p GPL
mkdir -p packages/ mkdir -p packages/
mkdir -p packages/bdd
mkdir -p packages/clib mkdir -p packages/clib
mkdir -p packages/clib/sha1 mkdir -p packages/clib/sha1
mkdir -p packages/clib/maildrop mkdir -p packages/clib/maildrop
@ -2392,6 +2404,7 @@ AC_CONFIG_FILES([packages/zlib/Makefile])
fi fi
if test "$ENABLE_CUDD" = ""; then if test "$ENABLE_CUDD" = ""; then
AC_CONFIG_FILES([packages/bdd/Makefile])
AC_CONFIG_FILES([packages/ProbLog/simplecudd/Makefile]) AC_CONFIG_FILES([packages/ProbLog/simplecudd/Makefile])
AC_CONFIG_FILES([packages/ProbLog/simplecudd_lfi/Makefile]) AC_CONFIG_FILES([packages/ProbLog/simplecudd_lfi/Makefile])
fi fi

View File

@ -75,6 +75,9 @@
#if HAVE_STRING_H #if HAVE_STRING_H
#include <string.h> #include <string.h>
#endif #endif
#if HAVE_IEEEFP_H
#include <ieeefp.h>
#endif
static void PROTO(do_top_goal,(YAP_Term)); static void PROTO(do_top_goal,(YAP_Term));
static void PROTO(exec_top_level,(int, YAP_init_args *)); static void PROTO(exec_top_level,(int, YAP_init_args *));

View File

@ -12635,6 +12635,13 @@ The path @var{Path} is a path starting at vertex @var{Vertex} in graph
The path @var{Path} is a path starting at vertex @var{Vertex} in graph The path @var{Path} is a path starting at vertex @var{Vertex} in graph
@var{Graph}. @var{Graph}.
@item dgraph_leaves(+@var{Graph}, ?@var{Vertices})
@findex dgraph_leaves/2
@snindex dgraph_leaves/2
@cnindex dgraph_leaves/2
The vertices @var{Vertices} have no outgoing edge in graph
@var{Graph}.
@end table @end table
@node UnDGraphs, Lambda , DGraphs, Library @node UnDGraphs, Lambda , DGraphs, Library
@ -16464,6 +16471,21 @@ then allow one to construct functors, and to obtain their name and arity.
Note that the functor is essentially a pair formed by an atom, and Note that the functor is essentially a pair formed by an atom, and
arity. arity.
Constructing terms in the stack may lead to overflow. The routine
@example
int YAP_RequiresExtraStack(size_t @var{min})
@end example
verifies whether you have at least @var{min} cells free in the stack,
and it returns true if it has to ensure enough memory by calling the
garbage collector and or stack shifter. The routine returns false if no
memory is needed, and a negative number if it cannot provide enough
memory.
You can set @var{min} to zero if you do not know how much room you need
but you do know you do not need a big chunk at a single go. Usually, the routine
would usually be called together with a long-jump to restart the
code. Slots can also be used if there is small state.
@node Unifying Terms, Manipulating Strings, Manipulating Terms, C-Interface @node Unifying Terms, Manipulating Strings, Manipulating Terms, C-Interface
@section Unification @section Unification

View File

@ -599,6 +599,8 @@ extern X_API size_t PROTO(YAP_SizeOfExportedTerm,(char *));
extern X_API YAP_Term PROTO(YAP_ImportTerm,(char *)); extern X_API YAP_Term PROTO(YAP_ImportTerm,(char *));
extern X_API int PROTO(YAP_RequiresExtraStack,(size_t));
#define YAP_InitCPred(N,A,F) YAP_UserCPredicate(N,F,A) #define YAP_InitCPred(N,A,F) YAP_UserCPredicate(N,F,A)
__END_DECLS __END_DECLS

View File

@ -115,6 +115,7 @@ typedef enum {
#define YAP_BOOT_FROM_SAVED_CODE 1 #define YAP_BOOT_FROM_SAVED_CODE 1
#define YAP_BOOT_FROM_SAVED_STACKS 2 #define YAP_BOOT_FROM_SAVED_STACKS 2
#define YAP_FULL_BOOT_FROM_PROLOG 4 #define YAP_FULL_BOOT_FROM_PROLOG 4
#define YAP_BOOT_DONE_BEFOREHAND 8
#define YAP_BOOT_ERROR -1 #define YAP_BOOT_ERROR -1
#define YAP_WRITE_QUOTED 1 #define YAP_WRITE_QUOTED 1

View File

@ -32,6 +32,7 @@
dgraph_min_paths/3, dgraph_min_paths/3,
dgraph_isomorphic/4, dgraph_isomorphic/4,
dgraph_path/3, dgraph_path/3,
dgraph_leaves/2,
dgraph_reachable/3 dgraph_reachable/3
]). ]).
@ -414,3 +415,13 @@ reachable([V|Vertices], Done0, DoneF, G, [V|EdgesF], Edges0) :-
rb_insert(Done0, V, [], Done1), rb_insert(Done0, V, [], Done1),
reachable(Kids, Done1, DoneI, G, EdgesF, EdgesI), reachable(Kids, Done1, DoneI, G, EdgesF, EdgesI),
reachable(Vertices, DoneI, DoneF, G, EdgesI, Edges0). reachable(Vertices, DoneI, DoneF, G, EdgesI, Edges0).
dgraph_leaves(Graph, Vertices) :-
rb_visit(Graph, Pairs),
vertices_without_children(Pairs, Vertices).
vertices_without_children([], []).
vertices_without_children((V-[]).Pairs, V.Vertices) :-
vertices_without_children(Pairs, Vertices).
vertices_without_children(_V-[_|_].Pairs, Vertices) :-
vertices_without_children(Pairs, Vertices).

View File

@ -25,7 +25,10 @@
:- ensure_loaded(library(lists)). :- ensure_loaded(library(lists)).
:- load_foreign_files([matlab], ['eng','mx','ut'], init_matlab). tell_warning :-
print_message(warning,functionality(matlab)).
:- ( catch(load_foreign_files([matlab], ['eng','mx','ut'], init_matlab),_,fail) -> true ; tell_warning).
matlab_eval_sequence(S) :- matlab_eval_sequence(S) :-
atomic_concat(S,S1), atomic_concat(S,S1),

View File

@ -4671,6 +4671,11 @@ EndPredDefs
#if __YAP_PROLOG__ #if __YAP_PROLOG__
void Yap_flush(void)
{
flush_output(0);
}
void * void *
Yap_GetStreamHandle(Atom at) Yap_GetStreamHandle(Atom at)
{ GET_LD { GET_LD

View File

@ -16,6 +16,11 @@ YAPLIBDIR=@libdir@/Yap
# #
SHAREDIR=$(ROOTDIR)/share/Yap SHAREDIR=$(ROOTDIR)/share/Yap
# #
# where YAP should store documentation
#
DOCDIR=$(ROOTDIR)/share/doc/Yap
EXDIR=$(DOCDIR)/examples/CLPBN
#
# #
# You shouldn't need to change what follows. # You shouldn't need to change what follows.
# #
@ -35,6 +40,7 @@ CLPBN_EXDIR = $(srcdir)/examples
CLPBN_PROGRAMS= \ CLPBN_PROGRAMS= \
$(CLPBN_SRCDIR)/aggregates.yap \ $(CLPBN_SRCDIR)/aggregates.yap \
$(CLPBN_SRCDIR)/bdd.yap \
$(CLPBN_SRCDIR)/bnt.yap \ $(CLPBN_SRCDIR)/bnt.yap \
$(CLPBN_SRCDIR)/bp.yap \ $(CLPBN_SRCDIR)/bp.yap \
$(CLPBN_SRCDIR)/connected.yap \ $(CLPBN_SRCDIR)/connected.yap \
@ -48,6 +54,7 @@ CLPBN_PROGRAMS= \
$(CLPBN_SRCDIR)/graphviz.yap \ $(CLPBN_SRCDIR)/graphviz.yap \
$(CLPBN_SRCDIR)/ground_factors.yap \ $(CLPBN_SRCDIR)/ground_factors.yap \
$(CLPBN_SRCDIR)/hmm.yap \ $(CLPBN_SRCDIR)/hmm.yap \
$(CLPBN_SRCDIR)/horus.yap \
$(CLPBN_SRCDIR)/jt.yap \ $(CLPBN_SRCDIR)/jt.yap \
$(CLPBN_SRCDIR)/matrix_cpt_utils.yap \ $(CLPBN_SRCDIR)/matrix_cpt_utils.yap \
$(CLPBN_SRCDIR)/pgrammar.yap \ $(CLPBN_SRCDIR)/pgrammar.yap \
@ -72,6 +79,8 @@ CLPBN_SCHOOL_EXAMPLES= \
$(CLPBN_EXDIR)/School/parschema.yap \ $(CLPBN_EXDIR)/School/parschema.yap \
$(CLPBN_EXDIR)/School/school_128.yap \ $(CLPBN_EXDIR)/School/school_128.yap \
$(CLPBN_EXDIR)/School/school_32.yap \ $(CLPBN_EXDIR)/School/school_32.yap \
$(CLPBN_EXDIR)/School/sch32.yap \
$(CLPBN_EXDIR)/School/school32_data.yap \
$(CLPBN_EXDIR)/School/school_64.yap \ $(CLPBN_EXDIR)/School/school_64.yap \
$(CLPBN_EXDIR)/School/tables.yap $(CLPBN_EXDIR)/School/tables.yap
@ -92,12 +101,13 @@ CLPBN_EXAMPLES= \
install: $(CLBN_TOP) $(CLBN_PROGRAMS) $(CLPBN_PROGRAMS) install: $(CLBN_TOP) $(CLBN_PROGRAMS) $(CLPBN_PROGRAMS)
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/learning mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/learning
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/examples/School mkdir -p $(DESTDIR)$(EXDIR)
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/examples/HMMer mkdir -p $(DESTDIR)$(EXDIR)/School
mkdir -p $(DESTDIR)$(EXDIR)/HMMer
for h in $(CLPBN_TOP); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR); done for h in $(CLPBN_TOP); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR); done
for h in $(CLPBN_PROGRAMS); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn; done for h in $(CLPBN_PROGRAMS); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn; done
for h in $(CLPBN_LEARNING_PROGRAMS); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/learning; done for h in $(CLPBN_LEARNING_PROGRAMS); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/learning; done
for h in $(CLPBN_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/examples; done for h in $(CLPBN_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(EXDIR); done
for h in $(CLPBN_SCHOOL_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/examples/School; done for h in $(CLPBN_SCHOOL_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(EXDIR)/School; done
for h in $(CLPBN_HMMER_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/examples/HMMer; done for h in $(CLPBN_HMMER_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(EXDIR)/HMMer; done

View File

@ -13,6 +13,7 @@
clpbn_init_graph/1, clpbn_init_graph/1,
probability/2, probability/2,
conditional_probability/3, conditional_probability/3,
use_parfactors/1,
op( 500, xfy, with)]). op( 500, xfy, with)]).
:- use_module(library(atts)). :- use_module(library(atts)).
@ -43,6 +44,7 @@
check_if_bp_done/1, check_if_bp_done/1,
init_bp_solver/4, init_bp_solver/4,
run_bp_solver/3, run_bp_solver/3,
call_bp_ground/6,
finalize_bp_solver/1 finalize_bp_solver/1
]). ]).
@ -61,11 +63,17 @@
run_jt_solver/3 run_jt_solver/3
]). ]).
:- use_module('clpbn/bnt', :- use_module('clpbn/bdd',
[do_bnt/3, [bdd/3,
check_if_bnt_done/1 init_bdd_solver/4,
run_bdd_solver/3
]). ]).
%% :- use_module('clpbn/bnt',
%% [do_bnt/3,
%% check_if_bnt_done/1
%% ]).
:- use_module('clpbn/gibbs', :- use_module('clpbn/gibbs',
[gibbs/3, [gibbs/3,
check_if_gibbs_done/1, check_if_gibbs_done/1,
@ -111,7 +119,7 @@
[clpbn2gviz/4]). [clpbn2gviz/4]).
:- use_module(clpbn/ground_factors, :- use_module(clpbn/ground_factors,
[generate_bn/2]). [generate_network/5]).
:- dynamic solver/1,output/1,use/1,suppress_attribute_display/1, parameter_softening/1, em_solver/1, use_parfactors/1. :- dynamic solver/1,output/1,use/1,suppress_attribute_display/1, parameter_softening/1, em_solver/1, use_parfactors/1.
@ -223,9 +231,17 @@ clpbn_marginalise(V, Dist) :-
% called by top-level % called by top-level
% or by call_residue/2 % or by call_residue/2
% %
project_attributes(GVars, AVars0) :- project_attributes(GVars, _AVars0) :-
use_parfactors(on),
clpbn_flag(solver, Solver), Solver \= fove, !,
generate_network(GVars, GKeys, Keys, Factors, Evidence),
(ground(GVars) ->
true
;
call_ground_solver(Solver, GVars, GKeys, Keys, Factors, Evidence, _Avars0)
).
project_attributes(GVars, AVars) :-
suppress_attribute_display(false), suppress_attribute_display(false),
generate_vars(GVars, AVars0, AVars),
AVars = [_|_], AVars = [_|_],
solver(Solver), solver(Solver),
( GVars = [_|_] ; Solver = graphs), !, ( GVars = [_|_] ; Solver = graphs), !,
@ -243,11 +259,6 @@ project_attributes(GVars, AVars0) :-
). ).
project_attributes(_, _). project_attributes(_, _).
generate_vars(GVars, _, NewAVars) :-
use_parfactors(on), !,
generate_bn(GVars, NewAVars).
generate_vars(_GVars, AVars, AVars).
clpbn_vars(AVars, DiffVars, AllVars) :- clpbn_vars(AVars, DiffVars, AllVars) :-
sort_vars_by_key(AVars,SortedAVars,DiffVars), sort_vars_by_key(AVars,SortedAVars,DiffVars),
incorporate_evidence(SortedAVars, AllVars). incorporate_evidence(SortedAVars, AllVars).
@ -289,6 +300,8 @@ write_out(ve, GVars, AVars, DiffVars) :-
ve(GVars, AVars, DiffVars). ve(GVars, AVars, DiffVars).
write_out(jt, GVars, AVars, DiffVars) :- write_out(jt, GVars, AVars, DiffVars) :-
jt(GVars, AVars, DiffVars). jt(GVars, AVars, DiffVars).
write_out(bdd, GVars, AVars, DiffVars) :-
bdd(GVars, AVars, DiffVars).
write_out(bp, GVars, AVars, DiffVars) :- write_out(bp, GVars, AVars, DiffVars) :-
bp(GVars, AVars, DiffVars). bp(GVars, AVars, DiffVars).
write_out(gibbs, GVars, AVars, DiffVars) :- write_out(gibbs, GVars, AVars, DiffVars) :-
@ -298,6 +311,11 @@ write_out(bnt, GVars, AVars, DiffVars) :-
write_out(fove, GVars, AVars, DiffVars) :- write_out(fove, GVars, AVars, DiffVars) :-
fove(GVars, AVars, DiffVars). fove(GVars, AVars, DiffVars).
% call a solver with keys, not actual variables
call_ground_solver(bp, GVars, GoalKeys, Keys, Factors, Evidence, Answ) :-
call_bp_ground(GVars, GoalKeys, Keys, Factors, Evidence, Answ).
get_bnode(Var, Goal) :- get_bnode(Var, Goal) :-
get_atts(Var, [key(Key),dist(Dist,Parents)]), get_atts(Var, [key(Key),dist(Dist,Parents)]),
get_dist(Dist,_,Domain,CPT), get_dist(Dist,_,Domain,CPT),
@ -382,6 +400,9 @@ bind_clpbn(_, Var, _, _, _, _, []) :-
bind_clpbn(_, Var, _, _, _, _, []) :- bind_clpbn(_, Var, _, _, _, _, []) :-
use(jt), use(jt),
check_if_ve_done(Var), !. check_if_ve_done(Var), !.
bind_clpbn(_, Var, _, _, _, _, []) :-
use(bdd),
check_if_bdd_done(Var), !.
bind_clpbn(T, Var, Key0, _, _, _, []) :- bind_clpbn(T, Var, Key0, _, _, _, []) :-
get_atts(Var, [key(Key)]), !, get_atts(Var, [key(Key)]), !,
( (
@ -397,11 +418,12 @@ fresh_attvar(Var, NVar) :-
% I will now allow two CLPBN variables to be bound together. % I will now allow two CLPBN variables to be bound together.
%bind_clpbns(Key, Dist, Parents, Key, Dist, Parents). %bind_clpbns(Key, Dist, Parents, Key, Dist, Parents).
bind_clpbns(Key, Dist, _Parents, Key1, Dist1, _Parents1) :- bind_clpbns(Key, Dist, Parents, Key1, Dist1, Parents1) :-
Key == Key1, !, Key == Key1, !,
get_dist(Dist,_Type,_Domain,_Table), get_dist(Dist,_Type,_Domain,_Table),
get_dist(Dist1,_Type1,_Domain1,_Table1), get_dist(Dist1,_Type1,_Domain1,_Table1),
Dist = Dist1. Dist = Dist1,
Parents = Parents1.
bind_clpbns(Key, _, _, _, Key1, _, _, _) :- bind_clpbns(Key, _, _, _, Key1, _, _, _) :-
Key\=Key1, !, fail. Key\=Key1, !, fail.
bind_clpbns(_, _, _, _, _, _, _, _) :- bind_clpbns(_, _, _, _, _, _, _, _) :-
@ -452,6 +474,8 @@ clpbn_init_solver(bp, LVs, Vs0, VarsWithUnboundKeys, State) :-
init_bp_solver(LVs, Vs0, VarsWithUnboundKeys, State). init_bp_solver(LVs, Vs0, VarsWithUnboundKeys, State).
clpbn_init_solver(jt, LVs, Vs0, VarsWithUnboundKeys, State) :- clpbn_init_solver(jt, LVs, Vs0, VarsWithUnboundKeys, State) :-
init_jt_solver(LVs, Vs0, VarsWithUnboundKeys, State). init_jt_solver(LVs, Vs0, VarsWithUnboundKeys, State).
clpbn_init_solver(bdd, LVs, Vs0, VarsWithUnboundKeys, State) :-
init_bdd_solver(LVs, Vs0, VarsWithUnboundKeys, State).
clpbn_init_solver(pcg, LVs, Vs0, VarsWithUnboundKeys, State) :- clpbn_init_solver(pcg, LVs, Vs0, VarsWithUnboundKeys, State) :-
init_pcg_solver(LVs, Vs0, VarsWithUnboundKeys, State). init_pcg_solver(LVs, Vs0, VarsWithUnboundKeys, State).
@ -478,6 +502,9 @@ clpbn_run_solver(bp, LVs, LPs, State) :-
clpbn_run_solver(jt, LVs, LPs, State) :- clpbn_run_solver(jt, LVs, LPs, State) :-
run_jt_solver(LVs, LPs, State). run_jt_solver(LVs, LPs, State).
clpbn_run_solver(bdd, LVs, LPs, State) :-
run_bdd_solver(LVs, LPs, State).
clpbn_run_solver(pcg, LVs, LPs, State) :- clpbn_run_solver(pcg, LVs, LPs, State) :-
run_pcg_solver(LVs, LPs, State). run_pcg_solver(LVs, LPs, State).
@ -538,4 +565,5 @@ match_probability([p(V0=C)=Prob|_], C, V, Prob) :-
match_probability([_|Probs], C, V, Prob) :- match_probability([_|Probs], C, V, Prob) :-
match_probability(Probs, C, V, Prob). match_probability(Probs, C, V, Prob).
:- use_parfactors(on) -> true ; assert(use_parfactors(off)).

View File

@ -27,7 +27,7 @@
:- use_module(library('clpbn/dists'), :- use_module(library('clpbn/dists'),
[ [
dist/4, add_dist/6,
get_dist_domain_size/2]). get_dist_domain_size/2]).
:- use_module(library('clpbn/matrix_cpt_utils'), :- use_module(library('clpbn/matrix_cpt_utils'),
@ -44,8 +44,9 @@ check_for_agg_vars([_|Vs0], Vs1) :-
% transform aggregate distribution into tree % transform aggregate distribution into tree
simplify_dist(avg(Domain), V, Key, Parents, Vs0, VsF) :- !, simplify_dist(avg(Domain), V, Key, Parents, Vs0, VsF) :- !,
cpt_average([V|Parents], Key, Domain, NewDist, Vs0, VsF), cpt_average([V|Parents], Key, Domain, NewDist, Vs0, VsF),
dist(NewDist, Id, Key, ParentsF), NewDist = p(Dom, Tab, Ps),
clpbn:put_atts(V, [dist(Id,ParentsF)]). add_dist(Dom, tab, Tab, Ps, Key, Id),
clpbn:put_atts(V, [dist(Id,Ps)]).
simplify_dist(_, _, _, _, Vs0, Vs0). simplify_dist(_, _, _, _, Vs0, Vs0).
cpt_average(AllVars, Key, Els0, Tab, Vs, NewVs) :- cpt_average(AllVars, Key, Els0, Tab, Vs, NewVs) :-

View File

@ -0,0 +1,802 @@
/************************************************
BDDs in CLP(BN)
A variable is represented by the N possible cases it can take
V = v(Va, Vb, Vc)
The generic formula is
V <- X, Y
Va <- P*X1*Y1 + Q*X2*Y2 + ...
**************************************************/
:- module(clpbn_bdd,
[bdd/3,
set_solver_parameter/2,
init_bdd_solver/4,
run_bdd_solver/3,
finalize_bdd_solver/1,
check_if_bdd_done/1
]).
:- use_module(library('clpbn/dists'),
[dist/4,
get_dist_domain/2,
get_dist_domain_size/2,
get_dist_all_sizes/2,
get_dist_params/2
]).
:- use_module(library('clpbn/display'),
[clpbn_bind_vals/3]).
:- use_module(library('clpbn/aggregates'),
[check_for_agg_vars/2]).
:- use_module(library(atts)).
:- use_module(library(hacks)).
:- use_module(library(lists)).
:- use_module(library(dgraphs)).
:- use_module(library(bdd)).
:- use_module(library(rbtrees)).
:- use_module(library(bhash)).
:- use_module(library(matrix)).
:- dynamic network_counting/1.
:- attribute order/1.
check_if_bdd_done(_Var).
bdd([[]],_,_) :- !.
bdd([QueryVars], AllVars, AllDiffs) :-
init_bdd_solver(_, AllVars, _, BayesNet),
run_bdd_solver([QueryVars], LPs, BayesNet),
finalize_bdd_solver(BayesNet),
clpbn_bind_vals([QueryVars], [LPs], AllDiffs).
init_bdd_solver(_, AllVars0, _, bdd(Term, Leaves, Tops)) :-
% check_for_agg_vars(AllVars0, AllVars1),
sort_vars(AllVars0, AllVars, Leaves),
order_vars(AllVars, 0),
rb_new(Vars0),
rb_new(Pars0),
init_tops(Leaves,Tops),
get_vars_info(AllVars, Vars0, _Vars, Pars0, _Pars, Leaves, Tops, Term, []).
order_vars([], _).
order_vars([V|AllVars], I0) :-
put_atts(V, [order(I0)]),
I is I0+1,
order_vars(AllVars, I).
init_tops([],[]).
init_tops(_.Leaves,_.Tops) :-
init_tops(Leaves,Tops).
sort_vars(AllVars0, AllVars, Leaves) :-
dgraph_new(Graph0),
build_graph(AllVars0, Graph0, Graph),
dgraph_leaves(Graph, Leaves),
dgraph_top_sort(Graph, AllVars).
build_graph([], Graph, Graph).
build_graph(V.AllVars0, Graph0, Graph) :-
clpbn:get_atts(V, [dist(_DistId, Parents)]), !,
dgraph_add_vertex(Graph0, V, Graph1),
add_parents(Parents, V, Graph1, GraphI),
build_graph(AllVars0, GraphI, Graph).
build_graph(_V.AllVars0, Graph0, Graph) :-
build_graph(AllVars0, Graph0, Graph).
add_parents([], _V, Graph, Graph).
add_parents(V0.Parents, V, Graph0, GraphF) :-
dgraph_add_edge(Graph0, V0, V, GraphI),
add_parents(Parents, V, GraphI, GraphF).
get_vars_info([], Vs, Vs, Ps, Ps, _, _) --> [].
get_vars_info([V|MoreVs], Vs, VsF, Ps, PsF, Lvs, Outs) -->
{ clpbn:get_atts(V, [dist(DistId, Parents)]) }, !,
%{writeln(v:DistId:Parents)},
[DIST],
{ get_var_info(V, DistId, Parents, Vs, Vs2, Ps, Ps1, Lvs, Outs, DIST) },
get_vars_info(MoreVs, Vs2, VsF, Ps1, PsF, Lvs, Outs).
get_vars_info([_|MoreVs], Vs0, VsF, Ps0, PsF, VarsInfo, Lvs, Outs) :-
get_vars_info(MoreVs, Vs0, VsF, Ps0, PsF, VarsInfo, Lvs, Outs).
%
% let's have some fun with avg
%
get_var_info(V, avg(Domain), Parents, Vs, Vs2, Ps, Ps, Lvs, Outs, DIST) :- !,
length(Domain, DSize),
% run_though_avg(V, DSize, Domain, Parents, Vs, Vs2, Lvs, Outs, DIST).
top_down_with_tabling(V, DSize, Domain, Parents, Vs, Vs2, Lvs, Outs, DIST).
% bup_avg(V, DSize, Domain, Parents, Vs, Vs2, Lvs, Outs, DIST).
% standard random variable
get_var_info(V, DistId, Parents0, Vs, Vs2, Ps, Ps1, Lvs, Outs, DIST) :-
% clpbn:get_atts(V, [key(K)]), writeln(V:K:DistId:Parents),
reorder_vars(Parents0, Parents, Map),
check_p(DistId, Map, Parms, _ParmVars, Ps, Ps1),
unbound_parms(Parms, ParmVars),
check_v(V, DistId, DIST, Vs, Vs1),
DIST = info(V, Tree, Ev, Values, Formula, ParmVars, Parms),
% get a list of form [[P00,P01], [P10,P11], [P20,P21]]
get_parents(Parents, PVars, Vs1, Vs2),
cross_product(Values, Ev, PVars, ParmVars, Formula0),
% (numbervars(Formula0,0,_),writeln(formula0:Ev:Formula0), fail ; true),
get_evidence(V, Tree, Ev, Formula0, Formula, Lvs, Outs).
%, (numbervars(Formula,0,_),writeln(formula:Formula), fail ; true)
%
% reorder all variables and make sure we get a
% map of how the transfer was done.
%
% position zero is output
%
reorder_vars(Vs, OVs, Map) :-
add_pos(Vs, 1, PVs),
keysort(PVs, SVs),
remove_key(SVs, OVs, Map).
add_pos([], _, []).
add_pos([V|Vs], I0, [K-(I0,V)|PVs]) :-
get_atts(V,[order(K)]),
I is I0+1,
add_pos(Vs, I, PVs).
remove_key([], [], []).
remove_key([_-(I,V)|SVs], [V|OVs], [I|Map]) :-
remove_key(SVs, OVs, Map).
%%%%%%%%%%%%%%%%%%%%%%%%%
%
% use top-down to generate average
%
run_though_avg(V, 3, Domain, Parents0, Vs, Vs2, Lvs, Outs, DIST) :-
reorder_vars(Parents0, Parents, _Map),
check_v(V, avg(Domain,Parents0), DIST, Vs, Vs1),
DIST = info(V, Tree, Ev, [V0,V1,V2], Formula, [], []),
get_parents(Parents, PVars, Vs1, Vs2),
length(Parents, N),
generate_3tree(F00, PVars, 0, 0, 0, N, N0, N1, N2, R, (N1+2*N2 =< N/2), (N1+2*(N2+R) =< N/2)),
simplify_exp(F00, F0),
% generate_3tree(F1, PVars, 0, 0, 0, N, N0, N1, N2, R, ((N1+2*(N2+R) > N/2, N1+2*N2 < (3*N)/2))),
generate_3tree(F20, PVars, 0, 0, 0, N, N0, N1, N2, R, (N1+2*(N2+R) >= (3*N)/2), N1+2*N2 >= (3*N)/2),
% simplify_exp(F20, F2),
F20=F2,
Formula0 = [V0=F0*Ev0,V2=F2*Ev2,V1=not(F0+F2)*Ev1],
Ev = [Ev0,Ev1,Ev2],
get_evidence(V, Tree, Ev, Formula0, Formula, Lvs, Outs).
generate_3tree(OUT, _, I00, I10, I20, IR0, N0, N1, N2, R, _Exp, ExpF) :-
IR is IR0-1,
satisf(I00, I10, I20, IR, N0, N1, N2, R, ExpF),
!,
OUT = 1.
generate_3tree(OUT, [[P0,P1,P2]], I00, I10, I20, IR0, N0, N1, N2, R, Exp, _ExpF) :-
IR is IR0-1,
( satisf(I00+1, I10, I20, IR, N0, N1, N2, R, Exp) ->
L0 = [P0|L1]
;
L0 = L1
),
( satisf(I00, I10+1, I20, IR, N0, N1, N2, R, Exp) ->
L1 = [P1|L2]
;
L1 = L2
),
( satisf(I00, I10, I20+1, IR, N0, N1, N2, R, Exp) ->
L2 = [P2]
;
L2 = []
),
to_disj(L0, OUT).
generate_3tree(OUT, [[P0,P1,P2]|Ps], I00, I10, I20, IR0, N0, N1, N2, R, Exp, ExpF) :-
IR is IR0-1,
( satisf(I00+1, I10, I20, IR, N0, N1, N2, R, Exp) ->
I0 is I00+1, generate_3tree(O0, Ps, I0, I10, I20, IR, N0, N1, N2, R, Exp, ExpF)
->
L0 = [P0*O0|L1]
;
L0 = L1
),
( satisf(I00, I10+1, I20, IR0, N0, N1, N2, R, Exp) ->
I1 is I10+1, generate_3tree(O1, Ps, I00, I1, I20, IR, N0, N1, N2, R, Exp, ExpF)
->
L1 = [P1*O1|L2]
;
L1 = L2
),
( satisf(I00, I10, I20+1, IR0, N0, N1, N2, R, Exp) ->
I2 is I20+1, generate_3tree(O2, Ps, I00, I10, I2, IR, N0, N1, N2, R, Exp, ExpF)
->
L2 = [P2*O2]
;
L2 = []
),
to_disj(L0, OUT).
satisf(I0, I1, I2, IR, N0, N1, N2, R, Exp) :-
\+ \+ ( I0 = N0, I1=N1, I2=N2, IR=R, call(Exp) ).
not_satisf(I0, I1, I2, IR, N0, N1, N2, R, Exp) :-
\+ ( I0 = N0, I1=N1, I2=N2, IR=R, call(Exp) ).
%%%%%%%%%%%%%%%%%%%%%%%%%
%
% use top-down to generate average
%
top_down_with_tabling(V, Size, Domain, Parents0, Vs, Vs2, Lvs, Outs, DIST) :-
reorder_vars(Parents0, Parents, _Map),
check_v(V, avg(Domain,Parents), DIST, Vs, Vs1),
DIST = info(V, Tree, Ev, OVs, Formula, [], []),
get_parents(Parents, PVars, Vs1, Vs2),
length(Parents, N),
Max is (Size-1)*N, % This should be true
avg_borders(0, Size, Max, Borders),
b_hash_new(H0),
avg_trees(0, Max, PVars, Size, F1, 0, Borders, OVs, Ev, H0, H),
generate_avg_code(H, Formula, F),
% Formula0 = [V0=F0*Ev0,V2=F2*Ev2,V1=not(F0+F2)*Ev1],
% Ev = [Ev0,Ev1,Ev2],
get_evidence(V, Tree, Ev, F1, F, Lvs, Outs).
avg_trees(Size, _, _, Size, F0, _, F0, [], [], H, H) :- !.
avg_trees(I0, Max, PVars, Size, [V=O*E|F0], Im, [IM|Borders], [V|OVs], [E|Ev], H0, H) :-
I is I0+1,
avg_tree(PVars, 0, Max, Im, IM, Size, O, H0, HI),
Im1 is IM+1,
avg_trees(I, Max, PVars, Size, F0, Im1, Borders, OVs, Ev, HI, H).
avg_tree( _PVars, P, _, Im, IM, _Size, O, H0, H0) :-
b_hash_lookup(k(P,Im,IM), O=_Exp, H0), !.
avg_tree([], _P, _Max, _Im, _IM, _Size, 1, H, H).
avg_tree([Vals|PVars], P, Max, Im, IM, Size, O, H0, HF) :-
b_hash_insert(H0, k(P,Im,IM), O=Simp, HI),
MaxI is Max-(Size-1),
avg_exp(Vals, PVars, 0, P, MaxI, Size, Im, IM, HI, HF, Exp),
simplify_exp(Exp, Simp).
avg_exp([], _, _, _P, _Max, _Size, _Im, _IM, H, H, 0).
avg_exp([Val|Vals], PVars, I0, P0, Max, Size, Im, IM, HI, HF, O) :-
(Vals = [] -> O=O1 ; O = Val*O1+not(Val)*O2 ),
Im1 is max(0, Im-I0),
IM1 is IM-I0,
( IM1 < 0 -> O1 = 0, H2 = HI; /* we have exceed maximum */
Im1 > Max -> O1 = 0, H2 = HI; /* we cannot make to minimum */
Im1 = 0, IM1 > Max -> O1 = 1, H2 = HI; /* we cannot exceed maximum */
P is P0+1,
avg_tree(PVars, P, Max, Im1, IM1, Size, O1, HI, H2)
),
I is I0+1,
avg_exp(Vals, PVars, I, P0, Max, Size, Im, IM, H2, HF, O2).
generate_avg_code(H, Formula, Formula0) :-
b_hash_to_list(H,L),
sort(L, S),
strip_and_add(S, Formula0, Formula).
strip_and_add([], F, F).
strip_and_add([_-Exp|S], F0, F) :-
strip_and_add(S, [Exp|F0], F).
%%%%%%%%%%%%%%%%%%%%%%%%%
%
% use bottom-up dynamic programming to generate average
%
bup_avg(V, Size, Domain, Parents0, Vs, Vs2, Lvs, Outs, DIST) :-
reorder_vars(Parents0, Parents, _),
check_v(V, avg(Domain,Parents), DIST, Vs, Vs1),
DIST = info(V, Tree, Ev, OVs, Formula, [], []),
get_parents(Parents, PVars, Vs1, Vs2),
length(Parents, N),
Max is (Size-1)*N, % This should be true
ArraySize is Max+1,
functor(Protected, protected, ArraySize),
avg_domains(0, Size, 0, Max, LDomains),
Domains =.. [d|LDomains],
Reach is (Size-1),
generate_sums(PVars, Size, Max, Reach, Protected, Domains, ArraySize, Sums, F0),
% bin_sums(PVars, Sums, F00),
% reverse(F00,F0),
% easier to do recursion on lists
Sums =.. [_|LSums],
generate_avg(0, Size, 0, Max, LSums, OVs, Ev, F1, []),
reverse(F0, RF0),
get_evidence(V, Tree, Ev, F1, F2, Lvs, Outs),
append(RF0, F2, Formula).
%
% use binary approach, like what is standard
%
bin_sums(Vs, Sums, F) :-
vs_to_sums(Vs, Sums0),
bin_sums(Sums0, Sums, F, []).
vs_to_sums([], []).
vs_to_sums([V|Vs], [Sum|Sums0]) :-
Sum =.. [sum|V],
vs_to_sums(Vs, Sums0).
bin_sums([Sum], Sum) --> !.
bin_sums(LSums, Sum) -->
{ halve(LSums, Sums1, Sums2) },
bin_sums(Sums1, Sum1),
bin_sums(Sums2, Sum2),
sum(Sum1, Sum2, Sum).
halve(LSums, Sums1, Sums2) :-
length(LSums, L),
Take is L div 2,
head(Take, LSums, Sums1, Sums2).
head(0, L, [], L) :- !.
head(Take, [H|L], [H|Sums1], Sum2) :-
Take1 is Take-1,
head(Take1, L, Sums1, Sum2).
sum(Sum1, Sum2, Sum) -->
{ functor(Sum1, _, M1),
functor(Sum2, _, M2),
Max is M1+M2-2,
Max1 is Max+1,
Max0 is M2-1,
functor(Sum, sum, Max1),
Sum1 =.. [_|PVals] },
expand_sums(PVals, 0, Max0, Max1, M2, Sum2, Sum).
%
% bottom up step by step
%
%
generate_sums([PVals], Size, Max, _, _Protected, _Domains, _, Sum, []) :- !,
Max is Size-1,
Sum =.. [sum|PVals].
generate_sums([PVals|Parents], Size, Max, Reach, Protected, Domains, ASize, NewSums, F) :-
NewReach is Reach+(Size-1),
generate_sums(Parents, Size, Max0, NewReach, Protected, Domains, ASize, Sums, F0),
Max is Max0+(Size-1),
Max1 is Max+1,
functor(NewSums, sum, Max1),
protect_avg(0, Max0, Protected, Domains, ASize, Reach),
expand_sums(PVals, 0, Max0, Max1, Size, Sums, Protected, NewSums, F, F0).
protect_avg(Max0,Max0,_Protected, _Domains, _ASize, _Reach) :- !.
protect_avg(I0, Max0, Protected, Domains, ASize, Reach) :-
I is I0+1,
Top is I+Reach,
( Top > ASize ;
arg(I, Domains, CD),
arg(Top, Domains, CD)
), !,
arg(I, Protected, yes),
protect_avg(I, Max0, Protected, Domains, ASize, Reach).
protect_avg(I0, Max0, Protected, Domains, ASize, Reach) :-
I is I0+1,
protect_avg(I, Max0, Protected, Domains, ASize, Reach).
%
% outer loop: generate array of sums at level j= Sum[j0...jMax]
%
expand_sums(_Parents, Max, _, Max, _Size, _Sums, _P, _NewSums, F0, F0) :- !.
expand_sums(Parents, I0, Max0, Max, Size, Sums, Prot, NewSums, [O=SUM|F], F0) :-
I is I0+1,
arg(I, Prot, P),
var(P), !,
arg(I, NewSums, O),
sum_all(Parents, 0, I0, Max0, Sums, List),
to_disj(List, SUM),
expand_sums(Parents, I, Max0, Max, Size, Sums, Prot, NewSums, F, F0).
expand_sums(Parents, I0, Max0, Max, Size, Sums, Prot, NewSums, F, F0) :-
I is I0+1,
arg(I, Sums, O),
arg(I, NewSums, O),
expand_sums(Parents, I, Max0, Max, Size, Sums, Prot, NewSums, F, F0).
%
%inner loop: find all parents that contribute to A_ji,
% that is generate Pk*Sum_(j-1)l and k+l st k+l = i
%
sum_all([], _, _, _, _, []).
sum_all([V|Vs], Pos, I, Max0, Sums, [O|List]) :-
J is I-Pos,
J >= 0,
J =< Max0, !,
J1 is J+1,
arg(J1, Sums, S0),
( J < I -> O = V*S0 ; O = S0*V ),
Pos1 is Pos+1,
sum_all(Vs, Pos1, I, Max0, Sums, List).
sum_all([_V|Vs], Pos, I, Max0, Sums, List) :-
Pos1 is Pos+1,
sum_all(Vs, Pos1, I, Max0, Sums, List).
gen_arg(J, Sums, Max, S0) :-
gen_arg(0, Max, J, Sums, S0).
gen_arg(Max, Max, J, Sums, S0) :- !,
I is Max+1,
arg(I, Sums, A),
( Max = J -> S0 = A ; S0 = not(A)).
gen_arg(I0, Max, J, Sums, S) :-
I is I0+1,
arg(I, Sums, A),
( I0 = J -> S = A*S0 ; S = not(A)*S0),
gen_arg(I, Max, J, Sums, S0).
avg_borders(Size, Size, _Max, []) :- !.
avg_borders(I0, Size, Max, [J|Vals]) :-
I is I0+1,
Border is (I*Max)/Size,
J is integer(round(Border)),
avg_borders(I, Size, Max, Vals).
avg_domains(Size, Size, _J, _Max, []).
avg_domains(I0, Size, J0, Max, Vals) :-
I is I0+1,
Border is (I*Max)/Size,
fetch_domain_for_avg(J0, Border, J, I0, Vals, ValsI),
avg_domains(I, Size, J, Max, ValsI).
fetch_domain_for_avg(J, Border, J, _, Vals, Vals) :-
J > Border, !.
fetch_domain_for_avg(J0, Border, J, I0, [I0|LVals], RLVals) :-
J1 is J0+1,
fetch_domain_for_avg(J1, Border, J, I0, LVals, RLVals).
generate_avg(Size, Size, _J, _Max, [], [], [], F, F).
generate_avg(I0, Size, J0, Max, LSums, [O|OVs], [Ev|Evs], [O=Ev*Disj|F], F0) :-
I is I0+1,
Border is (I*Max)/Size,
fetch_for_avg(J0, Border, J, LSums, MySums, RSums),
to_disj(MySums, Disj),
generate_avg(I, Size, J, Max, RSums, OVs, Evs, F, F0).
fetch_for_avg(J, Border, J, RSums, [], RSums) :-
J > Border, !.
fetch_for_avg(J0, Border, J, [S|LSums], [S|MySums], RSums) :-
J1 is J0+1,
fetch_for_avg(J1, Border, J, LSums, MySums, RSums).
to_disj([], 0).
to_disj([V], V).
to_disj([V,V1|Vs], Out) :-
to_disj2([V1|Vs], V, Out).
to_disj2([V], V0, V0+V).
to_disj2([V,V1|Vs], V0, Out) :-
to_disj2([V1|Vs], V0+V, Out).
%
% look for parameters in the rb-tree, or add a new.
% distid is the key
%
check_p(DistId, Map, Parms, ParmVars, Ps, Ps) :-
rb_lookup(DistId-Map, theta(Parms, ParmVars), Ps), !.
check_p(DistId, Map, Parms, ParmVars, Ps, PsF) :-
get_dist_params(DistId, Parms0),
get_dist_all_sizes(DistId, Sizes),
swap_parms(Parms0, Sizes, [0|Map], Parms1),
length(Parms1, L0),
get_dist_domain_size(DistId, Size),
L1 is L0 div Size,
L is L0-L1,
initial_maxes(L1, Multipliers),
copy(L, Multipliers, NextMults, NextMults, Parms1, Parms, ParmVars),
%writeln(t:Size:Parms0:Parms:ParmVars),
rb_insert(Ps, DistId-Map, theta(Parms, ParmVars), PsF).
swap_parms(Parms0, Sizes, Map, Parms1) :-
matrix_new(floats, Sizes, Parms0, T0),
matrix_shuffle(T0,Map,TF),
matrix_to_list(TF, Parms1).
%
% we are using switches by two
%
initial_maxes(0, []) :- !.
initial_maxes(Size, [1.0|Multipliers]) :- !,
Size1 is Size-1,
initial_maxes(Size1, Multipliers).
copy(0, [], [], _, _Parms0, [], []) :- !.
copy(N, [], [], Ms, Parms0, Parms, ParmVars) :-!,
copy(N, Ms, NewMs, NewMs, Parms0, Parms, ParmVars).
copy(N, D.Ds, ND.NDs, New, El.Parms0, NEl.Parms, V.ParmVars) :-
N1 is N-1,
(El == 0.0 ->
NEl = 0,
V = NEl,
ND = D
;El == 1.0 ->
NEl = 1,
V = NEl,
ND = 0.0
;El == 0 ->
NEl = 0,
V = NEl,
ND = D
;El =:= 1 ->
NEl = 1,
V = NEl,
ND = 0.0,
V = NEl
;
NEl is El/D,
ND is D-El,
V = NEl
),
copy(N1, Ds, NDs, New, Parms0, Parms, ParmVars).
unbound_parms([], []).
unbound_parms(_.Parms, _.ParmVars) :-
unbound_parms(Parms, ParmVars).
check_v(V, _, INFO, Vs, Vs) :-
rb_lookup(V, INFO, Vs), !.
check_v(V, DistId, INFO, Vs0, Vs) :-
get_dist_domain_size(DistId, Size),
length(Values, Size),
length(Ev, Size),
INFO = info(V, _Tree, Ev, Values, _Formula, _, _),
rb_insert(Vs0, V, INFO, Vs).
get_parents([], [], Vs, Vs).
get_parents(V.Parents, Values.PVars, Vs0, Vs) :-
clpbn:get_atts(V, [dist(DistId, _)]),
check_v(V, DistId, INFO, Vs0, Vs1),
INFO = info(V, _Parent, _Ev, Values, _, _, _),
get_parents(Parents, PVars, Vs1, Vs).
%
% construct the formula, this is the key...
%
cross_product(Values, Ev, PVars, ParmVars, Formulas) :-
arrangements(PVars, Arranges),
apply_parents_first(Values, Ev, ParmCombos, ParmCombos, Arranges, Formulas, ParmVars).
%
% if we have the parent variables with two values, we get
% [[XP,YP],[XP,YN],[XN,YP],[XN,YN]]
%
arrangements([], [[]]).
arrangements([L1|Ls],O) :-
arrangements(Ls, LN),
expand(L1, LN, O, []).
expand([], _LN) --> [].
expand([H|L1], LN) -->
concatenate_all(H, LN),
expand(L1, LN).
concatenate_all(_H, []) --> [].
concatenate_all(H, L.LN) -->
[[H|L]],
concatenate_all(H, LN).
%
% core of algorithm
%
% Values -> Output Vars for BDD
% Es -> Evidence variables
% Previous -> top of difference list with parameters used so far
% P0 -> end of difference list with parameters used so far
% Pvars -> Parents
% Eqs -> Output Equations
% Pars -> Output Theta Parameters
%
apply_parents_first([Value], [E], Previous, [], PVars, [Value=Disj*E], Parameters) :- !,
apply_last_parent(PVars, Previous, Disj),
flatten(Previous, Parameters).
apply_parents_first([Value|Values], [E|Ev], Previous, P0, PVars, (Value=Disj*E).Formulas, Parameters) :-
P0 = [TheseParents|End],
apply_first_parent(PVars, Disj, TheseParents),
apply_parents_second(Values, Ev, Previous, End, PVars, Formulas, Parameters).
apply_parents_second([Value], [E], Previous, [], PVars, [Value=Disj*E], Parameters) :- !,
apply_last_parent(PVars, Previous, Disj),
flatten(Previous, Parameters).
apply_parents_second([Value|Values], [E|Ev], Previous, P0, PVars, (Value=Disj*E).Formulas, Parameters) :-
apply_middle_parent(PVars, Previous, Disj, TheseParents),
% this must be done after applying middle parents because of the var
% test.
P0 = [TheseParents|End],
apply_parents_second(Values, Ev, Previous, End, PVars, Formulas, Parameters).
apply_first_parent([Parents], Conj, [Theta]) :- !,
parents_to_conj(Parents,Theta,Conj).
apply_first_parent(Parents.PVars, Conj+Disj, Theta.TheseParents) :-
parents_to_conj(Parents,Theta,Conj),
apply_first_parent(PVars, Disj, TheseParents).
apply_middle_parent([Parents], Other, Conj, [ThetaPar]) :- !,
skim_for_theta(Other, Theta, _, ThetaPar),
parents_to_conj(Parents,Theta,Conj).
apply_middle_parent(Parents.PVars, Other, Conj+Disj, ThetaPar.TheseParents) :-
skim_for_theta(Other, Theta, Remaining, ThetaPar),
parents_to_conj(Parents,(Theta),Conj),
apply_middle_parent(PVars, Remaining, Disj, TheseParents).
apply_last_parent([Parents], Other, Conj) :- !,
parents_to_conj(Parents,(Theta),Conj),
skim_for_theta(Other, Theta, _, _).
apply_last_parent(Parents.PVars, Other, Conj+Disj) :-
parents_to_conj(Parents,(Theta),Conj),
skim_for_theta(Other, Theta, Remaining, _),
apply_last_parent(PVars, Remaining, Disj).
%
%
% simplify stuff, removing process that is cancelled by 0s
%
parents_to_conj([], Theta, Theta) :- !.
parents_to_conj(Ps, Theta, Theta*Conj) :-
parents_to_conj2(Ps, Conj).
parents_to_conj2([P],P) :- !.
parents_to_conj2(P.Ps,P*Conj) :-
parents_to_conj2(Ps,Conj).
%
% first case we haven't reached the end of the list so we need
% to create a new parameter variable
%
skim_for_theta([[P|Other]|V], not(P)*New, [Other|_], New) :- var(V), !.
%
% last theta, it is just negation of the other ones
%
skim_for_theta([[P|Other]], not(P), [Other], _) :- !.
%
% recursive case, build-up
%
skim_for_theta([[P|Other]|More], not(P)*Ps, [Other|Left], New ) :-
skim_for_theta(More, Ps, Left, New ).
get_evidence(V, Tree, Ev, F0, F, Leaves, Finals) :-
clpbn:get_atts(V, [evidence(Pos)]), !,
zero_pos(0, Pos, Ev),
insert_output(Leaves, V, Finals, Tree, Outs, SendOut),
get_outs(F0, F, SendOut, Outs).
% hidden deterministic node, can be removed.
get_evidence(V, _Tree, Ev, F0, [], _Leaves, _Finals) :-
clpbn:get_atts(V, [key(K)]),
functor(K, Name, 2),
( Name = 'AVG' ; Name = 'MAX' ; Name = 'MIN' ),
!,
one_list(Ev),
eval_outs(F0).
%% no evidence !!!
get_evidence(V, Tree, _Values, F0, F1, Leaves, Finals) :-
insert_output(Leaves, V, Finals, Tree, Outs, SendOut),
get_outs(F0, F1, SendOut, Outs).
zero_pos(_, _Pos, []).
zero_pos(Pos, Pos, 1.Values) :- !,
I is Pos+1,
zero_pos(I, Pos, Values).
zero_pos(I0, Pos, 0.Values) :-
I is I0+1,
zero_pos(I, Pos, Values).
one_list([]).
one_list(1.Ev) :-
one_list(Ev).
%
% insert a node with the disj of all alternatives, this is only done if node ends up to be in the output
%
insert_output([], _V, [], _Out, _Outs, []).
insert_output(V._Leaves, V0, [Top|_], Top, Outs, [Top = Outs]) :- V == V0, !.
insert_output(_.Leaves, V, _.Finals, Top, Outs, SendOut) :-
insert_output(Leaves, V, Finals, Top, Outs, SendOut).
get_outs([V=F], [V=NF|End], End, V) :- !,
% writeln(f0:F),
simplify_exp(F,NF).
get_outs((V=F).Outs, (V=NF).NOuts, End, (F0 + V)) :-
% writeln(f0:F),
simplify_exp(F,NF),
get_outs(Outs, NOuts, End, F0).
eval_outs([]).
eval_outs((V=F).Outs) :-
simplify_exp(F,NF),
V = NF,
eval_outs(Outs).
run_bdd_solver([[V]], LPs, bdd(Term, _Leaves, Nodes)) :-
build_out_node(Nodes, Node),
findall(Prob, get_prob(Term, Node, V, Prob),TermProbs),
sumlist(TermProbs, Sum),
writeln(TermProbs:Sum),
normalise(TermProbs, Sum, LPs).
build_out_node([_Top], []).
build_out_node([T,T1|Tops], [Top = T*Top]) :-
build_out_node2(T1.Tops, Top).
build_out_node2([Top], Top).
build_out_node2([T,T1|Tops], T*Top) :-
build_out_node2(T1.Tops, Top).
get_prob(Term, Node, V, SP) :-
bind_all(Term, Node, Bindings, V, AllParms, AllParmValues),
% reverse(AllParms, RAllParms),
term_variables(AllParms, NVs),
build_bdd(Bindings, NVs, AllParms, AllParmValues, Bdd),
bdd_to_probability_sum_product(Bdd, SP),
bdd_close(Bdd).
build_bdd(Bindings, NVs, VTheta, Theta, Bdd) :-
bdd_from_list(Bindings, NVs, Bdd),
bdd_size(Bdd, Len),
number_codes(Len,Codes),
atom_codes(Name,Codes),
bdd_print(Bdd, Name),
writeln(length=Len),
VTheta = Theta.
bind_all([], End, End, _V, [], []).
bind_all(info(V, _Tree, Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
V0 == V, !,
set_to_one_zeros(Ev),
bind_formula(Formula, BindsF, BindsI),
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
bind_all(info(_V, _Tree, Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
set_to_ones(Ev),!,
bind_formula(Formula, BindsF, BindsI),
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
% evidence: no need to add any stuff.
bind_all(info(_V, _Tree, _Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
bind_formula(Formula, BindsF, BindsI),
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
bind_formula([], L, L).
bind_formula(B.Formula, B.BsF, Bs0) :-
bind_formula(Formula, BsF, Bs0).
set_to_one_zeros([1|Values]) :-
set_to_zeros(Values).
set_to_one_zeros([0|Values]) :-
set_to_one_zeros(Values).
set_to_zeros([]).
set_to_zeros(0.Values) :-
set_to_zeros(Values).
set_to_ones([]).
set_to_ones(1.Values) :-
set_to_ones(Values).
normalise([], _Sum, []).
normalise(P.TermProbs, Sum, NP.LPs) :-
NP is P/Sum,
normalise(TermProbs, Sum, LPs).
finalize_bdd_solver(_).

View File

@ -1,16 +1,16 @@
/************************************************ /*******************************************************
Belief Propagation in CLP(BN) Belief Propagation and Variable Elimination Interface
**************************************************/ ********************************************************/
:- module(clpbn_bp, :- module(clpbn_bp,
[bp/3, [bp/3,
check_if_bp_done/1, check_if_bp_done/1,
set_horus_flag/2,
init_bp_solver/4, init_bp_solver/4,
run_bp_solver/3, run_bp_solver/3,
call_bp_ground/6,
finalize_bp_solver/1 finalize_bp_solver/1
]). ]).
@ -24,154 +24,143 @@
:- use_module(library('clpbn/display'), :- use_module(library('clpbn/display'),
[clpbn_bind_vals/3]). [clpbn_bind_vals/3]).
:- use_module(library('clpbn/aggregates'), :- use_module(library('clpbn/aggregates'),
[check_for_agg_vars/2]). [check_for_agg_vars/2]).
:- use_module(library(charsio),
[term_to_atom/2]).
:- use_module(library(pfl),
[skolem/2,
get_pfl_parameters/2
]).
:- use_module(library(lists)).
:- use_module(library(atts)). :- use_module(library(atts)).
:- use_module(library(lists)).
:- use_module(library(charsio)).
:- load_foreign_files(['horus'], [], init_predicates). :- use_module(library(bhash)).
:- attribute id/1.
%:- set_horus_flag(inf_alg, ve). :- use_module(horus,
:- set_horus_flag(inf_alg, bn_bp). [create_ground_network/4,
%:- set_horus_flag(inf_alg, fg_bp). set_factors_params/2,
%: -set_horus_flag(inf_alg, cbp). run_ground_solver/3,
set_vars_information/2,
free_ground_network/1
]).
:- set_horus_flag(schedule, seq_fixed).
%:- set_horus_flag(schedule, seq_random).
%:- set_horus_flag(schedule, parallel).
%:- set_horus_flag(schedule, max_residual).
:- set_horus_flag(accuracy, 0.0001). call_bp_ground(QueryVars, QueryKeys, AllKeys, Factors, Evidence, Output) :-
writeln(here:Factors),
b_hash_new(Hash0),
keys_to_ids(AllKeys, 0, Hash0, Hash),
get_factors_type(Factors, Type),
evidence_to_ids(Evidence, Hash, EvidenceIds),
factors_to_ids(Factors, Hash, FactorIds),
writeln(type:Type), writeln(''),
writeln(allKeys:AllKeys), writeln(''),
writeln(factors:Factors), writeln(''),
writeln(factorIds:FactorIds), writeln(''),
writeln(evidence:Evidence), writeln(''),
writeln(evidenceIds:EvidenceIds), writeln(''),
create_ground_network(Type, FactorIds, EvidenceIds, Network),
%get_vars_information(AllKeys, StatesNames),
%set_vars_information(AllKeys, StatesNames),
run_solver(ground(Network,Hash), QueryKeys, Solutions),
writeln(answer:Solutions),
clpbn_bind_vals([QueryVars], Solutions, Output),
free_ground_network(Network).
:- set_horus_flag(max_iter, 1000).
:- set_horus_flag(use_logarithms, false). run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
%:- set_horus_flag(use_logarithms, true). %get_dists_parameters(DistIds, DistsParams),
%set_factors_params(Network, DistsParams),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
writeln(queryKeys:QueryKeys), writeln(''),
writeln(queryIds:QueryIds), writeln(''),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
run_ground_solver(Network, [QueryIds], Solutions).
:- set_horus_flag(order_factor_variables, false).
%:- set_horus_flag(order_factor_variables, true).
keys_to_ids([], _, Hash, Hash).
keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
b_hash_insert(Hash0, Key, I0, HashI),
I is I0+1,
keys_to_ids(AllKeys, I, HashI, Hash).
get_factors_type([f(bayes, _, _, _)|_], bayes) :- ! .
get_factors_type([f(markov, _, _, _)|_], markov) :- ! .
list_of_keys_to_ids([], _, []).
list_of_keys_to_ids([Key|QueryKeys], Hash, [Id|QueryIds]) :-
b_hash_lookup(Key, Id, Hash),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds).
factors_to_ids([], _, []).
factors_to_ids([f(_, DistId, Keys, CPT)|Fs], Hash, [f(Ids, Ranges, CPT, DistId)|NFs]) :-
list_of_keys_to_ids(Keys, Hash, Ids),
get_ranges(Keys, Ranges),
factors_to_ids(Fs, Hash, NFs).
get_ranges([],[]).
get_ranges(K.Ks, Range.Rs) :- !,
skolem(K,Domain),
length(Domain,Range),
get_ranges(Ks, Rs).
evidence_to_ids([], _, []).
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
b_hash_lookup(Key, Id, Hash),
evidence_to_ids(QueryKeys, Hash, QueryIds).
get_vars_information([], []).
get_vars_information(Key.QueryKeys, Domain.StatesNames) :-
pfl:skolem(Key, Domain),
get_vars_information(QueryKeys, StatesNames).
finalize_bp_solver(bp(Network, _)) :-
free_ground_network(Network).
bp([[]],_,_) :- !. bp([[]],_,_) :- !.
bp([QueryVars], AllVars, Output) :- bp([QueryVars], AllVars, Output) :-
init_bp_solver(_, AllVars, _, Network), init_bp_solver(_, AllVars, _, Network),
run_bp_solver([QueryVars], LPs, Network), run_bp_solver([QueryVars], LPs, Network),
finalize_bp_solver(Network), finalize_bp_solver(Network),
clpbn_bind_vals([QueryVars], LPs, Output). clpbn_bind_vals([QueryVars], LPs, Output).
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :- init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
check_for_agg_vars(AllVars0, AllVars), %check_for_agg_vars(AllVars0, AllVars),
writeln('clpbn_vars:'), get_vars_info(AllVars0, VarsInfo, DistIds0),
print_clpbn_vars(AllVars), sort(DistIds0, DistIds),
assign_ids(AllVars, 0), create_ground_network(VarsInfo, BayesNet),
get_vars_info(AllVars, VarsInfo, DistIds0), true.
sort(DistIds0, DistIds),
create_ground_network(VarsInfo, BayesNet).
%get_extra_vars_info(AllVars, ExtraVarsInfo),
%set_extra_vars_info(BayesNet, ExtraVarsInfo).
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :- run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
get_dists_parameters(DistIds, DistsParams), get_dists_parameters(DistIds, DistsParams),
set_bayes_net_params(Network, DistsParams), set_factors_params(Network, DistsParams),
flatten_1_element_sublists(QueryVars, QueryVars1), vars_to_ids(QueryVars, QueryVarsIds),
vars_to_ids(QueryVars1, QueryVarsIds), run_ground_solver(Network, QueryVarsIds, Solutions).
run_other_solvers(Network, QueryVarsIds, Solutions).
finalize_bp_solver(bp(Network, _)) :-
free_bayesian_network(Network).
assign_ids([], _).
assign_ids([V|Vs], Count) :-
put_atts(V, [id(Count)]),
Count1 is Count + 1,
assign_ids(Vs, Count1).
get_vars_info([], [], []).
get_vars_info(V.Vs,
var(VarId,DS,Ev,PIds,DistId).VarsInfo,
DistId.DistIds) :-
clpbn:get_atts(V, [dist(DistId, Parents)]), !,
get_atts(V, [id(VarId)]),
get_dist_domain_size(DistId, DS),
get_evidence(V, Ev),
vars_to_ids(Parents, PIds),
get_vars_info(Vs, VarsInfo, DistIds).
get_evidence(V, Ev) :-
clpbn:get_atts(V, [evidence(Ev)]), !.
get_evidence(_V, -1). % no evidence !!!
vars_to_ids([], []).
vars_to_ids([L|Vars], [LIds|Ids]) :-
is_list(L), !,
vars_to_ids(L, LIds),
vars_to_ids(Vars, Ids).
vars_to_ids([V|Vars], [VarId|Ids]) :-
get_atts(V, [id(VarId)]),
vars_to_ids(Vars, Ids).
get_extra_vars_info([], []).
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
get_atts(V, [id(VarId)]), !,
clpbn:get_atts(V, [key(Key),dist(DistId, _)]),
term_to_atom(Key, Label),
get_dist_domain(DistId, Domain0),
numbers_to_atoms(Domain0, Domain),
get_extra_vars_info(Vs, VarsInfo).
get_extra_vars_info([_|Vs], VarsInfo) :-
get_extra_vars_info(Vs, VarsInfo).
get_dists_parameters([],[]). get_dists_parameters([],[]).
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :- get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
get_dist_params(Id, Params), get_dist_params(Id, Params),
get_dists_parameters(Ids, DistsInfo). get_dists_parameters(Ids, DistsInfo).
numbers_to_atoms([], []).
numbers_to_atoms([Atom|L0], [Atom|L]) :-
atom(Atom), !,
numbers_to_atoms(L0, L).
numbers_to_atoms([Number|L0], [Atom|L]) :-
number_atom(Number, Atom),
numbers_to_atoms(L0, L).
flatten_1_element_sublists([],[]).
flatten_1_element_sublists([[H|[]]|T],[H|R]) :- !,
flatten_1_element_sublists(T,R).
flatten_1_element_sublists([H|T],[H|R]) :-
flatten_1_element_sublists(T,R).
print_clpbn_vars(Var.AllVars) :-
clpbn:get_atts(Var, [key(Key),dist(DistId,Parents)]),
parents_to_keys(Parents, ParentKeys),
writeln(Var:Key:ParentKeys:DistId),
print_clpbn_vars(AllVars).
print_clpbn_vars([]).
parents_to_keys([], []).
parents_to_keys(Var.Parents, Key.Keys) :-
clpbn:get_atts(Var, [key(Key)]),
parents_to_keys(Parents, Keys).

View File

@ -0,0 +1,77 @@
#include <cstdlib>
#include <cassert>
#include <iostream>
#include <fstream>
#include <sstream>
#include "BayesBall.h"
#include "Util.h"
FactorGraph*
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
{
assert (fg_.isFromBayesNetwork());
Scheduling scheduling;
for (unsigned i = 0; i < queryIds.size(); i++) {
assert (dag_.getNode (queryIds[i]));
DAGraphNode* n = dag_.getNode (queryIds[i]);
scheduling.push (ScheduleInfo (n, false, true));
}
while (!scheduling.empty()) {
ScheduleInfo& sch = scheduling.front();
DAGraphNode* n = sch.node;
n->setAsVisited();
if (n->hasEvidence() == false && sch.visitedFromChild) {
if (n->isMarkedOnTop() == false) {
n->markOnTop();
scheduleParents (n, scheduling);
}
if (n->isMarkedOnBottom() == false) {
n->markOnBottom();
scheduleChilds (n, scheduling);
}
}
if (sch.visitedFromParent) {
if (n->hasEvidence() && n->isMarkedOnTop() == false) {
n->markOnTop();
scheduleParents (n, scheduling);
}
if (n->hasEvidence() == false && n->isMarkedOnBottom() == false) {
n->markOnBottom();
scheduleChilds (n, scheduling);
}
}
scheduling.pop();
}
FactorGraph* fg = new FactorGraph();
constructGraph (fg);
return fg;
}
void
BayesBall::constructGraph (FactorGraph* fg) const
{
const FacNodes& facNodes = fg_.facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
const DAGraphNode* n = dag_.getNode (
facNodes[i]->factor().argument (0));
if (n->isMarkedOnTop()) {
fg->addFactor (Factor (facNodes[i]->factor()));
} else if (n->hasEvidence() && n->isVisited()) {
VarIds varIds = { facNodes[i]->factor().argument (0) };
Ranges ranges = { facNodes[i]->factor().range (0) };
Params params (ranges[0], LogAware::noEvidence());
params[n->getEvidence()] = LogAware::withEvidence();
fg->addFactor (Factor (varIds, ranges, params));
}
}
}

View File

@ -0,0 +1,85 @@
#ifndef HORUS_BAYESBALL_H
#define HORUS_BAYESBALL_H
#include <vector>
#include <queue>
#include <list>
#include <map>
#include "FactorGraph.h"
#include "BayesNet.h"
#include "Horus.h"
using namespace std;
struct ScheduleInfo
{
ScheduleInfo (DAGraphNode* n, bool vfp, bool vfc) :
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
DAGraphNode* node;
bool visitedFromParent;
bool visitedFromChild;
};
typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
class BayesBall
{
public:
BayesBall (FactorGraph& fg)
: fg_(fg) , dag_(fg.getStructure())
{
dag_.clear();
}
FactorGraph* getMinimalFactorGraph (const VarIds&);
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
{
BayesBall bb (fg);
return bb.getMinimalFactorGraph (vids);
}
private:
void constructGraph (FactorGraph* fg) const;
void scheduleParents (const DAGraphNode* n, Scheduling& sch) const;
void scheduleChilds (const DAGraphNode* n, Scheduling& sch) const;
FactorGraph& fg_;
DAGraph& dag_;
};
inline void
BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
{
const vector<DAGraphNode*>& ps = n->parents();
for (vector<DAGraphNode*>::const_iterator it = ps.begin();
it != ps.end(); it++) {
sch.push (ScheduleInfo (*it, false, true));
}
}
inline void
BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const
{
const vector<DAGraphNode*>& cs = n->childs();
for (vector<DAGraphNode*>::const_iterator it = cs.begin();
it != cs.end(); it++) {
sch.push (ScheduleInfo (*it, true, false));
}
}
#endif // HORUS_BAYESBALL_H

View File

@ -5,381 +5,57 @@
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include "xmlParser/xmlParser.h"
#include "BayesNet.h" #include "BayesNet.h"
#include "Util.h" #include "Util.h"
void
BayesNet::~BayesNet (void) DAGraph::addNode (DAGraphNode* n)
{ {
for (unsigned i = 0; i < nodes_.size(); i++) { assert (Util::contains (varMap_, n->varId()) == false);
delete nodes_[i]; nodes_.push_back (n);
} varMap_[n->varId()] = n;
} }
void void
BayesNet::readFromBifFormat (const char* fileName) DAGraph::addEdge (VarId vid1, VarId vid2)
{ {
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF"); unordered_map<VarId, DAGraphNode*>::iterator it1;
// only the first network is parsed, others are ignored unordered_map<VarId, DAGraphNode*>::iterator it2;
XMLNode xNode = xMainNode.getChildNode ("NETWORK"); it1 = varMap_.find (vid1);
unsigned nVars = xNode.nChildNode ("VARIABLE"); it2 = varMap_.find (vid2);
for (unsigned i = 0; i < nVars; i++) { assert (it1 != varMap_.end());
XMLNode var = xNode.getChildNode ("VARIABLE", i); assert (it2 != varMap_.end());
if (string (var.getAttribute ("TYPE")) != "nature") { it1->second->addChild (it2->second);
cerr << "error: only \"nature\" variables are supported" << endl; it2->second->addParent (it1->second);
abort();
}
States states;
string label = var.getChildNode("NAME").getText();
unsigned nrStates = var.nChildNode ("OUTCOME");
for (unsigned j = 0; j < nrStates; j++) {
if (var.getChildNode("OUTCOME", j).getText() == 0) {
stringstream ss;
ss << j + 1;
states.push_back (ss.str());
} else {
states.push_back (var.getChildNode("OUTCOME", j).getText());
}
}
addNode (label, states);
}
unsigned nDefs = xNode.nChildNode ("DEFINITION");
if (nVars != nDefs) {
cerr << "error: different number of variables and definitions" << endl;
abort();
}
for (unsigned i = 0; i < nDefs; i++) {
XMLNode def = xNode.getChildNode ("DEFINITION", i);
string label = def.getChildNode("FOR").getText();
BayesNode* node = getBayesNode (label);
if (!node) {
cerr << "error: unknow variable `" << label << "'" << endl;
abort();
}
BnNodeSet parents;
unsigned nParams = node->nrStates();
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
string parentLabel = def.getChildNode("GIVEN", j).getText();
BayesNode* parentNode = getBayesNode (parentLabel);
if (!parentNode) {
cerr << "error: unknow variable `" << parentLabel << "'" << endl;
abort();
}
nParams *= parentNode->nrStates();
parents.push_back (parentNode);
}
node->setParents (parents);
unsigned count = 0;
Params params (nParams);
stringstream s (def.getChildNode("TABLE").getText());
while (!s.eof() && count < nParams) {
s >> params[count];
count ++;
}
if (count != nParams) {
cerr << "error: invalid number of parameters " ;
cerr << "for variable `" << label << "'" << endl;
abort();
}
params = reorderParameters (params, node->nrStates());
Distribution* dist = new Distribution (params);
node->setDistribution (dist);
addDistribution (dist);
}
setIndexes();
if (Globals::logDomain) {
distributionsToLogs();
}
} }
BayesNode* const DAGraphNode*
BayesNet::addNode (string label, const States& states) DAGraph::getNode (VarId vid) const
{ {
VarId vid = nodes_.size(); unordered_map<VarId, DAGraphNode*>::const_iterator it;
varMap_.insert (make_pair (vid, nodes_.size())); it = varMap_.find (vid);
GraphicalModel::addVariableInformation (vid, label, states); return it != varMap_.end() ? it->second : 0;
BayesNode* node = new BayesNode (VarNode (vid, states.size()));
nodes_.push_back (node);
return node;
} }
BayesNode* DAGraphNode*
BayesNet::addNode (VarId vid, unsigned dsize, int evidence, Distribution* dist) DAGraph::getNode (VarId vid)
{ {
varMap_.insert (make_pair (vid, nodes_.size())); unordered_map<VarId, DAGraphNode*>::const_iterator it;
nodes_.push_back (new BayesNode (vid, dsize, evidence, dist)); it = varMap_.find (vid);
return nodes_.back(); return it != varMap_.end() ? it->second : 0;
}
BayesNode*
BayesNet::getBayesNode (VarId vid) const
{
IndexMap::const_iterator it = varMap_.find (vid);
if (it == varMap_.end()) {
return 0;
} else {
return nodes_[it->second];
}
}
BayesNode*
BayesNet::getBayesNode (string label) const
{
BayesNode* node = 0;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (nodes_[i]->label() == label) {
node = nodes_[i];
break;
}
}
return node;
}
VarNode*
BayesNet::getVariableNode (VarId vid) const
{
BayesNode* node = getBayesNode (vid);
assert (node);
return node;
}
VarNodes
BayesNet::getVariableNodes (void) const
{
VarNodes vars;
for (unsigned i = 0; i < nodes_.size(); i++) {
vars.push_back (nodes_[i]);
}
return vars;
} }
void void
BayesNet::addDistribution (Distribution* dist) DAGraph::setIndexes (void)
{
dists_.push_back (dist);
}
Distribution*
BayesNet::getDistribution (unsigned distId) const
{
Distribution* dist = 0;
for (unsigned i = 0; i < dists_.size(); i++) {
if (dists_[i]->id == (int) distId) {
dist = dists_[i];
break;
}
}
return dist;
}
const BnNodeSet&
BayesNet::getBayesNodes (void) const
{
return nodes_;
}
unsigned
BayesNet::nrNodes (void) const
{
return nodes_.size();
}
BnNodeSet
BayesNet::getRootNodes (void) const
{
BnNodeSet roots;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (nodes_[i]->isRoot()) {
roots.push_back (nodes_[i]);
}
}
return roots;
}
BnNodeSet
BayesNet::getLeafNodes (void) const
{
BnNodeSet leafs;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (nodes_[i]->isLeaf()) {
leafs.push_back (nodes_[i]);
}
}
return leafs;
}
BayesNet*
BayesNet::getMinimalRequesiteNetwork (VarId vid) const
{
return getMinimalRequesiteNetwork (VarIds() = {vid});
}
BayesNet*
BayesNet::getMinimalRequesiteNetwork (const VarIds& queryVarIds) const
{
BnNodeSet queryVars;
Scheduling scheduling;
for (unsigned i = 0; i < queryVarIds.size(); i++) {
BayesNode* n = getBayesNode (queryVarIds[i]);
assert (n);
queryVars.push_back (n);
scheduling.push (ScheduleInfo (n, false, true));
}
vector<StateInfo*> states (nodes_.size(), 0);
while (!scheduling.empty()) {
ScheduleInfo& sch = scheduling.front();
StateInfo* state = states[sch.node->getIndex()];
if (!state) {
state = new StateInfo();
states[sch.node->getIndex()] = state;
} else {
state->visited = true;
}
if (!sch.node->hasEvidence() && sch.visitedFromChild) {
if (!state->markedOnTop) {
state->markedOnTop = true;
scheduleParents (sch.node, scheduling);
}
if (!state->markedOnBottom) {
state->markedOnBottom = true;
scheduleChilds (sch.node, scheduling);
}
}
if (sch.visitedFromParent) {
if (sch.node->hasEvidence() && !state->markedOnTop) {
state->markedOnTop = true;
scheduleParents (sch.node, scheduling);
}
if (!sch.node->hasEvidence() && !state->markedOnBottom) {
state->markedOnBottom = true;
scheduleChilds (sch.node, scheduling);
}
}
scheduling.pop();
}
/*
cout << "\t\ttop\tbottom" << endl;
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
cout << "----------------------------------------------------------" ;
cout << endl;
for (unsigned i = 0; i < states.size(); i++) {
cout << nodes_[i]->label() << ":\t\t" ;
if (states[i]) {
states[i]->markedOnTop ? cout << "yes\t" : cout << "no\t" ;
states[i]->markedOnBottom ? cout << "yes\t" : cout << "no\t" ;
states[i]->visited ? cout << "yes\t" : cout << "no\t" ;
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
cout << endl;
} else {
cout << "no\tno\tno\t" ;
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
cout << endl;
}
}
cout << endl;
*/
BayesNet* bn = new BayesNet();
constructGraph (bn, states);
for (unsigned i = 0; i < nodes_.size(); i++) {
delete states[i];
}
return bn;
}
void
BayesNet::constructGraph (BayesNet* bn,
const vector<StateInfo*>& states) const
{
BnNodeSet mrnNodes;
vector<VarIds> parents;
for (unsigned i = 0; i < nodes_.size(); i++) {
bool isRequired = false;
if (states[i]) {
isRequired = (nodes_[i]->hasEvidence() && states[i]->visited)
||
states[i]->markedOnTop;
}
if (isRequired) {
parents.push_back (VarIds());
if (states[i]->markedOnTop) {
const BnNodeSet& ps = nodes_[i]->getParents();
for (unsigned j = 0; j < ps.size(); j++) {
parents.back().push_back (ps[j]->varId());
}
}
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
BayesNode* mrnNode = bn->addNode (nodes_[i]->varId(),
nodes_[i]->nrStates(),
nodes_[i]->getEvidence(),
nodes_[i]->getDistribution());
mrnNodes.push_back (mrnNode);
}
}
for (unsigned i = 0; i < mrnNodes.size(); i++) {
BnNodeSet ps;
for (unsigned j = 0; j < parents[i].size(); j++) {
assert (bn->getBayesNode (parents[i][j]) != 0);
ps.push_back (bn->getBayesNode (parents[i][j]));
}
mrnNodes[i]->setParents (ps);
}
bn->setIndexes();
}
bool
BayesNet::isPolyTree (void) const
{
return !containsUndirectedCycle();
}
void
BayesNet::setIndexes (void)
{ {
for (unsigned i = 0; i < nodes_.size(); i++) { for (unsigned i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i); nodes_[i]->setIndex (i);
@ -389,233 +65,43 @@ BayesNet::setIndexes (void)
void void
BayesNet::distributionsToLogs (void) DAGraph::clear (void)
{
for (unsigned i = 0; i < dists_.size(); i++) {
Util::toLog (dists_[i]->params);
}
}
void
BayesNet::freeDistributions (void)
{
for (unsigned i = 0; i < dists_.size(); i++) {
delete dists_[i];
}
}
void
BayesNet::printGraphicalModel (void) const
{ {
for (unsigned i = 0; i < nodes_.size(); i++) { for (unsigned i = 0; i < nodes_.size(); i++) {
cout << *nodes_[i]; nodes_[i]->clear();
} }
} }
void void
BayesNet::exportToGraphViz (const char* fileName, DAGraph::exportToGraphViz (const char* fileName)
bool showNeighborless,
const VarIds& highlightVarIds) const
{ {
ofstream out (fileName); ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "error: cannot open file to write at " ; cerr << "error: cannot open file to write at " ;
cerr << "BayesNet::exportToDotFile()" << endl; cerr << "DAGraph::exportToDotFile()" << endl;
abort(); abort();
} }
out << "digraph {" << endl; out << "digraph {" << endl;
out << "ranksep=1" << endl; out << "ranksep=1" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) { for (unsigned i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->hasNeighbors()) { out << nodes_[i]->varId() ;
out << nodes_[i]->varId() ; out << " [" ;
if (nodes_[i]->hasEvidence()) { out << "label=\"" << nodes_[i]->label() << "\"" ;
out << " [" ; if (nodes_[i]->hasEvidence()) {
out << "label=\"" << nodes_[i]->label() << "\"," ; out << ",style=filled, fillcolor=yellow" ;
out << "style=filled, fillcolor=yellow" ;
out << "]" ;
} else {
out << " [" ;
out << "label=\"" << nodes_[i]->label() << "\"" ;
out << "]" ;
}
out << endl;
} }
out << "]" << endl;
} }
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
BayesNode* node = getBayesNode (highlightVarIds[i]);
if (node) {
out << node->varId() ;
out << " [shape=box3d]" << endl;
} else {
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
abort();
}
}
for (unsigned i = 0; i < nodes_.size(); i++) { for (unsigned i = 0; i < nodes_.size(); i++) {
const BnNodeSet& childs = nodes_[i]->getChilds(); const vector<DAGraphNode*>& childs = nodes_[i]->childs();
for (unsigned j = 0; j < childs.size(); j++) { for (unsigned j = 0; j < childs.size(); j++) {
out << nodes_[i]->varId() << " -> " << childs[j]->varId() << " [style=bold]" << endl ; out << nodes_[i]->varId() << " -> " << childs[j]->varId();
out << " [style=bold]" << endl ;
} }
} }
out << "}" << endl; out << "}" << endl;
out.close(); out.close();
} }
void
BayesNet::exportToBifFormat (const char* fileName) const
{
ofstream out (fileName);
if(!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "BayesNet::exportToBifFile()" << endl;
abort();
}
out << "<?xml version=\"1.0\" encoding=\"US-ASCII\"?>" << endl;
out << "<BIF VERSION=\"0.3\">" << endl;
out << "<NETWORK>" << endl;
out << "<NAME>" << fileName << "</NAME>" << endl << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
out << "<VARIABLE TYPE=\"nature\">" << endl;
out << "\t<NAME>" << nodes_[i]->label() << "</NAME>" << endl;
const States& states = nodes_[i]->states();
for (unsigned j = 0; j < states.size(); j++) {
out << "\t<OUTCOME>" << states[j] << "</OUTCOME>" << endl;
}
out << "</VARIABLE>" << endl << endl;
}
for (unsigned i = 0; i < nodes_.size(); i++) {
out << "<DEFINITION>" << endl;
out << "\t<FOR>" << nodes_[i]->label() << "</FOR>" << endl;
const BnNodeSet& parents = nodes_[i]->getParents();
for (unsigned j = 0; j < parents.size(); j++) {
out << "\t<GIVEN>" << parents[j]->label();
out << "</GIVEN>" << endl;
}
Params params = revertParameterReorder (nodes_[i]->getParameters(),
nodes_[i]->nrStates());
out << "\t<TABLE>" ;
for (unsigned j = 0; j < params.size(); j++) {
out << " " << params[j];
}
out << " </TABLE>" << endl;
out << "</DEFINITION>" << endl << endl;
}
out << "</NETWORK>" << endl;
out << "</BIF>" << endl << endl;
out.close();
}
bool
BayesNet::containsUndirectedCycle (void) const
{
vector<bool> visited (nodes_.size(), false);
for (unsigned i = 0; i < nodes_.size(); i++) {
int v = nodes_[i]->getIndex();
if (!visited[v]) {
if (containsUndirectedCycle (v, -1, visited)) {
return true;
}
}
}
return false;
}
bool
BayesNet::containsUndirectedCycle (int v, int p, vector<bool>& visited) const
{
visited[v] = true;
vector<int> adjacencies = getAdjacentNodes (v);
for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i];
if (!visited[w]) {
if (containsUndirectedCycle (w, v, visited)) {
return true;
}
}
else if (visited[w] && w != p) {
return true;
}
}
return false; // no cycle detected in this component
}
vector<int>
BayesNet::getAdjacentNodes (int v) const
{
vector<int> adjacencies;
const BnNodeSet& parents = nodes_[v]->getParents();
const BnNodeSet& childs = nodes_[v]->getChilds();
for (unsigned i = 0; i < parents.size(); i++) {
adjacencies.push_back (parents[i]->getIndex());
}
for (unsigned i = 0; i < childs.size(); i++) {
adjacencies.push_back (childs[i]->getIndex());
}
return adjacencies;
}
Params
BayesNet::reorderParameters (const Params& params, unsigned dsize) const
{
// the interchange format for bayesian networks keeps the probabilities
// in the following order:
// p(a1|b1,c1) p(a2|b1,c1) p(a1|b1,c2) p(a2|b1,c2) p(a1|b2,c1) p(a2|b2,c1)
// p(a1|b2,c2) p(a2|b2,c2).
//
// however, in clpbn we keep the probabilities in this order:
// p(a1|b1,c1) p(a1|b1,c2) p(a1|b2,c1) p(a1|b2,c2) p(a2|b1,c1) p(a2|b1,c2)
// p(a2|b2,c1) p(a2|b2,c2).
unsigned count = 0;
unsigned rowSize = params.size() / dsize;
Params reordered;
while (reordered.size() < params.size()) {
unsigned idx = count;
for (unsigned i = 0; i < rowSize; i++) {
reordered.push_back (params[idx]);
idx += dsize ;
}
count++;
}
return reordered;
}
Params
BayesNet::revertParameterReorder (const Params& params, unsigned dsize) const
{
unsigned count = 0;
unsigned rowSize = params.size() / dsize;
Params reordered;
while (reordered.size() < params.size()) {
unsigned idx = count;
for (unsigned i = 0; i < dsize; i++) {
reordered.push_back (params[idx]);
idx += rowSize;
}
count ++;
}
return reordered;
}

View File

@ -6,118 +6,83 @@
#include <list> #include <list>
#include <map> #include <map>
#include "GraphicalModel.h" #include "Var.h"
#include "BayesNode.h"
#include "Horus.h" #include "Horus.h"
using namespace std; using namespace std;
class Distribution;
struct ScheduleInfo class Var;
{
ScheduleInfo (BayesNode* n, bool vfp, bool vfc)
{
node = n;
visitedFromParent = vfp;
visitedFromChild = vfc;
}
BayesNode* node;
bool visitedFromParent;
bool visitedFromChild;
};
class DAGraphNode : public Var
struct StateInfo
{
StateInfo (void)
{
visited = true;
markedOnTop = false;
markedOnBottom = false;
}
bool visited;
bool markedOnTop;
bool markedOnBottom;
};
typedef vector<Distribution*> DistSet;
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
class BayesNet : public GraphicalModel
{ {
public: public:
BayesNet (void) {}; DAGraphNode (Var* v) : Var (v) , visited_(false),
~BayesNet (void); markedOnTop_(false), markedOnBottom_(false) { }
void readFromBifFormat (const char*); const vector<DAGraphNode*>& childs (void) const { return childs_; }
BayesNode* addNode (string, const States&);
// BayesNode* addNode (VarId, unsigned, int, BnNodeSet&, Distribution*); vector<DAGraphNode*>& childs (void) { return childs_; }
BayesNode* addNode (VarId, unsigned, int, Distribution*);
BayesNode* getBayesNode (VarId) const; const vector<DAGraphNode*>& parents (void) const { return parents_; }
BayesNode* getBayesNode (string) const;
VarNode* getVariableNode (VarId) const; vector<DAGraphNode*>& parents (void) { return parents_; }
VarNodes getVariableNodes (void) const;
void addDistribution (Distribution*); void addParent (DAGraphNode* p) { parents_.push_back (p); }
Distribution* getDistribution (unsigned) const;
const BnNodeSet& getBayesNodes (void) const; void addChild (DAGraphNode* c) { childs_.push_back (c); }
unsigned nrNodes (void) const;
BnNodeSet getRootNodes (void) const; bool isVisited (void) const { return visited_; }
BnNodeSet getLeafNodes (void) const;
BayesNet* getMinimalRequesiteNetwork (VarId) const; void setAsVisited (void) { visited_ = true; }
BayesNet* getMinimalRequesiteNetwork (const VarIds&) const;
void constructGraph ( bool isMarkedOnTop (void) const { return markedOnTop_; }
BayesNet*, const vector<StateInfo*>&) const;
bool isPolyTree (void) const; void markOnTop (void) { markedOnTop_ = true; }
void setIndexes (void);
void distributionsToLogs (void); bool isMarkedOnBottom (void) const { return markedOnBottom_; }
void freeDistributions (void);
void printGraphicalModel (void) const; void markOnBottom (void) { markedOnBottom_ = true; }
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const; void clear (void) { visited_ = markedOnTop_ = markedOnBottom_ = false; }
void exportToBifFormat (const char*) const;
private: private:
DISALLOW_COPY_AND_ASSIGN (BayesNet); bool visited_;
bool markedOnTop_;
bool markedOnBottom_;
bool containsUndirectedCycle (void) const; vector<DAGraphNode*> childs_;
bool containsUndirectedCycle (int, int, vector<bool>&)const; vector<DAGraphNode*> parents_;
vector<int> getAdjacentNodes (int) const;
Params reorderParameters (const Params&, unsigned) const;
Params revertParameterReorder (const Params&, unsigned) const;
void scheduleParents (const BayesNode*, Scheduling&) const;
void scheduleChilds (const BayesNode*, Scheduling&) const;
BnNodeSet nodes_;
DistSet dists_;
typedef unordered_map<unsigned, unsigned> IndexMap;
IndexMap varMap_;
}; };
class DAGraph
inline void
BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
{ {
const BnNodeSet& ps = n->getParents(); public:
for (BnNodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) { DAGraph (void) { }
sch.push (ScheduleInfo (*it, false, true));
}
}
void addNode (DAGraphNode* n);
void addEdge (VarId vid1, VarId vid2);
inline void const DAGraphNode* getNode (VarId vid) const;
BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
{ DAGraphNode* getNode (VarId vid);
const BnNodeSet& cs = n->getChilds();
for (BnNodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) { bool empty (void) const { return nodes_.empty(); }
sch.push (ScheduleInfo (*it, true, false));
} void setIndexes (void);
}
void clear (void);
void exportToGraphViz (const char*);
private:
vector<DAGraphNode*> nodes_;
unordered_map<VarId, DAGraphNode*> varMap_;
};
#endif // HORUS_BAYESNET_H #endif // HORUS_BAYESNET_H

View File

@ -1,291 +0,0 @@
#include <cstdlib>
#include <cassert>
#include <iomanip>
#include <iostream>
#include <sstream>
#include "BayesNode.h"
BayesNode::BayesNode (VarId vid,
unsigned dsize,
int evidence,
Distribution* dist)
: VarNode (vid, dsize, evidence)
{
dist_ = dist;
}
BayesNode::BayesNode (VarId vid,
unsigned dsize,
int evidence,
const BnNodeSet& parents,
Distribution* dist)
: VarNode (vid, dsize, evidence)
{
parents_ = parents;
dist_ = dist;
for (unsigned int i = 0; i < parents.size(); i++) {
parents[i]->addChild (this);
}
}
void
BayesNode::setParents (const BnNodeSet& parents)
{
parents_ = parents;
for (unsigned int i = 0; i < parents.size(); i++) {
parents[i]->addChild (this);
}
}
void
BayesNode::addChild (BayesNode* node)
{
childs_.push_back (node);
}
void
BayesNode::setDistribution (Distribution* dist)
{
assert (dist);
dist_ = dist;
}
Distribution*
BayesNode::getDistribution (void)
{
return dist_;
}
const Params&
BayesNode::getParameters (void)
{
return dist_->params;
}
Params
BayesNode::getRow (int rowIndex) const
{
int rowSize = getRowSize();
int offset = rowSize * rowIndex;
Params row (rowSize);
for (int i = 0; i < rowSize; i++) {
row[i] = dist_->params[offset + i] ;
}
return row;
}
bool
BayesNode::isRoot (void)
{
return getParents().empty();
}
bool
BayesNode::isLeaf (void)
{
return getChilds().empty();
}
bool
BayesNode::hasNeighbors (void) const
{
return childs_.size() != 0 || parents_.size() != 0;
}
int
BayesNode::getCptSize (void)
{
return dist_->params.size();
}
int
BayesNode::getIndexOfParent (const BayesNode* parent) const
{
for (unsigned int i = 0; i < parents_.size(); i++) {
if (parents_[i] == parent) {
return i;
}
}
return -1;
}
string
BayesNode::cptEntryToString (
int row,
const vector<unsigned>& stateConf) const
{
stringstream ss;
ss << "p(" ;
ss << states()[row];
if (parents_.size() > 0) {
ss << "|" ;
for (unsigned int i = 0; i < stateConf.size(); i++) {
if (i != 0) {
ss << ",";
}
ss << parents_[i]->states()[stateConf[i]];
}
}
ss << ")" ;
return ss.str();
}
vector<string>
BayesNode::getDomainHeaders (void) const
{
unsigned nParents = parents_.size();
unsigned rowSize = getRowSize();
unsigned nReps = 1;
vector<string> headers (rowSize);
for (int i = nParents - 1; i >= 0; i--) {
States states = parents_[i]->states();
unsigned index = 0;
while (index < rowSize) {
for (unsigned j = 0; j < parents_[i]->nrStates(); j++) {
for (unsigned r = 0; r < nReps; r++) {
if (headers[index] != "") {
headers[index] = states[j] + "," + headers[index];
} else {
headers[index] = states[j];
}
index++;
}
}
}
nReps *= parents_[i]->nrStates();
}
return headers;
}
ostream&
operator << (ostream& o, const BayesNode& node)
{
o << "variable " << node.getIndex() << endl;
o << "Var Id: " << node.varId() << endl;
o << "Label: " << node.label() << endl;
o << "Evidence: " ;
if (node.hasEvidence()) {
o << node.getEvidence();
}
else {
o << "no" ;
}
o << endl;
o << "Parents: " ;
const BnNodeSet& parents = node.getParents();
if (parents.size() != 0) {
for (unsigned int i = 0; i < parents.size() - 1; i++) {
o << parents[i]->label() << ", " ;
}
o << parents[parents.size() - 1]->label();
}
o << endl;
o << "Childs: " ;
const BnNodeSet& childs = node.getChilds();
if (childs.size() != 0) {
for (unsigned int i = 0; i < childs.size() - 1; i++) {
o << childs[i]->label() << ", " ;
}
o << childs[childs.size() - 1]->label();
}
o << endl;
o << "Domain: " ;
States states = node.states();
for (unsigned int i = 0; i < states.size() - 1; i++) {
o << states[i] << ", " ;
}
if (states.size() != 0) {
o << states[states.size() - 1];
}
o << endl;
// min width of first column
const unsigned int MIN_DOMAIN_WIDTH = 4;
// min width of following columns
const unsigned int MIN_COMBO_WIDTH = 12;
unsigned int domainWidth = states[0].length();
for (unsigned int i = 1; i < states.size(); i++) {
if (states[i].length() > domainWidth) {
domainWidth = states[i].length();
}
}
domainWidth = (domainWidth < MIN_DOMAIN_WIDTH)
? MIN_DOMAIN_WIDTH
: domainWidth;
o << left << setw (domainWidth) << "cpt" << right;
vector<int> widths;
int lineWidth = domainWidth;
vector<string> headers = node.getDomainHeaders();
if (!headers.empty()) {
for (unsigned int i = 0; i < headers.size(); i++) {
unsigned int len = headers[i].length();
int w = (len < MIN_COMBO_WIDTH) ? MIN_COMBO_WIDTH : len;
widths.push_back (w);
o << setw (w) << headers[i];
lineWidth += w;
}
o << endl;
} else {
cout << endl;
widths.push_back (domainWidth);
lineWidth += MIN_COMBO_WIDTH;
}
for (int i = 0; i < lineWidth; i++) {
o << "-" ;
}
o << endl;
for (unsigned int i = 0; i < states.size(); i++) {
Params row = node.getRow (i);
o << left << setw (domainWidth) << states[i] << right;
for (unsigned j = 0; j < node.getRowSize(); j++) {
o << setw (widths[j]) << row[j];
}
o << endl;
}
o << endl;
return o;
}

View File

@ -1,61 +0,0 @@
#ifndef HORUS_BAYESNODE_H
#define HORUS_BAYESNODE_H
#include <vector>
#include "VarNode.h"
#include "Distribution.h"
#include "Horus.h"
using namespace std;
class BayesNode : public VarNode
{
public:
BayesNode (const VarNode& v) : VarNode (v) {}
BayesNode (VarId, unsigned, int, Distribution*);
BayesNode (VarId, unsigned, int, const BnNodeSet&, Distribution*);
void setParents (const BnNodeSet&);
void addChild (BayesNode*);
void setDistribution (Distribution*);
Distribution* getDistribution (void);
const Params& getParameters (void);
Params getRow (int) const;
bool isRoot (void);
bool isLeaf (void);
bool hasNeighbors (void) const;
int getCptSize (void);
int getIndexOfParent (const BayesNode*) const;
string cptEntryToString (int, const vector<unsigned>&) const;
const BnNodeSet& getParents (void) const { return parents_; }
const BnNodeSet& getChilds (void) const { return childs_; }
unsigned getRowSize (void) const
{
return dist_->params.size() / nrStates();
}
double getProbability (int row, unsigned col)
{
int idx = (row * getRowSize()) + col;
return dist_->params[idx];
}
private:
DISALLOW_COPY_AND_ASSIGN (BayesNode);
States getDomainHeaders (void) const;
friend ostream& operator << (ostream&, const BayesNode&);
BnNodeSet parents_;
BnNodeSet childs_;
Distribution* dist_;
};
ostream& operator << (ostream&, const BayesNode&);
#endif // HORUS_BAYESNODE_H

View File

@ -1,803 +0,0 @@
#include <cstdlib>
#include <limits>
#include <time.h>
#include <algorithm>
#include <iostream>
#include <sstream>
#include <iomanip>
#include "BnBpSolver.h"
#include "Indexer.h"
BnBpSolver::BnBpSolver (const BayesNet& bn) : Solver (&bn)
{
bayesNet_ = &bn;
}
BnBpSolver::~BnBpSolver (void)
{
for (unsigned i = 0; i < nodesI_.size(); i++) {
delete nodesI_[i];
}
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
}
void
BnBpSolver::runSolver (void)
{
clock_t start;
if (COLLECT_STATISTICS) {
start = clock();
}
initializeSolver();
runLoopySolver();
if (DL >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Belief propagation converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = bayesNet_->nrNodes();
if (COLLECT_STATISTICS) {
unsigned nIters = 0;
bool loopy = bayesNet_->isPolyTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
stringstream ss;
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
bayesNet_->exportToGraphViz (ss.str().c_str());
}
}
Params
BnBpSolver::getPosterioriOf (VarId vid)
{
BayesNode* node = bayesNet_->getBayesNode (vid);
assert (node);
return nodesI_[node->getIndex()]->getBeliefs();
}
Params
BnBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
{
if (DL >= 2) {
cout << "calculating joint distribution on: " ;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
VarNode* var = bayesNet_->getBayesNode (jointVarIds[i]);
cout << var->label() << " " ;
}
cout << endl;
}
return getJointByConditioning (jointVarIds);
}
void
BnBpSolver::initializeSolver (void)
{
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
for (unsigned i = 0; i < nodesI_.size(); i++) {
delete nodesI_[i];
}
nodesI_.clear();
nodesI_.reserve (nodes.size());
links_.clear();
sortedOrder_.clear();
linkMap_.clear();
for (unsigned i = 0; i < nodes.size(); i++) {
nodesI_.push_back (new BpNodeInfo (nodes[i]));
}
BnNodeSet roots = bayesNet_->getRootNodes();
for (unsigned i = 0; i < roots.size(); i++) {
const Params& params = roots[i]->getParameters();
Params& piVals = ninf(roots[i])->getPiValues();
for (unsigned ri = 0; ri < roots[i]->nrStates(); ri++) {
piVals[ri] = params[ri];
}
}
for (unsigned i = 0; i < nodes.size(); i++) {
const BnNodeSet& parents = nodes[i]->getParents();
for (unsigned j = 0; j < parents.size(); j++) {
BpLink* newLink = new BpLink (
parents[j], nodes[i], LinkOrientation::DOWN);
links_.push_back (newLink);
ninf(nodes[i])->addIncomingParentLink (newLink);
ninf(parents[j])->addOutcomingChildLink (newLink);
}
const BnNodeSet& childs = nodes[i]->getChilds();
for (unsigned j = 0; j < childs.size(); j++) {
BpLink* newLink = new BpLink (
childs[j], nodes[i], LinkOrientation::UP);
links_.push_back (newLink);
ninf(nodes[i])->addIncomingChildLink (newLink);
ninf(childs[j])->addOutcomingParentLink (newLink);
}
}
for (unsigned i = 0; i < nodes.size(); i++) {
if (nodes[i]->hasEvidence()) {
Params& piVals = ninf(nodes[i])->getPiValues();
Params& ldVals = ninf(nodes[i])->getLambdaValues();
for (unsigned xi = 0; xi < nodes[i]->nrStates(); xi++) {
piVals[xi] = Util::noEvidence();
ldVals[xi] = Util::noEvidence();
}
piVals[nodes[i]->getEvidence()] = Util::withEvidence();
ldVals[nodes[i]->getEvidence()] = Util::withEvidence();
}
}
}
void
BnBpSolver::runLoopySolver()
{
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_++;
if (DL >= 2) {
cout << "****************************************" ;
cout << "****************************************" ;
cout << endl;
cout << " Iteration " << nIters_ << endl;
cout << "****************************************" ;
cout << "****************************************" ;
cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
updateValues (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
for (unsigned i = 0; i < links_.size(); i++) {
updateMessage (links_[i]);
updateValues (links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
if (DL >= 2) {
cout << endl;
}
}
}
bool
BnBpSolver::converged (void) const
{
// this can happen if the graph is fully disconnected
if (links_.size() == 0) {
return true;
}
if (nIters_ == 0 || nIters_ == 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual < BpOptions::accuracy) {
converged = true;
} else {
converged = false;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (DL >= 2) {
cout << links_[i]->toString() + " residual change = " ;
cout << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
break;
}
}
}
return converged;
}
void
BnBpSolver::maxResidualSchedule (void)
{
if (nIters_ == 1) {
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it));
}
return;
}
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
if (DL >= 2) {
cout << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
cout << " " << setw (30) << left << (*it)->toString();
cout << "residual = " << (*it)->getResidual() << endl;
}
}
SortedOrder::iterator it = sortedOrder_.begin();
BpLink* link = *it;
if (link->getResidual() < BpOptions::accuracy) {
sortedOrder_.erase (it);
it = sortedOrder_.begin();
return;
}
updateMessage (link);
updateValues (link);
link->clearResidual();
sortedOrder_.erase (it);
linkMap_.find (link)->second = sortedOrder_.insert (link);
const BpLinkSet& outParentLinks =
ninf(link->getDestination())->getOutcomingParentLinks();
for (unsigned i = 0; i < outParentLinks.size(); i++) {
if (outParentLinks[i]->getDestination() != link->getSource()
&& outParentLinks[i]->getDestination()->hasEvidence() == false) {
calculateMessage (outParentLinks[i]);
BpLinkMap::iterator iter = linkMap_.find (outParentLinks[i]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (outParentLinks[i]);
}
}
const BpLinkSet& outChildLinks =
ninf(link->getDestination())->getOutcomingChildLinks();
for (unsigned i = 0; i < outChildLinks.size(); i++) {
if (outChildLinks[i]->getDestination() != link->getSource()) {
calculateMessage (outChildLinks[i]);
BpLinkMap::iterator iter = linkMap_.find (outChildLinks[i]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (outChildLinks[i]);
}
}
if (DL >= 2) {
cout << "----------------------------------------" ;
cout << "----------------------------------------" << endl;
}
}
}
void
BnBpSolver::updatePiValues (BayesNode* x)
{
// π(Xi)
if (DL >= 3) {
cout << "updating " << PI_SYMBOL << " values for " << x->label() << endl;
}
Params& piValues = ninf(x)->getPiValues();
const BpLinkSet& parentLinks = ninf(x)->getIncomingParentLinks();
const BnNodeSet& ps = x->getParents();
Ranges ranges;
for (unsigned i = 0; i < ps.size(); i++) {
ranges.push_back (ps[i]->nrStates());
}
StatesIndexer indexer (ranges, false);
stringstream* calcs1 = 0;
stringstream* calcs2 = 0;
Params messageProducts (indexer.size());
for (unsigned k = 0; k < indexer.size(); k++) {
if (DL >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double messageProduct = Util::multIdenty();
if (Globals::logDomain) {
for (unsigned i = 0; i < parentLinks.size(); i++) {
messageProduct += parentLinks[i]->getMessage()[indexer[i]];
}
} else {
for (unsigned i = 0; i < parentLinks.size(); i++) {
messageProduct *= parentLinks[i]->getMessage()[indexer[i]];
if (DL >= 5) {
if (i != 0) *calcs1 << " + " ;
if (i != 0) *calcs2 << " + " ;
*calcs1 << parentLinks[i]->toString (indexer[i]);
*calcs2 << parentLinks[i]->getMessage()[indexer[i]];
}
}
}
messageProducts[k] = messageProduct;
if (DL >= 5) {
cout << " mp" << k;
cout << " = " << (*calcs1).str();
if (parentLinks.size() == 1) {
cout << " = " << messageProduct << endl;
} else {
cout << " = " << (*calcs2).str();
cout << " = " << messageProduct << endl;
}
delete calcs1;
delete calcs2;
}
++ indexer;
}
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
double sum = Util::addIdenty();
if (DL >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
indexer.reset();
if (Globals::logDomain) {
for (unsigned k = 0; k < indexer.size(); k++) {
Util::logSum (sum,
x->getProbability(xi, indexer.linearIndex()) + messageProducts[k]);
++ indexer;
}
} else {
for (unsigned k = 0; k < indexer.size(); k++) {
sum += x->getProbability (xi, indexer.linearIndex()) * messageProducts[k];
if (DL >= 5) {
if (k != 0) *calcs1 << " + " ;
if (k != 0) *calcs2 << " + " ;
*calcs1 << x->cptEntryToString (xi, indexer.indices());
*calcs1 << ".mp" << k;
*calcs2 << Util::fl (x->getProbability (xi, indexer.linearIndex()));
*calcs2 << "*" << messageProducts[k];
}
++ indexer;
}
}
piValues[xi] = sum;
if (DL >= 5) {
cout << " " << PI_SYMBOL << "(" << x->label() << ")" ;
cout << "[" << x->states()[xi] << "]" ;
cout << " = " << (*calcs1).str();
cout << " = " << (*calcs2).str();
cout << " = " << piValues[xi] << endl;
delete calcs1;
delete calcs2;
}
}
}
void
BnBpSolver::updateLambdaValues (BayesNode* x)
{
// λ(Xi)
if (DL >= 3) {
cout << "updating " << LD_SYMBOL << " values for " << x->label() << endl;
}
Params& lambdaValues = ninf(x)->getLambdaValues();
const BpLinkSet& childLinks = ninf(x)->getIncomingChildLinks();
stringstream* calcs1 = 0;
stringstream* calcs2 = 0;
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
if (DL >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double product = Util::multIdenty();
if (Globals::logDomain) {
for (unsigned i = 0; i < childLinks.size(); i++) {
product += childLinks[i]->getMessage()[xi];
}
} else {
for (unsigned i = 0; i < childLinks.size(); i++) {
product *= childLinks[i]->getMessage()[xi];
if (DL >= 5) {
if (i != 0) *calcs1 << "." ;
if (i != 0) *calcs2 << "*" ;
*calcs1 << childLinks[i]->toString (xi);
*calcs2 << childLinks[i]->getMessage()[xi];
}
}
}
lambdaValues[xi] = product;
if (DL >= 5) {
cout << " " << LD_SYMBOL << "(" << x->label() << ")" ;
cout << "[" << x->states()[xi] << "]" ;
cout << " = " << (*calcs1).str();
if (childLinks.size() == 1) {
cout << " = " << product << endl;
} else {
cout << " = " << (*calcs2).str();
cout << " = " << lambdaValues[xi] << endl;
}
delete calcs1;
delete calcs2;
}
}
}
void
BnBpSolver::calculatePiMessage (BpLink* link)
{
// πX(Zi)
BayesNode* z = link->getSource();
BayesNode* x = link->getDestination();
Params& zxPiNextMessage = link->getNextMessage();
const BpLinkSet& zChildLinks = ninf(z)->getIncomingChildLinks();
stringstream* calcs1 = 0;
stringstream* calcs2 = 0;
const Params& zPiValues = ninf(z)->getPiValues();
for (unsigned zi = 0; zi < z->nrStates(); zi++) {
double product = zPiValues[zi];
if (DL >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
*calcs1 << PI_SYMBOL << "(" << z->label() << ")";
*calcs1 << "[" << z->states()[zi] << "]" ;
*calcs2 << product;
}
if (Globals::logDomain) {
for (unsigned i = 0; i < zChildLinks.size(); i++) {
if (zChildLinks[i]->getSource() != x) {
product += zChildLinks[i]->getMessage()[zi];
}
}
} else {
for (unsigned i = 0; i < zChildLinks.size(); i++) {
if (zChildLinks[i]->getSource() != x) {
product *= zChildLinks[i]->getMessage()[zi];
if (DL >= 5) {
*calcs1 << "." << zChildLinks[i]->toString (zi);
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
}
}
}
}
zxPiNextMessage[zi] = product;
if (DL >= 5) {
cout << " " << link->toString();
cout << "[" << z->states()[zi] << "]" ;
cout << " = " << (*calcs1).str();
if (zChildLinks.size() == 1) {
cout << " = " << product << endl;
} else {
cout << " = " << (*calcs2).str();
cout << " = " << product << endl;
}
delete calcs1;
delete calcs2;
}
}
Util::normalize (zxPiNextMessage);
}
void
BnBpSolver::calculateLambdaMessage (BpLink* link)
{
// λY(Xi)
BayesNode* y = link->getSource();
BayesNode* x = link->getDestination();
if (x->hasEvidence()) {
return;
}
Params& yxLambdaNextMessage = link->getNextMessage();
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
const Params& yLambdaValues = ninf(y)->getLambdaValues();
int parentIndex = y->getIndexOfParent (x);
stringstream* calcs1 = 0;
stringstream* calcs2 = 0;
const BnNodeSet& ps = y->getParents();
Ranges ranges;
for (unsigned i = 0; i < ps.size(); i++) {
ranges.push_back (ps[i]->nrStates());
}
StatesIndexer indexer (ranges, false);
unsigned N = indexer.size() / x->nrStates();
Params messageProducts (N);
for (unsigned k = 0; k < N; k++) {
while (indexer[parentIndex] != 0) {
++ indexer;
}
if (DL >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double messageProduct = Util::multIdenty();
if (Globals::logDomain) {
for (unsigned i = 0; i < yParentLinks.size(); i++) {
if (yParentLinks[i]->getSource() != x) {
messageProduct += yParentLinks[i]->getMessage()[indexer[i]];
}
}
} else {
for (unsigned i = 0; i < yParentLinks.size(); i++) {
if (yParentLinks[i]->getSource() != x) {
if (DL >= 5) {
if (messageProduct != Util::multIdenty()) *calcs1 << "*" ;
if (messageProduct != Util::multIdenty()) *calcs2 << "*" ;
*calcs1 << yParentLinks[i]->toString (indexer[i]);
*calcs2 << yParentLinks[i]->getMessage()[indexer[i]];
}
messageProduct *= yParentLinks[i]->getMessage()[indexer[i]];
}
}
}
messageProducts[k] = messageProduct;
++ indexer;
if (DL >= 5) {
cout << " mp" << k;
cout << " = " << (*calcs1).str();
if (yParentLinks.size() == 1) {
cout << 1 << endl;
} else if (yParentLinks.size() == 2) {
cout << " = " << messageProduct << endl;
} else {
cout << " = " << (*calcs2).str();
cout << " = " << messageProduct << endl;
}
delete calcs1;
delete calcs2;
}
}
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
if (DL >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double outerSum = Util::addIdenty();
for (unsigned yi = 0; yi < y->nrStates(); yi++) {
if (DL >= 5) {
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
}
double innerSum = Util::addIdenty();
indexer.reset();
if (Globals::logDomain) {
for (unsigned k = 0; k < N; k++) {
while (indexer[parentIndex] != xi) {
++ indexer;
}
Util::logSum (innerSum, y->getProbability (
yi, indexer.linearIndex()) + messageProducts[k]);
++ indexer;
}
Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
} else {
for (unsigned k = 0; k < N; k++) {
while (indexer[parentIndex] != xi) {
++ indexer;
}
if (DL >= 5) {
if (k != 0) *calcs1 << " + " ;
if (k != 0) *calcs2 << " + " ;
*calcs1 << y->cptEntryToString (yi, indexer.indices());
*calcs1 << ".mp" << k;
*calcs2 << y->getProbability (yi, indexer.linearIndex());
*calcs2 << "*" << messageProducts[k];
}
innerSum += y->getProbability (
yi, indexer.linearIndex()) * messageProducts[k];
++ indexer;
}
outerSum += innerSum * yLambdaValues[yi];
}
if (DL >= 5) {
*calcs1 << "}." << LD_SYMBOL << "(" << y->label() << ")" ;
*calcs1 << "[" << y->states()[yi] << "]";
*calcs2 << "}*" << yLambdaValues[yi];
}
}
yxLambdaNextMessage[xi] = outerSum;
if (DL >= 5) {
cout << " " << link->toString();
cout << "[" << x->states()[xi] << "]" ;
cout << " = " << (*calcs1).str();
cout << " = " << (*calcs2).str();
cout << " = " << yxLambdaNextMessage[xi] << endl;
delete calcs1;
delete calcs2;
}
}
Util::normalize (yxLambdaNextMessage);
}
Params
BnBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
{
BnNodeSet jointVars;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
assert (bayesNet_->getBayesNode (jointVarIds[i]));
jointVars.push_back (bayesNet_->getBayesNode (jointVarIds[i]));
}
BayesNet* mrn = bayesNet_->getMinimalRequesiteNetwork (jointVarIds[0]);
BnBpSolver solver (*mrn);
solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
delete mrn;
VarIds observedVids = {jointVars[0]->varId()};
for (unsigned i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
VarIds reqVars = {jointVarIds[i]};
reqVars.insert (reqVars.end(), observedVids.begin(), observedVids.end());
mrn = bayesNet_->getMinimalRequesiteNetwork (reqVars);
Params newBeliefs;
VarNodes observedVars;
for (unsigned j = 0; j < observedVids.size(); j++) {
observedVars.push_back (mrn->getBayesNode (observedVids[j]));
}
StatesIndexer idx (observedVars, false);
while (idx.valid()) {
for (unsigned j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (idx[j]);
}
BnBpSolver solver (*mrn);
solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (unsigned k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ idx;
}
int count = -1;
for (unsigned j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->nrStates() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
delete mrn;
}
return prevBeliefs;
}
void
BnBpSolver::printPiLambdaValues (const BayesNode* var) const
{
cout << left;
cout << setw (10) << "states" ;
cout << setw (20) << PI_SYMBOL << "(" + var->label() + ")" ;
cout << setw (20) << LD_SYMBOL << "(" + var->label() + ")" ;
cout << setw (16) << "belief" ;
cout << endl;
cout << "--------------------------------" ;
cout << "--------------------------------" ;
cout << endl;
const States& states = var->states();
const Params& piVals = ninf(var)->getPiValues();
const Params& ldVals = ninf(var)->getLambdaValues();
const Params& beliefs = ninf(var)->getBeliefs();
for (unsigned xi = 0; xi < var->nrStates(); xi++) {
cout << setw (10) << states[xi];
cout << setw (19) << piVals[xi];
cout << setw (19) << ldVals[xi];
cout.precision (PRECISION);
cout << setw (16) << beliefs[xi];
cout << endl;
}
cout << endl;
}
void
BnBpSolver::printAllMessageStatus (void) const
{
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
printPiLambdaValues (nodes[i]);
}
}
BpNodeInfo::BpNodeInfo (BayesNode* node)
{
node_ = node;
piVals_.resize (node->nrStates(), Util::one());
ldVals_.resize (node->nrStates(), Util::one());
}
Params
BpNodeInfo::getBeliefs (void) const
{
double sum = 0.0;
Params beliefs (node_->nrStates());
if (Globals::logDomain) {
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
beliefs[xi] = exp (piVals_[xi] + ldVals_[xi]);
sum += beliefs[xi];
}
} else {
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
beliefs[xi] = piVals_[xi] * ldVals_[xi];
sum += beliefs[xi];
}
}
assert (sum);
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
beliefs[xi] /= sum;
}
return beliefs;
}
bool
BpNodeInfo::receivedBottomInfluence (void) const
{
// if all lambda values are equal, then neither
// this node neither its descendents have evidence,
// we can use this to don't send lambda messages his parents
bool childInfluenced = false;
for (unsigned xi = 1; xi < node_->nrStates(); xi++) {
if (ldVals_[xi] != ldVals_[0]) {
childInfluenced = true;
break;
}
}
return childInfluenced;
}

View File

@ -1,245 +0,0 @@
#ifndef HORUS_BNBPSOLVER_H
#define HORUS_BNBPSOLVER_H
#include <vector>
#include <set>
#include "Solver.h"
#include "BayesNet.h"
#include "Horus.h"
#include "Util.h"
using namespace std;
class BpNodeInfo;
static const string PI_SYMBOL = "pi" ;
static const string LD_SYMBOL = "ld" ;
enum LinkOrientation {UP, DOWN};
class BpLink
{
public:
BpLink (BayesNode* s, BayesNode* d, LinkOrientation o)
{
source_ = s;
destin_ = d;
orientation_ = o;
if (orientation_ == LinkOrientation::DOWN) {
v1_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
v2_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
} else {
v1_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
v2_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
}
currMsg_ = &v1_;
nextMsg_ = &v2_;
residual_ = 0;
msgSended_ = false;
}
void updateMessage (void)
{
swap (currMsg_, nextMsg_);
msgSended_ = true;
}
void updateResidual (void)
{
residual_ = Util::getMaxNorm (v1_, v2_);
}
string toString (void) const
{
stringstream ss;
if (orientation_ == LinkOrientation::DOWN) {
ss << PI_SYMBOL;
} else {
ss << LD_SYMBOL;
}
ss << "(" << source_->label();
ss << " --> " << destin_->label() << ")" ;
return ss.str();
}
string toString (unsigned stateIndex) const
{
stringstream ss;
ss << toString() << "[" ;
if (orientation_ == LinkOrientation::DOWN) {
ss << source_->states()[stateIndex] << "]" ;
} else {
ss << destin_->states()[stateIndex] << "]" ;
}
return ss.str();
}
BayesNode* getSource (void) const { return source_; }
BayesNode* getDestination (void) const { return destin_; }
LinkOrientation getOrientation (void) const { return orientation_; }
const Params& getMessage (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0;}
private:
BayesNode* source_;
BayesNode* destin_;
LinkOrientation orientation_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
bool msgSended_;
double residual_;
};
typedef vector<BpLink*> BpLinkSet;
class BpNodeInfo
{
public:
BpNodeInfo (BayesNode*);
Params getBeliefs (void) const;
bool receivedBottomInfluence (void) const;
Params& getPiValues (void) { return piVals_; }
Params& getLambdaValues (void) { return ldVals_; }
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; }
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); }
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
private:
DISALLOW_COPY_AND_ASSIGN (BpNodeInfo);
const BayesNode* node_;
Params piVals_; // pi values
Params ldVals_; // lambda values
BpLinkSet inParentLinks_;
BpLinkSet inChildLinks_;
BpLinkSet outParentLinks_;
BpLinkSet outChildLinks_;
};
class BnBpSolver : public Solver
{
public:
BnBpSolver (const BayesNet&);
~BnBpSolver (void);
void runSolver (void);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
private:
DISALLOW_COPY_AND_ASSIGN (BnBpSolver);
void initializeSolver (void);
void runLoopySolver (void);
void maxResidualSchedule (void);
bool converged (void) const;
void updatePiValues (BayesNode*);
void updateLambdaValues (BayesNode*);
void calculateLambdaMessage (BpLink*);
void calculatePiMessage (BpLink*);
Params getJointByJunctionNode (const VarIds&);
Params getJointByConditioning (const VarIds&) const;
void printPiLambdaValues (const BayesNode*) const;
void printAllMessageStatus (void) const;
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
cout << "calculating & updating " << link->toString() << endl;
}
if (link->getOrientation() == LinkOrientation::DOWN) {
calculatePiMessage (link);
} else if (link->getOrientation() == LinkOrientation::UP) {
calculateLambdaMessage (link);
}
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void calculateMessage (BpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
cout << "calculating " << link->toString() << endl;
}
if (link->getOrientation() == LinkOrientation::DOWN) {
calculatePiMessage (link);
} else if (link->getOrientation() == LinkOrientation::UP) {
calculateLambdaMessage (link);
}
if (calcResidual) {
link->updateResidual();
}
}
void updateMessage (BpLink* link)
{
if (DL >= 3) {
cout << "updating " << link->toString() << endl;
}
link->updateMessage();
}
void updateValues (BpLink* link)
{
if (!link->getDestination()->hasEvidence()) {
if (link->getOrientation() == LinkOrientation::DOWN) {
updatePiValues (link->getDestination());
} else if (link->getOrientation() == LinkOrientation::UP) {
updateLambdaValues (link->getDestination());
}
}
}
BpNodeInfo* ninf (const BayesNode* node) const
{
assert (node);
assert (node == bayesNet_->getBayesNode (node->varId()));
assert (node->getIndex() < nodesI_.size());
return nodesI_[node->getIndex()];
}
const BayesNet* bayesNet_;
vector<BpLink*> links_;
vector<BpNodeInfo*> nodesI_;
unsigned nIters_;
struct compare
{
inline bool operator() (const BpLink* e1, const BpLink* e2)
{
return e1->getResidual() > e2->getResidual();
}
};
typedef multiset<BpLink*, compare> SortedOrder;
SortedOrder sortedOrder_;
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
BpLinkMap linkMap_;
};
#endif // HORUS_BNBPSOLVER_H

View File

@ -5,21 +5,22 @@
#include <iostream> #include <iostream>
#include "FgBpSolver.h" #include "BpSolver.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Factor.h" #include "Factor.h"
#include "Indexer.h" #include "Indexer.h"
#include "Horus.h" #include "Horus.h"
FgBpSolver::FgBpSolver (const FactorGraph& fg) : Solver (&fg) BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
{ {
factorGraph_ = &fg; fg_ = &fg;
runned_ = false;
} }
FgBpSolver::~FgBpSolver (void) BpSolver::~BpSolver (void)
{ {
for (unsigned i = 0; i < varsI_.size(); i++) { for (unsigned i = 0; i < varsI_.size(); i++) {
delete varsI_[i]; delete varsI_[i];
@ -34,64 +35,45 @@ FgBpSolver::~FgBpSolver (void)
void Params
FgBpSolver::runSolver (void) BpSolver::solveQuery (VarIds queryVids)
{ {
clock_t start; assert (queryVids.empty() == false);
if (COLLECT_STATISTICS) { if (queryVids.size() == 1) {
start = clock(); return getPosterioriOf (queryVids[0]);
} } else {
runLoopySolver(); return getJointDistributionOf (queryVids);
if (DL >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = factorGraph_->getVarNodes().size();
if (COLLECT_STATISTICS) {
unsigned nIters = 0;
bool loopy = factorGraph_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
stringstream ss;
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
factorGraph_->exportToGraphViz (ss.str().c_str());
} }
} }
Params Params
FgBpSolver::getPosterioriOf (VarId vid) BpSolver::getPosterioriOf (VarId vid)
{ {
assert (factorGraph_->getFgVarNode (vid)); if (runned_ == false) {
FgVarNode* var = factorGraph_->getFgVarNode (vid); runSolver();
}
assert (fg_->getVarNode (vid));
VarNode* var = fg_->getVarNode (vid);
Params probs; Params probs;
if (var->hasEvidence()) { if (var->hasEvidence()) {
probs.resize (var->nrStates(), Util::noEvidence()); probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = Util::withEvidence(); probs[var->getEvidence()] = LogAware::withEvidence();
} else { } else {
probs.resize (var->nrStates(), Util::multIdenty()); probs.resize (var->range(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks(); const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) { if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
Util::add (probs, links[i]->getMessage()); Util::add (probs, links[i]->getMessage());
} }
Util::normalize (probs); LogAware::normalize (probs);
Util::fromLog (probs); Util::fromLog (probs);
} else { } else {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
Util::multiply (probs, links[i]->getMessage()); Util::multiply (probs, links[i]->getMessage());
} }
Util::normalize (probs); LogAware::normalize (probs);
} }
} }
return probs; return probs;
@ -100,13 +82,16 @@ FgBpSolver::getPosterioriOf (VarId vid)
Params Params
FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds) BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
{ {
FgVarNode* vn = factorGraph_->getFgVarNode (jointVarIds[0]); if (runned_ == false) {
const FgFacSet& factorNodes = vn->neighbors(); runSolver();
}
int idx = -1; int idx = -1;
for (unsigned i = 0; i < factorNodes.size(); i++) { VarNode* vn = fg_->getVarNode (jointVarIds[0]);
if (factorNodes[i]->factor()->contains (jointVarIds)) { const FacNodes& facNodes = vn->neighbors();
for (unsigned i = 0; i < facNodes.size(); i++) {
if (facNodes[i]->factor().contains (jointVarIds)) {
idx = i; idx = i;
break; break;
} }
@ -114,18 +99,18 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
if (idx == -1) { if (idx == -1) {
return getJointByConditioning (jointVarIds); return getJointByConditioning (jointVarIds);
} else { } else {
Factor r (*factorNodes[idx]->factor()); Factor res (facNodes[idx]->factor());
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks(); const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
Factor msg (links[i]->getVariable()->varId(), Factor msg ({links[i]->getVariable()->varId()},
links[i]->getVariable()->nrStates(), {links[i]->getVariable()->range()},
getVar2FactorMsg (links[i])); getVar2FactorMsg (links[i]));
r.multiply (msg); res.multiply (msg);
} }
r.sumOutAllExcept (jointVarIds); res.sumOutAllExcept (jointVarIds);
r.reorderVariables (jointVarIds); res.reorderArguments (jointVarIds);
r.normalize(); res.normalize();
Params jointDist = r.getParameters(); Params jointDist = res.params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::fromLog (jointDist); Util::fromLog (jointDist);
} }
@ -136,35 +121,29 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
void void
FgBpSolver::runLoopySolver (void) BpSolver::runSolver (void)
{ {
clock_t start;
if (Constants::COLLECT_STATS) {
start = clock();
}
initializeSolver(); initializeSolver();
nIters_ = 0; nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) { while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++; nIters_ ++;
if (DL >= 2) { if (Constants::DEBUG >= 2) {
cout << "****************************************" ; Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
cout << "****************************************" ; // cout << endl;
cout << endl;
cout << " Iteration " << nIters_ << endl;
cout << "****************************************" ;
cout << "****************************************" ;
cout << endl;
} }
switch (BpOptions::schedule) { switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM: case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end()); random_shuffle (links_.begin(), links_.end());
// no break // no break
case BpOptions::Schedule::SEQ_FIXED: case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) { for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]); calculateAndUpdateMessage (links_[i]);
} }
break; break;
case BpOptions::Schedule::PARALLEL: case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) { for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]); calculateMessage (links_[i]);
@ -173,61 +152,43 @@ FgBpSolver::runLoopySolver (void)
updateMessage(links_[i]); updateMessage(links_[i]);
} }
break; break;
case BpOptions::Schedule::MAX_RESIDUAL: case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule(); maxResidualSchedule();
break; break;
} }
if (DL >= 2) { if (Constants::DEBUG >= 2) {
cout << endl; cout << endl;
} }
} }
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = fg_->varNodes().size();
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = fg_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
runned_ = true;
} }
void void
FgBpSolver::initializeSolver (void) BpSolver::createLinks (void)
{ {
const FgVarSet& varNodes = factorGraph_->getVarNodes(); const FacNodes& facNodes = fg_->facNodes();
for (unsigned i = 0; i < varsI_.size(); i++) {
delete varsI_[i];
}
varsI_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo());
}
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
for (unsigned i = 0; i < facsI_.size(); i++) {
delete facsI_[i];
}
facsI_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo()); const VarNodes& neighbors = facNodes[i]->neighbors();
}
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
createLinks();
for (unsigned i = 0; i < links_.size(); i++) {
FgFacNode* src = links_[i]->getFactor();
FgVarNode* dst = links_[i]->getVariable();
ninf (dst)->addSpLink (links_[i]);
ninf (src)->addSpLink (links_[i]);
}
}
void
FgBpSolver::createLinks (void)
{
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
const FgVarSet& neighbors = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighbors.size(); j++) { for (unsigned j = 0; j < neighbors.size(); j++) {
links_.push_back (new SpLink (facNodes[i], neighbors[j])); links_.push_back (new SpLink (facNodes[i], neighbors[j]));
} }
@ -236,42 +197,8 @@ FgBpSolver::createLinks (void)
bool
FgBpSolver::converged (void)
{
if (links_.size() == 0) {
return true;
}
if (nIters_ == 0 || nIters_ == 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual > BpOptions::accuracy) {
converged = false;
} else {
converged = true;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (DL >= 2) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
if (DL == 0) break;
}
}
}
return converged;
}
void void
FgBpSolver::maxResidualSchedule (void) BpSolver::maxResidualSchedule (void)
{ {
if (nIters_ == 1) { if (nIters_ == 1) {
for (unsigned i = 0; i < links_.size(); i++) { for (unsigned i = 0; i < links_.size(); i++) {
@ -283,7 +210,7 @@ FgBpSolver::maxResidualSchedule (void)
} }
for (unsigned c = 0; c < links_.size(); c++) { for (unsigned c = 0; c < links_.size(); c++) {
if (DL >= 2) { if (Constants::DEBUG >= 2) {
cout << "current residuals:" << endl; cout << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin(); for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) { it != sortedOrder_.end(); it ++) {
@ -303,7 +230,7 @@ FgBpSolver::maxResidualSchedule (void)
linkMap_.find (link)->second = sortedOrder_.insert (link); linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin // update the messages that depend on message source --> destin
const FgFacSet& factorNeighbors = link->getVariable()->neighbors(); const FacNodes& factorNeighbors = link->getVariable()->neighbors();
for (unsigned i = 0; i < factorNeighbors.size(); i++) { for (unsigned i = 0; i < factorNeighbors.size(); i++) {
if (factorNeighbors[i] != link->getFactor()) { if (factorNeighbors[i] != link->getFactor()) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks(); const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
@ -317,9 +244,8 @@ FgBpSolver::maxResidualSchedule (void)
} }
} }
} }
if (DL >= 2) { if (Constants::DEBUG >= 2) {
cout << "----------------------------------------" ; Util::printDashedLine();
cout << "----------------------------------------" << endl;
} }
} }
} }
@ -327,26 +253,26 @@ FgBpSolver::maxResidualSchedule (void)
void void
FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const BpSolver::calculateFactor2VariableMsg (SpLink* link)
{ {
const FgFacNode* src = link->getFactor(); FacNode* src = link->getFactor();
const FgVarNode* dst = link->getVariable(); const VarNode* dst = link->getVariable();
const SpLinkSet& links = ninf(src)->getLinks(); const SpLinkSet& links = ninf(src)->getLinks();
// calculate the product of messages that were sent // calculate the product of messages that were sent
// to factor `src', except from var `dst' // to factor `src', except from var `dst'
unsigned msgSize = 1; unsigned msgSize = 1;
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
msgSize *= links[i]->getVariable()->nrStates(); msgSize *= links[i]->getVariable()->range();
} }
unsigned repetitions = 1; unsigned repetitions = 1;
Params msgProduct (msgSize, Util::multIdenty()); Params msgProduct (msgSize, LogAware::multIdenty());
if (Globals::logDomain) { if (Globals::logDomain) {
for (int i = links.size() - 1; i >= 0; i--) { for (int i = links.size() - 1; i >= 0; i--) {
if (links[i]->getVariable() != dst) { if (links[i]->getVariable() != dst) {
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions); Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->nrStates(); repetitions *= links[i]->getVariable()->range();
} else { } else {
unsigned ds = links[i]->getVariable()->nrStates(); unsigned ds = links[i]->getVariable()->range();
Util::add (msgProduct, Params (ds, 1.0), repetitions); Util::add (msgProduct, Params (ds, 1.0), repetitions);
repetitions *= ds; repetitions *= ds;
} }
@ -354,70 +280,64 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
} else { } else {
for (int i = links.size() - 1; i >= 0; i--) { for (int i = links.size() - 1; i >= 0; i--) {
if (links[i]->getVariable() != dst) { if (links[i]->getVariable() != dst) {
if (DL >= 5) { if (Constants::DEBUG >= 5) {
cout << " message from " << links[i]->getVariable()->label(); cout << " message from " << links[i]->getVariable()->label();
cout << ": " << endl; cout << ": " << endl;
} }
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions); Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->nrStates(); repetitions *= links[i]->getVariable()->range();
} else { } else {
unsigned ds = links[i]->getVariable()->nrStates(); unsigned ds = links[i]->getVariable()->range();
Util::multiply (msgProduct, Params (ds, 1.0), repetitions); Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
repetitions *= ds; repetitions *= ds;
} }
} }
} }
Factor result (src->factor()->getVarIds(), Factor result (src->factor().arguments(),
src->factor()->getRanges(), src->factor().ranges(), msgProduct);
msgProduct); result.multiply (src->factor());
result.multiply (*(src->factor())); if (Constants::DEBUG >= 5) {
if (DL >= 5) { cout << " message product: " << msgProduct << endl;
cout << " message product: " ; cout << " original factor: " << src->factor().params() << endl;
cout << Util::parametersToString (msgProduct) << endl; cout << " factor product: " << result.params() << endl;
cout << " original factor: " ;
cout << Util::parametersToString (src->getParameters()) << endl;
cout << " factor product: " ;
cout << Util::parametersToString (result.getParameters()) << endl;
} }
result.sumOutAllExcept (dst->varId()); result.sumOutAllExcept (dst->varId());
if (DL >= 5) { if (Constants::DEBUG >= 5) {
cout << " marginalized: " ; cout << " marginalized: " ;
cout << Util::parametersToString (result.getParameters()) << endl; cout << result.params() << endl;
} }
const Params& resultParams = result.getParameters(); const Params& resultParams = result.params();
Params& message = link->getNextMessage(); Params& message = link->getNextMessage();
for (unsigned i = 0; i < resultParams.size(); i++) { for (unsigned i = 0; i < resultParams.size(); i++) {
message[i] = resultParams[i]; message[i] = resultParams[i];
} }
Util::normalize (message); LogAware::normalize (message);
if (DL >= 5) { if (Constants::DEBUG >= 5) {
cout << " curr msg: " ; cout << " curr msg: " << link->getMessage() << endl;
cout << Util::parametersToString (link->getMessage()) << endl; cout << " next msg: " << message << endl;
cout << " next msg: " ;
cout << Util::parametersToString (message) << endl;
} }
} }
Params Params
FgBpSolver::getVar2FactorMsg (const SpLink* link) const BpSolver::getVar2FactorMsg (const SpLink* link) const
{ {
const FgVarNode* src = link->getVariable(); const VarNode* src = link->getVariable();
const FgFacNode* dst = link->getFactor(); const FacNode* dst = link->getFactor();
Params msg; Params msg;
if (src->hasEvidence()) { if (src->hasEvidence()) {
msg.resize (src->nrStates(), Util::noEvidence()); msg.resize (src->range(), LogAware::noEvidence());
msg[src->getEvidence()] = Util::withEvidence(); msg[src->getEvidence()] = LogAware::withEvidence();
if (DL >= 5) { if (Constants::DEBUG >= 5) {
cout << Util::parametersToString (msg); cout << msg;
} }
} else { } else {
msg.resize (src->nrStates(), Util::one()); msg.resize (src->range(), LogAware::one());
} }
if (DL >= 5) { if (Constants::DEBUG >= 5) {
cout << Util::parametersToString (msg); cout << msg;
} }
const SpLinkSet& links = ninf (src)->getLinks(); const SpLinkSet& links = ninf (src)->getLinks();
if (Globals::logDomain) { if (Globals::logDomain) {
@ -430,14 +350,14 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) { if (links[i]->getFactor() != dst) {
Util::multiply (msg, links[i]->getMessage()); Util::multiply (msg, links[i]->getMessage());
if (DL >= 5) { if (Constants::DEBUG >= 5) {
cout << " x " << Util::parametersToString (links[i]->getMessage()); cout << " x " << links[i]->getMessage();
} }
} }
} }
} }
if (DL >= 5) { if (Constants::DEBUG >= 5) {
cout << " = " << Util::parametersToString (msg); cout << " = " << msg;
} }
return msg; return msg;
} }
@ -445,16 +365,16 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
Params Params
FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
{ {
FgVarSet jointVars; VarNodes jointVars;
for (unsigned i = 0; i < jointVarIds.size(); i++) { for (unsigned i = 0; i < jointVarIds.size(); i++) {
assert (factorGraph_->getFgVarNode (jointVarIds[i])); assert (fg_->getVarNode (jointVarIds[i]));
jointVars.push_back (factorGraph_->getFgVarNode (jointVarIds[i])); jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
} }
FactorGraph* fg = new FactorGraph (*factorGraph_); FactorGraph* fg = new FactorGraph (*fg_);
FgBpSolver solver (*fg); BpSolver solver (*fg);
solver.runSolver(); solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]); Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
@ -463,9 +383,9 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
for (unsigned i = 1; i < jointVarIds.size(); i++) { for (unsigned i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false); assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs; Params newBeliefs;
VarNodes observedVars; Vars observedVars;
for (unsigned j = 0; j < observedVids.size(); j++) { for (unsigned j = 0; j < observedVids.size(); j++) {
observedVars.push_back (fg->getFgVarNode (observedVids[j])); observedVars.push_back (fg->getVarNode (observedVids[j]));
} }
StatesIndexer idx (observedVars, false); StatesIndexer idx (observedVars, false);
while (idx.valid()) { while (idx.valid()) {
@ -473,7 +393,7 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
observedVars[j]->setEvidence (idx[j]); observedVars[j]->setEvidence (idx[j]);
} }
++ idx; ++ idx;
FgBpSolver solver (*fg); BpSolver solver (*fg);
solver.runSolver(); solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]); Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (unsigned k = 0; k < beliefs.size(); k++) { for (unsigned k = 0; k < beliefs.size(); k++) {
@ -483,7 +403,7 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
int count = -1; int count = -1;
for (unsigned j = 0; j < newBeliefs.size(); j++) { for (unsigned j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->nrStates() == 0) { if (j % jointVars[i]->range() == 0) {
count ++; count ++;
} }
newBeliefs[j] *= prevBeliefs[count]; newBeliefs[j] *= prevBeliefs[count];
@ -497,15 +417,76 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
void void
FgBpSolver::printLinkInformation (void) const BpSolver::initializeSolver (void)
{
const VarNodes& varNodes = fg_->varNodes();
varsI_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo());
}
const FacNodes& facNodes = fg_->facNodes();
facsI_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo());
}
createLinks();
for (unsigned i = 0; i < links_.size(); i++) {
FacNode* src = links_[i]->getFactor();
VarNode* dst = links_[i]->getVariable();
ninf (dst)->addSpLink (links_[i]);
ninf (src)->addSpLink (links_[i]);
}
}
bool
BpSolver::converged (void)
{
if (links_.size() == 0) {
return true;
}
if (nIters_ <= 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual > BpOptions::accuracy) {
converged = false;
} else {
converged = true;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (Constants::DEBUG >= 2) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
if (Constants::DEBUG == 0) break;
}
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
return converged;
}
void
BpSolver::printLinkInformation (void) const
{ {
for (unsigned i = 0; i < links_.size(); i++) { for (unsigned i = 0; i < links_.size(); i++) {
SpLink* l = links_[i]; SpLink* l = links_[i];
cout << l->toString() << ":" << endl; cout << l->toString() << ":" << endl;
cout << " curr msg = " ; cout << " curr msg = " ;
cout << Util::parametersToString (l->getMessage()) << endl; cout << l->getMessage() << endl;
cout << " next msg = " ; cout << " next msg = " ;
cout << Util::parametersToString (l->getNextMessage()) << endl; cout << l->getNextMessage() << endl;
cout << " residual = " << l->getResidual() << endl; cout << " residual = " << l->getResidual() << endl;
} }
} }

View File

@ -0,0 +1,188 @@
#ifndef HORUS_BPSOLVER_H
#define HORUS_BPSOLVER_H
#include <set>
#include <vector>
#include <sstream>
#include "Solver.h"
#include "Factor.h"
#include "FactorGraph.h"
#include "Util.h"
using namespace std;
class SpLink
{
public:
SpLink (FacNode* fn, VarNode* vn)
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
msgSended_ = false;
residual_ = 0.0;
}
virtual ~SpLink (void) { };
FacNode* getFactor (void) const { return fac_; }
VarNode* getVariable (void) const { return var_; }
const Params& getMessage (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; }
void updateResidual (void)
{
residual_ = LogAware::getMaxNorm (v1_,v2_);
}
virtual void updateMessage (void)
{
swap (currMsg_, nextMsg_);
msgSended_ = true;
}
string toString (void) const
{
stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
protected:
FacNode* fac_;
VarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
bool msgSended_;
double residual_;
};
typedef vector<SpLink*> SpLinkSet;
class SPNodeInfo
{
public:
void addSpLink (SpLink* link) { links_.push_back (link); }
const SpLinkSet& getLinks (void) { return links_; }
private:
SpLinkSet links_;
};
class BpSolver : public Solver
{
public:
BpSolver (const FactorGraph&);
virtual ~BpSolver (void);
Params solveQuery (VarIds);
virtual Params getPosterioriOf (VarId);
virtual Params getJointDistributionOf (const VarIds&);
protected:
void runSolver (void);
virtual void createLinks (void);
virtual void maxResidualSchedule (void);
virtual void calculateFactor2VariableMsg (SpLink*);
virtual Params getVar2FactorMsg (const SpLink*) const;
virtual Params getJointByConditioning (const VarIds&) const;
SPNodeInfo* ninf (const VarNode* var) const
{
return varsI_[var->getIndex()];
}
SPNodeInfo* ninf (const FacNode* fac) const
{
return facsI_[fac->getIndex()];
}
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
{
if (Constants::DEBUG >= 3) {
cout << "calculating & updating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void calculateMessage (SpLink* link, bool calcResidual = true)
{
if (Constants::DEBUG >= 3) {
cout << "calculating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
}
void updateMessage (SpLink* link)
{
link->updateMessage();
if (Constants::DEBUG >= 3) {
cout << "updating " << link->toString() << endl;
}
}
struct CompareResidual
{
inline bool operator() (const SpLink* link1, const SpLink* link2)
{
return link1->getResidual() > link2->getResidual();
}
};
SpLinkSet links_;
unsigned nIters_;
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_;
bool runned_;
const FactorGraph* fg_;
typedef multiset<SpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_;
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
SpLinkMap linkMap_;
private:
void initializeSolver (void);
bool converged (void);
void printLinkInformation (void) const;
};
#endif // HORUS_BPSOLVER_H

View File

@ -1,7 +1,6 @@
#include "CFactorGraph.h" #include "CFactorGraph.h"
#include "Factor.h" #include "Factor.h"
#include "Distribution.h"
bool CFactorGraph::checkForIdenticalFactors = true; bool CFactorGraph::checkForIdenticalFactors = true;
@ -11,22 +10,22 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg)
groundFg_ = &fg; groundFg_ = &fg;
freeColor_ = 0; freeColor_ = 0;
const FgVarSet& varNodes = fg.getVarNodes(); const VarNodes& varNodes = fg.varNodes();
varSignatures_.reserve (varNodes.size()); varSignatures_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) { for (unsigned i = 0; i < varNodes.size(); i++) {
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1; unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
varSignatures_.push_back (Signature (c)); varSignatures_.push_back (Signature (c));
} }
const FgFacSet& facNodes = fg.getFactorNodes(); const FacNodes& facNodes = fg.facNodes();
factorSignatures_.reserve (facNodes.size()); facSignatures_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned c = facNodes[i]->neighbors().size() + 1; unsigned c = facNodes[i]->neighbors().size() + 1;
factorSignatures_.push_back (Signature (c)); facSignatures_.push_back (Signature (c));
} }
varColors_.resize (varNodes.size()); varColors_.resize (varNodes.size());
factorColors_.resize (facNodes.size()); facColors_.resize (facNodes.size());
setInitialColors(); setInitialColors();
createGroups(); createGroups();
} }
@ -50,9 +49,9 @@ CFactorGraph::setInitialColors (void)
{ {
// create the initial variable colors // create the initial variable colors
VarColorMap colorMap; VarColorMap colorMap;
const FgVarSet& varNodes = groundFg_->getVarNodes(); const VarNodes& varNodes = groundFg_->varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) { for (unsigned i = 0; i < varNodes.size(); i++) {
unsigned dsize = varNodes[i]->nrStates(); unsigned dsize = varNodes[i]->range();
VarColorMap::iterator it = colorMap.find (dsize); VarColorMap::iterator it = colorMap.find (dsize);
if (it == colorMap.end()) { if (it == colorMap.end()) {
it = colorMap.insert (make_pair ( it = colorMap.insert (make_pair (
@ -71,29 +70,40 @@ CFactorGraph::setInitialColors (void)
setColor (varNodes[i], stateColors[idx]); setColor (varNodes[i], stateColors[idx]);
} }
const FgFacSet& facNodes = groundFg_->getFactorNodes(); const FacNodes& facNodes = groundFg_->facNodes();
if (checkForIdenticalFactors) { for (unsigned i = 0; i < facNodes.size(); i++) {
for (unsigned i = 0, s = facNodes.size(); i < s; i++) { facNodes[i]->factor().setDistId (Util::maxUnsigned());
Distribution* dist1 = facNodes[i]->getDistribution(); }
for (unsigned j = 0; j < i; j++) { // FIXME FIXME FIXME : pfl should give correct dist ids.
Distribution* dist2 = facNodes[j]->getDistribution(); if (checkForIdenticalFactors || true) {
if (dist1 != dist2 && dist1->params == dist2->params) { unsigned groupCount = 1;
if (facNodes[i]->factor()->getRanges() == for (unsigned i = 0; i < facNodes.size(); i++) {
facNodes[j]->factor()->getRanges()) { Factor& f1 = facNodes[i]->factor();
facNodes[i]->factor()->setDistribution (dist2); if (f1.distId() != Util::maxUnsigned()) {
} continue;
}
f1.setDistId (groupCount);
for (unsigned j = i + 1; j < facNodes.size(); j++) {
Factor& f2 = facNodes[j]->factor();
if (f2.distId() != Util::maxUnsigned()) {
continue;
}
if (f1.size() == f2.size() &&
f1.ranges() == f2.ranges() &&
f1.params() == f2.params()) {
f2.setDistId (groupCount);
} }
} }
groupCount ++;
} }
} }
// create the initial factor colors // create the initial factor colors
DistColorMap distColors; DistColorMap distColors;
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
const Distribution* dist = facNodes[i]->getDistribution(); unsigned distId = facNodes[i]->factor().distId();
DistColorMap::iterator it = distColors.find (dist); DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) { if (it == distColors.end()) {
it = distColors.insert (make_pair (dist, getFreeColor())).first; it = distColors.insert (make_pair (distId, getFreeColor())).first;
} }
setColor (facNodes[i], it->second); setColor (facNodes[i], it->second);
} }
@ -104,31 +114,31 @@ CFactorGraph::setInitialColors (void)
void void
CFactorGraph::createGroups (void) CFactorGraph::createGroups (void)
{ {
VarSignMap varGroups; VarSignMap varGroups;
FacSignMap factorGroups; FacSignMap facGroups;
unsigned nIters = 0; unsigned nIters = 0;
bool groupsHaveChanged = true; bool groupsHaveChanged = true;
const FgVarSet& varNodes = groundFg_->getVarNodes(); const VarNodes& varNodes = groundFg_->varNodes();
const FgFacSet& facNodes = groundFg_->getFactorNodes(); const FacNodes& facNodes = groundFg_->facNodes();
while (groupsHaveChanged || nIters == 1) { while (groupsHaveChanged || nIters == 1) {
nIters ++; nIters ++;
unsigned prevFactorGroupsSize = factorGroups.size(); unsigned prevFactorGroupsSize = facGroups.size();
factorGroups.clear(); facGroups.clear();
// set a new color to the factors with the same signature // set a new color to the factors with the same signature
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
const Signature& signature = getSignature (facNodes[i]); const Signature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = factorGroups.find (signature); FacSignMap::iterator it = facGroups.find (signature);
if (it == factorGroups.end()) { if (it == facGroups.end()) {
it = factorGroups.insert (make_pair (signature, FgFacSet())).first; it = facGroups.insert (make_pair (signature, FacNodes())).first;
} }
it->second.push_back (facNodes[i]); it->second.push_back (facNodes[i]);
} }
for (FacSignMap::iterator it = factorGroups.begin(); for (FacSignMap::iterator it = facGroups.begin();
it != factorGroups.end(); it++) { it != facGroups.end(); it++) {
Color newColor = getFreeColor(); Color newColor = getFreeColor();
FgFacSet& groupMembers = it->second; FacNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) { for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor); setColor (groupMembers[i], newColor);
} }
@ -141,36 +151,37 @@ CFactorGraph::createGroups (void)
const Signature& signature = getSignature (varNodes[i]); const Signature& signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature); VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) { if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signature, FgVarSet())).first; it = varGroups.insert (make_pair (signature, VarNodes())).first;
} }
it->second.push_back (varNodes[i]); it->second.push_back (varNodes[i]);
} }
for (VarSignMap::iterator it = varGroups.begin(); for (VarSignMap::iterator it = varGroups.begin();
it != varGroups.end(); it++) { it != varGroups.end(); it++) {
Color newColor = getFreeColor(); Color newColor = getFreeColor();
FgVarSet& groupMembers = it->second; VarNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) { for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor); setColor (groupMembers[i], newColor);
} }
} }
groupsHaveChanged = prevVarGroupsSize != varGroups.size() groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != factorGroups.size(); || prevFactorGroupsSize != facGroups.size();
} }
//printGroups (varGroups, factorGroups); printGroups (varGroups, facGroups);
createClusters (varGroups, factorGroups); createClusters (varGroups, facGroups);
} }
void void
CFactorGraph::createClusters (const VarSignMap& varGroups, CFactorGraph::createClusters (
const FacSignMap& factorGroups) const VarSignMap& varGroups,
const FacSignMap& facGroups)
{ {
varClusters_.reserve (varGroups.size()); varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin(); for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); it++) { it != varGroups.end(); it++) {
const FgVarSet& groupVars = it->second; const VarNodes& groupVars = it->second;
VarCluster* vc = new VarCluster (groupVars); VarCluster* vc = new VarCluster (groupVars);
for (unsigned i = 0; i < groupVars.size(); i++) { for (unsigned i = 0; i < groupVars.size(); i++) {
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc)); vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
@ -178,12 +189,12 @@ CFactorGraph::createClusters (const VarSignMap& varGroups,
varClusters_.push_back (vc); varClusters_.push_back (vc);
} }
facClusters_.reserve (factorGroups.size()); facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = factorGroups.begin(); for (FacSignMap::const_iterator it = facGroups.begin();
it != factorGroups.end(); it++) { it != facGroups.end(); it++) {
FgFacNode* groupFactor = it->second[0]; FacNode* groupFactor = it->second[0];
const FgVarSet& neighs = groupFactor->neighbors(); const VarNodes& neighs = groupFactor->neighbors();
VarClusterSet varClusters; VarClusters varClusters;
varClusters.reserve (neighs.size()); varClusters.reserve (neighs.size());
for (unsigned i = 0; i < neighs.size(); i++) { for (unsigned i = 0; i < neighs.size(); i++) {
VarId vid = neighs[i]->varId(); VarId vid = neighs[i]->varId();
@ -196,15 +207,15 @@ CFactorGraph::createClusters (const VarSignMap& varGroups,
const Signature& const Signature&
CFactorGraph::getSignature (const FgVarNode* varNode) CFactorGraph::getSignature (const VarNode* varNode)
{ {
Signature& sign = varSignatures_[varNode->getIndex()]; Signature& sign = varSignatures_[varNode->getIndex()];
vector<Color>::iterator it = sign.colors.begin(); vector<Color>::iterator it = sign.colors.begin();
const FgFacSet& neighs = varNode->neighbors(); const FacNodes& neighs = varNode->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) { for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]); *it = getColor (neighs[i]);
it ++; it ++;
*it = neighs[i]->factor()->indexOf (varNode->varId()); *it = neighs[i]->factor().indexOf (varNode->varId());
it ++; it ++;
} }
*it = getColor (varNode); *it = getColor (varNode);
@ -214,11 +225,11 @@ CFactorGraph::getSignature (const FgVarNode* varNode)
const Signature& const Signature&
CFactorGraph::getSignature (const FgFacNode* facNode) CFactorGraph::getSignature (const FacNode* facNode)
{ {
Signature& sign = factorSignatures_[facNode->getIndex()]; Signature& sign = facSignatures_[facNode->getIndex()];
vector<Color>::iterator it = sign.colors.begin(); vector<Color>::iterator it = sign.colors.begin();
const FgVarSet& neighs = facNode->neighbors(); const VarNodes& neighs = facNode->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) { for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]); *it = getColor (neighs[i]);
it ++; it ++;
@ -230,55 +241,53 @@ CFactorGraph::getSignature (const FgFacNode* facNode)
FactorGraph* FactorGraph*
CFactorGraph::getCompressedFactorGraph (void) CFactorGraph::getGroundFactorGraph (void) const
{ {
FactorGraph* fg = new FactorGraph(); FactorGraph* fg = new FactorGraph();
for (unsigned i = 0; i < varClusters_.size(); i++) { for (unsigned i = 0; i < varClusters_.size(); i++) {
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0]; VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
FgVarNode* newVar = new FgVarNode (var); VarNode* newVar = new VarNode (var);
varClusters_[i]->setRepresentativeVariable (newVar); varClusters_[i]->setRepresentativeVariable (newVar);
fg->addVariable (newVar); fg->addVarNode (newVar);
} }
for (unsigned i = 0; i < facClusters_.size(); i++) { for (unsigned i = 0; i < facClusters_.size(); i++) {
const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters(); const VarClusters& myVarClusters = facClusters_[i]->getVarClusters();
VarNodes myGroundVars; Vars myGroundVars;
myGroundVars.reserve (myVarClusters.size()); myGroundVars.reserve (myVarClusters.size());
for (unsigned j = 0; j < myVarClusters.size(); j++) { for (unsigned j = 0; j < myVarClusters.size(); j++) {
FgVarNode* v = myVarClusters[j]->getRepresentativeVariable(); VarNode* v = myVarClusters[j]->getRepresentativeVariable();
myGroundVars.push_back (v); myGroundVars.push_back (v);
} }
Factor* newFactor = new Factor (myGroundVars, FacNode* fn = new FacNode (Factor (myGroundVars,
facClusters_[i]->getGroundFactors()[0]->getDistribution()); facClusters_[i]->getGroundFactors()[0]->factor().params()));
FgFacNode* fn = new FgFacNode (newFactor);
facClusters_[i]->setRepresentativeFactor (fn); facClusters_[i]->setRepresentativeFactor (fn);
fg->addFactor (fn); fg->addFacNode (fn);
for (unsigned j = 0; j < myGroundVars.size(); j++) { for (unsigned j = 0; j < myGroundVars.size(); j++) {
fg->addEdge (fn, static_cast<FgVarNode*> (myGroundVars[j])); fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
} }
} }
fg->setIndexes();
return fg; return fg;
} }
unsigned unsigned
CFactorGraph::getGroundEdgeCount ( CFactorGraph::getEdgeCount (
const FacCluster* fc, const FacCluster* fc,
const VarCluster* vc) const const VarCluster* vc) const
{ {
const FgFacSet& clusterGroundFactors = fc->getGroundFactors();
FgVarNode* varNode = vc->getGroundFgVarNodes()[0];
unsigned count = 0; unsigned count = 0;
VarId vid = vc->getGroundVarNodes().front()->varId();
const FacNodes& clusterGroundFactors = fc->getGroundFactors();
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) { for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
if (clusterGroundFactors[i]->factor()->indexOf (varNode->varId()) != -1) { if (clusterGroundFactors[i]->factor().contains (vid)) {
count ++; count ++;
} }
} }
// CFgVarSet vars = vc->getGroundFgVarNodes(); // CVarNodes vars = vc->getGroundVarNodes();
// for (unsigned i = 1; i < vars.size(); i++) { // for (unsigned i = 1; i < vars.size(); i++) {
// FgVarNode* var = vc->getGroundFgVarNodes()[i]; // VarNode* var = vc->getGroundVarNodes()[i];
// unsigned count2 = 0; // unsigned count2 = 0;
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) { // for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
// if (clusterGroundFactors[i]->getPosition (var) != -1) { // if (clusterGroundFactors[i]->getPosition (var) != -1) {
@ -293,14 +302,15 @@ CFactorGraph::getGroundEdgeCount (
void void
CFactorGraph::printGroups (const VarSignMap& varGroups, CFactorGraph::printGroups (
const FacSignMap& factorGroups) const const VarSignMap& varGroups,
const FacSignMap& facGroups) const
{ {
unsigned count = 1; unsigned count = 1;
cout << "variable groups:" << endl; cout << "variable groups:" << endl;
for (VarSignMap::const_iterator it = varGroups.begin(); for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); it++) { it != varGroups.end(); it++) {
const FgVarSet& groupMembers = it->second; const VarNodes& groupMembers = it->second;
if (groupMembers.size() > 0) { if (groupMembers.size() > 0) {
cout << count << ": " ; cout << count << ": " ;
for (unsigned i = 0; i < groupMembers.size(); i++) { for (unsigned i = 0; i < groupMembers.size(); i++) {
@ -313,9 +323,9 @@ CFactorGraph::printGroups (const VarSignMap& varGroups,
count = 1; count = 1;
cout << endl << "factor groups:" << endl; cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = factorGroups.begin(); for (FacSignMap::const_iterator it = facGroups.begin();
it != factorGroups.end(); it++) { it != facGroups.end(); it++) {
const FgFacSet& groupMembers = it->second; const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) { if (groupMembers.size() > 0) {
cout << ++count << ": " ; cout << ++count << ": " ;
for (unsigned i = 0; i < groupMembers.size(); i++) { for (unsigned i = 0; i < groupMembers.size(); i++) {

View File

@ -15,23 +15,25 @@ class Signature;
class SignatureHash; class SignatureHash;
typedef long Color; typedef long Color;
typedef unordered_map<unsigned, vector<Color> > VarColorMap;
typedef unordered_map<const Distribution*, Color> DistColorMap; typedef unordered_map<unsigned, vector<Color>> VarColorMap;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
typedef vector<VarCluster*> VarClusterSet; typedef unordered_map<unsigned, Color> DistColorMap;
typedef vector<FacCluster*> FacClusterSet; typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap; typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters;
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
struct Signature struct Signature
{ {
Signature (unsigned size) Signature (unsigned size) : colors(size) { }
{
colors.resize (size);
}
bool operator< (const Signature& sig) const bool operator< (const Signature& sig) const
{ {
if (colors.size() < sig.colors.size()) { if (colors.size() < sig.colors.size()) {
@ -49,6 +51,7 @@ struct Signature
} }
return false; return false;
} }
bool operator== (const Signature& sig) const bool operator== (const Signature& sig) const
{ {
if (colors.size() != sig.colors.size()) { if (colors.size() != sig.colors.size()) {
@ -61,12 +64,14 @@ struct Signature
} }
return true; return true;
} }
vector<Color> colors; vector<Color> colors;
}; };
struct SignatureHash { struct SignatureHash
{
size_t operator() (const Signature &sig) const size_t operator() (const Signature &sig) const
{ {
size_t val = hash<size_t>()(sig.colors.size()); size_t val = hash<size_t>()(sig.colors.size());
@ -82,7 +87,7 @@ struct SignatureHash {
class VarCluster class VarCluster
{ {
public: public:
VarCluster (const FgVarSet& vs) VarCluster (const VarNodes& vs)
{ {
for (unsigned i = 0; i < vs.size(); i++) { for (unsigned i = 0; i < vs.size(); i++) {
groundVars_.push_back (vs[i]); groundVars_.push_back (vs[i]);
@ -94,26 +99,28 @@ class VarCluster
facClusters_.push_back (fc); facClusters_.push_back (fc);
} }
const FacClusterSet& getFacClusters (void) const const FacClusters& getFacClusters (void) const
{ {
return facClusters_; return facClusters_;
} }
FgVarNode* getRepresentativeVariable (void) const { return representVar_; } VarNode* getRepresentativeVariable (void) const { return representVar_; }
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
const FgVarSet& getGroundFgVarNodes (void) const { return groundVars_; } void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
private: private:
FgVarSet groundVars_; VarNodes groundVars_;
FacClusterSet facClusters_; FacClusters facClusters_;
FgVarNode* representVar_; VarNode* representVar_;
}; };
class FacCluster class FacCluster
{ {
public: public:
FacCluster (const FgFacSet& groundFactors, const VarClusterSet& vcs) FacCluster (const FacNodes& groundFactors, const VarClusters& vcs)
{ {
groundFactors_ = groundFactors; groundFactors_ = groundFactors;
varClusters_ = vcs; varClusters_ = vcs;
@ -122,12 +129,12 @@ class FacCluster
} }
} }
const VarClusterSet& getVarClusters (void) const const VarClusters& getVarClusters (void) const
{ {
return varClusters_; return varClusters_;
} }
bool containsGround (const FgFacNode* fn) bool containsGround (const FacNode* fn)
{ {
for (unsigned i = 0; i < groundFactors_.size(); i++) { for (unsigned i = 0; i < groundFactors_.size(); i++) {
if (groundFactors_[i] == fn) { if (groundFactors_[i] == fn) {
@ -137,24 +144,26 @@ class FacCluster
return false; return false;
} }
FgFacNode* getRepresentativeFactor (void) const FacNode* getRepresentativeFactor (void) const
{ {
return representFactor_; return representFactor_;
} }
void setRepresentativeFactor (FgFacNode* fn)
void setRepresentativeFactor (FacNode* fn)
{ {
representFactor_ = fn; representFactor_ = fn;
} }
const FgFacSet& getGroundFactors (void) const
const FacNodes& getGroundFactors (void) const
{ {
return groundFactors_; return groundFactors_;
} }
private: private:
FgFacSet groundFactors_; FacNodes groundFactors_;
VarClusterSet varClusters_; VarClusters varClusters_;
FgFacNode* representFactor_; FacNode* representFactor_;
}; };
@ -162,51 +171,48 @@ class CFactorGraph
{ {
public: public:
CFactorGraph (const FactorGraph&); CFactorGraph (const FactorGraph&);
~CFactorGraph (void); ~CFactorGraph (void);
FactorGraph* getCompressedFactorGraph (void); const VarClusters& getVarClusters (void) { return varClusters_; }
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
FgVarNode* getEquivalentVariable (VarId vid) const FacClusters& getFacClusters (void) { return facClusters_; }
VarNode* getEquivalentVariable (VarId vid)
{ {
VarCluster* vc = vid2VarCluster_.find (vid)->second; VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->getRepresentativeVariable(); return vc->getRepresentativeVariable();
} }
const VarClusterSet& getVarClusters (void) { return varClusters_; }
const FacClusterSet& getFacClusters (void) { return facClusters_; }
FactorGraph* getGroundFactorGraph (void) const;
unsigned getEdgeCount (const FacCluster*, const VarCluster*) const;
static bool checkForIdenticalFactors; static bool checkForIdenticalFactors;
private: private:
void setInitialColors (void); Color getFreeColor (void)
void createGroups (void); {
void createClusters (const VarSignMap&, const FacSignMap&);
const Signature& getSignature (const FgVarNode*);
const Signature& getSignature (const FgFacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
Color getFreeColor (void) {
++ freeColor_; ++ freeColor_;
return freeColor_ - 1; return freeColor_ - 1;
} }
Color getColor (const FgVarNode* vn) const Color getColor (const VarNode* vn) const
{ {
return varColors_[vn->getIndex()]; return varColors_[vn->getIndex()];
} }
Color getColor (const FgFacNode* fn) const { Color getColor (const FacNode* fn) const {
return factorColors_[fn->getIndex()]; return facColors_[fn->getIndex()];
} }
void setColor (const FgVarNode* vn, Color c) void setColor (const VarNode* vn, Color c)
{ {
varColors_[vn->getIndex()] = c; varColors_[vn->getIndex()] = c;
} }
void setColor (const FgFacNode* fn, Color c) void setColor (const FacNode* fn, Color c)
{ {
factorColors_[fn->getIndex()] = c; facColors_[fn->getIndex()] = c;
} }
VarCluster* getVariableCluster (VarId vid) const VarCluster* getVariableCluster (VarId vid) const
@ -214,14 +220,26 @@ class CFactorGraph
return vid2VarCluster_.find (vid)->second; return vid2VarCluster_.find (vid)->second;
} }
void setInitialColors (void);
void createGroups (void);
void createClusters (const VarSignMap&, const FacSignMap&);
const Signature& getSignature (const VarNode*);
const Signature& getSignature (const FacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
Color freeColor_; Color freeColor_;
vector<Color> varColors_; vector<Color> varColors_;
vector<Color> factorColors_; vector<Color> facColors_;
vector<Signature> varSignatures_; vector<Signature> varSignatures_;
vector<Signature> factorSignatures_; vector<Signature> facSignatures_;
VarClusterSet varClusters_; VarClusters varClusters_;
FacClusterSet facClusters_; FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_; VarId2VarCluster vid2VarCluster_;
const FactorGraph* groundFg_; const FactorGraph* groundFg_;
}; };

View File

@ -1,10 +1,41 @@
#include "CbpSolver.h" #include "CbpSolver.h"
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
{
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
if (Constants::COLLECT_STATS) {
nGroundVars = fg_->varNodes().size();
nGroundFacs = fg_->facNodes().size();
const VarNodes& vars = fg_->varNodes();
nWithoutNeighs = 0;
for (unsigned i = 0; i < vars.size(); i++) {
const FacNodes& factors = vars[i]->neighbors();
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
nWithoutNeighs ++;
}
}
}
cfg_ = new CFactorGraph (fg);
fg_ = cfg_->getGroundFactorGraph();
if (Constants::COLLECT_STATS) {
unsigned nClusterVars = fg_->varNodes().size();
unsigned nClusterFacs = fg_->facNodes().size();
Statistics::updateCompressingStatistics (nGroundVars,
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
}
Util::printHeader ("Uncompressed Factor Graph");
fg.print();
Util::printHeader ("Compressed Factor Graph");
fg_->print();
}
CbpSolver::~CbpSolver (void) CbpSolver::~CbpSolver (void)
{ {
delete lfg_; delete cfg_;
delete factorGraph_; delete fg_;
for (unsigned i = 0; i < links_.size(); i++) { for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i]; delete links_[i];
} }
@ -16,28 +47,31 @@ CbpSolver::~CbpSolver (void)
Params Params
CbpSolver::getPosterioriOf (VarId vid) CbpSolver::getPosterioriOf (VarId vid)
{ {
assert (lfg_->getEquivalentVariable (vid)); if (runned_ == false) {
FgVarNode* var = lfg_->getEquivalentVariable (vid); runSolver();
}
assert (cfg_->getEquivalentVariable (vid));
VarNode* var = cfg_->getEquivalentVariable (vid);
Params probs; Params probs;
if (var->hasEvidence()) { if (var->hasEvidence()) {
probs.resize (var->nrStates(), Util::noEvidence()); probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = Util::withEvidence(); probs[var->getEvidence()] = LogAware::withEvidence();
} else { } else {
probs.resize (var->nrStates(), Util::multIdenty()); probs.resize (var->range(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks(); const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) { if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (probs, l->getPoweredMessage()); Util::add (probs, l->poweredMessage());
} }
Util::normalize (probs); LogAware::normalize (probs);
Util::fromLog (probs); Util::fromLog (probs);
} else { } else {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (probs, l->getPoweredMessage()); Util::multiply (probs, l->poweredMessage());
} }
Util::normalize (probs); LogAware::normalize (probs);
} }
} }
return probs; return probs;
@ -46,55 +80,14 @@ CbpSolver::getPosterioriOf (VarId vid)
Params Params
CbpSolver::getJointDistributionOf (const VarIds& jointVarIds) CbpSolver::getJointDistributionOf (const VarIds& jointVids)
{ {
VarIds eqVarIds; VarIds eqVarIds;
for (unsigned i = 0; i < jointVarIds.size(); i++) { for (unsigned i = 0; i < jointVids.size(); i++) {
eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId()); VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
eqVarIds.push_back (vn->varId());
} }
return FgBpSolver::getJointDistributionOf (eqVarIds); return BpSolver::getJointDistributionOf (eqVarIds);
}
void
CbpSolver::initializeSolver (void)
{
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
if (COLLECT_STATISTICS) {
nGroundVars = factorGraph_->getVarNodes().size();
nGroundFacs = factorGraph_->getFactorNodes().size();
const FgVarSet& vars = factorGraph_->getVarNodes();
nWithoutNeighs = 0;
for (unsigned i = 0; i < vars.size(); i++) {
const FgFacSet& factors = vars[i]->neighbors();
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
nWithoutNeighs ++;
}
}
}
lfg_ = new CFactorGraph (*factorGraph_);
// cout << "Uncompressed Factor Graph" << endl;
// factorGraph_->printGraphicalModel();
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
factorGraph_ = lfg_->getCompressedFactorGraph();
if (COLLECT_STATISTICS) {
unsigned nClusterVars = factorGraph_->getVarNodes().size();
unsigned nClusterFacs = factorGraph_->getFactorNodes().size();
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
nClusterVars, nClusterFacs,
nWithoutNeighs);
}
// cout << "Compressed Factor Graph" << endl;
// factorGraph_->printGraphicalModel();
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
// abort();
FgBpSolver::initializeSolver();
} }
@ -102,12 +95,13 @@ CbpSolver::initializeSolver (void)
void void
CbpSolver::createLinks (void) CbpSolver::createLinks (void)
{ {
const FacClusterSet fcs = lfg_->getFacClusters(); const FacClusters& fcs = cfg_->getFacClusters();
for (unsigned i = 0; i < fcs.size(); i++) { for (unsigned i = 0; i < fcs.size(); i++) {
const VarClusterSet vcs = fcs[i]->getVarClusters(); const VarClusters& vcs = fcs[i]->getVarClusters();
for (unsigned j = 0; j < vcs.size(); j++) { for (unsigned j = 0; j < vcs.size(); j++) {
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]); unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]);
links_.push_back (new CbpSolverLink (fcs[i]->getRepresentativeFactor(), links_.push_back (new CbpSolverLink (
fcs[i]->getRepresentativeFactor(),
vcs[j]->getRepresentativeVariable(), c)); vcs[j]->getRepresentativeVariable(), c));
} }
} }
@ -123,7 +117,7 @@ CbpSolver::maxResidualSchedule (void)
calculateMessage (links_[i]); calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]); SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it)); linkMap_.insert (make_pair (links_[i], it));
if (DL >= 2 && DL < 5) { if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
cout << "calculating " << links_[i]->toString() << endl; cout << "calculating " << links_[i]->toString() << endl;
} }
} }
@ -131,7 +125,7 @@ CbpSolver::maxResidualSchedule (void)
} }
for (unsigned c = 0; c < links_.size(); c++) { for (unsigned c = 0; c < links_.size(); c++) {
if (DL >= 2) { if (Constants::DEBUG >= 2) {
cout << endl << "current residuals:" << endl; cout << endl << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin(); for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) { it != sortedOrder_.end(); it ++) {
@ -142,7 +136,7 @@ CbpSolver::maxResidualSchedule (void)
SortedOrder::iterator it = sortedOrder_.begin(); SortedOrder::iterator it = sortedOrder_.begin();
SpLink* link = *it; SpLink* link = *it;
if (DL >= 2) { if (Constants::DEBUG >= 2) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl; cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
} }
if (link->getResidual() < BpOptions::accuracy) { if (link->getResidual() < BpOptions::accuracy) {
@ -154,12 +148,12 @@ CbpSolver::maxResidualSchedule (void)
linkMap_.find (link)->second = sortedOrder_.insert (link); linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin // update the messages that depend on message source --> destin
const FgFacSet& factorNeighbors = link->getVariable()->neighbors(); const FacNodes& factorNeighbors = link->getVariable()->neighbors();
for (unsigned i = 0; i < factorNeighbors.size(); i++) { for (unsigned i = 0; i < factorNeighbors.size(); i++) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks(); const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
for (unsigned j = 0; j < links.size(); j++) { for (unsigned j = 0; j < links.size(); j++) {
if (links[j]->getVariable() != link->getVariable()) { if (links[j]->getVariable() != link->getVariable()) {
if (DL >= 2 && DL < 5) { if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
cout << " calculating " << links[j]->toString() << endl; cout << " calculating " << links[j]->toString() << endl;
} }
calculateMessage (links[j]); calculateMessage (links[j]);
@ -174,7 +168,7 @@ CbpSolver::maxResidualSchedule (void)
const SpLinkSet& links = ninf(link->getFactor())->getLinks(); const SpLinkSet& links = ninf(link->getFactor())->getLinks();
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getVariable() != link->getVariable()) { if (links[i]->getVariable() != link->getVariable()) {
if (DL >= 2 && DL < 5) { if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
cout << " calculating " << links[i]->toString() << endl; cout << " calculating " << links[i]->toString() << endl;
} }
calculateMessage (links[i]); calculateMessage (links[i]);
@ -192,43 +186,43 @@ Params
CbpSolver::getVar2FactorMsg (const SpLink* link) const CbpSolver::getVar2FactorMsg (const SpLink* link) const
{ {
Params msg; Params msg;
const FgVarNode* src = link->getVariable(); const VarNode* src = link->getVariable();
const FgFacNode* dst = link->getFactor(); const FacNode* dst = link->getFactor();
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link); const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
if (src->hasEvidence()) { if (src->hasEvidence()) {
msg.resize (src->nrStates(), Util::noEvidence()); msg.resize (src->range(), LogAware::noEvidence());
double value = link->getMessage()[src->getEvidence()]; double value = link->getMessage()[src->getEvidence()];
msg[src->getEvidence()] = Util::pow (value, l->getNumberOfEdges() - 1); msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
} else { } else {
msg = link->getMessage(); msg = link->getMessage();
Util::pow (msg, l->getNumberOfEdges() - 1); LogAware::pow (msg, l->nrEdges() - 1);
} }
if (DL >= 5) { if (Constants::DEBUG >= 5) {
cout << " " << "init: " << Util::parametersToString (msg) << endl; cout << " " << "init: " << msg << endl;
} }
const SpLinkSet& links = ninf(src)->getLinks(); const SpLinkSet& links = ninf(src)->getLinks();
if (Globals::logDomain) { if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) { if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (msg, l->getPoweredMessage()); Util::add (msg, l->poweredMessage());
} }
} }
} else { } else {
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) { if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (msg, l->getPoweredMessage()); Util::multiply (msg, l->poweredMessage());
if (DL >= 5) { if (Constants::DEBUG >= 5) {
cout << " msg from " << l->getFactor()->getLabel() << ": " ; cout << " msg from " << l->getFactor()->getLabel() << ": " ;
cout << Util::parametersToString (l->getPoweredMessage()) << endl; cout << l->poweredMessage() << endl;
} }
} }
} }
} }
if (DL >= 5) { if (Constants::DEBUG >= 5) {
cout << " result = " << Util::parametersToString (msg) << endl; cout << " result = " << msg << endl;
} }
return msg; return msg;
} }
@ -241,12 +235,9 @@ CbpSolver::printLinkInformation (void) const
for (unsigned i = 0; i < links_.size(); i++) { for (unsigned i = 0; i < links_.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]); CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
cout << l->toString() << ":" << endl; cout << l->toString() << ":" << endl;
cout << " curr msg = " ; cout << " curr msg = " << l->getMessage() << endl;
cout << Util::parametersToString (l->getMessage()) << endl; cout << " next msg = " << l->getNextMessage() << endl;
cout << " next msg = " ; cout << " powered = " << l->poweredMessage() << endl;
cout << Util::parametersToString (l->getNextMessage()) << endl;
cout << " powered = " ;
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
cout << " residual = " << l->getResidual() << endl; cout << " residual = " << l->getResidual() << endl;
} }
} }

View File

@ -1,7 +1,7 @@
#ifndef HORUS_CBP_H #ifndef HORUS_CBP_H
#define HORUS_CBP_H #define HORUS_CBP_H
#include "FgBpSolver.h" #include "BpSolver.h"
#include "CFactorGraph.h" #include "CFactorGraph.h"
class Factor; class Factor;
@ -9,49 +9,51 @@ class Factor;
class CbpSolverLink : public SpLink class CbpSolverLink : public SpLink
{ {
public: public:
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn) CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
{ : SpLink (fn, vn), nrEdges_(c),
edgeCount_ = c; pwdMsg_(vn->range(), LogAware::one()) { }
poweredMsg_.resize (vn->nrStates(), Util::one());
} unsigned nrEdges (void) const { return nrEdges_; }
const Params& poweredMessage (void) const { return pwdMsg_; }
void updateMessage (void) void updateMessage (void)
{ {
poweredMsg_ = *nextMsg_; pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_); swap (currMsg_, nextMsg_);
msgSended_ = true; msgSended_ = true;
Util::pow (poweredMsg_, edgeCount_); LogAware::pow (pwdMsg_, nrEdges_);
} }
unsigned getNumberOfEdges (void) const { return edgeCount_; }
const Params& getPoweredMessage (void) const { return poweredMsg_; }
private: private:
Params poweredMsg_; unsigned nrEdges_;
unsigned edgeCount_; Params pwdMsg_;
}; };
class CbpSolver : public FgBpSolver class CbpSolver : public BpSolver
{ {
public: public:
CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { } CbpSolver (const FactorGraph& fg);
~CbpSolver (void);
Params getPosterioriOf (VarId); ~CbpSolver (void);
Params getJointDistributionOf (const VarIds&);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
private: private:
void initializeSolver (void);
void createLinks (void);
void maxResidualSchedule (void); void createLinks (void);
Params getVar2FactorMsg (const SpLink*) const;
void printLinkInformation (void) const;
void maxResidualSchedule (void);
CFactorGraph* lfg_; Params getVar2FactorMsg (const SpLink*) const;
void printLinkInformation (void) const;
CFactorGraph* cfg_;
}; };
#endif // HORUS_CBP_H #endif // HORUS_CBP_H

View File

@ -1,10 +1,11 @@
#include <queue> #include <queue>
#include <fstream>
#include "ConstraintTree.h" #include "ConstraintTree.h"
#include "Util.h" #include "Util.h"
void void
CTNode::addChild (CTNode* child, bool updateLevels) CTNode::addChild (CTNode* child, bool updateLevels)
{ {
@ -42,6 +43,34 @@ CTNode::removeChild (CTNode* child)
void
CTNode::removeChilds (void)
{
childs_.clear();
}
void
CTNode::removeAndDeleteChild (CTNode* child)
{
removeChild (child);
CTNode::deleteSubtree (child);
}
void
CTNode::removeAndDeleteAllChilds (void)
{
for (unsigned i = 0; i < childs_.size(); i++) {
deleteSubtree (childs_[i]);
}
childs_.clear();
}
SymbolSet SymbolSet
CTNode::childSymbols (void) const CTNode::childSymbols (void) const
{ {
@ -66,6 +95,32 @@ CTNode::updateChildLevels (CTNode* n, unsigned level)
CTNode*
CTNode::copySubtree (const CTNode* n)
{
CTNode* newNode = new CTNode (*n);
const CTNodes& childs = n->childs();
for (unsigned i = 0; i < childs.size(); i++) {
newNode->addChild (copySubtree (childs[i]));
}
return newNode;
}
void
CTNode::deleteSubtree (CTNode* n)
{
assert (n);
const CTNodes& childs = n->childs();
for (unsigned i = 0; i < childs.size(); i++) {
deleteSubtree (childs[i]);
}
delete n;
}
ostream& operator<< (ostream &out, const CTNode& n) ostream& operator<< (ostream &out, const CTNode& n)
{ {
// out << "(" << n.level() << ") " ; // out << "(" << n.level() << ") " ;
@ -75,6 +130,17 @@ ostream& operator<< (ostream &out, const CTNode& n)
ConstraintTree::ConstraintTree (unsigned nrLvs)
{
for (unsigned i = 0; i < nrLvs; i++) {
logVars_.push_back (LogVar (i));
}
root_ = new CTNode (0, 0);
logVarSet_ = LogVarSet (logVars_);
}
ConstraintTree::ConstraintTree (const LogVars& logVars) ConstraintTree::ConstraintTree (const LogVars& logVars)
{ {
root_ = new CTNode (0, 0); root_ = new CTNode (0, 0);
@ -99,7 +165,7 @@ ConstraintTree::ConstraintTree (const LogVars& logVars,
ConstraintTree::ConstraintTree (const ConstraintTree& ct) ConstraintTree::ConstraintTree (const ConstraintTree& ct)
{ {
root_ = copySubtree (ct.root_); root_ = CTNode::copySubtree (ct.root_);
logVars_ = ct.logVars_; logVars_ = ct.logVars_;
logVarSet_ = ct.logVarSet_; logVarSet_ = ct.logVarSet_;
} }
@ -108,7 +174,7 @@ ConstraintTree::ConstraintTree (const ConstraintTree& ct)
ConstraintTree::~ConstraintTree (void) ConstraintTree::~ConstraintTree (void)
{ {
deleteSubtree (root_); CTNode::deleteSubtree (root_);
} }
@ -200,21 +266,28 @@ ConstraintTree::moveToBottom (const LogVars& lvs)
void void
ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound) ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
{ {
if (logVarSet_.empty()) {
delete root_;
root_ = CTNode::copySubtree (ct->root());
logVars_ = ct->logVars();
logVarSet_ = ct->logVarSet();
return;
}
LogVarSet intersect = logVarSet_ & ct->logVarSet_; LogVarSet intersect = logVarSet_ & ct->logVarSet_;
if (intersect.empty()) { if (intersect.empty()) {
const CTNodes& childs = ct->root()->childs(); const CTNodes& childs = ct->root()->childs();
CTNodes leafs = getNodesAtLevel (getLevel (logVars_.back())); CTNodes leafs = getNodesAtLevel (getLevel (logVars_.back()));
for (unsigned i = 0; i < leafs.size(); i++) { for (unsigned i = 0; i < leafs.size(); i++) {
for (unsigned j = 0; j < childs.size(); j++) { for (unsigned j = 0; j < childs.size(); j++) {
leafs[i]->addChild (copySubtree (childs[j])); leafs[i]->addChild (CTNode::copySubtree (childs[j]));
} }
} }
logVars_.insert (logVars_.end(), ct->logVars_.begin(), ct->logVars_.end()); Util::addToVector (logVars_, ct->logVars_);
logVarSet_ |= ct->logVarSet_; logVarSet_ |= ct->logVarSet_;
} else { } else {
moveToBottom (intersect.elements()); moveToBottom (intersect.elements());
ct->moveToTop (intersect.elements()); ct->moveToTop (intersect.elements());
@ -222,25 +295,27 @@ ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
CTNodes nodes = getNodesAtLevel (level); CTNodes nodes = getNodesAtLevel (level);
Tuples tuples; Tuples tuples;
CTNodes continuationNodes; CTNodes continNodes;
getTuples (ct->root(), getTuples (ct->root(),
Tuples(), Tuples(),
intersect.size(), intersect.size(),
tuples, tuples,
continuationNodes); continNodes);
for (unsigned i = 0; i < tuples.size(); i++) { for (unsigned i = 0; i < tuples.size(); i++) {
bool tupleFounded = false; bool tupleFounded = false;
for (unsigned j = 0; j < nodes.size(); j++) { for (unsigned j = 0; j < nodes.size(); j++) {
tupleFounded |= join (nodes[j], tuples[i], 0, continuationNodes[i]); tupleFounded |= join (nodes[j], tuples[i], 0, continNodes[i]);
} }
if (assertWhenNotFound) { if (assertWhenNotFound) {
assert (tupleFounded); assert (tupleFounded);
} }
} }
LogVarSet newLvs = ct->logVarSet_ - intersect;
logVars_.insert (logVars_.end(), newLvs.begin(), newLvs.end()); LogVars newLvs (ct->logVars().begin() + intersect.size(),
logVarSet_ |= newLvs; ct->logVars().end());
Util::addToVector (logVars_, newLvs);
logVarSet_ |= LogVarSet (newLvs);
} }
} }
@ -280,6 +355,10 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new)
void void
ConstraintTree::applySubstitution (const Substitution& theta) ConstraintTree::applySubstitution (const Substitution& theta)
{ {
LogVars discardedLvs = theta.getDiscardedLogVars();
for (unsigned i = 0; i < discardedLvs.size(); i++) {
remove(discardedLvs[i]);
}
for (unsigned i = 0; i < logVars_.size(); i++) { for (unsigned i = 0; i < logVars_.size(); i++) {
logVars_[i] = theta.newNameFor (logVars_[i]); logVars_[i] = theta.newNameFor (logVars_[i]);
} }
@ -308,11 +387,7 @@ ConstraintTree::remove (const LogVarSet& X)
unsigned level = getLevel (X.front()) - 1; unsigned level = getLevel (X.front()) - 1;
CTNodes nodes = getNodesAtLevel (level); CTNodes nodes = getNodesAtLevel (level);
for (unsigned i = 0; i < nodes.size(); i++) { for (unsigned i = 0; i < nodes.size(); i++) {
CTNodes childs = nodes[i]->childs(); nodes[i]->removeAndDeleteAllChilds();
for (unsigned j = 0; j < childs.size(); j++) {
nodes[i]->removeChild (childs[j]);
deleteSubtree (childs[j]);
}
} }
logVars_.resize (logVars_.size() - X.size()); logVars_.resize (logVars_.size() - X.size());
logVarSet_ -= X; logVarSet_ -= X;
@ -545,16 +620,16 @@ ConstraintTree::split (
for (unsigned i = 0; i < commNodes.size(); i++) { for (unsigned i = 0; i < commNodes.size(); i++) {
commCt->root()->addChild (commNodes[i]); commCt->root()->addChild (commNodes[i]);
} }
//cout << commCt->tupleSet() << " + " ; // cout << commCt->tupleSet() << " + " ;
//cout << exclCt->tupleSet() << " = " ; // cout << exclCt->tupleSet() << " = " ;
//cout << tupleSet() << endl << endl; // cout << tupleSet() << endl << endl;
// if (((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet()) == false) { // if (((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet()) == false) {
// exportToGraphViz ("_fail.dot", true); // exportToGraphViz ("_fail.dot", true);
// commCt->exportToGraphViz ("_fail_comm.dot", true); // commCt->exportToGraphViz ("_fail_comm.dot", true);
// exclCt->exportToGraphViz ("_fail_excl.dot", true); // exclCt->exportToGraphViz ("_fail_excl.dot", true);
// } // }
assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet()); // assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty()); // assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
return {commCt, exclCt}; return {commCt, exclCt};
} }
@ -601,36 +676,32 @@ ConstraintTree::jointCountNormalize (
LogVar X_new1, LogVar X_new1,
LogVar X_new2) LogVar X_new2)
{ {
exportToGraphViz ("C.dot", true);
commCt->exportToGraphViz ("C_comm.dot", true);
exclCt->exportToGraphViz ("C_exlc.dot", true);
unsigned N = getConditionalCount (X); unsigned N = getConditionalCount (X);
cout << "My tuples: " << tupleSet() << endl; // cout << "My tuples: " << tupleSet() << endl;
cout << "CommCt tuples: " << commCt->tupleSet() << endl; // cout << "CommCt tuples: " << commCt->tupleSet() << endl;
cout << "ExclCt tuples: " << exclCt->tupleSet() << endl; // cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
cout << "Counted Lv: " << X << endl; // cout << "Counted Lv: " << X << endl;
cout << "Original N: " << N << endl; // cout << "X_new1: " << X_new1 << endl;
cout << endl; // cout << "X_new2: " << X_new2 << endl;
// cout << "Original N: " << N << endl;
// cout << endl;
ConstraintTrees normCts1 = commCt->countNormalize (X); ConstraintTrees normCts1 = commCt->countNormalize (X);
vector<unsigned> counts1 (normCts1.size()); vector<unsigned> counts1 (normCts1.size());
for (unsigned i = 0; i < normCts1.size(); i++) { for (unsigned i = 0; i < normCts1.size(); i++) {
counts1[i] = normCts1[i]->getConditionalCount (X); counts1[i] = normCts1[i]->getConditionalCount (X);
cout << "normCts1[" << i << "] #" << counts1[i] ; // cout << "normCts1[" << i << "] #" << counts1[i] ;
cout << " " << normCts1[i]->tupleSet() << endl; // cout << " " << normCts1[i]->tupleSet() << endl;
} }
ConstraintTrees normCts2 = exclCt->countNormalize (X); ConstraintTrees normCts2 = exclCt->countNormalize (X);
vector<unsigned> counts2 (normCts2.size()); vector<unsigned> counts2 (normCts2.size());
for (unsigned i = 0; i < normCts2.size(); i++) { for (unsigned i = 0; i < normCts2.size(); i++) {
counts2[i] = normCts2[i]->getConditionalCount (X); counts2[i] = normCts2[i]->getConditionalCount (X);
cout << "normCts2[" << i << "] #" << counts2[i] ; // cout << "normCts2[" << i << "] #" << counts2[i] ;
cout << " " << normCts2[i]->tupleSet() << endl; // cout << " " << normCts2[i]->tupleSet() << endl;
} }
cout << endl; // cout << endl;
cout << "1###### " << normCts1.size() << endl;
cout << "2###### " << normCts2.size() << endl;
ConstraintTree* excl1 = 0; ConstraintTree* excl1 = 0;
for (unsigned i = 0; i < normCts1.size(); i++) { for (unsigned i = 0; i < normCts1.size(); i++) {
@ -638,7 +709,7 @@ ConstraintTree::jointCountNormalize (
excl1 = normCts1[i]; excl1 = normCts1[i];
normCts1.erase (normCts1.begin() + i); normCts1.erase (normCts1.begin() + i);
counts1.erase (counts1.begin() + i); counts1.erase (counts1.begin() + i);
cout << ">joint-count(" << N << ",0)" << endl; // cout << "joint-count(" << N << ",0)" << endl;
break; break;
} }
} }
@ -649,22 +720,21 @@ ConstraintTree::jointCountNormalize (
excl2 = normCts2[i]; excl2 = normCts2[i];
normCts2.erase (normCts2.begin() + i); normCts2.erase (normCts2.begin() + i);
counts2.erase (counts2.begin() + i); counts2.erase (counts2.begin() + i);
cout << ">>joint-count(0," << N << ")" << endl; // cout << "joint-count(0," << N << ")" << endl;
break; break;
} }
} }
cout << "3###### " << normCts1.size() << endl;
cout << "4###### " << normCts2.size() << endl;
for (unsigned i = 0; i < normCts1.size(); i++) { for (unsigned i = 0; i < normCts1.size(); i++) {
unsigned j; unsigned j;
for (j = 0; counts1[i] + counts2[j] != N; j++) ; for (j = 0; counts1[i] + counts2[j] != N; j++) ;
cout << "joint-count(" << counts1[i] << "," << counts2[j] << ")" << endl; // cout << "joint-count(" << counts1[i] ;
// cout << "," << counts2[j] << ")" << endl;
const CTNodes& childs = normCts2[j]->root_->childs(); const CTNodes& childs = normCts2[j]->root_->childs();
for (unsigned k = 0; k < childs.size(); k++) { for (unsigned k = 0; k < childs.size(); k++) {
normCts1[i]->root_->addChild (childs[k]); normCts1[i]->root_->addChild (CTNode::copySubtree (childs[k]));
} }
delete normCts2[j];
} }
ConstraintTrees cts = normCts1; ConstraintTrees cts = normCts1;
@ -683,11 +753,6 @@ ConstraintTree::jointCountNormalize (
cts.push_back (excl2); cts.push_back (excl2);
} }
for (unsigned i = 0; i < cts.size(); i++) {
stringstream ss;
ss << "aaacts_" << i + 1 << ".dot" ;
cts[i]->exportToGraphViz (ss.str().c_str(), true);
}
return cts; return cts;
} }
@ -735,11 +800,11 @@ ConstraintTree::expand (LogVar X)
unsigned nrSymbols = getConditionalCount (X); unsigned nrSymbols = getConditionalCount (X);
for (unsigned i = 0; i < nodes.size(); i++) { for (unsigned i = 0; i < nodes.size(); i++) {
Symbols symbols; Symbols symbols;
CTNodes childs = nodes[i]->childs(); const CTNodes& childs = nodes[i]->childs();
for (unsigned j = 0; j < childs.size(); j++) { for (unsigned j = 0; j < childs.size(); j++) {
symbols.push_back (childs[j]->symbol()); symbols.push_back (childs[j]->symbol());
nodes[i]->removeChild (childs[j]);
} }
nodes[i]->removeAndDeleteAllChilds();
CTNode* prev = nodes[i]; CTNode* prev = nodes[i];
assert (symbols.size() == nrSymbols); assert (symbols.size() == nrSymbols);
for (unsigned j = 0; j < nrSymbols; j++) { for (unsigned j = 0; j < nrSymbols; j++) {
@ -768,7 +833,7 @@ ConstraintTree::ground (LogVar X)
ConstraintTrees cts; ConstraintTrees cts;
const CTNodes& nodes = root_->childs(); const CTNodes& nodes = root_->childs();
for (unsigned i = 0; i < nodes.size(); i++) { for (unsigned i = 0; i < nodes.size(); i++) {
CTNode* copy = copySubtree (nodes[i]); CTNode* copy = CTNode::copySubtree (nodes[i]);
copy->setSymbol (nodes[i]->symbol()); copy->setSymbol (nodes[i]->symbol());
ConstraintTree* newCt = new ConstraintTree (logVars_); ConstraintTree* newCt = new ConstraintTree (logVars_);
newCt->root()->addChild (copy); newCt->root()->addChild (copy);
@ -840,19 +905,19 @@ ConstraintTree::getNodesAtLevel (unsigned level) const
void void
ConstraintTree::swapLogVar (LogVar X) ConstraintTree::swapLogVar (LogVar X)
{ {
TupleSet before = tupleSet();
LogVars::iterator it = LogVars::iterator it =
std::find (logVars_.begin(),logVars_.end(), X); std::find (logVars_.begin(),logVars_.end(), X);
assert (it != logVars_.end()); assert (it != logVars_.end());
unsigned pos = std::distance (logVars_.begin(), it); unsigned pos = std::distance (logVars_.begin(), it);
const CTNodes& nodes = getNodesAtLevel (pos); const CTNodes& nodes = getNodesAtLevel (pos);
for (unsigned i = 0; i < nodes.size(); i++) { for (unsigned i = 0; i < nodes.size(); i++) {
const CTNodes childs = nodes[i]->childs(); CTNodes childsCopy = nodes[i]->childs();
for (unsigned j = 0; j < childs.size(); j++) { nodes[i]->removeChilds();
nodes[i]->removeChild (childs[j]); for (unsigned j = 0; j < childsCopy.size(); j++) {
const CTNodes grandsons = childs[j]->childs(); const CTNodes grandsons = childsCopy[j]->childs();
for (unsigned k = 0; k < grandsons.size(); k++) { for (unsigned k = 0; k < grandsons.size(); k++) {
CTNode* childCopy = new CTNode (*childs[j]); CTNode* childCopy = new CTNode (*childsCopy[j]);
const CTNodes greatGrandsons = grandsons[k]->childs(); const CTNodes greatGrandsons = grandsons[k]->childs();
for (unsigned t = 0; t < greatGrandsons.size(); t++) { for (unsigned t = 0; t < greatGrandsons.size(); t++) {
grandsons[k]->removeChild (greatGrandsons[t]); grandsons[k]->removeChild (greatGrandsons[t]);
@ -863,10 +928,9 @@ ConstraintTree::swapLogVar (LogVar X)
grandsons[k]->setLevel (grandsons[k]->level() - 1); grandsons[k]->setLevel (grandsons[k]->level() - 1);
nodes[i]->addChild (grandsons[k], false); nodes[i]->addChild (grandsons[k], false);
} }
delete childs[j]; delete childsCopy[j];
} }
} }
std::swap (logVars_[pos], logVars_[pos + 1]); std::swap (logVars_[pos], logVars_[pos + 1]);
} }
@ -884,7 +948,7 @@ ConstraintTree::join (
if (currIdx == tuple.size() - 1) { if (currIdx == tuple.size() - 1) {
const CTNodes& childs = appendNode->childs(); const CTNodes& childs = appendNode->childs();
for (unsigned i = 0; i < childs.size(); i++) { for (unsigned i = 0; i < childs.size(); i++) {
n->addChild (copySubtree (childs[i])); n->addChild (CTNode::copySubtree (childs[i]));
} }
return true; return true;
} }
@ -985,7 +1049,7 @@ ConstraintTree::countNormalize (
{ {
if (n->level() == stopLevel) { if (n->level() == stopLevel) {
return vector<pair<CTNode*, unsigned>>() = { return vector<pair<CTNode*, unsigned>>() = {
make_pair (copySubtree (n), countTuples (n)) make_pair (CTNode::copySubtree (n), countTuples (n))
}; };
} }
@ -1004,65 +1068,6 @@ ConstraintTree::countNormalize (
} }
/*
void
ConstraintTree::split (
CTNode* n1,
CTNode* n2,
CTNodes& nodes,
unsigned stopLevel)
{
CTNodes& childs1 = n1->childs();
CTNodes& childs2 = n2->childs();
// cout << string (n1->level() * 8, '-') << "Level = " << n1->level() + 1;
// cout << ", #I = " << childs1.size();
// cout << ", #J = " << childs2.size() << endl;
for (unsigned i = 0; i < childs1.size(); i++) {
for (unsigned j = 0; j < childs2.size(); j++) {
if (childs1[i]->symbol() != childs2[j]->symbol()) {
continue;
}
if (childs1[i]->level() == stopLevel) {
CTNode* newNode = copySubtree (childs1[i]);
newNode->setSymbol (childs1[i]->symbol());
nodes.push_back (newNode);
childs1[i]->setSymbol (Symbol::invalid());
break;
} else {
CTNodes lowerNodes;
split (childs1[i], childs2[j], lowerNodes, stopLevel);
if (lowerNodes.empty() == false) {
CTNode* me = new CTNode (childs1[i]->symbol(), childs1[i]->level());
for (unsigned k = 0; k < lowerNodes.size(); k++) {
me->addChild (lowerNodes[k]);
}
nodes.push_back (me);
}
if (childs1[i]->isLeaf()) {
break;
}
}
}
}
for (int i = 0; i < (int)childs1.size(); i++) {
// cout << string (n1->level() * 8, '-') << childs1[i];
if (childs1[i]->symbol() == Symbol::invalid()) {
// cout << " empty, removing..." ;
n1->removeChild (childs1[i]);
i --;
} else if (childs1[i]->isLeaf() &&
childs1[i]->level() != stopLevel) {
// cout << " leaf, removing..." ;
n1->removeChild (childs1[i]);
i --;
}
// cout << endl;
}
}
*/
void void
ConstraintTree::split ( ConstraintTree::split (
@ -1085,7 +1090,7 @@ ConstraintTree::split (
continue; continue;
} }
if (childs1[i]->level() == stopLevel) { if (childs1[i]->level() == stopLevel) {
CTNode* newNode = copySubtree (childs1[i]); CTNode* newNode = CTNode::copySubtree (childs1[i]);
nodes.push_back (newNode); nodes.push_back (newNode);
childs1[i]->setSymbol (Symbol::invalid()); childs1[i]->setSymbol (Symbol::invalid());
} else { } else {
@ -1103,11 +1108,11 @@ ConstraintTree::split (
for (int i = 0; i < (int)childs1.size(); i++) { for (int i = 0; i < (int)childs1.size(); i++) {
if (childs1[i]->symbol() == Symbol::invalid()) { if (childs1[i]->symbol() == Symbol::invalid()) {
n1->removeChild (childs1[i]); n1->removeAndDeleteChild (childs1[i]);
i --; i --;
} else if (childs1[i]->isLeaf() && } else if (childs1[i]->isLeaf() &&
childs1[i]->level() != stopLevel) { childs1[i]->level() != stopLevel) {
n1->removeChild (childs1[i]); n1->removeAndDeleteChild (childs1[i]);
i --; i --;
} }
} }
@ -1141,29 +1146,3 @@ ConstraintTree::overlap (
return false; return false;
} }
CTNode*
ConstraintTree::copySubtree (const CTNode* n)
{
CTNode* newNode = new CTNode (*n);
const CTNodes& childs = n->childs();
for (unsigned i = 0; i < childs.size(); i++) {
newNode->addChild (copySubtree (childs[i]));
}
return newNode;
}
void
ConstraintTree::deleteSubtree (CTNode* n)
{
assert (n);
const CTNodes& childs = n->childs();
for (unsigned i = 0; i < childs.size(); i++) {
deleteSubtree (childs[i]);
}
delete n;
}

View File

@ -21,7 +21,6 @@ typedef vector<ConstraintTree*> ConstraintTrees;
class CTNode class CTNode
{ {
public: public:
@ -47,29 +46,44 @@ class CTNode
bool isLeaf (void) const { return childs_.empty(); } bool isLeaf (void) const { return childs_.empty(); }
void addChild (CTNode*, bool = true); void addChild (CTNode*, bool = true);
void removeChild (CTNode*);
SymbolSet childSymbols (void) const; void removeChild (CTNode*);
void removeChilds (void);
void removeAndDeleteChild (CTNode*);
void removeAndDeleteAllChilds (void);
SymbolSet childSymbols (void) const;
static CTNode* copySubtree (const CTNode*);
static void deleteSubtree (CTNode*);
private: private:
void updateChildLevels (CTNode*, unsigned); void updateChildLevels (CTNode*, unsigned);
Symbol symbol_; Symbol symbol_;
CTNodes childs_; CTNodes childs_;
unsigned level_; unsigned level_;
}; };
ostream& operator<< (ostream &out, const CTNode&); ostream& operator<< (ostream &out, const CTNode&);
class ConstraintTree class ConstraintTree
{ {
public: public:
ConstraintTree (unsigned);
ConstraintTree (const LogVars&); ConstraintTree (const LogVars&);
ConstraintTree (const LogVars&, const Tuples&); ConstraintTree (const LogVars&, const Tuples&);
ConstraintTree (const ConstraintTree&); ConstraintTree (const ConstraintTree&);
~ConstraintTree (void); ~ConstraintTree (void);
CTNode* root (void) const { return root_; } CTNode* root (void) const { return root_; }
@ -94,94 +108,95 @@ class ConstraintTree
assert (LogVarSet (logVars_) == logVarSet_); assert (LogVarSet (logVars_) == logVarSet_);
} }
void addTuple (const Tuple&); void addTuple (const Tuple&);
bool containsTuple (const Tuple&);
void moveToTop (const LogVars&); bool containsTuple (const Tuple&);
void moveToBottom (const LogVars&);
void join (ConstraintTree*, bool = false); void moveToTop (const LogVars&);
unsigned getLevel (LogVar) const;
void rename (LogVar, LogVar); void moveToBottom (const LogVars&);
void applySubstitution (const Substitution&);
void project (const LogVarSet&); void join (ConstraintTree*, bool = false);
void remove (const LogVarSet&);
bool isSingleton (LogVar); unsigned getLevel (LogVar) const;
LogVarSet singletons (void);
TupleSet tupleSet (unsigned = 0) const; void rename (LogVar, LogVar);
TupleSet tupleSet (const LogVars&);
unsigned size (void) const; void applySubstitution (const Substitution&);
unsigned nrSymbols (LogVar);
void exportToGraphViz (const char*, bool = false) const; void project (const LogVarSet&);
bool isCountNormalized (const LogVarSet&);
unsigned getConditionalCount (const LogVarSet&); void remove (const LogVarSet&);
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
bool isCarteesianProduct (const LogVarSet&) const; bool isSingleton (LogVar);
LogVarSet singletons (void);
TupleSet tupleSet (unsigned = 0) const;
TupleSet tupleSet (const LogVars&);
unsigned size (void) const;
unsigned nrSymbols (LogVar);
void exportToGraphViz (const char*, bool = false) const;
bool isCountNormalized (const LogVarSet&);
unsigned getConditionalCount (const LogVarSet&);
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
bool isCarteesianProduct (const LogVarSet&) const;
std::pair<ConstraintTree*, ConstraintTree*> split ( std::pair<ConstraintTree*, ConstraintTree*> split (
const Tuple&, const Tuple&, unsigned);
unsigned);
std::pair<ConstraintTree*, ConstraintTree*> split ( std::pair<ConstraintTree*, ConstraintTree*> split (
const ConstraintTree*, const ConstraintTree*, unsigned) const;
unsigned) const;
ConstraintTrees countNormalize (const LogVarSet&); ConstraintTrees countNormalize (const LogVarSet&);
ConstraintTrees jointCountNormalize ( ConstraintTrees jointCountNormalize (
ConstraintTree*, ConstraintTree*, ConstraintTree*, LogVar, LogVar, LogVar);
ConstraintTree*,
LogVar,
LogVar,
LogVar);
static bool identical ( static bool identical (
const ConstraintTree*, const ConstraintTree*, const ConstraintTree*, unsigned);
const ConstraintTree*,
unsigned);
static bool overlap ( static bool overlap (
const ConstraintTree*, const ConstraintTree*, const ConstraintTree*, unsigned);
const ConstraintTree*,
unsigned);
LogVars expand (LogVar); LogVars expand (LogVar);
ConstraintTrees ground (LogVar); ConstraintTrees ground (LogVar);
private: private:
unsigned countTuples (const CTNode*) const; unsigned countTuples (const CTNode*) const;
CTNodes getNodesBelow (CTNode*) const;
CTNodes getNodesAtLevel (unsigned) const;
void swapLogVar (LogVar);
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
bool indenticalSubtrees ( CTNodes getNodesBelow (CTNode*) const;
const CTNode*,
const CTNode*,
bool) const;
void getTuples ( CTNodes getNodesAtLevel (unsigned) const;
CTNode*,
Tuples, void swapLogVar (LogVar);
unsigned,
Tuples&, bool join (CTNode*, const Tuple&, unsigned, CTNode*);
CTNodes&) const;
bool indenticalSubtrees (
const CTNode*, const CTNode*, bool) const;
void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const;
vector<std::pair<CTNode*, unsigned>> countNormalize ( vector<std::pair<CTNode*, unsigned>> countNormalize (
const CTNode*, const CTNode*, unsigned);
unsigned);
static void split ( static void split (
CTNode*, CTNode*, CTNode*, CTNodes&, unsigned);
CTNode*,
CTNodes&,
unsigned);
static bool overlap (const CTNode*, const CTNode*, unsigned); static bool overlap (const CTNode*, const CTNode*, unsigned);
static CTNode* copySubtree (const CTNode*);
static void deleteSubtree (CTNode*);
CTNode* root_; CTNode* root_;
LogVars logVars_; LogVars logVars_;
LogVarSet logVarSet_; LogVarSet logVarSet_;
}; };

View File

@ -1,45 +0,0 @@
#ifndef HORUS_DISTRIBUTION_H
#define HORUS_DISTRIBUTION_H
#include <vector>
#include "Horus.h"
//TODO die die die die die
using namespace std;
struct Distribution
{
public:
Distribution (int id)
{
this->id = id;
}
Distribution (const Params& params, int id = -1)
{
this->id = id;
this->params = params;
}
void updateParameters (const Params& params)
{
this->params = params;
}
bool shared (void)
{
return id != -1;
}
int id;
Params params;
private:
DISALLOW_COPY_AND_ASSIGN (Distribution);
};
#endif // HORUS_DISTRIBUTION_H

View File

@ -1,53 +1,39 @@
#include <limits> #include <limits>
#include "ElimGraph.h" #include <fstream>
#include "BayesNet.h"
#include "ElimGraph.h"
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS; ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
ElimGraph::ElimGraph (const BayesNet& bayesNet) ElimGraph::ElimGraph (const vector<Factor*>& factors)
{ {
const BnNodeSet& bnNodes = bayesNet.getBayesNodes(); for (unsigned i = 0; i < factors.size(); i++) {
for (unsigned i = 0; i < bnNodes.size(); i++) { const VarIds& vids = factors[i]->arguments();
if (bnNodes[i]->hasEvidence() == false) { for (unsigned j = 0; j < vids.size() - 1; j++) {
addNode (new EgNode (bnNodes[i])); EgNode* n1 = getEgNode (vids[j]);
} if (n1 == 0) {
} n1 = new EgNode (vids[j], factors[i]->range (j));
addNode (n1);
for (unsigned i = 0; i < bnNodes.size(); i++) { }
if (bnNodes[i]->hasEvidence() == false) { for (unsigned k = j + 1; k < vids.size(); k++) {
EgNode* n = getEgNode (bnNodes[i]->varId()); EgNode* n2 = getEgNode (vids[k]);
const BnNodeSet& childs = bnNodes[i]->getChilds(); if (n2 == 0) {
for (unsigned j = 0; j < childs.size(); j++) { n2 = new EgNode (vids[k], factors[i]->range (k));
if (childs[j]->hasEvidence() == false) { addNode (n2);
addEdge (n, getEgNode (childs[j]->varId()));
} }
if (neighbors (n1, n2) == false) {
addEdge (n1, n2);
}
}
}
if (vids.size() == 1) {
if (getEgNode (vids[0]) == 0) {
addNode (new EgNode (vids[0], factors[i]->range (0)));
} }
} }
} }
for (unsigned i = 0; i < bnNodes.size(); i++) {
vector<EgNode*> neighs;
const vector<BayesNode*>& parents = bnNodes[i]->getParents();
for (unsigned i = 0; i < parents.size(); i++) {
if (parents[i]->hasEvidence() == false) {
neighs.push_back (getEgNode (parents[i]->varId()));
}
}
if (neighs.size() > 0) {
for (unsigned i = 0; i < neighs.size() - 1; i++) {
for (unsigned j = i+1; j < neighs.size(); j++) {
if (!neighbors (neighs[i], neighs[j])) {
addEdge (neighs[i], neighs[j]);
}
}
}
}
}
setIndexes();
} }
@ -61,40 +47,16 @@ ElimGraph::~ElimGraph (void)
void
ElimGraph::addNode (EgNode* n)
{
nodes_.push_back (n);
varMap_.insert (make_pair (n->varId(), n));
}
EgNode*
ElimGraph::getEgNode (VarId vid) const
{
unordered_map<VarId,EgNode*>::const_iterator it =varMap_.find (vid);
if (it ==varMap_.end()) {
return 0;
} else {
return it->second;
}
}
VarIds VarIds
ElimGraph::getEliminatingOrder (const VarIds& exclude) ElimGraph::getEliminatingOrder (const VarIds& exclude)
{ {
VarIds elimOrder; VarIds elimOrder;
marked_.resize (nodes_.size(), false); marked_.resize (nodes_.size(), false);
for (unsigned i = 0; i < exclude.size(); i++) { for (unsigned i = 0; i < exclude.size(); i++) {
assert (getEgNode (exclude[i]));
EgNode* node = getEgNode (exclude[i]); EgNode* node = getEgNode (exclude[i]);
assert (node);
marked_[*node] = true; marked_[*node] = true;
} }
unsigned nVarsToEliminate = nodes_.size() - exclude.size(); unsigned nVarsToEliminate = nodes_.size() - exclude.size();
for (unsigned i = 0; i < nVarsToEliminate; i++) { for (unsigned i = 0; i < nVarsToEliminate; i++) {
EgNode* node = getLowestCostNode(); EgNode* node = getLowestCostNode();
@ -107,6 +69,100 @@ ElimGraph::getEliminatingOrder (const VarIds& exclude)
void
ElimGraph::print (void) const
{
for (unsigned i = 0; i < nodes_.size(); i++) {
cout << "node " << nodes_[i]->label() << " neighs:" ;
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
cout << " " << neighs[j]->label();
}
cout << endl;
}
}
void
ElimGraph::exportToGraphViz (
const char* fileName,
bool showNeighborless,
const VarIds& highlightVarIds) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "Markov::exportToDotFile()" << endl;
abort();
}
out << "strict graph {" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
out << '"' << nodes_[i]->label() << '"' << endl;
}
}
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
EgNode* node =getEgNode (highlightVarIds[i]);
if (node) {
out << '"' << node->label() << '"' ;
out << " [shape=box3d]" << endl;
} else {
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
abort();
}
}
for (unsigned i = 0; i < nodes_.size(); i++) {
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
out << '"' << nodes_[i]->label() << '"' << " -- " ;
out << '"' << neighs[j]->label() << '"' << endl;
}
}
out << "}" << endl;
out.close();
}
VarIds
ElimGraph::getEliminationOrder (
const vector<Factor*> factors,
VarIds excludedVids)
{
ElimGraph graph (factors);
// graph.print();
// graph.exportToGraphViz ("_egg.dot");
return graph.getEliminatingOrder (excludedVids);
}
void
ElimGraph::addNode (EgNode* n)
{
nodes_.push_back (n);
n->setIndex (nodes_.size() - 1);
varMap_.insert (make_pair (n->varId(), n));
}
EgNode*
ElimGraph::getEgNode (VarId vid) const
{
unordered_map<VarId, EgNode*>::const_iterator it;
it = varMap_.find (vid);
return (it != varMap_.end()) ? it->second : 0;
}
EgNode* EgNode*
ElimGraph::getLowestCostNode (void) const ElimGraph::getLowestCostNode (void) const
{ {
@ -164,7 +220,7 @@ ElimGraph::getWeightCost (const EgNode* n) const
const vector<EgNode*>& neighs = n->neighbors(); const vector<EgNode*>& neighs = n->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) { for (unsigned i = 0; i < neighs.size(); i++) {
if (marked_[*neighs[i]] == false) { if (marked_[*neighs[i]] == false) {
cost *= neighs[i]->nrStates(); cost *= neighs[i]->range();
} }
} }
return cost; return cost;
@ -204,7 +260,7 @@ ElimGraph::getWeightedFillCost (const EgNode* n) const
for (unsigned j = i+1; j < neighs.size(); j++) { for (unsigned j = i+1; j < neighs.size(); j++) {
if (marked_[*neighs[j]] == true) continue; if (marked_[*neighs[j]] == true) continue;
if (!neighbors (neighs[i], neighs[j])) { if (!neighbors (neighs[i], neighs[j])) {
cost += neighs[i]->nrStates() * neighs[j]->nrStates(); cost += neighs[i]->range() * neighs[j]->range();
} }
} }
} }
@ -245,78 +301,3 @@ ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
return false; return false;
} }
void
ElimGraph::setIndexes (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i);
}
}
void
ElimGraph::printGraphicalModel (void) const
{
for (unsigned i = 0; i < nodes_.size(); i++) {
cout << "node " << nodes_[i]->label() << " neighs:" ;
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
cout << " " << neighs[j]->label();
}
cout << endl;
}
}
void
ElimGraph::exportToGraphViz (const char* fileName,
bool showNeighborless,
const VarIds& highlightVarIds) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "Markov::exportToDotFile()" << endl;
abort();
}
out << "strict graph {" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
out << '"' << nodes_[i]->label() << '"' ;
if (nodes_[i]->hasEvidence()) {
out << " [style=filled, fillcolor=yellow]" << endl;
} else {
out << endl;
}
}
}
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
EgNode* node =getEgNode (highlightVarIds[i]);
if (node) {
out << '"' << node->label() << '"' ;
out << " [shape=box3d]" << endl;
} else {
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
abort();
}
}
for (unsigned i = 0; i < nodes_.size(); i++) {
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
out << '"' << nodes_[i]->label() << '"' << " -- " ;
out << '"' << neighs[j]->label() << '"' << endl;
}
}
out << "}" << endl;
out.close();
}

View File

@ -17,15 +17,15 @@ enum ElimHeuristic
}; };
class EgNode : public VarNode { class EgNode : public Var
{
public: public:
EgNode (VarNode* var) : VarNode (var) { } EgNode (VarId vid, unsigned range) : Var (vid, range) { }
void addNeighbor (EgNode* n)
{ void addNeighbor (EgNode* n) { neighs_.push_back (n); }
neighs_.push_back (n);
}
const vector<EgNode*>& neighbors (void) const { return neighs_; } const vector<EgNode*>& neighbors (void) const { return neighs_; }
private: private:
vector<EgNode*> neighs_; vector<EgNode*> neighs_;
}; };
@ -34,22 +34,18 @@ class EgNode : public VarNode {
class ElimGraph class ElimGraph
{ {
public: public:
ElimGraph (const BayesNet&); ElimGraph (const vector<Factor*>&); // TODO
~ElimGraph (void); ~ElimGraph (void);
void addEdge (EgNode* n1, EgNode* n2) VarIds getEliminatingOrder (const VarIds&);
{
assert (n1 != n2); void print (void) const;
n1->addNeighbor (n2);
n2->addNeighbor (n1); void exportToGraphViz (const char*, bool = true,
} const VarIds& = VarIds()) const;
void addNode (EgNode*);
EgNode* getEgNode (VarId) const; static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
VarIds getEliminatingOrder (const VarIds&);
void printGraphicalModel (void) const;
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const;
void setIndexes();
static void setEliminationHeuristic (ElimHeuristic h) static void setEliminationHeuristic (ElimHeuristic h)
{ {
@ -57,18 +53,34 @@ class ElimGraph
} }
private: private:
EgNode* getLowestCostNode (void) const;
unsigned getNeighborsCost (const EgNode*) const;
unsigned getWeightCost (const EgNode*) const;
unsigned getFillCost (const EgNode*) const;
unsigned getWeightedFillCost (const EgNode*) const;
void connectAllNeighbors (const EgNode*);
bool neighbors (const EgNode*, const EgNode*) const;
void addEdge (EgNode* n1, EgNode* n2)
{
assert (n1 != n2);
n1->addNeighbor (n2);
n2->addNeighbor (n1);
}
void addNode (EgNode*);
EgNode* getEgNode (VarId) const;
EgNode* getLowestCostNode (void) const;
unsigned getNeighborsCost (const EgNode*) const;
unsigned getWeightCost (const EgNode*) const;
unsigned getFillCost (const EgNode*) const;
unsigned getWeightedFillCost (const EgNode*) const;
void connectAllNeighbors (const EgNode*);
bool neighbors (const EgNode*, const EgNode*) const;
vector<EgNode*> nodes_; vector<EgNode*> nodes_;
vector<bool> marked_; vector<bool> marked_;
unordered_map<VarId,EgNode*> varMap_; unordered_map<VarId, EgNode*> varMap_;
static ElimHeuristic elimHeuristic_; static ElimHeuristic elimHeuristic_;
}; };

View File

@ -8,7 +8,7 @@
#include "Factor.h" #include "Factor.h"
#include "Indexer.h" #include "Indexer.h"
#include "Util.h"
Factor::Factor (const Factor& g) Factor::Factor (const Factor& g)
@ -18,206 +18,33 @@ Factor::Factor (const Factor& g)
Factor::Factor (VarId vid, unsigned nStates) Factor::Factor (
const VarIds& vids,
const Ranges& ranges,
const Params& params,
unsigned distId)
{ {
varids_.push_back (vid); args_ = vids;
ranges_.push_back (nStates);
dist_ = new Distribution (Params (nStates, 1.0));
}
Factor::Factor (const VarNodes& vars)
{
int nParams = 1;
for (unsigned i = 0; i < vars.size(); i++) {
varids_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
nParams *= vars[i]->nrStates();
}
// create a uniform distribution
double val = 1.0 / nParams;
dist_ = new Distribution (Params (nParams, val));
}
Factor::Factor (VarId vid, unsigned nStates, const Params& params)
{
varids_.push_back (vid);
ranges_.push_back (nStates);
dist_ = new Distribution (params);
}
Factor::Factor (VarNodes& vars, Distribution* dist)
{
for (unsigned i = 0; i < vars.size(); i++) {
varids_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
}
dist_ = dist;
}
Factor::Factor (const VarNodes& vars, const Params& params)
{
for (unsigned i = 0; i < vars.size(); i++) {
varids_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
}
dist_ = new Distribution (params);
}
Factor::Factor (const VarIds& vids,
const Ranges& ranges,
const Params& params)
{
varids_ = vids;
ranges_ = ranges; ranges_ = ranges;
dist_ = new Distribution (params); params_ = params;
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
} }
Factor::~Factor (void) Factor::Factor (
const Vars& vars,
const Params& params,
unsigned distId)
{ {
if (dist_->shared() == false) { for (unsigned i = 0; i < vars.size(); i++) {
delete dist_; args_.push_back (vars[i]->varId());
} ranges_.push_back (vars[i]->range());
}
void
Factor::setParameters (const Params& params)
{
assert (dist_->params.size() == params.size());
dist_->params = params;
}
void
Factor::copyFromFactor (const Factor& g)
{
varids_ = g.getVarIds();
ranges_ = g.getRanges();
dist_ = new Distribution (g.getParameters());
}
void
Factor::multiply (const Factor& g)
{
if (varids_.size() == 0) {
copyFromFactor (g);
return;
}
const VarIds& g_varids = g.getVarIds();
const Ranges& g_ranges = g.getRanges();
const Params& g_params = g.getParameters();
if (varids_ == g_varids) {
// optimization: if the factors contain the same set of variables,
// we can do a 1 to 1 operation on the parameters
if (Globals::logDomain) {
Util::add (dist_->params, g_params);
} else {
Util::multiply (dist_->params, g_params);
}
} else {
bool sharedVars = false;
vector<unsigned> gvarpos;
for (unsigned i = 0; i < g_varids.size(); i++) {
int idx = indexOf (g_varids[i]);
if (idx == -1) {
insertVariable (g_varids[i], g_ranges[i]);
gvarpos.push_back (varids_.size() - 1);
} else {
sharedVars = true;
gvarpos.push_back (idx);
}
}
if (sharedVars == false) {
// optimization: if the original factors doesn't have common variables,
// we don't need to marry the states of the common variables
unsigned count = 0;
for (unsigned i = 0; i < dist_->params.size(); i++) {
if (Globals::logDomain) {
dist_->params[i] += g_params[count];
} else {
dist_->params[i] *= g_params[count];
}
count ++;
if (count >= g_params.size()) {
count = 0;
}
}
} else {
StatesIndexer indexer (ranges_, false);
while (indexer.valid()) {
unsigned g_li = 0;
unsigned prod = 1;
for (int j = gvarpos.size() - 1; j >= 0; j--) {
g_li += indexer[gvarpos[j]] * prod;
prod *= g_ranges[j];
}
if (Globals::logDomain) {
dist_->params[indexer] += g_params[g_li];
} else {
dist_->params[indexer] *= g_params[g_li];
}
++ indexer;
}
}
}
}
void
Factor::insertVariable (VarId varId, unsigned nrStates)
{
assert (indexOf (varId) == -1);
Params oldParams = dist_->params;
dist_->params.clear();
dist_->params.reserve (oldParams.size() * nrStates);
for (unsigned i = 0; i < oldParams.size(); i++) {
for (unsigned reps = 0; reps < nrStates; reps++) {
dist_->params.push_back (oldParams[i]);
}
}
varids_.push_back (varId);
ranges_.push_back (nrStates);
}
void
Factor::insertVariables (const VarIds& varIds, const Ranges& ranges)
{
Params oldParams = dist_->params;
unsigned nrStates = 1;
for (unsigned i = 0; i < varIds.size(); i++) {
assert (indexOf (varIds[i]) == -1);
varids_.push_back (varIds[i]);
ranges_.push_back (ranges[i]);
nrStates *= ranges[i];
}
dist_->params.clear();
dist_->params.reserve (oldParams.size() * nrStates);
for (unsigned i = 0; i < oldParams.size(); i++) {
for (unsigned reps = 0; reps < nrStates; reps++) {
dist_->params.push_back (oldParams[i]);
}
} }
params_ = params;
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
} }
@ -226,10 +53,10 @@ void
Factor::sumOutAllExcept (VarId vid) Factor::sumOutAllExcept (VarId vid)
{ {
assert (indexOf (vid) != -1); assert (indexOf (vid) != -1);
while (varids_.back() != vid) { while (args_.back() != vid) {
sumOutLastVariable(); sumOutLastVariable();
} }
while (varids_.front() != vid) { while (args_.front() != vid) {
sumOutFirstVariable(); sumOutFirstVariable();
} }
} }
@ -239,9 +66,10 @@ Factor::sumOutAllExcept (VarId vid)
void void
Factor::sumOutAllExcept (const VarIds& vids) Factor::sumOutAllExcept (const VarIds& vids)
{ {
for (unsigned i = 0; i < varids_.size(); i++) { for (int i = 0; i < (int)args_.size(); i++) {
if (std::find (vids.begin(), vids.end(), varids_[i]) == vids.end()) { if (Util::contains (vids, args_[i]) == false) {
sumOut (varids_[i]); sumOut (args_[i]);
i --;
} }
} }
} }
@ -254,11 +82,11 @@ Factor::sumOut (VarId vid)
int idx = indexOf (vid); int idx = indexOf (vid);
assert (idx != -1); assert (idx != -1);
if (vid == varids_.back()) { if (vid == args_.back()) {
sumOutLastVariable(); // optimization sumOutLastVariable(); // optimization
return; return;
} }
if (vid == varids_.front()) { if (vid == args_.front()) {
sumOutFirstVariable(); // optimization sumOutFirstVariable(); // optimization
return; return;
} }
@ -271,7 +99,7 @@ Factor::sumOut (VarId vid)
// on the left of `var', with the states of the remaining vars fixed // on the left of `var', with the states of the remaining vars fixed
unsigned leftVarOffset = 1; unsigned leftVarOffset = 1;
for (int i = varids_.size() - 1; i > idx; i--) { for (int i = args_.size() - 1; i > idx; i--) {
varOffset *= ranges_[i]; varOffset *= ranges_[i];
leftVarOffset *= ranges_[i]; leftVarOffset *= ranges_[i];
} }
@ -280,25 +108,24 @@ Factor::sumOut (VarId vid)
unsigned offset = 0; unsigned offset = 0;
unsigned count1 = 0; unsigned count1 = 0;
unsigned count2 = 0; unsigned count2 = 0;
unsigned newpsSize = dist_->params.size() / ranges_[idx]; unsigned newpsSize = params_.size() / ranges_[idx];
Params newps; Params newps;
newps.reserve (newpsSize); newps.reserve (newpsSize);
Params& params = dist_->params;
while (newps.size() < newpsSize) { while (newps.size() < newpsSize) {
double sum = Util::addIdenty(); double sum = LogAware::addIdenty();
for (unsigned i = 0; i < ranges_[idx]; i++) { for (unsigned i = 0; i < ranges_[idx]; i++) {
if (Globals::logDomain) { if (Globals::logDomain) {
Util::logSum (sum, params[offset]); sum = Util::logSum (sum, params_[offset]);
} else { } else {
sum += params[offset]; sum += params_[offset];
} }
offset += varOffset; offset += varOffset;
} }
newps.push_back (sum); newps.push_back (sum);
count1 ++; count1 ++;
if (idx == (int)varids_.size() - 1) { if (idx == (int)args_.size() - 1) {
offset = count1 * ranges_[idx]; offset = count1 * ranges_[idx];
} else { } else {
if (((offset - varOffset + 1) % leftVarOffset) == 0) { if (((offset - varOffset + 1) % leftVarOffset) == 0) {
@ -308,9 +135,9 @@ Factor::sumOut (VarId vid)
offset = (leftVarOffset * count2) + count1; offset = (leftVarOffset * count2) + count1;
} }
} }
varids_.erase (varids_.begin() + idx); args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx); ranges_.erase (ranges_.begin() + idx);
dist_->params = newps; params_ = newps;
} }
@ -318,20 +145,19 @@ Factor::sumOut (VarId vid)
void void
Factor::sumOutFirstVariable (void) Factor::sumOutFirstVariable (void)
{ {
Params& params = dist_->params; unsigned range = ranges_.front();
unsigned nStates = ranges_.front(); unsigned sep = params_.size() / range;
unsigned sep = params.size() / nStates;
if (Globals::logDomain) { if (Globals::logDomain) {
for (unsigned i = sep; i < params.size(); i++) { for (unsigned i = sep; i < params_.size(); i++) {
Util::logSum (params[i % sep], params[i]); params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
} }
} else { } else {
for (unsigned i = sep; i < params.size(); i++) { for (unsigned i = sep; i < params_.size(); i++) {
params[i % sep] += params[i]; params_[i % sep] += params_[i];
} }
} }
params.resize (sep); params_.resize (sep);
varids_.erase (varids_.begin()); args_.erase (args_.begin());
ranges_.erase (ranges_.begin()); ranges_.erase (ranges_.begin());
} }
@ -340,143 +166,55 @@ Factor::sumOutFirstVariable (void)
void void
Factor::sumOutLastVariable (void) Factor::sumOutLastVariable (void)
{ {
Params& params = dist_->params; unsigned range = ranges_.back();
unsigned nStates = ranges_.back();
unsigned idx1 = 0; unsigned idx1 = 0;
unsigned idx2 = 0; unsigned idx2 = 0;
if (Globals::logDomain) { if (Globals::logDomain) {
while (idx1 < params.size()) { while (idx1 < params_.size()) {
params[idx2] = params[idx1]; params_[idx2] = params_[idx1];
idx1 ++; idx1 ++;
for (unsigned j = 1; j < nStates; j++) { for (unsigned j = 1; j < range; j++) {
Util::logSum (params[idx2], params[idx1]); params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
idx1 ++; idx1 ++;
} }
idx2 ++; idx2 ++;
} }
} else { } else {
while (idx1 < params.size()) { while (idx1 < params_.size()) {
params[idx2] = params[idx1]; params_[idx2] = params_[idx1];
idx1 ++; idx1 ++;
for (unsigned j = 1; j < nStates; j++) { for (unsigned j = 1; j < range; j++) {
params[idx2] += params[idx1]; params_[idx2] += params_[idx1];
idx1 ++; idx1 ++;
} }
idx2 ++; idx2 ++;
} }
} }
params.resize (idx2); params_.resize (idx2);
varids_.pop_back(); args_.pop_back();
ranges_.pop_back(); ranges_.pop_back();
} }
void void
Factor::orderVariables (void) Factor::multiply (Factor& g)
{ {
VarIds sortedVarIds = varids_; if (args_.size() == 0) {
sort (sortedVarIds.begin(), sortedVarIds.end()); copyFromFactor (g);
reorderVariables (sortedVarIds);
}
void
Factor::reorderVariables (const VarIds& newVarIds)
{
assert (newVarIds.size() == varids_.size());
if (newVarIds == varids_) {
return; return;
} }
TFactor<VarId>::multiply (g);
Ranges newRanges;
vector<unsigned> positions;
for (unsigned i = 0; i < newVarIds.size(); i++) {
unsigned idx = indexOf (newVarIds[i]);
newRanges.push_back (ranges_[idx]);
positions.push_back (idx);
}
unsigned N = ranges_.size();
Params newParams (dist_->params.size());
for (unsigned i = 0; i < dist_->params.size(); i++) {
unsigned li = i;
// calculate vector index corresponding to linear index
vector<unsigned> vi (N);
for (int k = N-1; k >= 0; k--) {
vi[k] = li % ranges_[k];
li /= ranges_[k];
}
// convert permuted vector index to corresponding linear index
unsigned prod = 1;
unsigned new_li = 0;
for (int k = N-1; k >= 0; k--) {
new_li += vi[positions[k]] * prod;
prod *= ranges_[positions[k]];
}
newParams[new_li] = dist_->params[i];
}
varids_ = newVarIds;
ranges_ = newRanges;
dist_->params = newParams;
} }
void void
Factor::absorveEvidence (VarId vid, unsigned evidence) Factor::reorderAccordingVarIds (void)
{ {
int idx = indexOf (vid); VarIds sortedVarIds = args_;
assert (idx != -1); sort (sortedVarIds.begin(), sortedVarIds.end());
reorderArguments (sortedVarIds);
Params oldParams = dist_->params;
dist_->params.clear();
dist_->params.reserve (oldParams.size() / ranges_[idx]);
StatesIndexer indexer (ranges_);
for (unsigned i = 0; i < evidence; i++) {
indexer.increment (idx);
}
while (indexer.valid()) {
dist_->params.push_back (oldParams[indexer]);
indexer.incrementExcluding (idx);
}
varids_.erase (varids_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
void
Factor::normalize (void)
{
Util::normalize (dist_->params);
}
bool
Factor::contains (const VarIds& vars) const
{
for (unsigned i = 0; i < vars.size(); i++) {
if (indexOf (vars[i]) == -1) {
return false;
}
}
return true;
}
int
Factor::indexOf (VarId vid) const
{
for (unsigned i = 0; i < varids_.size(); i++) {
if (varids_[i] == vid) {
return i;
}
}
return -1;
} }
@ -486,9 +224,9 @@ Factor::getLabel (void) const
{ {
stringstream ss; stringstream ss;
ss << "f(" ; ss << "f(" ;
for (unsigned i = 0; i < varids_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ; if (i != 0) ss << "," ;
ss << VarNode (varids_[i], ranges_[i]).label(); ss << Var (args_[i], ranges_[i]).label();
} }
ss << ")" ; ss << ")" ;
return ss.str(); return ss.str();
@ -499,14 +237,14 @@ Factor::getLabel (void) const
void void
Factor::print (void) const Factor::print (void) const
{ {
VarNodes vars; Vars vars;
for (unsigned i = 0; i < varids_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
vars.push_back (new VarNode (varids_[i], ranges_[i])); vars.push_back (new Var (args_[i], ranges_[i]));
} }
vector<string> jointStrings = Util::getJointStateStrings (vars); vector<string> jointStrings = Util::getStateLines (vars);
for (unsigned i = 0; i < dist_->params.size(); i++) { for (unsigned i = 0; i < params_.size(); i++) {
cout << "f(" << jointStrings[i] << ")" ; cout << "[" << distId_ << "] f(" << jointStrings[i] << ")" ;
cout << " = " << dist_->params[i] << endl; cout << " = " << params_[i] << endl;
} }
cout << endl; cout << endl;
for (unsigned i = 0; i < vars.size(); i++) { for (unsigned i = 0; i < vars.size(); i++) {
@ -515,3 +253,13 @@ Factor::print (void) const
} }
void
Factor::copyFromFactor (const Factor& g)
{
args_ = g.arguments();
ranges_ = g.ranges();
params_ = g.params();
distId_ = g.distId();
}

View File

@ -3,64 +3,285 @@
#include <vector> #include <vector>
#include "Distribution.h" #include "Var.h"
#include "VarNode.h" #include "Indexer.h"
#include "Util.h"
using namespace std; using namespace std;
class Distribution;
template <typename T>
class TFactor
{
public:
const vector<T>& arguments (void) const { return args_; }
vector<T>& arguments (void) { return args_; }
const Ranges& ranges (void) const { return ranges_; }
const Params& params (void) const { return params_; }
Params& params (void) { return params_; }
unsigned nrArguments (void) const { return args_.size(); }
unsigned size (void) const { return params_.size(); }
unsigned distId (void) const { return distId_; }
void setDistId (unsigned id) { distId_ = id; }
void normalize (void) { LogAware::normalize (params_); }
void setParams (const Params& newParams)
{
params_ = newParams;
assert (params_.size() == Util::expectedSize (ranges_));
}
int indexOf (const T& t) const
{
int idx = -1;
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i] == t) {
idx = i;
break;
}
}
return idx;
}
const T& argument (unsigned idx) const
{
assert (idx < args_.size());
return args_[idx];
}
T& argument (unsigned idx)
{
assert (idx < args_.size());
return args_[idx];
}
unsigned range (unsigned idx) const
{
assert (idx < ranges_.size());
return ranges_[idx];
}
void multiply (TFactor<T>& g)
{
const vector<T>& g_args = g.arguments();
const Ranges& g_ranges = g.ranges();
const Params& g_params = g.params();
if (args_ == g_args) {
// optimization: if the factors contain the same set of args,
// we can do a 1 to 1 operation on the parameters
if (Globals::logDomain) {
Util::add (params_, g_params);
} else {
Util::multiply (params_, g_params);
}
} else {
bool sharedArgs = false;
vector<unsigned> gvarpos;
for (unsigned i = 0; i < g_args.size(); i++) {
int idx = indexOf (g_args[i]);
if (idx == -1) {
insertArgument (g_args[i], g_ranges[i]);
gvarpos.push_back (args_.size() - 1);
} else {
sharedArgs = true;
gvarpos.push_back (idx);
}
}
if (sharedArgs == false) {
// optimization: if the original factors doesn't have common args,
// we don't need to marry the states of the common args
unsigned count = 0;
for (unsigned i = 0; i < params_.size(); i++) {
if (Globals::logDomain) {
params_[i] += g_params[count];
} else {
params_[i] *= g_params[count];
}
count ++;
if (count >= g_params.size()) {
count = 0;
}
}
} else {
StatesIndexer indexer (ranges_, false);
while (indexer.valid()) {
unsigned g_li = 0;
unsigned prod = 1;
for (int j = gvarpos.size() - 1; j >= 0; j--) {
g_li += indexer[gvarpos[j]] * prod;
prod *= g_ranges[j];
}
if (Globals::logDomain) {
params_[indexer] += g_params[g_li];
} else {
params_[indexer] *= g_params[g_li];
}
++ indexer;
}
}
}
}
void absorveEvidence (const T& arg, unsigned evidence)
{
int idx = indexOf (arg);
assert (idx != -1);
assert (evidence < ranges_[idx]);
Params copy = params_;
params_.clear();
params_.reserve (copy.size() / ranges_[idx]);
StatesIndexer indexer (ranges_);
for (unsigned i = 0; i < evidence; i++) {
indexer.increment (idx);
}
while (indexer.valid()) {
params_.push_back (copy[indexer]);
indexer.incrementExcluding (idx);
}
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
void reorderArguments (const vector<T> newArgs)
{
assert (newArgs.size() == args_.size());
if (newArgs == args_) {
return; // already in the wanted order
}
Ranges newRanges;
vector<unsigned> positions;
for (unsigned i = 0; i < newArgs.size(); i++) {
unsigned idx = indexOf (newArgs[i]);
newRanges.push_back (ranges_[idx]);
positions.push_back (idx);
}
unsigned N = ranges_.size();
Params newParams (params_.size());
for (unsigned i = 0; i < params_.size(); i++) {
unsigned li = i;
// calculate vector index corresponding to linear index
vector<unsigned> vi (N);
for (int k = N-1; k >= 0; k--) {
vi[k] = li % ranges_[k];
li /= ranges_[k];
}
// convert permuted vector index to corresponding linear index
unsigned prod = 1;
unsigned new_li = 0;
for (int k = N - 1; k >= 0; k--) {
new_li += vi[positions[k]] * prod;
prod *= ranges_[positions[k]];
}
newParams[new_li] = params_[i];
}
args_ = newArgs;
ranges_ = newRanges;
params_ = newParams;
}
bool contains (const T& arg) const
{
return Util::contains (args_, arg);
}
bool contains (const vector<T>& args) const
{
for (unsigned i = 0; i < args_.size(); i++) {
if (contains (args[i]) == false) {
return false;
}
}
return true;
}
protected:
vector<T> args_;
Ranges ranges_;
Params params_;
unsigned distId_;
private:
void insertArgument (const T& arg, unsigned range)
{
assert (indexOf (arg) == -1);
Params copy = params_;
params_.clear();
params_.reserve (copy.size() * range);
for (unsigned i = 0; i < copy.size(); i++) {
for (unsigned reps = 0; reps < range; reps++) {
params_.push_back (copy[i]);
}
}
args_.push_back (arg);
ranges_.push_back (range);
}
void insertArguments (const vector<T>& args, const Ranges& ranges)
{
Params copy = params_;
unsigned nrStates = 1;
for (unsigned i = 0; i < args.size(); i++) {
assert (indexOf (args[i]) == -1);
args_.push_back (args[i]);
ranges_.push_back (ranges[i]);
nrStates *= ranges[i];
}
params_.clear();
params_.reserve (copy.size() * nrStates);
for (unsigned i = 0; i < copy.size(); i++) {
for (unsigned reps = 0; reps < nrStates; reps++) {
params_.push_back (copy[i]);
}
}
}
};
class Factor
class Factor : public TFactor<VarId>
{ {
public: public:
Factor (void) { } Factor (void) { }
Factor (const Factor&); Factor (const Factor&);
Factor (VarId, unsigned);
Factor (const VarNodes&);
Factor (VarId, unsigned, const Params&);
Factor (VarNodes&, Distribution*);
Factor (const VarNodes&, const Params&);
Factor (const VarIds&, const Ranges&, const Params&);
~Factor (void);
void setParameters (const Params&); Factor (const VarIds&, const Ranges&, const Params&,
void copyFromFactor (const Factor& f); unsigned = Util::maxUnsigned());
void multiply (const Factor&);
void insertVariable (VarId, unsigned);
void insertVariables (const VarIds&, const Ranges&);
void sumOutAllExcept (VarId);
void sumOutAllExcept (const VarIds&);
void sumOut (VarId);
void sumOutFirstVariable (void);
void sumOutLastVariable (void);
void orderVariables (void);
void reorderVariables (const VarIds&);
void absorveEvidence (VarId, unsigned);
void normalize (void);
bool contains (const VarIds&) const;
int indexOf (VarId) const;
string getLabel (void) const;
void print (void) const;
const VarIds& getVarIds (void) const { return varids_; } Factor (const Vars&, const Params&,
const Ranges& getRanges (void) const { return ranges_; } unsigned = Util::maxUnsigned());
const Params& getParameters (void) const { return dist_->params; }
Distribution* getDistribution (void) const { return dist_; }
unsigned nrVariables (void) const { return varids_.size(); }
unsigned nrParameters() const { return dist_->params.size(); }
void setDistribution (Distribution* dist) void sumOutAllExcept (VarId);
{
dist_ = dist; void sumOutAllExcept (const VarIds&);
}
void sumOut (VarId);
void sumOutFirstVariable (void);
void sumOutLastVariable (void);
void multiply (Factor&);
void reorderAccordingVarIds (void);
string getLabel (void) const;
void print (void) const;
private: private:
void copyFromFactor (const Factor& f);
VarIds varids_;
Ranges ranges_;
Distribution* dist_;
}; };
#endif // HORUS_FACTOR_H #endif // HORUS_FACTOR_H

View File

@ -9,6 +9,7 @@
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Factor.h" #include "Factor.h"
#include "BayesNet.h" #include "BayesNet.h"
#include "BayesBall.h"
#include "Util.h" #include "Util.h"
@ -17,140 +18,92 @@ bool FactorGraph::orderFactorVariables = false;
FactorGraph::FactorGraph (const FactorGraph& fg) FactorGraph::FactorGraph (const FactorGraph& fg)
{ {
const FgVarSet& vars = fg.getVarNodes(); const VarNodes& varNodes = fg.varNodes();
for (unsigned i = 0; i < vars.size(); i++) { for (unsigned i = 0; i < varNodes.size(); i++) {
FgVarNode* varNode = new FgVarNode (vars[i]); addVarNode (new VarNode (varNodes[i]));
addVariable (varNode);
} }
const FacNodes& facNodes = fg.facNodes();
const FgFacSet& facs = fg.getFactorNodes(); for (unsigned i = 0; i < facNodes.size(); i++) {
for (unsigned i = 0; i < facs.size(); i++) { FacNode* facNode = new FacNode (facNodes[i]->factor());
FgFacNode* facNode = new FgFacNode (facs[i]); addFacNode (facNode);
addFactor (facNode); const VarNodes& neighs = facNodes[i]->neighbors();
const FgVarSet& neighs = facs[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) { for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (facNode, varNodes_[neighs[j]->getIndex()]); addEdge (varNodes_[neighs[j]->getIndex()], facNode);
} }
} }
} }
FactorGraph::FactorGraph (const BayesNet& bn)
{
const BnNodeSet& nodes = bn.getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
FgVarNode* varNode = new FgVarNode (nodes[i]);
addVariable (varNode);
}
for (unsigned i = 0; i < nodes.size(); i++) {
const BnNodeSet& parents = nodes[i]->getParents();
if (!(nodes[i]->hasEvidence() && parents.size() == 0)) {
VarNodes neighs;
neighs.push_back (varNodes_[nodes[i]->getIndex()]);
for (unsigned j = 0; j < parents.size(); j++) {
neighs.push_back (varNodes_[parents[j]->getIndex()]);
}
FgFacNode* fn = new FgFacNode (
new Factor (neighs, nodes[i]->getDistribution()));
if (orderFactorVariables) {
sort (neighs.begin(), neighs.end(), CompVarId());
fn->factor()->orderVariables();
}
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
}
}
}
setIndexes();
}
void void
FactorGraph::readFromUaiFormat (const char* fileName) FactorGraph::readFromUaiFormat (const char* fileName)
{ {
ifstream is (fileName); std::ifstream is (fileName);
if (!is.is_open()) { if (!is.is_open()) {
cerr << "error: cannot read from file " + std::string (fileName) << endl; cerr << "error: cannot read from file " << fileName << endl;
abort(); abort();
} }
ignoreLines (is);
string line; string line;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
getline (is, line); getline (is, line);
if (line != "MARKOV") { if (line != "MARKOV") {
cerr << "error: the network must be a MARKOV network " << endl; cerr << "error: the network must be a MARKOV network " << endl;
abort(); abort();
} }
// read the number of vars
while (is.peek() == '#' || is.peek() == '\n') getline (is, line); ignoreLines (is);
unsigned nVars; unsigned nrVars;
is >> nVars; is >> nrVars;
// read the range of each var
while (is.peek() == '#' || is.peek() == '\n') getline (is, line); ignoreLines (is);
vector<int> domainSizes (nVars); Ranges ranges (nrVars);
for (unsigned i = 0; i < nVars; i++) { for (unsigned i = 0; i < nrVars; i++) {
unsigned ds; is >> ranges[i];
is >> ds;
domainSizes[i] = ds;
} }
unsigned nrFactors;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line); unsigned nrArgs;
for (unsigned i = 0; i < nVars; i++) { unsigned vid;
addVariable (new FgVarNode (i, domainSizes[i])); is >> nrFactors;
} vector<VarIds> factorVarIds;
vector<Ranges> factorRanges;
unsigned nFactors; for (unsigned i = 0; i < nrFactors; i++) {
is >> nFactors; ignoreLines (is);
for (unsigned i = 0; i < nFactors; i++) { is >> nrArgs;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line); factorVarIds.push_back ({ });
unsigned nFactorVars; factorRanges.push_back ({ });
is >> nFactorVars; for (unsigned j = 0; j < nrArgs; j++) {
VarNodes neighs;
for (unsigned j = 0; j < nFactorVars; j++) {
unsigned vid;
is >> vid; is >> vid;
FgVarNode* neigh = getFgVarNode (vid); if (vid >= ranges.size()) {
if (!neigh) { cerr << "error: invalid variable identifier `" << vid << "'" << endl;
cerr << "error: invalid variable identifier (" << vid << ")" << endl; cerr << "identifiers must be between 0 and " << ranges.size() - 1 ;
cerr << endl;
abort(); abort();
} }
neighs.push_back (neigh); factorVarIds.back().push_back (vid);
} factorRanges.back().push_back (ranges[vid]);
FgFacNode* fn = new FgFacNode (new Factor (neighs));
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
} }
} }
// read the parameters
for (unsigned i = 0; i < nFactors; i++) { unsigned nrParams;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line); for (unsigned i = 0; i < nrFactors; i++) {
unsigned nParams; ignoreLines (is);
is >> nParams; is >> nrParams;
if (facNodes_[i]->getParameters().size() != nParams) { if (nrParams != Util::expectedSize (factorRanges[i])) {
cerr << "error: invalid number of parameters for factor " ; cerr << "error: invalid number of parameters for factor nº " << i ;
cerr << facNodes_[i]->getLabel() ; cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
cerr << ", expected: " << facNodes_[i]->getParameters().size(); cerr << ", given: " << nrParams << endl;
cerr << ", given: " << nParams << endl;
abort(); abort();
} }
Params params (nParams); Params params (nrParams);
for (unsigned j = 0; j < nParams; j++) { for (unsigned j = 0; j < nrParams; j++) {
double param; is >> params[j];
is >> param;
params[j] = param;
} }
if (Globals::logDomain) { if (Globals::logDomain) {
Util::toLog (params); Util::toLog (params);
} }
facNodes_[i]->factor()->setParameters (params); addFactor (Factor (factorVarIds[i], factorRanges[i], params));
} }
is.close(); is.close();
setIndexes();
} }
@ -158,87 +111,58 @@ FactorGraph::readFromUaiFormat (const char* fileName)
void void
FactorGraph::readFromLibDaiFormat (const char* fileName) FactorGraph::readFromLibDaiFormat (const char* fileName)
{ {
ifstream is (fileName); std::ifstream is (fileName);
if (!is.is_open()) { if (!is.is_open()) {
cerr << "error: cannot read from file " + std::string (fileName) << endl; cerr << "error: cannot read from file " << fileName << endl;
abort(); abort();
} }
ignoreLines (is);
string line; unsigned nrFactors;
unsigned nFactors; unsigned nrArgs;
VarId vid;
while ((is.peek()) == '#') getline (is, line); is >> nrFactors;
is >> nFactors; for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
if (is.fail()) { // read the factor arguments
cerr << "error: cannot read the number of factors" << endl; is >> nrArgs;
abort();
}
getline (is, line);
if (is.fail() || line.size() > 0) {
cerr << "error: cannot read the number of factors" << endl;
abort();
}
for (unsigned i = 0; i < nFactors; i++) {
unsigned nVars;
while ((is.peek()) == '#') getline (is, line);
is >> nVars;
VarIds vids; VarIds vids;
for (unsigned j = 0; j < nVars; j++) { for (unsigned j = 0; j < nrArgs; j++) {
VarId vid; ignoreLines (is);
while ((is.peek()) == '#') getline (is, line);
is >> vid; is >> vid;
vids.push_back (vid); vids.push_back (vid);
} }
// read ranges
VarNodes neighs; Ranges ranges (nrArgs);
unsigned nParams = 1; for (unsigned j = 0; j < nrArgs; j++) {
for (unsigned j = 0; j < nVars; j++) { ignoreLines (is);
unsigned dsize; is >> ranges[j];
while ((is.peek()) == '#') getline (is, line); VarNode* var = getVarNode (vids[j]);
is >> dsize; if (var != 0 && ranges[j] != var->range()) {
FgVarNode* var = getFgVarNode (vids[j]); cerr << "error: variable `" << vids[j] << "' appears in two or " ;
if (var == 0) { cerr << "more factors with a different range" << endl;
var = new FgVarNode (vids[j], dsize);
addVariable (var);
} else {
if (var->nrStates() != dsize) {
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
cerr << "more factors with different domain sizes" << endl;
}
} }
neighs.push_back (var);
nParams *= var->nrStates();
} }
Params params (nParams, 0); // read parameters
ignoreLines (is);
unsigned nNonzeros; unsigned nNonzeros;
while ((is.peek()) == '#') getline (is, line);
is >> nNonzeros; is >> nNonzeros;
Params params (Util::expectedSize (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) { for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
unsigned index; unsigned index;
double val;
while ((is.peek()) == '#') getline (is, line);
is >> index; is >> index;
while ((is.peek()) == '#') getline (is, line); ignoreLines (is);
double val;
is >> val; is >> val;
params[index] = val; params[index] = val;
} }
reverse (neighs.begin(), neighs.end()); reverse (vids.begin(), vids.end());
if (Globals::logDomain) { if (Globals::logDomain) {
Util::toLog (params); Util::toLog (params);
} }
FgFacNode* fn = new FgFacNode (new Factor (neighs, params)); addFactor (Factor (vids, ranges, params));
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
}
} }
is.close(); is.close();
setIndexes();
} }
@ -256,17 +180,41 @@ FactorGraph::~FactorGraph (void)
void void
FactorGraph::addVariable (FgVarNode* vn) FactorGraph::addFactor (const Factor& factor)
{ {
varNodes_.push_back (vn); FacNode* fn = new FacNode (factor);
vn->setIndex (varNodes_.size() - 1); addFacNode (fn);
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1)); const VarIds& vids = factor.arguments();
for (unsigned i = 0; i < vids.size(); i++) {
bool found = false;
for (unsigned j = 0; j < varNodes_.size(); j++) {
if (varNodes_[j]->varId() == vids[i]) {
addEdge (varNodes_[j], fn);
found = true;
}
}
if (found == false) {
VarNode* vn = new VarNode (vids[i], factor.range (i));
addVarNode (vn);
addEdge (vn, fn);
}
}
} }
void void
FactorGraph::addFactor (FgFacNode* fn) FactorGraph::addVarNode (VarNode* vn)
{
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
varMap_.insert (make_pair (vn->varId(), vn));
}
void
FactorGraph::addFacNode (FacNode* fn)
{ {
facNodes_.push_back (fn); facNodes_.push_back (fn);
fn->setIndex (facNodes_.size() - 1); fn->setIndex (facNodes_.size() - 1);
@ -275,7 +223,7 @@ FactorGraph::addFactor (FgFacNode* fn)
void void
FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn) FactorGraph::addEdge (VarNode* vn, FacNode* fn)
{ {
vn->addNeighbor (fn); vn->addNeighbor (fn);
fn->addNeighbor (vn); fn->addNeighbor (vn);
@ -283,37 +231,6 @@ FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
void
FactorGraph::addEdge (FgFacNode* fn, FgVarNode* vn)
{
fn->addNeighbor (vn);
vn->addNeighbor (fn);
}
VarNode*
FactorGraph::getVariableNode (VarId vid) const
{
FgVarNode* vn = getFgVarNode (vid);
assert (vn);
return vn;
}
VarNodes
FactorGraph::getVariableNodes (void) const
{
VarNodes vars;
for (unsigned i = 0; i < varNodes_.size(); i++) {
vars.push_back (varNodes_[i]);
}
return vars;
}
bool bool
FactorGraph::isTree (void) const FactorGraph::isTree (void) const
{ {
@ -322,51 +239,42 @@ FactorGraph::isTree (void) const
void DAGraph&
FactorGraph::setIndexes (void) FactorGraph::getStructure (void)
{ {
for (unsigned i = 0; i < varNodes_.size(); i++) { assert (fromBayesNet_);
varNodes_[i]->setIndex (i); if (structure_.empty()) {
} for (unsigned i = 0; i < varNodes_.size(); i++) {
for (unsigned i = 0; i < facNodes_.size(); i++) { structure_.addNode (new DAGraphNode (varNodes_[i]));
facNodes_[i]->setIndex (i); }
for (unsigned i = 0; i < facNodes_.size(); i++) {
const VarIds& vids = facNodes_[i]->factor().arguments();
for (unsigned j = 1; j < vids.size(); j++) {
structure_.addEdge (vids[j], vids[0]);
}
}
} }
return structure_;
} }
void void
FactorGraph::freeDistributions (void) FactorGraph::print (void) const
{
set<Distribution*> dists;
for (unsigned i = 0; i < facNodes_.size(); i++) {
dists.insert (facNodes_[i]->factor()->getDistribution());
}
for (set<Distribution*>::iterator it = dists.begin();
it != dists.end(); it++) {
delete *it;
}
}
void
FactorGraph::printGraphicalModel (void) const
{ {
for (unsigned i = 0; i < varNodes_.size(); i++) { for (unsigned i = 0; i < varNodes_.size(); i++) {
cout << "VarId = " << varNodes_[i]->varId() << endl; cout << "var id = " << varNodes_[i]->varId() << endl;
cout << "Label = " << varNodes_[i]->label() << endl; cout << "label = " << varNodes_[i]->label() << endl;
cout << "Nr States = " << varNodes_[i]->nrStates() << endl; cout << "range = " << varNodes_[i]->range() << endl;
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl; cout << "evidence = " << varNodes_[i]->getEvidence() << endl;
cout << "Factors = " ; cout << "factors = " ;
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) { for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ; cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
} }
cout << endl << endl; cout << endl << endl;
} }
for (unsigned i = 0; i < facNodes_.size(); i++) { for (unsigned i = 0; i < facNodes_.size(); i++) {
facNodes_[i]->factor()->print(); facNodes_[i]->factor().print();
cout << endl;
} }
} }
@ -381,31 +289,26 @@ FactorGraph::exportToGraphViz (const char* fileName) const
cerr << "FactorGraph::exportToDotFile()" << endl; cerr << "FactorGraph::exportToDotFile()" << endl;
abort(); abort();
} }
out << "graph \"" << fileName << "\" {" << endl; out << "graph \"" << fileName << "\" {" << endl;
for (unsigned i = 0; i < varNodes_.size(); i++) { for (unsigned i = 0; i < varNodes_.size(); i++) {
if (varNodes_[i]->hasEvidence()) { if (varNodes_[i]->hasEvidence()) {
out << '"' << varNodes_[i]->label() << '"' ; out << '"' << varNodes_[i]->label() << '"' ;
out << " [style=filled, fillcolor=yellow]" << endl; out << " [style=filled, fillcolor=yellow]" << endl;
} }
} }
for (unsigned i = 0; i < facNodes_.size(); i++) { for (unsigned i = 0; i < facNodes_.size(); i++) {
out << '"' << facNodes_[i]->getLabel() << '"' ; out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " [label=\"" << facNodes_[i]->getLabel(); out << " [label=\"" << facNodes_[i]->getLabel();
out << "\"" << ", shape=box]" << endl; out << "\"" << ", shape=box]" << endl;
} }
for (unsigned i = 0; i < facNodes_.size(); i++) { for (unsigned i = 0; i < facNodes_.size(); i++) {
const FgVarSet& myVars = facNodes_[i]->neighbors(); const VarNodes& myVars = facNodes_[i]->neighbors();
for (unsigned j = 0; j < myVars.size(); j++) { for (unsigned j = 0; j < myVars.size(); j++) {
out << '"' << facNodes_[i]->getLabel() << '"' ; out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " -- " ; out << " -- " ;
out << '"' << myVars[j]->label() << '"' << endl; out << '"' << myVars[j]->label() << '"' << endl;
} }
} }
out << "}" << endl; out << "}" << endl;
out.close(); out.close();
} }
@ -417,30 +320,26 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
{ {
ofstream out (fileName); ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "error: cannot open file to write at " ; cerr << "error: cannot open file " << fileName << endl;
cerr << "FactorGraph::exportToUaiFormat()" << endl;
abort(); abort();
} }
out << "MARKOV" << endl; out << "MARKOV" << endl;
out << varNodes_.size() << endl; out << varNodes_.size() << endl;
for (unsigned i = 0; i < varNodes_.size(); i++) { for (unsigned i = 0; i < varNodes_.size(); i++) {
out << varNodes_[i]->nrStates() << " " ; out << varNodes_[i]->range() << " " ;
} }
out << endl; out << endl;
out << facNodes_.size() << endl; out << facNodes_.size() << endl;
for (unsigned i = 0; i < facNodes_.size(); i++) { for (unsigned i = 0; i < facNodes_.size(); i++) {
const FgVarSet& factorVars = facNodes_[i]->neighbors(); const VarNodes& factorVars = facNodes_[i]->neighbors();
out << factorVars.size(); out << factorVars.size();
for (unsigned j = 0; j < factorVars.size(); j++) { for (unsigned j = 0; j < factorVars.size(); j++) {
out << " " << factorVars[j]->getIndex(); out << " " << factorVars[j]->getIndex();
} }
out << endl; out << endl;
} }
for (unsigned i = 0; i < facNodes_.size(); i++) { for (unsigned i = 0; i < facNodes_.size(); i++) {
Params params = facNodes_[i]->getParameters(); Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::fromLog (params); Util::fromLog (params);
} }
@ -450,7 +349,6 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
} }
out << endl; out << endl;
} }
out.close(); out.close();
} }
@ -461,23 +359,22 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
{ {
ofstream out (fileName); ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "error: cannot open file to write at " ; cerr << "error: cannot open file " << fileName << endl;
cerr << "FactorGraph::exportToLibDaiFormat()" << endl;
abort(); abort();
} }
out << facNodes_.size() << endl << endl; out << facNodes_.size() << endl << endl;
for (unsigned i = 0; i < facNodes_.size(); i++) { for (unsigned i = 0; i < facNodes_.size(); i++) {
const FgVarSet& factorVars = facNodes_[i]->neighbors(); const VarNodes& factorVars = facNodes_[i]->neighbors();
out << factorVars.size() << endl; out << factorVars.size() << endl;
for (int j = factorVars.size() - 1; j >= 0; j--) { for (int j = factorVars.size() - 1; j >= 0; j--) {
out << factorVars[j]->varId() << " " ; out << factorVars[j]->varId() << " " ;
} }
out << endl; out << endl;
for (unsigned j = 0; j < factorVars.size(); j++) { for (unsigned j = 0; j < factorVars.size(); j++) {
out << factorVars[j]->nrStates() << " " ; out << factorVars[j]->range() << " " ;
} }
out << endl; out << endl;
Params params = facNodes_[i]->factor()->getParameters(); Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::fromLog (params); Util::fromLog (params);
} }
@ -492,6 +389,17 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
void
FactorGraph::ignoreLines (std::ifstream& is) const
{
string ignoreStr;
while (is.peek() == '#' || is.peek() == '\n') {
getline (is, ignoreStr);
}
}
bool bool
FactorGraph::containsCycle (void) const FactorGraph::containsCycle (void) const
{ {
@ -511,13 +419,14 @@ FactorGraph::containsCycle (void) const
bool bool
FactorGraph::containsCycle (const FgVarNode* v, FactorGraph::containsCycle (
const FgFacNode* p, const VarNode* v,
vector<bool>& visitedVars, const FacNode* p,
vector<bool>& visitedFactors) const vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
{ {
visitedVars[v->getIndex()] = true; visitedVars[v->getIndex()] = true;
const FgFacSet& adjacencies = v->neighbors(); const FacNodes& adjacencies = v->neighbors();
for (unsigned i = 0; i < adjacencies.size(); i++) { for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i]->getIndex(); int w = adjacencies[i]->getIndex();
if (!visitedFactors[w]) { if (!visitedFactors[w]) {
@ -535,13 +444,14 @@ FactorGraph::containsCycle (const FgVarNode* v,
bool bool
FactorGraph::containsCycle (const FgFacNode* v, FactorGraph::containsCycle (
const FgVarNode* p, const FacNode* v,
vector<bool>& visitedVars, const VarNode* p,
vector<bool>& visitedFactors) const vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
{ {
visitedFactors[v->getIndex()] = true; visitedFactors[v->getIndex()] = true;
const FgVarSet& adjacencies = v->neighbors(); const VarNodes& adjacencies = v->neighbors();
for (unsigned i = 0; i < adjacencies.size(); i++) { for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i]->getIndex(); int w = adjacencies[i]->getIndex();
if (!visitedVars[w]) { if (!visitedVars[w]) {

View File

@ -3,135 +3,139 @@
#include <vector> #include <vector>
#include "GraphicalModel.h"
#include "Distribution.h"
#include "Factor.h" #include "Factor.h"
#include "BayesNet.h"
#include "Horus.h" #include "Horus.h"
using namespace std; using namespace std;
class BayesNet;
class FgFacNode;
class FgVarNode : public VarNode class FacNode;
class VarNode : public Var
{ {
public: public:
FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { } VarNode (VarId varId, unsigned nrStates)
FgVarNode (const VarNode* v) : VarNode (v) { } : Var (varId, nrStates) { }
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); } VarNode (const Var* v) : Var (v) { }
const FgFacSet& neighbors (void) const { return neighs_; }
void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
const FacNodes& neighbors (void) const { return neighs_; }
private: private:
DISALLOW_COPY_AND_ASSIGN (FgVarNode); DISALLOW_COPY_AND_ASSIGN (VarNode);
// members
FgFacSet neighs_; FacNodes neighs_;
}; };
class FgFacNode class FacNode
{ {
public: public:
FgFacNode (const FgFacNode* fn) { FacNode (const Factor& f) : factor_(f), index_(-1) { }
factor_ = new Factor (*fn->factor());
index_ = -1; const Factor& factor (void) const { return factor_; }
}
FgFacNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { } Factor& factor (void) { return factor_; }
Factor* factor() const { return factor_; }
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); } void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
const FgVarSet& neighbors (void) const { return neighs_; }
const VarNodes& neighbors (void) const { return neighs_; }
int getIndex (void) const { return index_; }
void setIndex (int index) { index_ = index; }
string getLabel (void) { return factor_.getLabel(); }
int getIndex (void) const
{
assert (index_ != -1);
return index_;
}
void setIndex (int index)
{
index_ = index;
}
Distribution* getDistribution (void)
{
return factor_->getDistribution();
}
const Params& getParameters (void) const
{
return factor_->getParameters();
}
string getLabel (void)
{
return factor_->getLabel();
}
private: private:
DISALLOW_COPY_AND_ASSIGN (FgFacNode); DISALLOW_COPY_AND_ASSIGN (FacNode);
Factor* factor_; VarNodes neighs_;
int index_; Factor factor_;
FgVarSet neighs_; int index_;
}; };
struct CompVarId struct CompVarId
{ {
bool operator() (const VarNode* vn1, const VarNode* vn2) const bool operator() (const Var* v1, const Var* v2) const
{ {
return vn1->varId() < vn2->varId(); return v1->varId() < v2->varId();
} }
}; };
class FactorGraph : public GraphicalModel class FactorGraph
{ {
public: public:
FactorGraph (void) {}; FactorGraph (bool fbn = false) : fromBayesNet_(fbn) { }
FactorGraph (const FactorGraph&); FactorGraph (const FactorGraph&);
FactorGraph (const BayesNet&);
~FactorGraph (void); ~FactorGraph (void);
void readFromUaiFormat (const char*); const VarNodes& varNodes (void) const { return varNodes_; }
void readFromLibDaiFormat (const char*);
void addVariable (FgVarNode*);
void addFactor (FgFacNode*);
void addEdge (FgVarNode*, FgFacNode*);
void addEdge (FgFacNode*, FgVarNode*);
VarNode* getVariableNode (unsigned) const;
VarNodes getVariableNodes (void) const;
bool isTree (void) const;
void setIndexes (void);
void freeDistributions (void);
void printGraphicalModel (void) const;
void exportToGraphViz (const char*) const;
void exportToUaiFormat (const char*) const;
void exportToLibDaiFormat (const char*) const;
const FgVarSet& getVarNodes (void) const { return varNodes_; }
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
FgVarNode* getFgVarNode (VarId vid) const const FacNodes& facNodes (void) const { return facNodes_; }
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
VarNode* getVarNode (VarId vid) const
{ {
IndexMap::const_iterator it = varMap_.find (vid); VarMap::const_iterator it = varMap_.find (vid);
if (it == varMap_.end()) { return it != varMap_.end() ? it->second : 0;
return 0;
} else {
return varNodes_[it->second];
}
} }
void readFromUaiFormat (const char*);
void readFromLibDaiFormat (const char*);
void addFactor (const Factor& factor);
void addVarNode (VarNode*);
void addFacNode (FacNode*);
void addEdge (VarNode*, FacNode*);
bool isTree (void) const;
DAGraph& getStructure (void);
void print (void) const;
void exportToGraphViz (const char*) const;
void exportToUaiFormat (const char*) const;
void exportToLibDaiFormat (const char*) const;
static bool orderFactorVariables; static bool orderFactorVariables;
private: private:
//DISALLOW_COPY_AND_ASSIGN (FactorGraph); // DISALLOW_COPY_AND_ASSIGN (FactorGraph);
bool containsCycle (void) const;
bool containsCycle (const FgVarNode*, const FgFacNode*,
vector<bool>&, vector<bool>&) const;
bool containsCycle (const FgFacNode*, const FgVarNode*,
vector<bool>&, vector<bool>&) const;
FgVarSet varNodes_; void ignoreLines (std::ifstream&) const;
FgFacSet facNodes_;
typedef unordered_map<unsigned, unsigned> IndexMap; bool containsCycle (void) const;
IndexMap varMap_;
bool containsCycle (const VarNode*, const FacNode*,
vector<bool>&, vector<bool>&) const;
bool containsCycle (const FacNode*, const VarNode*,
vector<bool>&, vector<bool>&) const;
VarNodes varNodes_;
FacNodes facNodes_;
DAGraph structure_;
bool fromBayesNet_;
typedef unordered_map<unsigned, VarNode*> VarMap;
VarMap varMap_;
}; };
#endif // HORUS_FACTORGRAPH_H #endif // HORUS_FACTORGRAPH_H

View File

@ -1,175 +0,0 @@
#ifndef HORUS_FGBPSOLVER_H
#define HORUS_FGBPSOLVER_H
#include <set>
#include <vector>
#include <sstream>
#include "Solver.h"
#include "Factor.h"
#include "FactorGraph.h"
#include "Util.h"
using namespace std;
class SpLink
{
public:
SpLink (FgFacNode* fn, FgVarNode* vn)
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
v2_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
msgSended_ = false;
residual_ = 0.0;
}
virtual ~SpLink (void) {};
virtual void updateMessage (void)
{
swap (currMsg_, nextMsg_);
msgSended_ = true;
}
void updateResidual (void)
{
residual_ = Util::getMaxNorm (v1_, v2_);
}
string toString (void) const
{
stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
FgFacNode* getFactor (void) const { return fac_; }
FgVarNode* getVariable (void) const { return var_; }
const Params& getMessage (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; }
protected:
FgFacNode* fac_;
FgVarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
bool msgSended_;
double residual_;
};
typedef vector<SpLink*> SpLinkSet;
class SPNodeInfo
{
public:
void addSpLink (SpLink* link) { links_.push_back (link); }
const SpLinkSet& getLinks (void) { return links_; }
private:
SpLinkSet links_;
};
class FgBpSolver : public Solver
{
public:
FgBpSolver (const FactorGraph&);
virtual ~FgBpSolver (void);
void runSolver (void);
virtual Params getPosterioriOf (VarId);
virtual Params getJointDistributionOf (const VarIds&);
protected:
virtual void initializeSolver (void);
virtual void createLinks (void);
virtual void maxResidualSchedule (void);
virtual void calculateFactor2VariableMsg (SpLink*) const;
virtual Params getVar2FactorMsg (const SpLink*) const;
virtual Params getJointByConditioning (const VarIds&) const;
virtual void printLinkInformation (void) const;
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
cout << "calculating & updating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void calculateMessage (SpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
cout << "calculating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
}
void updateMessage (SpLink* link)
{
link->updateMessage();
if (DL >= 3) {
cout << "updating " << link->toString() << endl;
}
}
SPNodeInfo* ninf (const FgVarNode* var) const
{
return varsI_[var->getIndex()];
}
SPNodeInfo* ninf (const FgFacNode* fac) const
{
return facsI_[fac->getIndex()];
}
struct CompareResidual {
inline bool operator() (const SpLink* link1, const SpLink* link2)
{
return link1->getResidual() > link2->getResidual();
}
};
SpLinkSet links_;
unsigned nIters_;
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_;
const FactorGraph* factorGraph_;
typedef multiset<SpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_;
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
SpLinkMap linkMap_;
private:
void runLoopySolver (void);
bool converged (void);
};
#endif // HORUS_FGBPSOLVER_H

View File

@ -8,7 +8,9 @@
vector<LiftedOperator*> vector<LiftedOperator*>
LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query) LiftedOperator::getValidOps (
ParfactorList& pfList,
const Grounds& query)
{ {
vector<LiftedOperator*> validOps; vector<LiftedOperator*> validOps;
vector<SumOutOperator*> sumOutOps; vector<SumOutOperator*> sumOutOps;
@ -28,12 +30,15 @@ LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
void void
LiftedOperator::printValidOps (ParfactorList& pfList, const Grounds& query) LiftedOperator::printValidOps (
ParfactorList& pfList,
const Grounds& query)
{ {
vector<LiftedOperator*> validOps; vector<LiftedOperator*> validOps;
validOps = LiftedOperator::getValidOps (pfList, query); validOps = LiftedOperator::getValidOps (pfList, query);
for (unsigned i = 0; i < validOps.size(); i++) { for (unsigned i = 0; i < validOps.size(); i++) {
cout << "-> " << validOps[i]->toString() << endl; cout << "-> " << validOps[i]->toString() << endl;
delete validOps[i];
} }
} }
@ -56,14 +61,14 @@ SumOutOperator::getCost (void)
pfIter = pfList_.begin(); pfIter = pfList_.begin();
while (pfIter != pfList_.end()) { while (pfIter != pfList_.end()) {
if ((*pfIter)->containsGroup (groupSet[i])) { if ((*pfIter)->containsGroup (groupSet[i])) {
int idx = (*pfIter)->indexOfFormulaWithGroup (groupSet[i]); int idx = (*pfIter)->indexOfGroup (groupSet[i]);
cost *= (*pfIter)->range (idx); cost *= (*pfIter)->range (idx);
break; break;
} }
++ pfIter; ++ pfIter;
} }
} }
return cost; return cost;
} }
@ -77,14 +82,13 @@ SumOutOperator::apply (void)
pfList_.remove (iters[0]); pfList_.remove (iters[0]);
for (unsigned i = 1; i < iters.size(); i++) { for (unsigned i = 1; i < iters.size(); i++) {
product->multiply (**(iters[i])); product->multiply (**(iters[i]));
delete *(iters[i]); pfList_.removeAndDelete (iters[i]);
pfList_.remove (iters[i]);
} }
if (product->nrFormulas() == 1) { if (product->nrArguments() == 1) {
delete product; delete product;
return; return;
} }
int fIdx = product->indexOfFormulaWithGroup (group_); int fIdx = product->indexOfGroup (group_);
LogVarSet excl = product->exclusiveLogVars (fIdx); LogVarSet excl = product->exclusiveLogVars (fIdx);
if (product->constr()->isCountNormalized (excl)) { if (product->constr()->isCountNormalized (excl)) {
product->sumOut (fIdx); product->sumOut (fIdx);
@ -96,21 +100,21 @@ SumOutOperator::apply (void)
pfList_.add (pfs[i]); pfList_.add (pfs[i]);
} }
delete product; delete product;
pfList_.shatter();
} }
} }
vector<SumOutOperator*> vector<SumOutOperator*>
SumOutOperator::getValidOps (ParfactorList& pfList, const Grounds& query) SumOutOperator::getValidOps (
ParfactorList& pfList,
const Grounds& query)
{ {
vector<SumOutOperator*> validOps; vector<SumOutOperator*> validOps;
set<unsigned> allGroups; set<unsigned> allGroups;
ParfactorList::const_iterator it = pfList.begin(); ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) { while (it != pfList.end()) {
assert (*it); const ProbFormulas& formulas = (*it)->arguments();
const ProbFormulas& formulas = (*it)->formulas();
for (unsigned i = 0; i < formulas.size(); i++) { for (unsigned i = 0; i < formulas.size(); i++) {
allGroups.insert (formulas[i].group()); allGroups.insert (formulas[i].group());
} }
@ -134,8 +138,8 @@ SumOutOperator::toString (void)
stringstream ss; stringstream ss;
vector<ParfactorList::iterator> pfIters; vector<ParfactorList::iterator> pfIters;
pfIters = parfactorsWithGroup (pfList_, group_); pfIters = parfactorsWithGroup (pfList_, group_);
int idx = (*pfIters[0])->indexOfFormulaWithGroup (group_); int idx = (*pfIters[0])->indexOfGroup (group_);
ProbFormula f = (*pfIters[0])->formula (idx); ProbFormula f = (*pfIters[0])->argument (idx);
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars()); TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
ss << "sum out " << f.functor() << "/" << f.arity(); ss << "sum out " << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")"; ss << "|" << tupleSet << " (group " << group_ << ")";
@ -158,9 +162,9 @@ SumOutOperator::validOp (
} }
unordered_map<unsigned, unsigned> groupToRange; unordered_map<unsigned, unsigned> groupToRange;
for (unsigned i = 0; i < pfIters.size(); i++) { for (unsigned i = 0; i < pfIters.size(); i++) {
int fIdx = (*pfIters[i])->indexOfFormulaWithGroup (group); int fIdx = (*pfIters[i])->indexOfGroup (group);
if ((*pfIters[i])->formulas()[fIdx].contains ( if ((*pfIters[i])->argument (fIdx).contains (
(*pfIters[i])->elimLogVars()) == false) { (*pfIters[i])->elimLogVars()) == false) {
return false; return false;
} }
vector<unsigned> ranges = (*pfIters[i])->ranges(); vector<unsigned> ranges = (*pfIters[i])->ranges();
@ -206,8 +210,8 @@ SumOutOperator::isToEliminate (
unsigned group, unsigned group,
const Grounds& query) const Grounds& query)
{ {
int fIdx = g->indexOfFormulaWithGroup (group); int fIdx = g->indexOfGroup (group);
const ProbFormula& formula = g->formula (fIdx); const ProbFormula& formula = g->argument (fIdx);
bool toElim = true; bool toElim = true;
for (unsigned i = 0; i < query.size(); i++) { for (unsigned i = 0; i < query.size(); i++) {
if (formula.functor() == query[i].functor() && if (formula.functor() == query[i].functor() &&
@ -228,7 +232,7 @@ unsigned
CountingOperator::getCost (void) CountingOperator::getCost (void)
{ {
unsigned cost = 0; unsigned cost = 0;
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_); int fIdx = (*pfIter_)->indexOfLogVar (X_);
unsigned range = (*pfIter_)->range (fIdx); unsigned range = (*pfIter_)->range (fIdx);
unsigned size = (*pfIter_)->size() / range; unsigned size = (*pfIter_)->size() / range;
TinySet<unsigned> counts; TinySet<unsigned> counts;
@ -247,18 +251,19 @@ CountingOperator::apply (void)
if ((*pfIter_)->constr()->isCountNormalized (X_)) { if ((*pfIter_)->constr()->isCountNormalized (X_)) {
(*pfIter_)->countConvert (X_); (*pfIter_)->countConvert (X_);
} else { } else {
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_); Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_);
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
for (unsigned i = 0; i < pfs.size(); i++) { for (unsigned i = 0; i < pfs.size(); i++) {
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_); unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
bool cartProduct = pfs[i]->constr()->isCarteesianProduct ( bool cartProduct = pfs[i]->constr()->isCarteesianProduct (
(*pfIter_)->countedLogVars() | X_); pfs[i]->countedLogVars() | X_);
if (condCount > 1 && cartProduct) { if (condCount > 1 && cartProduct) {
pfs[i]->countConvert (X_); pfs[i]->countConvert (X_);
} }
pfList_.add (pfs[i]); pfList_.add (pfs[i]);
} }
pfList_.deleteAndRemove (pfIter_); delete pf;
pfList_.shatter();
} }
} }
@ -289,14 +294,17 @@ CountingOperator::toString (void)
{ {
stringstream ss; stringstream ss;
ss << "count convert " << X_ << " in " ; ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getHeaderString(); ss << (*pfIter_)->getLabel();
ss << " [cost=" << getCost() << "]" << endl; ss << " [cost=" << getCost() << "]" << endl;
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_); Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) { if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (unsigned i = 0; i < pfs.size(); i++) { for (unsigned i = 0; i < pfs.size(); i++) {
ss << " º " << pfs[i]->getHeaderString() << endl; ss << " º " << pfs[i]->getLabel() << endl;
} }
} }
for (unsigned i = 0; i < pfs.size(); i++) {
delete pfs[i];
}
return ss.str(); return ss.str();
} }
@ -308,8 +316,8 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
if (g->nrFormulas (X) != 1) { if (g->nrFormulas (X) != 1) {
return false; return false;
} }
int fIdx = g->indexOfFormulaWithLogVar (X); int fIdx = g->indexOfLogVar (X);
if (g->formulas()[fIdx].isCounting()) { if (g->argument (fIdx).isCounting()) {
return false; return false;
} }
bool countNormalized = g->constr()->isCountNormalized (X); bool countNormalized = g->constr()->isCountNormalized (X);
@ -332,10 +340,10 @@ GroundOperator::getCost (void)
unsigned cost = 0; unsigned cost = 0;
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_); bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
if (isCountingLv) { if (isCountingLv) {
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_); int fIdx = (*pfIter_)->indexOfLogVar (X_);
unsigned currSize = (*pfIter_)->size(); unsigned currSize = (*pfIter_)->size();
unsigned nrHists = (*pfIter_)->range (fIdx); unsigned nrHists = (*pfIter_)->range (fIdx);
unsigned range = (*pfIter_)->formula(fIdx).range(); unsigned range = (*pfIter_)->argument (fIdx).range();
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_); unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
cost = (currSize / nrHists) * (std::pow (range, nrSymbols)); cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
} else { } else {
@ -350,18 +358,17 @@ void
GroundOperator::apply (void) GroundOperator::apply (void)
{ {
bool countedLv = (*pfIter_)->countedLogVars().contains (X_); bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_);
if (countedLv) { if (countedLv) {
(*pfIter_)->fullExpand (X_); pf->fullExpand (X_);
(*pfIter_)->setNewGroups(); pfList_.add (pf);
pfList_.shatter();
} else { } else {
ConstraintTrees cts = (*pfIter_)->constr()->ground (X_); ConstraintTrees cts = pf->constr()->ground (X_);
for (unsigned i = 0; i < cts.size(); i++) { for (unsigned i = 0; i < cts.size(); i++) {
Parfactor* newPf = new Parfactor (*pfIter_, cts[i]); pfList_.add (new Parfactor (pf, cts[i]));
pfList_.add (newPf);
} }
pfList_.deleteAndRemove (pfIter_); delete pf;
pfList_.shatter();
} }
} }
@ -393,24 +400,13 @@ GroundOperator::toString (void)
((*pfIter_)->countedLogVars().contains (X_)) ((*pfIter_)->countedLogVars().contains (X_))
? ss << "full expanding " ? ss << "full expanding "
: ss << "grounding " ; : ss << "grounding " ;
ss << X_ << " in " << (*pfIter_)->getHeaderString(); ss << X_ << " in " << (*pfIter_)->getLabel();
ss << " [cost=" << getCost() << "]" << endl; ss << " [cost=" << getCost() << "]" << endl;
return ss.str(); return ss.str();
} }
FoveSolver::FoveSolver (const ParfactorList* pfList)
{
for (ParfactorList::const_iterator it = pfList->begin();
it != pfList->end();
it ++) {
pfList_.addShattered (new Parfactor (**it));
}
}
Params Params
FoveSolver::getPosterioriOf (const Ground& query) FoveSolver::getPosterioriOf (const Ground& query)
{ {
@ -422,14 +418,12 @@ FoveSolver::getPosterioriOf (const Ground& query)
Params Params
FoveSolver::getJointDistributionOf (const Grounds& query) FoveSolver::getJointDistributionOf (const Grounds& query)
{ {
shatterAgainstQuery (query);
runSolver (query); runSolver (query);
(*pfList_.begin())->normalize(); (*pfList_.begin())->normalize();
Params params = (*pfList_.begin())->params(); Params params = (*pfList_.begin())->params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::fromLog (params); Util::fromLog (params);
} }
delete *pfList_.begin();
return params; return params;
} }
@ -438,32 +432,38 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
void void
FoveSolver::absorveEvidence ( FoveSolver::absorveEvidence (
ParfactorList& pfList, ParfactorList& pfList,
const ObservedFormulas& obsFormulas) ObservedFormulas& obsFormulas)
{ {
ParfactorList::iterator it = pfList.begin(); for (unsigned i = 0; i < obsFormulas.size(); i++) {
while (it != pfList.end()) { Parfactors newPfs;
bool increment = true; ParfactorList::iterator it = pfList.begin();
for (unsigned i = 0; i < obsFormulas.size(); i++) { while (it != pfList.end()) {
if (absorved (pfList, it, obsFormulas[i])) { Parfactor* pf = *it;
it = pfList.deleteAndRemove (it); it = pfList.remove (it);
increment = false; Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
break; if (absorvedPfs.empty() == false) {
} if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
} // just remove pf;
if (increment) { } else {
++ it; Util::addToVector (newPfs, absorvedPfs);
}
delete pf;
} else {
it = pfList.insertShattered (it, pf);
++ it;
}
} }
pfList.add (newPfs);
} }
pfList.shatter(); if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) {
if (obsFormulas.empty() == false) { Util::printAsteriskLine();
cout << "*******************************************************" << endl;
cout << "AFTER EVIDENCE ABSORVED" << endl; cout << "AFTER EVIDENCE ABSORVED" << endl;
for (unsigned i = 0; i < obsFormulas.size(); i++) { for (unsigned i = 0; i < obsFormulas.size(); i++) {
cout << " -> " << *obsFormulas[i] << endl; cout << " -> " << obsFormulas[i] << endl;
} }
cout << "*******************************************************" << endl; Util::printAsteriskLine();
pfList.print();
} }
pfList.print();
} }
@ -473,14 +473,14 @@ FoveSolver::countNormalize (
Parfactor* g, Parfactor* g,
const LogVarSet& set) const LogVarSet& set)
{ {
if (set.empty()) {
assert (false); // TODO
return {};
}
Parfactors normPfs; Parfactors normPfs;
ConstraintTrees normCts = g->constr()->countNormalize (set); if (set.empty()) {
for (unsigned i = 0; i < normCts.size(); i++) { normPfs.push_back (new Parfactor (*g));
normPfs.push_back (new Parfactor (g, normCts[i])); } else {
ConstraintTrees normCts = g->constr()->countNormalize (set);
for (unsigned i = 0; i < normCts.size(); i++) {
normPfs.push_back (new Parfactor (g, normCts[i]));
}
} }
return normPfs; return normPfs;
} }
@ -490,17 +490,25 @@ FoveSolver::countNormalize (
void void
FoveSolver::runSolver (const Grounds& query) FoveSolver::runSolver (const Grounds& query)
{ {
shatterAgainstQuery (query);
runWeakBayesBall (query);
while (true) { while (true) {
cout << "---------------------------------------------------" << endl; if (Constants::DEBUG >= 2) {
pfList_.print(); Util::printDashedLine();
LiftedOperator::printValidOps (pfList_, query); pfList_.print();
LiftedOperator::printValidOps (pfList_, query);
}
LiftedOperator* op = getBestOperation (query); LiftedOperator* op = getBestOperation (query);
if (op == 0) { if (op == 0) {
break; break;
} }
cout << "best operation: " << op->toString() << endl; if (Constants::DEBUG >= 2) {
cout << "best operation: " << op->toString() << endl;
}
op->apply(); op->apply();
delete op;
} }
assert (pfList_.size() > 0);
if (pfList_.size() > 1) { if (pfList_.size() > 1) {
ParfactorList::iterator pfIter = pfList_.begin(); ParfactorList::iterator pfIter = pfList_.begin();
pfIter ++; pfIter ++;
@ -514,26 +522,6 @@ FoveSolver::runSolver (const Grounds& query)
bool
FoveSolver::allEliminated (const Grounds&)
{
ParfactorList::iterator pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
const ProbFormulas formulas = (*pfIter)->formulas();
for (unsigned i = 0; i < formulas.size(); i++) {
//bool toElim = false;
//for (unsigned j = 0; j < queries.size(); j++) {
// if ((*pfIter)->containsGround (queries[i]) == false) {
// return
// }
}
++ pfIter;
}
return false;
}
LiftedOperator* LiftedOperator*
FoveSolver::getBestOperation (const Grounds& query) FoveSolver::getBestOperation (const Grounds& query)
{ {
@ -548,156 +536,176 @@ FoveSolver::getBestOperation (const Grounds& query)
bestCost = cost; bestCost = cost;
} }
} }
for (unsigned i = 0; i < validOps.size(); i++) {
if (validOps[i] != bestOp) {
delete validOps[i];
}
}
return bestOp; return bestOp;
} }
void
FoveSolver::runWeakBayesBall (const Grounds& query)
{
queue<unsigned> todo; // groups to process
set<unsigned> done; // processed or in queue
for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
int group = (*it)->findGroup (query[i]);
if (group != -1) {
todo.push (group);
done.insert (group);
break;
}
++ it;
}
}
set<Parfactor*> requiredPfs;
while (todo.empty() == false) {
unsigned group = todo.front();
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) {
vector<unsigned> groups = (*it)->getAllGroups();
for (unsigned i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]);
done.insert (groups[i]);
}
}
requiredPfs.insert (*it);
}
++ it;
}
todo.pop();
}
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false) {
it = pfList_.removeAndDelete (it);
} else {
++ it;
}
}
if (Constants::DEBUG >= 2) {
Util::printHeader ("REQUIRED PARFACTORS");
pfList_.print();
}
}
void void
FoveSolver::shatterAgainstQuery (const Grounds& query) FoveSolver::shatterAgainstQuery (const Grounds& query)
{ {
// return;
for (unsigned i = 0; i < query.size(); i++) { for (unsigned i = 0; i < query.size(); i++) {
if (query[i].isAtom()) { if (query[i].isAtom()) {
continue; continue;
} }
ParfactorList pfListCopy = pfList_; bool found = false;
pfList_.clear(); Parfactors newPfs;
for (ParfactorList::iterator it = pfListCopy.begin(); ParfactorList::iterator it = pfList_.begin();
it != pfListCopy.end(); ++ it) { while (it != pfList_.end()) {
Parfactor* pf = *it; if ((*it)->containsGround (query[i])) {
if (pf->containsGround (query[i])) { found = true;
std::pair<ConstraintTree*, ConstraintTree*> split = std::pair<ConstraintTree*, ConstraintTree*> split =
pf->constr()->split (query[i].args(), query[i].arity()); (*it)->constr()->split (query[i].args(), query[i].arity());
ConstraintTree* commCt = split.first; ConstraintTree* commCt = split.first;
ConstraintTree* exclCt = split.second; ConstraintTree* exclCt = split.second;
pfList_.add (new Parfactor (pf, commCt)); newPfs.push_back (new Parfactor (*it, commCt));
if (exclCt->empty() == false) { if (exclCt->empty() == false) {
pfList_.add (new Parfactor (pf, exclCt)); newPfs.push_back (new Parfactor (*it, exclCt));
} else { } else {
delete exclCt; delete exclCt;
} }
delete pf; it = pfList_.removeAndDelete (it);
} else { } else {
pfList_.add (pf); ++ it;
} }
} }
pfList_.shatter(); if (found == false) {
cerr << "error: could not find a parfactor with ground " ;
cerr << "`" << query[i] << "'" << endl;
exit (0);
}
pfList_.add (newPfs);
} }
cout << endl; if (Constants::DEBUG >= 2) {
cout << "*******************************************************" << endl; cout << endl;
cout << "SHATTERED AGAINST THE QUERY" << endl; Util::printAsteriskLine();
for (unsigned i = 0; i < query.size(); i++) { cout << "SHATTERED AGAINST THE QUERY" << endl;
cout << " -> " << query[i] << endl; for (unsigned i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl;
}
Util::printAsteriskLine();
pfList_.print();
} }
cout << "*******************************************************" << endl;
pfList_.print();
} }
bool Parfactors
FoveSolver::absorved ( FoveSolver::absorve (
ParfactorList& pfList, ObservedFormula& obsFormula,
ParfactorList::iterator pfIter, Parfactor* g)
const ObservedFormula* obsFormula)
{ {
Parfactors absorvedPfs; Parfactors absorvedPfs;
Parfactor* g = *pfIter; const ProbFormulas& formulas = g->arguments();
const ProbFormulas& formulas = g->formulas();
for (unsigned i = 0; i < formulas.size(); i++) { for (unsigned i = 0; i < formulas.size(); i++) {
if (obsFormula->functor() == formulas[i].functor() && if (obsFormula.functor() == formulas[i].functor() &&
obsFormula->arity() == formulas[i].arity()) { obsFormula.arity() == formulas[i].arity()) {
if (obsFormula->isAtom()) { if (obsFormula.isAtom()) {
if (formulas.size() > 1) { if (formulas.size() > 1) {
g->absorveEvidence (i, obsFormula->evidence()); g->absorveEvidence (formulas[i], obsFormula.evidence());
} else { } else {
return true; // hack to erase parfactor g
absorvedPfs.push_back (0);
} }
break;
} }
g->constr()->moveToTop (formulas[i].logVars()); g->constr()->moveToTop (formulas[i].logVars());
std::pair<ConstraintTree*, ConstraintTree*> res std::pair<ConstraintTree*, ConstraintTree*> res
= g->constr()->split (obsFormula->constr(), formulas[i].arity()); = g->constr()->split (&(obsFormula.constr()), formulas[i].arity());
ConstraintTree* commCt = res.first; ConstraintTree* commCt = res.first;
ConstraintTree* exclCt = res.second; ConstraintTree* exclCt = res.second;
if (commCt->empty()) { if (commCt->empty() == false) {
delete commCt; if (formulas.size() > 1) {
delete exclCt; LogVarSet excl = g->exclusiveLogVars (i);
continue; Parfactors countNormPfs = countNormalize (g, excl);
} for (unsigned j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (
if (exclCt->empty() == false) { formulas[i], obsFormula.evidence());
pfList.add (new Parfactor (g, exclCt)); absorvedPfs.push_back (countNormPfs[j]);
} else { }
delete exclCt; } else {
} delete commCt;
if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i);
Parfactors countNormPfs = countNormalize (g, excl);
for (unsigned j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (i, obsFormula->evidence());
absorvedPfs.push_back (countNormPfs[j]);
} }
if (exclCt->empty() == false) {
absorvedPfs.push_back (new Parfactor (g, exclCt));
} else {
delete exclCt;
}
if (absorvedPfs.empty()) {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
} else { } else {
delete commCt; delete commCt;
delete exclCt;
} }
return true;
} }
} }
return false; return absorvedPfs;
}
bool
FoveSolver::proper (
const ProbFormula& f1,
ConstraintTree* c1,
const ProbFormula& f2,
ConstraintTree* c2)
{
return disjoint (f1, c1, f2, c2)
|| identical (f1, c1, f2, c2);
}
bool
FoveSolver::identical (
const ProbFormula& f1,
ConstraintTree* c1,
const ProbFormula& f2,
ConstraintTree* c2)
{
if (f1.sameSkeletonAs (f2) == false) {
return false;
}
c1->moveToTop (f1.logVars());
c2->moveToTop (f2.logVars());
return ConstraintTree::identical (
c1, c2, f1.logVars().size());
}
bool
FoveSolver::disjoint (
const ProbFormula& f1,
ConstraintTree* c1,
const ProbFormula& f2,
ConstraintTree* c2)
{
if (f1.sameSkeletonAs (f2) == false) {
return true;
}
c1->moveToTop (f1.logVars());
c2->moveToTop (f2.logVars());
return ConstraintTree::overlap (
c1, c2, f1.arity()) == false;
} }

View File

@ -9,10 +9,14 @@ class LiftedOperator
{ {
public: public:
virtual unsigned getCost (void) = 0; virtual unsigned getCost (void) = 0;
virtual void apply (void) = 0; virtual void apply (void) = 0;
virtual string toString (void) = 0; virtual string toString (void) = 0;
static vector<LiftedOperator*> getValidOps ( static vector<LiftedOperator*> getValidOps (
ParfactorList&, const Grounds&); ParfactorList&, const Grounds&);
static void printValidOps (ParfactorList&, const Grounds&); static void printValidOps (ParfactorList&, const Grounds&);
}; };
@ -23,18 +27,26 @@ class SumOutOperator : public LiftedOperator
public: public:
SumOutOperator (unsigned group, ParfactorList& pfList) SumOutOperator (unsigned group, ParfactorList& pfList)
: group_(group), pfList_(pfList) { } : group_(group), pfList_(pfList) { }
unsigned getCost (void); unsigned getCost (void);
void apply (void); void apply (void);
static vector<SumOutOperator*> getValidOps ( static vector<SumOutOperator*> getValidOps (
ParfactorList&, const Grounds&); ParfactorList&, const Grounds&);
string toString (void); string toString (void);
private: private:
static bool validOp (unsigned, ParfactorList&, const Grounds&); static bool validOp (unsigned, ParfactorList&, const Grounds&);
static vector<ParfactorList::iterator> parfactorsWithGroup ( static vector<ParfactorList::iterator> parfactorsWithGroup (
ParfactorList& pfList, unsigned group); ParfactorList& pfList, unsigned group);
static bool isToEliminate (Parfactor*, unsigned, const Grounds&); static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
unsigned group_;
ParfactorList& pfList_; unsigned group_;
ParfactorList& pfList_;
}; };
@ -47,15 +59,21 @@ class CountingOperator : public LiftedOperator
LogVar X, LogVar X,
ParfactorList& pfList) ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { } : pfIter_(pfIter), X_(X), pfList_(pfList) { }
unsigned getCost (void); unsigned getCost (void);
void apply (void); void apply (void);
static vector<CountingOperator*> getValidOps (ParfactorList&); static vector<CountingOperator*> getValidOps (ParfactorList&);
string toString (void); string toString (void);
private: private:
static bool validOp (Parfactor*, LogVar); static bool validOp (Parfactor*, LogVar);
ParfactorList::iterator pfIter_;
LogVar X_; ParfactorList::iterator pfIter_;
ParfactorList& pfList_; LogVar X_;
ParfactorList& pfList_;
}; };
@ -68,14 +86,19 @@ class GroundOperator : public LiftedOperator
LogVar X, LogVar X,
ParfactorList& pfList) ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { } : pfIter_(pfIter), X_(X), pfList_(pfList) { }
unsigned getCost (void); unsigned getCost (void);
void apply (void); void apply (void);
static vector<GroundOperator*> getValidOps (ParfactorList&); static vector<GroundOperator*> getValidOps (ParfactorList&);
string toString (void); string toString (void);
private: private:
ParfactorList::iterator pfIter_; ParfactorList::iterator pfIter_;
LogVar X_; LogVar X_;
ParfactorList& pfList_; ParfactorList& pfList_;
}; };
@ -83,49 +106,29 @@ class GroundOperator : public LiftedOperator
class FoveSolver class FoveSolver
{ {
public: public:
FoveSolver (const ParfactorList*); FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
Params getPosterioriOf (const Ground&); Params getPosterioriOf (const Ground&);
Params getJointDistributionOf (const Grounds&);
static void absorveEvidence ( Params getJointDistributionOf (const Grounds&);
ParfactorList& pfList,
const ObservedFormulas& obsFormulas);
static Parfactors countNormalize (Parfactor*, const LogVarSet&); static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas);
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
private: private:
void runSolver (const Grounds&); void runSolver (const Grounds&);
bool allEliminated (const Grounds&);
LiftedOperator* getBestOperation (const Grounds&);
void shatterAgainstQuery (const Grounds&);
static bool absorved ( LiftedOperator* getBestOperation (const Grounds&);
ParfactorList& pfList,
ParfactorList::iterator pfIter,
const ObservedFormula*);
public: void runWeakBayesBall (const Grounds&);
static bool proper ( void shatterAgainstQuery (const Grounds&);
const ProbFormula&,
ConstraintTree*,
const ProbFormula&,
ConstraintTree*);
static bool identical ( static Parfactors absorve (ObservedFormula&, Parfactor*);
const ProbFormula&,
ConstraintTree*,
const ProbFormula&,
ConstraintTree*);
static bool disjoint ( ParfactorList pfList_;
const ProbFormula&,
ConstraintTree*,
const ProbFormula&,
ConstraintTree*);
ParfactorList pfList_;
}; };
#endif // HORUS_FOVESOLVER_H #endif // HORUS_FOVESOLVER_H

View File

@ -1,67 +0,0 @@
#ifndef HORUS_GRAPHICALMODEL_H
#define HORUS_GRAPHICALMODEL_H
#include <sstream>
#include "VarNode.h"
#include "Distribution.h"
#include "Horus.h"
using namespace std;
struct VariableInfo
{
VariableInfo (string l, const States& sts)
{
label = l;
states = sts;
}
string label;
States states;
};
class GraphicalModel
{
public:
virtual ~GraphicalModel (void) {};
virtual VarNode* getVariableNode (VarId) const = 0;
virtual VarNodes getVariableNodes (void) const = 0;
virtual void printGraphicalModel (void) const = 0;
static void addVariableInformation (VarId vid, string label,
const States& states)
{
assert (varsInfo_.find (vid) == varsInfo_.end());
varsInfo_.insert (make_pair (vid, VariableInfo (label, states)));
}
static VariableInfo getVariableInformation (VarId vid)
{
assert (varsInfo_.find (vid) != varsInfo_.end());
return varsInfo_.find (vid)->second;
}
static bool variablesHaveInformation (void)
{
return varsInfo_.size() != 0;
}
static void clearVariablesInformation (void)
{
varsInfo_.clear();
}
static void addDistribution (unsigned id, Distribution* dist)
{
distsInfo_[id] = dist;
}
static void updateDistribution (unsigned id, const Params& params)
{
distsInfo_[id]->updateParameters (params);
}
private:
static unordered_map<VarId,VariableInfo> varsInfo_;
static unordered_map<unsigned,Distribution*> distsInfo_;
};
#endif // HORUS_GRAPHICALMODEL_H

View File

@ -84,16 +84,34 @@ HistogramSet::nrHistograms (unsigned N, unsigned R)
unsigned unsigned
HistogramSet::findIndex ( HistogramSet::findIndex (
const Histogram& hist, const Histogram& h,
const vector<Histogram>& histograms) const vector<Histogram>& hists)
{ {
vector<Histogram>::const_iterator it = std::lower_bound ( vector<Histogram>::const_iterator it = std::lower_bound (
histograms.begin(), hists.begin(), hists.end(), h, std::greater<Histogram>());
histograms.end(), assert (it != hists.end() && *it == h);
hist, return std::distance (hists.begin(), it);
std::greater<Histogram>()); }
assert (it != histograms.end() && *it == hist);
return std::distance (histograms.begin(), it);
vector<double>
HistogramSet::getNumAssigns (unsigned N, unsigned R)
{
HistogramSet hs (N, R);
unsigned N_factorial = Util::factorial (N);
unsigned H = hs.nrHistograms();
vector<double> numAssigns;
numAssigns.reserve (H);
for (unsigned h = 0; h < H; h++) {
unsigned prod = 1;
for (unsigned r = 0; r < R; r++) {
prod *= Util::factorial (hs[r]);
}
numAssigns.push_back (LogAware::tl (N_factorial / prod));
hs.nextHistogram();
}
return numAssigns;
} }

View File

@ -26,8 +26,9 @@ class HistogramSet
static unsigned nrHistograms (unsigned, unsigned); static unsigned nrHistograms (unsigned, unsigned);
static unsigned findIndex ( static unsigned findIndex (
const Histogram&, const Histogram&, const vector<Histogram>&);
const vector<Histogram>&);
static vector<double> getNumAssigns (unsigned, unsigned);
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs); friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);

View File

@ -1,17 +1,9 @@
#ifndef HORUS_HORUS_H #ifndef HORUS_HORUS_H
#define HORUS_HORUS_H #define HORUS_HORUS_H
#include <cmath>
#include <cassert>
#include <limits> #include <limits>
#include <algorithm>
#include <vector> #include <vector>
#include <unordered_map>
#include <iostream>
#include <fstream>
#include <sstream>
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \ #define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&); \ TypeName(const TypeName&); \
@ -19,55 +11,51 @@
using namespace std; using namespace std;
class VarNode; class Var;
class BayesNode;
class FgVarNode;
class FgFacNode;
class Factor; class Factor;
class VarNode;
class FacNode;
typedef vector<double> Params; typedef vector<double> Params;
typedef unsigned VarId; typedef unsigned VarId;
typedef vector<VarId> VarIds; typedef vector<VarId> VarIds;
typedef vector<VarNode*> VarNodes; typedef vector<Var*> Vars;
typedef vector<BayesNode*> BnNodeSet; typedef vector<VarNode*> VarNodes;
typedef vector<FgVarNode*> FgVarSet; typedef vector<FacNode*> FacNodes;
typedef vector<FgFacNode*> FgFacSet; typedef vector<Factor*> Factors;
typedef vector<Factor*> FactorSet; typedef vector<string> States;
typedef vector<string> States; typedef vector<unsigned> Ranges;
typedef vector<unsigned> Ranges;
namespace Globals { enum InfAlgorithms
extern bool logDomain; {
VE, // variable elimination
BP, // belief propagation
CBP // counting belief propagation
}; };
// level of debug information namespace Globals {
static const unsigned DL = 1;
static const int NO_EVIDENCE = -1; extern bool logDomain;
extern InfAlgorithms infAlgorithm;
};
namespace Constants {
// level of debug information
const unsigned DEBUG = 0;
const int NO_EVIDENCE = -1;
// number of digits to show when printing a parameter // number of digits to show when printing a parameter
static const unsigned PRECISION = 5; const unsigned PRECISION = 5;
static const bool COLLECT_STATISTICS = false; const bool COLLECT_STATS = false;
static const bool EXPORT_TO_GRAPHVIZ = false;
static const unsigned EXPORT_MINIMAL_SIZE = 100;
static const double INF = -numeric_limits<double>::infinity();
namespace InfAlgorithms {
enum InfAlgs
{
VE, // variable elimination
BN_BP, // bayesian network belief propagation
FG_BP, // factor graph belief propagation
CBP // counting bp solver
};
extern InfAlgs infAlgorithm;
}; };

View File

@ -3,197 +3,89 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "BayesNet.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "VarElimSolver.h" #include "VarElimSolver.h"
#include "BnBpSolver.h" #include "BpSolver.h"
#include "FgBpSolver.h"
#include "CbpSolver.h" #include "CbpSolver.h"
//#include "TinySet.h"
#include "LiftedUtils.h"
using namespace std; using namespace std;
void processArguments (BayesNet&, int, const char* []);
void processArguments (FactorGraph&, int, const char* []); void processArguments (FactorGraph&, int, const char* []);
void runSolver (Solver*, const VarNodes&); void runSolver (const FactorGraph&, const VarIds&);
const string USAGE = "usage: \ const string USAGE = "usage: \
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ; ./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
class Cenas
{
public:
Cenas (int cc)
{
c = cc;
}
//operator int (void) const
//{
// cout << "return int" << endl;
// return c;
//}
operator double (void) const
{
cout << "return double" << endl;
return 0.0;
}
private:
int c;
};
int int
main (int argc, const char* argv[]) main (int argc, const char* argv[])
{ {
LogVar X = 3; if (argc <= 1) {
LogVarSet Xs = X; cerr << "error: no solver specified" << endl;
cout << "set: " << X << endl;
Cenas c1 (1);
Cenas c2 (3);
cout << (c1 < c2) << endl;
return 0;
if (!argv[1]) {
cerr << "error: no graphical model specified" << endl; cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl; cerr << USAGE << endl;
exit (0); exit (0);
} }
const string& fileName = argv[1]; if (argc <= 2) {
const string& extension = fileName.substr (fileName.find_last_of ('.') + 1); cerr << "error: no graphical model specified" << endl;
if (extension == "xml") { cerr << USAGE << endl;
BayesNet bn;
bn.readFromBifFormat (argv[1]);
processArguments (bn, argc, argv);
} else if (extension == "uai") {
FactorGraph fg;
fg.readFromUaiFormat (argv[1]);
processArguments (fg, argc, argv);
} else if (extension == "fg") {
FactorGraph fg;
fg.readFromLibDaiFormat (argv[1]);
processArguments (fg, argc, argv);
} else {
cerr << "error: the graphical model must be defined either " ;
cerr << "in a xml, uai or libDAI file" << endl;
exit (0); exit (0);
} }
string solver (argv[1]);
if (solver == "ve") {
Globals::infAlgorithm = InfAlgorithms::VE;
} else if (solver == "bp") {
Globals::infAlgorithm = InfAlgorithms::BP;
} else if (solver == "cbp") {
Globals::infAlgorithm = InfAlgorithms::CBP;
} else {
cerr << "error: unknow solver `" << solver << "'" << endl ;
cerr << USAGE << endl;
exit(0);
}
string fileName (argv[2]);
string extension = fileName.substr (
fileName.find_last_of ('.') + 1);
FactorGraph fg;
if (extension == "uai") {
fg.readFromUaiFormat (fileName.c_str());
} else if (extension == "fg") {
fg.readFromLibDaiFormat (fileName.c_str());
} else {
cerr << "error: the graphical model must be defined either " ;
cerr << "in a UAI or libDAI file" << endl;
exit (0);
}
processArguments (fg, argc, argv);
return 0; return 0;
} }
void
processArguments (BayesNet& bn, int argc, const char* argv[])
{
VarNodes queryVars;
for (int i = 2; i < argc; i++) {
const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) {
BayesNode* queryVar = bn.getBayesNode (arg);
if (queryVar) {
queryVars.push_back (queryVar);
} else {
cerr << "error: there isn't a variable labeled of " ;
cerr << "`" << arg << "'" ;
cerr << endl;
bn.freeDistributions();
exit (0);
}
} else {
size_t pos = arg.find ('=');
const string& label = arg.substr (0, pos);
const string& state = arg.substr (pos + 1);
if (label.empty()) {
cerr << "error: missing left argument" << endl;
cerr << USAGE << endl;
bn.freeDistributions();
exit (0);
}
if (state.empty()) {
cerr << "error: missing right argument" << endl;
cerr << USAGE << endl;
bn.freeDistributions();
exit (0);
}
BayesNode* node = bn.getBayesNode (label);
if (node) {
if (node->isValidState (state)) {
node->setEvidence (state);
} else {
cerr << "error: `" << state << "' " ;
cerr << "is not a valid state for " ;
cerr << "`" << node->label() << "'" ;
cerr << endl;
bn.freeDistributions();
exit (0);
}
} else {
cerr << "error: there isn't a variable labeled of " ;
cerr << "`" << label << "'" ;
cerr << endl;
bn.freeDistributions();
exit (0);
}
}
}
Solver* solver = 0;
FactorGraph* fg = 0;
switch (InfAlgorithms::infAlgorithm) {
case InfAlgorithms::VE:
fg = new FactorGraph (bn);
solver = new VarElimSolver (*fg);
break;
case InfAlgorithms::BN_BP:
solver = new BnBpSolver (bn);
break;
case InfAlgorithms::FG_BP:
fg = new FactorGraph (bn);
solver = new FgBpSolver (*fg);
break;
case InfAlgorithms::CBP:
fg = new FactorGraph (bn);
solver = new CbpSolver (*fg);
break;
default:
assert (false);
}
runSolver (solver, queryVars);
delete fg;
bn.freeDistributions();
}
void void
processArguments (FactorGraph& fg, int argc, const char* argv[]) processArguments (FactorGraph& fg, int argc, const char* argv[])
{ {
VarNodes queryVars; VarIds queryIds;
for (int i = 2; i < argc; i++) { for (int i = 3; i < argc; i++) {
const string& arg = argv[i]; const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) { if (arg.find ('=') == std::string::npos) {
if (!Util::isInteger (arg)) { if (!Util::isInteger (arg)) {
cerr << "error: `" << arg << "' " ; cerr << "error: `" << arg << "' " ;
cerr << "is not a valid variable id" ; cerr << "is not a valid variable id" ;
cerr << endl; cerr << endl;
fg.freeDistributions();
exit (0); exit (0);
} }
VarId vid; VarId vid;
stringstream ss; stringstream ss;
ss << arg; ss << arg;
ss >> vid; ss >> vid;
VarNode* queryVar = fg.getFgVarNode (vid); VarNode* queryVar = fg.getVarNode (vid);
if (queryVar) { if (queryVar) {
queryVars.push_back (queryVar); queryIds.push_back (vid);
} else { } else {
cerr << "error: there isn't a variable with " ; cerr << "error: there isn't a variable with " ;
cerr << "`" << vid << "' as id" ; cerr << "`" << vid << "' as id" ;
cerr << endl; cerr << endl;
fg.freeDistributions();
exit (0); exit (0);
} }
} else { } else {
@ -201,33 +93,29 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
if (arg.substr (0, pos).empty()) { if (arg.substr (0, pos).empty()) {
cerr << "error: missing left argument" << endl; cerr << "error: missing left argument" << endl;
cerr << USAGE << endl; cerr << USAGE << endl;
fg.freeDistributions();
exit (0); exit (0);
} }
if (arg.substr (pos + 1).empty()) { if (arg.substr (pos + 1).empty()) {
cerr << "error: missing right argument" << endl; cerr << "error: missing right argument" << endl;
cerr << USAGE << endl; cerr << USAGE << endl;
fg.freeDistributions();
exit (0); exit (0);
} }
if (!Util::isInteger (arg.substr (0, pos))) { if (!Util::isInteger (arg.substr (0, pos))) {
cerr << "error: `" << arg.substr (0, pos) << "' " ; cerr << "error: `" << arg.substr (0, pos) << "' " ;
cerr << "is not a variable id" ; cerr << "is not a variable id" ;
cerr << endl; cerr << endl;
fg.freeDistributions();
exit (0); exit (0);
} }
VarId vid; VarId vid;
stringstream ss; stringstream ss;
ss << arg.substr (0, pos); ss << arg.substr (0, pos);
ss >> vid; ss >> vid;
VarNode* var = fg.getFgVarNode (vid); VarNode* var = fg.getVarNode (vid);
if (var) { if (var) {
if (!Util::isInteger (arg.substr (pos + 1))) { if (!Util::isInteger (arg.substr (pos + 1))) {
cerr << "error: `" << arg.substr (pos + 1) << "' " ; cerr << "error: `" << arg.substr (pos + 1) << "' " ;
cerr << "is not a state index" ; cerr << "is not a state index" ;
cerr << endl; cerr << endl;
fg.freeDistributions();
exit (0); exit (0);
} }
int stateIndex; int stateIndex;
@ -241,29 +129,31 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
cerr << "is not a valid state index for variable " ; cerr << "is not a valid state index for variable " ;
cerr << "`" << var->varId() << "'" ; cerr << "`" << var->varId() << "'" ;
cerr << endl; cerr << endl;
fg.freeDistributions();
exit (0); exit (0);
} }
} else { } else {
cerr << "error: there isn't a variable with " ; cerr << "error: there isn't a variable with " ;
cerr << "`" << vid << "' as id" ; cerr << "`" << vid << "' as id" ;
cerr << endl; cerr << endl;
fg.freeDistributions();
exit (0); exit (0);
} }
} }
} }
runSolver (fg, queryIds);
}
void
runSolver (const FactorGraph& fg, const VarIds& queryIds)
{
Solver* solver = 0; Solver* solver = 0;
switch (InfAlgorithms::infAlgorithm) { switch (Globals::infAlgorithm) {
case InfAlgorithms::VE: case InfAlgorithms::VE:
solver = new VarElimSolver (fg); solver = new VarElimSolver (fg);
break; break;
case InfAlgorithms::BN_BP: case InfAlgorithms::BP:
case InfAlgorithms::FG_BP: solver = new BpSolver (fg);
//cout << "here!" << endl;
//fg.printGraphicalModel();
//fg.exportToLibDaiFormat ("net.fg");
solver = new FgBpSolver (fg);
break; break;
case InfAlgorithms::CBP: case InfAlgorithms::CBP:
solver = new CbpSolver (fg); solver = new CbpSolver (fg);
@ -271,28 +161,10 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
default: default:
assert (false); assert (false);
} }
runSolver (solver, queryVars); if (queryIds.size() == 0) {
fg.freeDistributions();
}
void
runSolver (Solver* solver, const VarNodes& queryVars)
{
VarIds vids;
for (unsigned i = 0; i < queryVars.size(); i++) {
vids.push_back (queryVars[i]->varId());
}
if (queryVars.size() == 0) {
solver->runSolver();
solver->printAllPosterioris(); solver->printAllPosterioris();
} else if (queryVars.size() == 1) {
solver->runSolver();
solver->printPosterioriOf (vids[0]);
} else { } else {
solver->runSolver(); solver->printAnswer (queryIds);
solver->printJointDistributionOf (vids);
} }
delete solver; delete solver;
} }

View File

@ -7,22 +7,50 @@
#include <YapInterface.h> #include <YapInterface.h>
#include "BayesNet.h" #include "ParfactorList.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "FoveSolver.h"
#include "VarElimSolver.h" #include "VarElimSolver.h"
#include "BnBpSolver.h" #include "BpSolver.h"
#include "FgBpSolver.h"
#include "CbpSolver.h" #include "CbpSolver.h"
#include "ElimGraph.h" #include "ElimGraph.h"
#include "FoveSolver.h" #include "BayesBall.h"
#include "ParfactorList.h"
using namespace std; using namespace std;
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
Params readParameters (YAP_Term);
vector<unsigned> readUnsignedList (YAP_Term);
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
Parfactor* readParfactor (YAP_Term);
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
vector<unsigned>
readUnsignedList (YAP_Term list)
{
vector<unsigned> vec;
while (list != YAP_TermNil()) {
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
list = YAP_TailOfTerm (list);
}
return vec;
}
Params readParams (YAP_Term);
int createLiftedNetwork (void) int createLiftedNetwork (void)
@ -30,107 +58,121 @@ int createLiftedNetwork (void)
Parfactors parfactors; Parfactors parfactors;
YAP_Term parfactorList = YAP_ARG1; YAP_Term parfactorList = YAP_ARG1;
while (parfactorList != YAP_TermNil()) { while (parfactorList != YAP_TermNil()) {
YAP_Term parfactor = YAP_HeadOfTerm (parfactorList); YAP_Term pfTerm = YAP_HeadOfTerm (parfactorList);
parfactors.push_back (readParfactor (pfTerm));
// read dist id
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, parfactor));
// read the ranges
Ranges ranges;
YAP_Term rangeList = YAP_ArgOfTerm (3, parfactor);
while (rangeList != YAP_TermNil()) {
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
ranges.push_back (range);
rangeList = YAP_TailOfTerm (rangeList);
}
// read parametric random vars
ProbFormulas formulas;
unsigned count = 0;
unordered_map<YAP_Term, LogVar> lvMap;
YAP_Term pvList = YAP_ArgOfTerm (2, parfactor);
while (pvList != YAP_TermNil()) {
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
if (YAP_IsAtomTerm (formulaTerm)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
Symbol functor = LiftedUtils::getSymbol (name);
formulas.push_back (ProbFormula (functor, ranges[count]));
} else {
LogVars logVars;
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
Symbol functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
if (it != lvMap.end()) {
logVars.push_back (it->second);
} else {
unsigned newLv = lvMap.size();
lvMap[ti] = newLv;
logVars.push_back (newLv);
}
}
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
}
count ++;
pvList = YAP_TailOfTerm (pvList);
}
// read the parameters
const Params& params = readParams (YAP_ArgOfTerm (4, parfactor));
// read the constraint
Tuples tuples;
if (lvMap.size() >= 1) {
YAP_Term tupleList = YAP_ArgOfTerm (5, parfactor);
while (tupleList != YAP_TermNil()) {
YAP_Term term = YAP_HeadOfTerm (tupleList);
assert (YAP_IsApplTerm (term));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
assert (lvMap.size() == arity);
Tuple tuple (arity);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, term);
if (YAP_IsAtomTerm (ti) == false) {
cerr << "error: bad formed constraint" << endl;
abort();
}
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
tuple[i - 1] = LiftedUtils::getSymbol (name);
}
tuples.push_back (tuple);
tupleList = YAP_TailOfTerm (tupleList);
}
}
parfactors.push_back (new Parfactor (formulas, params, tuples, distId));
parfactorList = YAP_TailOfTerm (parfactorList); parfactorList = YAP_TailOfTerm (parfactorList);
} }
// LiftedUtils::printSymbolDictionary(); // LiftedUtils::printSymbolDictionary();
cout << "*******************************************************" << endl; if (Constants::DEBUG > 2) {
cout << "INITIAL PARFACTORS" << endl; // Util::printHeader ("INITIAL PARFACTORS");
cout << "*******************************************************" << endl; // for (unsigned i = 0; i < parfactors.size(); i++) {
for (unsigned i = 0; i < parfactors.size(); i++) { // parfactors[i]->print();
parfactors[i]->print(); // }
cout << endl;
} }
ParfactorList* pfList = new ParfactorList();
for (unsigned i = 0; i < parfactors.size(); i++) {
pfList->add (parfactors[i]);
}
cout << endl;
cout << "*******************************************************" << endl;
cout << "SHATTERED PARFACTORS" << endl;
cout << "*******************************************************" << endl;
pfList->shatter();
pfList->print();
// insert the evidence ParfactorList* pfList = new ParfactorList (parfactors);
ObservedFormulas obsFormulas;
YAP_Term observedList = YAP_ARG2; if (Constants::DEBUG >= 2) {
Util::printHeader ("SHATTERED PARFACTORS");
pfList->print();
}
// read evidence
ObservedFormulas* obsFormulas = new ObservedFormulas();
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
YAP_Int p = (YAP_Int) (net);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
}
Parfactor* readParfactor (YAP_Term pfTerm)
{
// read dist id
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
// read the ranges
Ranges ranges;
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
while (rangeList != YAP_TermNil()) {
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
ranges.push_back (range);
rangeList = YAP_TailOfTerm (rangeList);
}
// read parametric random vars
ProbFormulas formulas;
unsigned count = 0;
unordered_map<YAP_Term, LogVar> lvMap;
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
while (pvList != YAP_TermNil()) {
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
if (YAP_IsAtomTerm (formulaTerm)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
Symbol functor = LiftedUtils::getSymbol (name);
formulas.push_back (ProbFormula (functor, ranges[count]));
} else {
LogVars logVars;
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
Symbol functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
if (it != lvMap.end()) {
logVars.push_back (it->second);
} else {
unsigned newLv = lvMap.size();
lvMap[ti] = newLv;
logVars.push_back (newLv);
}
}
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
}
count ++;
pvList = YAP_TailOfTerm (pvList);
}
// read the parameters
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
// read the constraint
Tuples tuples;
if (lvMap.size() >= 1) {
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
while (tupleList != YAP_TermNil()) {
YAP_Term term = YAP_HeadOfTerm (tupleList);
assert (YAP_IsApplTerm (term));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
assert (lvMap.size() == arity);
Tuple tuple (arity);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, term);
if (YAP_IsAtomTerm (ti) == false) {
cerr << "error: constraint has free variables" << endl;
abort();
}
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
tuple[i - 1] = LiftedUtils::getSymbol (name);
}
tuples.push_back (tuple);
tupleList = YAP_TailOfTerm (tupleList);
}
}
return new Parfactor (formulas, params, tuples, distId);
}
void readLiftedEvidence (
YAP_Term observedList,
ObservedFormulas& obsFormulas)
{
while (observedList != YAP_TermNil()) { while (observedList != YAP_TermNil()) {
YAP_Term pair = YAP_HeadOfTerm (observedList); YAP_Term pair = YAP_HeadOfTerm (observedList);
YAP_Term ground = YAP_ArgOfTerm (1, pair); YAP_Term ground = YAP_ArgOfTerm (1, pair);
@ -155,22 +197,18 @@ int createLiftedNetwork (void)
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair)); unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
bool found = false; bool found = false;
for (unsigned i = 0; i < obsFormulas.size(); i++) { for (unsigned i = 0; i < obsFormulas.size(); i++) {
if (obsFormulas[i]->functor() == functor && if (obsFormulas[i].functor() == functor &&
obsFormulas[i]->arity() == args.size() && obsFormulas[i].arity() == args.size() &&
obsFormulas[i]->evidence() == evidence) { obsFormulas[i].evidence() == evidence) {
obsFormulas[i]->addTuple (args); obsFormulas[i].addTuple (args);
found = true; found = true;
} }
} }
if (found == false) { if (found == false) {
obsFormulas.push_back (new ObservedFormula (functor, evidence, args)); obsFormulas.push_back (ObservedFormula (functor, evidence, args));
} }
observedList = YAP_TailOfTerm (observedList); observedList = YAP_TailOfTerm (observedList);
} }
FoveSolver::absorveEvidence (*pfList, obsFormulas);
YAP_Int p = (YAP_Int) (pfList);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
} }
@ -178,93 +216,46 @@ int createLiftedNetwork (void)
int int
createGroundNetwork (void) createGroundNetwork (void)
{ {
Statistics::incrementPrimaryNetworksCounting(); string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
// cout << "creating network number " ; bool fromBayesNet = factorsType == "bayes";
// cout << Statistics::getPrimaryNetworksCounting() << endl; FactorGraph* fg = new FactorGraph (fromBayesNet);
// if (Statistics::getPrimaryNetworksCounting() > 98) { YAP_Term factorList = YAP_ARG2;
// Statistics::writeStatisticsToFile ("../../compressing.stats"); while (factorList != YAP_TermNil()) {
// } YAP_Term factor = YAP_HeadOfTerm (factorList);
BayesNet* bn = new BayesNet(); // read the var ids
YAP_Term varList = YAP_ARG1; VarIds varIds = readUnsignedList (YAP_ArgOfTerm (1, factor));
BnNodeSet nodes; // read the ranges
vector<VarIds> parents; Ranges ranges = readUnsignedList (YAP_ArgOfTerm (2, factor));
while (varList != YAP_TermNil()) { // read the parameters
YAP_Term var = YAP_HeadOfTerm (varList); Params params = readParameters (YAP_ArgOfTerm (3, factor));
VarId vid = (VarId) YAP_IntOfTerm (YAP_ArgOfTerm (1, var)); // read dist id
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var)); unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (4, factor));
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var)); fg->addFactor (Factor (varIds, ranges, params, distId));
YAP_Term parentL = YAP_ArgOfTerm (4, var); factorList = YAP_TailOfTerm (factorList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
parents.push_back (VarIds());
while (parentL != YAP_TermNil()) {
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
parents.back().push_back (parentId);
parentL = YAP_TailOfTerm (parentL);
}
Distribution* dist = bn->getDistribution (distId);
if (!dist) {
dist = new Distribution (distId);
bn->addDistribution (dist);
}
assert (bn->getBayesNode (vid) == 0);
nodes.push_back (bn->addNode (vid, dsize, evidence, dist));
varList = YAP_TailOfTerm (varList);
} }
for (unsigned i = 0; i < nodes.size(); i++) {
BnNodeSet ps; YAP_Term evidenceList = YAP_ARG3;
for (unsigned j = 0; j < parents[i].size(); j++) { while (evidenceList != YAP_TermNil()) {
assert (bn->getBayesNode (parents[i][j]) != 0); YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
ps.push_back (bn->getBayesNode (parents[i][j])); unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
} unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
nodes[i]->setParents (ps); assert (fg->getVarNode (vid));
fg->getVarNode (vid)->setEvidence (ev);
evidenceList = YAP_TailOfTerm (evidenceList);
} }
bn->setIndexes();
YAP_Int p = (YAP_Int) (bn);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
}
YAP_Int p = (YAP_Int) (fg);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4);
int
setBayesNetParams (void)
{
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distList = YAP_ARG2;
while (distList != YAP_TermNil()) {
YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
const Params params = readParams (YAP_ArgOfTerm (2, dist));
bn->getDistribution(distId)->updateParameters (params);
distList = YAP_TailOfTerm (distList);
}
return TRUE;
}
int
setParfactorGraphParams (void)
{
// FIXME
// ParfactorGraph* pfg = (ParfactorGraph*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distList = YAP_ARG2;
while (distList != YAP_TermNil()) {
// YAP_Term dist = YAP_HeadOfTerm (distList);
// unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
// const Params params = readParams (YAP_ArgOfTerm (2, dist));
// pfg->getDistribution(distId)->setData (params);
distList = YAP_TailOfTerm (distList);
}
return TRUE;
} }
Params Params
readParams (YAP_Term paramL) readParameters (YAP_Term paramL)
{ {
Params params; Params params;
while (paramL!= YAP_TermNil()) { assert (YAP_IsPairTerm (paramL));
while (paramL != YAP_TermNil()) {
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL))); params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
paramL = YAP_TailOfTerm (paramL); paramL = YAP_TailOfTerm (paramL);
} }
@ -279,15 +270,14 @@ readParams (YAP_Term paramL)
int int
runLiftedSolver (void) runLiftedSolver (void)
{ {
ParfactorList* pfList = (ParfactorList*) YAP_IntOfTerm (YAP_ARG1); LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term taskList = YAP_ARG2; YAP_Term taskList = YAP_ARG2;
vector<Params> results; vector<Params> results;
ParfactorList pfListCopy (*network->first);
FoveSolver::absorveEvidence (pfListCopy, *network->second);
while (taskList != YAP_TermNil()) { while (taskList != YAP_TermNil()) {
YAP_Term jointList = YAP_HeadOfTerm (taskList);
Grounds queryVars; Grounds queryVars;
assert (YAP_IsPairTerm (taskList)); YAP_Term jointList = YAP_HeadOfTerm (taskList);
assert (YAP_IsPairTerm (jointList));
while (jointList != YAP_TermNil()) { while (jointList != YAP_TermNil()) {
YAP_Term ground = YAP_HeadOfTerm (jointList); YAP_Term ground = YAP_HeadOfTerm (jointList);
if (YAP_IsAtomTerm (ground)) { if (YAP_IsAtomTerm (ground)) {
@ -310,11 +300,11 @@ runLiftedSolver (void)
} }
jointList = YAP_TailOfTerm (jointList); jointList = YAP_TailOfTerm (jointList);
} }
FoveSolver solver (pfList); FoveSolver solver (pfListCopy);
if (queryVars.size() == 1) { if (queryVars.size() == 1) {
results.push_back (solver.getPosterioriOf (queryVars[0])); results.push_back (solver.getPosterioriOf (queryVars[0]));
} else { } else {
assert (false); // TODO joint dist results.push_back (solver.getJointDistributionOf (queryVars));
} }
taskList = YAP_TailOfTerm (taskList); taskList = YAP_TailOfTerm (taskList);
} }
@ -339,77 +329,23 @@ runLiftedSolver (void)
int int
runOtherSolvers (void) runGroundSolver (void)
{ {
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1); FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term taskList = YAP_ARG2;
vector<VarIds> tasks; vector<VarIds> tasks;
std::set<VarId> vids; YAP_Term taskList = YAP_ARG2;
while (taskList != YAP_TermNil()) { while (taskList != YAP_TermNil()) {
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) { tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList)));
tasks.push_back (VarIds());
YAP_Term jointList = YAP_HeadOfTerm (taskList);
while (jointList != YAP_TermNil()) {
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
assert (bn->getBayesNode (vid));
tasks.back().push_back (vid);
vids.insert (vid);
jointList = YAP_TailOfTerm (jointList);
}
} else {
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
assert (bn->getBayesNode (vid));
tasks.push_back (VarIds() = {vid});
vids.insert (vid);
}
taskList = YAP_TailOfTerm (taskList); taskList = YAP_TailOfTerm (taskList);
} }
Solver* bpSolver = 0;
GraphicalModel* graphicalModel = 0;
CFactorGraph::checkForIdenticalFactors = false;
if (InfAlgorithms::infAlgorithm != InfAlgorithms::VE) {
BayesNet* mrn = bn->getMinimalRequesiteNetwork (
VarIds (vids.begin(), vids.end()));
if (InfAlgorithms::infAlgorithm == InfAlgorithms::BN_BP) {
graphicalModel = mrn;
bpSolver = new BnBpSolver (*static_cast<BayesNet*> (graphicalModel));
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::FG_BP) {
graphicalModel = new FactorGraph (*mrn);
bpSolver = new FgBpSolver (*static_cast<FactorGraph*> (graphicalModel));
delete mrn;
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::CBP) {
graphicalModel = new FactorGraph (*mrn);
bpSolver = new CbpSolver (*static_cast<FactorGraph*> (graphicalModel));
delete mrn;
}
bpSolver->runSolver();
}
vector<Params> results; vector<Params> results;
results.reserve (tasks.size()); if (Globals::infAlgorithm == InfAlgorithms::VE) {
for (unsigned i = 0; i < tasks.size(); i++) { runVeSolver (fg, tasks, results);
//if (i == 1) exit (0); } else {
if (InfAlgorithms::infAlgorithm == InfAlgorithms::VE) { runBpSolver (fg, tasks, results);
BayesNet* mrn = bn->getMinimalRequesiteNetwork (tasks[i]);
VarElimSolver* veSolver = new VarElimSolver (*mrn);
if (tasks[i].size() == 1) {
results.push_back (veSolver->getPosterioriOf (tasks[i][0]));
} else {
results.push_back (veSolver->getJointDistributionOf (tasks[i]));
}
delete mrn;
delete veSolver;
} else {
if (tasks[i].size() == 1) {
results.push_back (bpSolver->getPosterioriOf (tasks[i][0]));
} else {
results.push_back (bpSolver->getJointDistributionOf (tasks[i]));
}
}
} }
delete bpSolver;
delete graphicalModel;
YAP_Term list = YAP_TermNil(); YAP_Term list = YAP_TermNil();
for (int i = results.size() - 1; i >= 0; i--) { for (int i = results.size() - 1; i >= 0; i--) {
@ -424,32 +360,142 @@ runOtherSolvers (void)
} }
list = YAP_MkPairTerm (queryBeliefsL, list); list = YAP_MkPairTerm (queryBeliefsL, list);
} }
return YAP_Unify (list, YAP_ARG3); return YAP_Unify (list, YAP_ARG3);
} }
int void runVeSolver (
setExtraVarsInfo (void) FactorGraph* fg,
const vector<VarIds>& tasks,
vector<Params>& results)
{ {
// BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1); results.reserve (tasks.size());
GraphicalModel::clearVariablesInformation(); for (unsigned i = 0; i < tasks.size(); i++) {
YAP_Term varsInfoL = YAP_ARG2; FactorGraph* mfg = fg;
while (varsInfoL != YAP_TermNil()) { if (fg->isFromBayesNetwork()) {
YAP_Term head = YAP_HeadOfTerm (varsInfoL); mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
VarId vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
YAP_Term statesL = YAP_ArgOfTerm (3, head);
States states;
while (statesL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (statesL));
states.push_back ((char*) YAP_AtomName (atom));
statesL = YAP_TailOfTerm (statesL);
} }
GraphicalModel::addVariableInformation (vid, VarElimSolver solver (*mfg);
(char*) YAP_AtomName (label), states); results.push_back (solver.solveQuery (tasks[i]));
varsInfoL = YAP_TailOfTerm (varsInfoL); if (fg->isFromBayesNetwork()) {
delete mfg;
}
}
}
void runBpSolver (
FactorGraph* fg,
const vector<VarIds>& tasks,
vector<Params>& results)
{
std::set<VarId> vids;
for (unsigned i = 0; i < tasks.size(); i++) {
Util::addToSet (vids, tasks[i]);
}
Solver* solver = 0;
FactorGraph* mfg = fg;
if (fg->isFromBayesNetwork()) {
mfg = BayesBall::getMinimalFactorGraph (
*fg, VarIds (vids.begin(),vids.end()));
}
if (Globals::infAlgorithm == InfAlgorithms::BP) {
solver = new BpSolver (*mfg);
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
CFactorGraph::checkForIdenticalFactors = false;
solver = new CbpSolver (*mfg);
} else {
cerr << "error: unknow solver" << endl;
abort();
}
results.reserve (tasks.size());
for (unsigned i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i]));
}
if (fg->isFromBayesNetwork()) {
delete mfg;
}
delete solver;
}
int
setParfactorsParams (void)
{
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
ParfactorList* pfList = network->first;
YAP_Term distList = YAP_ARG2;
unordered_map<unsigned, Params> paramsMap;
while (distList != YAP_TermNil()) {
YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
assert (Util::contains (paramsMap, distId) == false);
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
distList = YAP_TailOfTerm (distList);
}
ParfactorList::iterator it = pfList->begin();
while (it != pfList->end()) {
assert (Util::contains (paramsMap, (*it)->distId()));
// (*it)->setParams (paramsMap[(*it)->distId()]);
++ it;
}
return TRUE;
}
int
setFactorsParams (void)
{
return TRUE; // TODO
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distList = YAP_ARG2;
unordered_map<unsigned, Params> paramsMap;
while (distList != YAP_TermNil()) {
YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
assert (Util::contains (paramsMap, distId) == false);
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
distList = YAP_TailOfTerm (distList);
}
const FacNodes& facNodes = fg->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor().distId();
assert (Util::contains (paramsMap, distId));
facNodes[i]->factor().setParams (paramsMap[distId]);
}
return TRUE;
}
int
setVarsInformation (void)
{
Var::clearVarsInfo();
YAP_Term labelsL = YAP_ARG1;
vector<string> labels;
while (labelsL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
labels.push_back ((char*) YAP_AtomName (atom));
labelsL = YAP_TailOfTerm (labelsL);
}
unsigned count = 0;
YAP_Term stateNamesL = YAP_ARG2;
while (stateNamesL != YAP_TermNil()) {
States states;
YAP_Term namesL = YAP_HeadOfTerm (stateNamesL);
while (namesL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (namesL));
states.push_back ((char*) YAP_AtomName (atom));
namesL = YAP_TailOfTerm (namesL);
}
Var::addVarInfo (count, labels[count], states);
count ++;
stateNamesL = YAP_TailOfTerm (stateNamesL);
} }
return TRUE; return TRUE;
} }
@ -463,13 +509,11 @@ setHorusFlag (void)
if (key == "inf_alg") { if (key == "inf_alg") {
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2))); string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
if ( value == "ve") { if ( value == "ve") {
InfAlgorithms::infAlgorithm = InfAlgorithms::VE; Globals::infAlgorithm = InfAlgorithms::VE;
} else if (value == "bn_bp") { } else if (value == "bp") {
InfAlgorithms::infAlgorithm = InfAlgorithms::BN_BP; Globals::infAlgorithm = InfAlgorithms::BP;
} else if (value == "fg_bp") {
InfAlgorithms::infAlgorithm = InfAlgorithms::FG_BP;
} else if (value == "cbp") { } else if (value == "cbp") {
InfAlgorithms::infAlgorithm = InfAlgorithms::CBP; Globals::infAlgorithm = InfAlgorithms::CBP;
} else { } else {
cerr << "warning: invalid value `" << value << "' " ; cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;
@ -541,21 +585,21 @@ setHorusFlag (void)
int int
freeBayesNetwork (void) freeGroundNetwork (void)
{ {
//Statistics::writeStatisticsToFile ("stats.txt"); delete (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
bn->freeDistributions();
delete bn;
return TRUE; return TRUE;
} }
int int
freeParfactorGraph (void) freeParfactors (void)
{ {
delete (ParfactorList*) YAP_IntOfTerm (YAP_ARG1); LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
delete network->first;
delete network->second;
delete network;
return TRUE; return TRUE;
} }
@ -564,15 +608,15 @@ freeParfactorGraph (void)
extern "C" void extern "C" void
init_predicates (void) init_predicates (void)
{ {
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3); YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2); YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 4);
YAP_UserCPredicate ("set_parfactor_graph_params", setParfactorGraphParams, 2); YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2); YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3); YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
YAP_UserCPredicate ("run_other_solvers", runOtherSolvers, 3); YAP_UserCPredicate ("set_factors_params", setFactorsParams, 2);
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2); YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2); YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1); YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
YAP_UserCPredicate ("free_parfactor_graph", freeParfactorGraph, 1); YAP_UserCPredicate ("free_ground_network", freeGroundNetwork, 1);
} }

View File

@ -8,11 +8,13 @@
#include <sstream> #include <sstream>
#include <iomanip> #include <iomanip>
#include "VarNode.h" #include "Var.h"
#include "Util.h" #include "Util.h"
class StatesIndexer {
class StatesIndexer
{
public: public:
StatesIndexer (const Ranges& ranges, bool calcOffsets = true) StatesIndexer (const Ranges& ranges, bool calcOffsets = true)
@ -29,14 +31,14 @@ class StatesIndexer {
} }
} }
StatesIndexer (const VarNodes& vars, bool calcOffsets = true) StatesIndexer (const Vars& vars, bool calcOffsets = true)
{ {
size_ = 1; size_ = 1;
indices_.resize (vars.size(), 0); indices_.resize (vars.size(), 0);
ranges_.reserve (vars.size()); ranges_.reserve (vars.size());
for (unsigned i = 0; i < vars.size(); i++) { for (unsigned i = 0; i < vars.size(); i++) {
ranges_.push_back (vars[i]->nrStates()); ranges_.push_back (vars[i]->range());
size_ *= vars[i]->nrStates(); size_ *= vars[i]->range();
} }
li_ = 0; li_ = 0;
if (calcOffsets) { if (calcOffsets) {
@ -134,11 +136,11 @@ class StatesIndexer {
return size_ ; return size_ ;
} }
friend ostream& operator<< (ostream &out, const StatesIndexer& idx) friend ostream& operator<< (ostream &os, const StatesIndexer& idx)
{ {
out << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ; os << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
out << idx.indices_; os << idx.indices_;
return out; return os;
} }
private: private:
@ -274,21 +276,14 @@ class MapIndexer
index_ = 0; index_ = 0;
} }
friend ostream& operator<< (ostream &out, const MapIndexer& idx) friend ostream& operator<< (ostream &os, const MapIndexer& idx)
{ {
out << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ; os << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
out << idx.indices_; os << idx.indices_;
return out; return os;
} }
private: private:
MapIndexer (const Ranges& ranges) :
ranges_(ranges),
indices_(ranges.size(), 0),
offsets_(ranges.size())
{
index_ = 0;
}
unsigned index_; unsigned index_;
bool valid_; bool valid_;
vector<unsigned> ranges_; vector<unsigned> ranges_;

View File

@ -95,26 +95,37 @@ ostream& operator<< (ostream &os, const Ground& gr)
void LogVars
ObservedFormula::addTuple (const Tuple& t) Substitution::getDiscardedLogVars (void) const
{ {
if (constr_ == 0) { LogVars discardedLvs;
LogVars lvs (arity_); set<LogVar> doneLvs;
for (unsigned i = 0; i < arity_; i++) { unordered_map<LogVar, LogVar>::const_iterator it;
lvs[i] = i; it = subs_.begin();
while (it != subs_.end()) {
if (Util::contains (doneLvs, it->second)) {
discardedLvs.push_back (it->first);
} else {
doneLvs.insert (it->second);
} }
constr_ = new ConstraintTree (lvs); it ++;
} }
constr_->addTuple (t); return discardedLvs;
} }
ostream& operator<< (ostream &os, const ObservedFormula of) ostream& operator<< (ostream &os, const Substitution& theta)
{ {
os << of.functor_ << "/" << of.arity_; unordered_map<LogVar, LogVar>::const_iterator it;
os << "|" << of.constr_->tupleSet(); os << "[" ;
os << " [evidence=" << of.evidence_ << "]"; it = theta.subs_.begin();
while (it != theta.subs_.end()) {
if (it != theta.subs_.begin()) os << ", " ;
os << it->first << "->" << it->second ;
++ it;
}
os << "]" ;
return os; return os;
} }

View File

@ -18,11 +18,17 @@ class Symbol
{ {
public: public:
Symbol (void) : id_(numeric_limits<unsigned>::max()) { } Symbol (void) : id_(numeric_limits<unsigned>::max()) { }
Symbol (unsigned id) : id_(id) { } Symbol (unsigned id) : id_(id) { }
operator unsigned (void) const { return id_; } operator unsigned (void) const { return id_; }
bool valid (void) const { return id_ != numeric_limits<unsigned>::max(); } bool valid (void) const { return id_ != numeric_limits<unsigned>::max(); }
static Symbol invalid (void) { return Symbol(); } static Symbol invalid (void) { return Symbol(); }
friend ostream& operator<< (ostream &os, const Symbol& s); friend ostream& operator<< (ostream &os, const Symbol& s);
private: private:
unsigned id_; unsigned id_;
}; };
@ -32,7 +38,9 @@ class LogVar
{ {
public: public:
LogVar (void) : id_(numeric_limits<unsigned>::max()) { } LogVar (void) : id_(numeric_limits<unsigned>::max()) { }
LogVar (unsigned id) : id_(id) { } LogVar (unsigned id) : id_(id) { }
operator unsigned (void) const { return id_; } operator unsigned (void) const { return id_; }
LogVar& operator++ (void) LogVar& operator++ (void)
@ -48,6 +56,7 @@ class LogVar
} }
friend ostream& operator<< (ostream &os, const LogVar& X); friend ostream& operator<< (ostream &os, const LogVar& X);
private: private:
unsigned id_; unsigned id_;
}; };
@ -79,8 +88,8 @@ ostream& operator<< (ostream &os, const Tuple& t);
namespace LiftedUtils { namespace LiftedUtils {
Symbol getSymbol (const string&); Symbol getSymbol (const string&);
void printSymbolDictionary (void); void printSymbolDictionary (void);
} }
@ -89,71 +98,56 @@ class Ground
{ {
public: public:
Ground (Symbol f) : functor_(f) { } Ground (Symbol f) : functor_(f) { }
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { } Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
Symbol functor (void) const { return functor_; } Symbol functor (void) const { return functor_; }
Symbols args (void) const { return args_; }
unsigned arity (void) const { return args_.size(); } Symbols args (void) const { return args_; }
bool isAtom (void) const { return args_.size() == 0; }
unsigned arity (void) const { return args_.size(); }
bool isAtom (void) const { return args_.size() == 0; }
friend ostream& operator<< (ostream &os, const Ground& gr); friend ostream& operator<< (ostream &os, const Ground& gr);
private: private:
Symbol functor_; Symbol functor_;
Symbols args_; Symbols args_;
}; };
typedef vector<Ground> Grounds; typedef vector<Ground> Grounds;
class ConstraintTree;
class ObservedFormula
{
public:
ObservedFormula (Symbol f, unsigned a, unsigned ev)
: functor_(f), arity_(a), evidence_(ev), constr_(0) { }
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(0)
{
addTuple (tuple);
}
Symbol functor (void) const { return functor_; }
unsigned arity (void) const { return arity_; }
unsigned evidence (void) const { return evidence_; }
ConstraintTree* constr (void) const { return constr_; }
bool isAtom (void) const { return arity_ == 0; }
void addTuple (const Tuple& t);
friend ostream& operator<< (ostream &os, const ObservedFormula opv);
private:
Symbol functor_;
unsigned arity_;
unsigned evidence_;
ConstraintTree* constr_;
};
typedef vector<ObservedFormula*> ObservedFormulas;
class Substitution class Substitution
{ {
public: public:
void add (LogVar X_old, LogVar X_new) void add (LogVar X_old, LogVar X_new)
{ {
assert (Util::contains (subs_, X_old) == false);
subs_.insert (make_pair (X_old, X_new)); subs_.insert (make_pair (X_old, X_new));
} }
void rename (LogVar X_old, LogVar X_new) void rename (LogVar X_old, LogVar X_new)
{ {
assert (subs_.find (X_old) != subs_.end()); assert (Util::contains (subs_, X_old));
subs_.find (X_old)->second = X_new; subs_.find (X_old)->second = X_new;
} }
LogVar newNameFor (LogVar X) const LogVar newNameFor (LogVar X) const
{ {
assert (subs_.find (X) != subs_.end()); assert (Util::contains (subs_, X));
return subs_.find (X)->second; return subs_.find (X)->second;
} }
LogVars getDiscardedLogVars (void) const;
friend ostream& operator<< (ostream &os, const Substitution& theta);
private: private:
unordered_map<LogVar, LogVar> subs_; unordered_map<LogVar, LogVar> subs_;
}; };

View File

@ -45,9 +45,8 @@ CWD=$(PWD)
HEADERS = \ HEADERS = \
$(srcdir)/GraphicalModel.h \
$(srcdir)/BayesNet.h \ $(srcdir)/BayesNet.h \
$(srcdir)/BayesNode.h \ $(srcdir)/BayesBall.h \
$(srcdir)/ElimGraph.h \ $(srcdir)/ElimGraph.h \
$(srcdir)/FactorGraph.h \ $(srcdir)/FactorGraph.h \
$(srcdir)/Factor.h \ $(srcdir)/Factor.h \
@ -55,12 +54,10 @@ HEADERS = \
$(srcdir)/ConstraintTree.h \ $(srcdir)/ConstraintTree.h \
$(srcdir)/Solver.h \ $(srcdir)/Solver.h \
$(srcdir)/VarElimSolver.h \ $(srcdir)/VarElimSolver.h \
$(srcdir)/BnBpSolver.h \ $(srcdir)/BpSolver.h \
$(srcdir)/FgBpSolver.h \
$(srcdir)/CbpSolver.h \ $(srcdir)/CbpSolver.h \
$(srcdir)/FoveSolver.h \ $(srcdir)/FoveSolver.h \
$(srcdir)/VarNode.h \ $(srcdir)/Var.h \
$(srcdir)/Distribution.h \
$(srcdir)/Indexer.h \ $(srcdir)/Indexer.h \
$(srcdir)/Parfactor.h \ $(srcdir)/Parfactor.h \
$(srcdir)/ProbFormula.h \ $(srcdir)/ProbFormula.h \
@ -69,22 +66,20 @@ HEADERS = \
$(srcdir)/LiftedUtils.h \ $(srcdir)/LiftedUtils.h \
$(srcdir)/TinySet.h \ $(srcdir)/TinySet.h \
$(srcdir)/Util.h \ $(srcdir)/Util.h \
$(srcdir)/Horus.h \ $(srcdir)/Horus.h
$(srcdir)/xmlParser/xmlParser.h
CPP_SOURCES = \ CPP_SOURCES = \
$(srcdir)/BayesNet.cpp \ $(srcdir)/BayesNet.cpp \
$(srcdir)/BayesNode.cpp \ $(srcdir)/BayesBall.cpp \
$(srcdir)/ElimGraph.cpp \ $(srcdir)/ElimGraph.cpp \
$(srcdir)/FactorGraph.cpp \ $(srcdir)/FactorGraph.cpp \
$(srcdir)/Factor.cpp \ $(srcdir)/Factor.cpp \
$(srcdir)/CFactorGraph.cpp \ $(srcdir)/CFactorGraph.cpp \
$(srcdir)/ConstraintTree.cpp \ $(srcdir)/ConstraintTree.cpp \
$(srcdir)/VarNode.cpp \ $(srcdir)/Var.cpp \
$(srcdir)/Solver.cpp \ $(srcdir)/Solver.cpp \
$(srcdir)/VarElimSolver.cpp \ $(srcdir)/VarElimSolver.cpp \
$(srcdir)/BnBpSolver.cpp \ $(srcdir)/BpSolver.cpp \
$(srcdir)/FgBpSolver.cpp \
$(srcdir)/CbpSolver.cpp \ $(srcdir)/CbpSolver.cpp \
$(srcdir)/FoveSolver.cpp \ $(srcdir)/FoveSolver.cpp \
$(srcdir)/Parfactor.cpp \ $(srcdir)/Parfactor.cpp \
@ -94,22 +89,20 @@ CPP_SOURCES = \
$(srcdir)/LiftedUtils.cpp \ $(srcdir)/LiftedUtils.cpp \
$(srcdir)/Util.cpp \ $(srcdir)/Util.cpp \
$(srcdir)/HorusYap.cpp \ $(srcdir)/HorusYap.cpp \
$(srcdir)/HorusCli.cpp \ $(srcdir)/HorusCli.cpp
$(srcdir)/xmlParser/xmlParser.cpp
OBJS = \ OBJS = \
BayesNet.o \ BayesNet.o \
BayesNode.o \ BayesBall.o \
ElimGraph.o \ ElimGraph.o \
FactorGraph.o \ FactorGraph.o \
Factor.o \ Factor.o \
CFactorGraph.o \ CFactorGraph.o \
ConstraintTree.o \ ConstraintTree.o \
VarNode.o \ Var.o \
Solver.o \ Solver.o \
VarElimSolver.o \ VarElimSolver.o \
BnBpSolver.o \ BpSolver.o \
FgBpSolver.o \
CbpSolver.o \ CbpSolver.o \
FoveSolver.o \ FoveSolver.o \
Parfactor.o \ Parfactor.o \
@ -122,17 +115,16 @@ OBJS = \
HCLI_OBJS = \ HCLI_OBJS = \
BayesNet.o \ BayesNet.o \
BayesNode.o \ BayesBall.o \
ElimGraph.o \ ElimGraph.o \
FactorGraph.o \ FactorGraph.o \
Factor.o \ Factor.o \
CFactorGraph.o \ CFactorGraph.o \
ConstraintTree.o \ ConstraintTree.o \
VarNode.o \ Var.o \
Solver.o \ Solver.o \
VarElimSolver.o \ VarElimSolver.o \
BnBpSolver.o \ BpSolver.o \
FgBpSolver.o \
CbpSolver.o \ CbpSolver.o \
FoveSolver.o \ FoveSolver.o \
Parfactor.o \ Parfactor.o \
@ -141,7 +133,6 @@ HCLI_OBJS = \
ParfactorList.o \ ParfactorList.o \
LiftedUtils.o \ LiftedUtils.o \
Util.o \ Util.o \
xmlParser/xmlParser.o \
HorusCli.o HorusCli.o
SOBJS=horus.@SO@ SOBJS=horus.@SO@
@ -154,10 +145,6 @@ all: $(SOBJS) hcli
$(CXX) -c $(CXXFLAGS) $< -o $@ $(CXX) -c $(CXXFLAGS) $< -o $@
xmlParser/xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
$(CXX) -c $(CXXFLAGS) $< -o $@
@DO_SECOND_LD@horus.@SO@: $(OBJS) @DO_SECOND_LD@horus.@SO@: $(OBJS)
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@ @DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
@ -171,7 +158,7 @@ install: all
clean: clean:
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK hcli xmlParser/*.o rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK hcli
erase_dots: erase_dots:

View File

@ -2,6 +2,7 @@
#include "Parfactor.h" #include "Parfactor.h"
#include "Histogram.h" #include "Histogram.h"
#include "Indexer.h" #include "Indexer.h"
#include "Util.h"
#include "Horus.h" #include "Horus.h"
@ -11,55 +12,58 @@ Parfactor::Parfactor (
const Tuples& tuples, const Tuples& tuples,
unsigned distId) unsigned distId)
{ {
formulas_ = formulas; args_ = formulas;
params_ = params; params_ = params;
distId_ = distId; distId_ = distId;
LogVars logVars; LogVars logVars;
for (unsigned i = 0; i < formulas_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
ranges_.push_back (formulas_[i].range()); ranges_.push_back (args_[i].range());
const LogVars& lvs = formulas_[i].logVars(); const LogVars& lvs = args_[i].logVars();
for (unsigned j = 0; j < lvs.size(); j++) { for (unsigned j = 0; j < lvs.size(); j++) {
if (std::find (logVars.begin(), logVars.end(), lvs[j]) == if (Util::contains (logVars, lvs[j]) == false) {
logVars.end()) {
logVars.push_back (lvs[j]); logVars.push_back (lvs[j]);
} }
} }
} }
constr_ = new ConstraintTree (logVars, tuples); constr_ = new ConstraintTree (logVars, tuples);
assert (params_.size() == Util::expectedSize (ranges_));
} }
Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple) Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
{ {
formulas_ = g->formulas(); args_ = g->arguments();
params_ = g->params(); params_ = g->params();
ranges_ = g->ranges(); ranges_ = g->ranges();
distId_ = g->distId(); distId_ = g->distId();
constr_ = new ConstraintTree (g->logVars(), {tuple}); constr_ = new ConstraintTree (g->logVars(), {tuple});
assert (params_.size() == Util::expectedSize (ranges_));
} }
Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr) Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
{ {
formulas_ = g->formulas(); args_ = g->arguments();
params_ = g->params(); params_ = g->params();
ranges_ = g->ranges(); ranges_ = g->ranges();
distId_ = g->distId(); distId_ = g->distId();
constr_ = constr; constr_ = constr;
assert (params_.size() == Util::expectedSize (ranges_));
} }
Parfactor::Parfactor (const Parfactor& g) Parfactor::Parfactor (const Parfactor& g)
{ {
formulas_ = g.formulas(); args_ = g.arguments();
params_ = g.params(); params_ = g.params();
ranges_ = g.ranges(); ranges_ = g.ranges();
distId_ = g.distId(); distId_ = g.distId();
constr_ = new ConstraintTree (*g.constr()); constr_ = new ConstraintTree (*g.constr());
assert (params_.size() == Util::expectedSize (ranges_));
} }
@ -75,9 +79,9 @@ LogVarSet
Parfactor::countedLogVars (void) const Parfactor::countedLogVars (void) const
{ {
LogVarSet set; LogVarSet set;
for (unsigned i = 0; i < formulas_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
if (formulas_[i].isCounting()) { if (args_[i].isCounting()) {
set.insert (formulas_[i].countedLogVar()); set.insert (args_[i].countedLogVar());
} }
} }
return set; return set;
@ -107,14 +111,14 @@ Parfactor::elimLogVars (void) const
LogVarSet LogVarSet
Parfactor::exclusiveLogVars (unsigned fIdx) const Parfactor::exclusiveLogVars (unsigned fIdx) const
{ {
assert (fIdx < formulas_.size()); assert (fIdx < args_.size());
LogVarSet remaining; LogVarSet remaining;
for (unsigned i = 0; i < formulas_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
if (i != fIdx) { if (i != fIdx) {
remaining |= formulas_[i].logVarSet(); remaining |= args_[i].logVarSet();
} }
} }
return formulas_[fIdx].logVarSet() - remaining; return args_[fIdx].logVarSet() - remaining;
} }
@ -131,44 +135,51 @@ Parfactor::setConstraintTree (ConstraintTree* newTree)
void void
Parfactor::sumOut (unsigned fIdx) Parfactor::sumOut (unsigned fIdx)
{ {
assert (fIdx < formulas_.size()); assert (fIdx < args_.size());
assert (formulas_[fIdx].contains (elimLogVars())); assert (args_[fIdx].contains (elimLogVars()));
LogVarSet excl = exclusiveLogVars (fIdx); LogVarSet excl = exclusiveLogVars (fIdx);
unsigned condCount = constr_->getConditionalCount (excl); if (args_[fIdx].isCounting()) {
Util::pow (params_, condCount); LogAware::pow (params_, constr_->getConditionalCount (
excl - args_[fIdx].countedLogVar()));
} else {
LogAware::pow (params_, constr_->getConditionalCount (excl));
}
vector<unsigned> numAssigns (ranges_[fIdx], 1); if (args_[fIdx].isCounting()) {
if (formulas_[fIdx].isCounting()) {
unsigned N = constr_->getConditionalCount ( unsigned N = constr_->getConditionalCount (
formulas_[fIdx].countedLogVar()); args_[fIdx].countedLogVar());
unsigned R = formulas_[fIdx].range(); unsigned R = args_[fIdx].range();
unsigned H = ranges_[fIdx]; vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
HistogramSet hs (N, R); StatesIndexer sindexer (ranges_, fIdx);
unsigned N_factorial = Util::factorial (N); while (sindexer.valid()) {
for (unsigned h = 0; h < H; h++) { unsigned h = sindexer[fIdx];
unsigned prod = 1; if (Globals::logDomain) {
for (unsigned r = 0; r < R; r++) { params_[sindexer] += numAssigns[h];
prod *= Util::factorial (hs[r]); } else {
params_[sindexer] *= numAssigns[h];
} }
numAssigns[h] = N_factorial / prod; ++ sindexer;
hs.nextHistogram();
} }
cout << endl;
} }
Params copy = params_; Params copy = params_;
params_.clear(); params_.clear();
params_.resize (copy.size() / ranges_[fIdx], 0.0); params_.resize (copy.size() / ranges_[fIdx], LogAware::addIdenty());
MapIndexer indexer (ranges_, fIdx); MapIndexer indexer (ranges_, fIdx);
for (unsigned i = 0; i < copy.size(); i++) { if (Globals::logDomain) {
unsigned h = indexer[fIdx]; for (unsigned i = 0; i < copy.size(); i++) {
// TODO NOT LOG DOMAIN AWARE :( params_[indexer] = Util::logSum (params_[indexer], copy[i]);
params_[indexer] += numAssigns[h] * copy[i]; ++ indexer;
++ indexer; }
} else {
for (unsigned i = 0; i < copy.size(); i++) {
params_[indexer] += copy[i];
++ indexer;
}
} }
formulas_.erase (formulas_.begin() + fIdx);
args_.erase (args_.begin() + fIdx);
ranges_.erase (ranges_.begin() + fIdx); ranges_.erase (ranges_.begin() + fIdx);
constr_->remove (excl); constr_->remove (excl);
} }
@ -179,55 +190,7 @@ void
Parfactor::multiply (Parfactor& g) Parfactor::multiply (Parfactor& g)
{ {
alignAndExponentiate (this, &g); alignAndExponentiate (this, &g);
bool sharedVars = false; TFactor<ProbFormula>::multiply (g);
vector<unsigned> g_varpos;
const ProbFormulas& g_formulas = g.formulas();
const Params& g_params = g.params();
const Ranges& g_ranges = g.ranges();
for (unsigned i = 0; i < g_formulas.size(); i++) {
int group = g_formulas[i].group();
if (indexOfFormulaWithGroup (group) == -1) {
insertDimension (g.ranges()[i]);
formulas_.push_back (g_formulas[i]);
g_varpos.push_back (formulas_.size() - 1);
} else {
sharedVars = true;
g_varpos.push_back (indexOfFormulaWithGroup (group));
}
}
if (sharedVars == false) {
unsigned count = 0;
for (unsigned i = 0; i < params_.size(); i++) {
if (Globals::logDomain) {
params_[i] += g_params[count];
} else {
params_[i] *= g_params[count];
}
count ++;
if (count >= g_params.size()) {
count = 0;
}
}
} else {
StatesIndexer indexer (ranges_, false);
while (indexer.valid()) {
unsigned g_li = 0;
unsigned prod = 1;
for (int j = g_varpos.size() - 1; j >= 0; j--) {
g_li += indexer[g_varpos[j]] * prod;
prod *= g_ranges[j];
}
if (Globals::logDomain) {
params_[indexer] += g_params[g_li];
} else {
params_[indexer] *= g_params[g_li];
}
++ indexer;
}
}
constr_->join (g.constr(), true); constr_->join (g.constr(), true);
} }
@ -236,7 +199,7 @@ Parfactor::multiply (Parfactor& g)
void void
Parfactor::countConvert (LogVar X) Parfactor::countConvert (LogVar X)
{ {
int fIdx = indexOfFormulaWithLogVar (X); int fIdx = indexOfLogVar (X);
assert (fIdx != -1); assert (fIdx != -1);
assert (constr_->isCountNormalized (X)); assert (constr_->isCountNormalized (X));
assert (constr_->getConditionalCount (X) > 1); assert (constr_->getConditionalCount (X) > 1);
@ -248,12 +211,12 @@ Parfactor::countConvert (LogVar X)
vector<Histogram> histograms = HistogramSet::getHistograms (N, R); vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
StatesIndexer indexer (ranges_); StatesIndexer indexer (ranges_);
vector<Params> summout (params_.size() / R); vector<Params> sumout (params_.size() / R);
unsigned count = 0; unsigned count = 0;
while (indexer.valid()) { while (indexer.valid()) {
summout[count].reserve (R); sumout[count].reserve (R);
for (unsigned r = 0; r < R; r++) { for (unsigned r = 0; r < R; r++) {
summout[count].push_back (params_[indexer]); sumout[count].push_back (params_[indexer]);
indexer.increment (fIdx); indexer.increment (fIdx);
} }
count ++; count ++;
@ -262,45 +225,42 @@ Parfactor::countConvert (LogVar X)
} }
params_.clear(); params_.clear();
params_.reserve (summout.size() * H); params_.reserve (sumout.size() * H);
vector<bool> mapDims (ranges_.size(), true);
ranges_[fIdx] = H; ranges_[fIdx] = H;
mapDims[fIdx] = false; MapIndexer mapIndexer (ranges_, fIdx);
MapIndexer mapIndexer (ranges_, mapDims);
while (mapIndexer.valid()) { while (mapIndexer.valid()) {
double prod = 1.0; double prod = LogAware::multIdenty();
unsigned i = mapIndexer.mappedIndex(); unsigned i = mapIndexer.mappedIndex();
unsigned h = mapIndexer[fIdx]; unsigned h = mapIndexer[fIdx];
for (unsigned r = 0; r < R; r++) { for (unsigned r = 0; r < R; r++) {
// TODO not log domain aware if (Globals::logDomain) {
prod *= Util::pow (summout[i][r], histograms[h][r]); prod += LogAware::pow (sumout[i][r], histograms[h][r]);
} else {
prod *= LogAware::pow (sumout[i][r], histograms[h][r]);
}
} }
params_.push_back (prod); params_.push_back (prod);
++ mapIndexer; ++ mapIndexer;
} }
formulas_[fIdx].setCountedLogVar (X); args_[fIdx].setCountedLogVar (X);
} }
void void
Parfactor::expandPotential ( Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
LogVar X,
LogVar X_new1,
LogVar X_new2)
{ {
int fIdx = indexOfFormulaWithLogVar (X); int fIdx = indexOfLogVar (X);
assert (fIdx != -1); assert (fIdx != -1);
assert (formulas_[fIdx].isCounting()); assert (args_[fIdx].isCounting());
unsigned N1 = constr_->getConditionalCount (X_new1); unsigned N1 = constr_->getConditionalCount (X_new1);
unsigned N2 = constr_->getConditionalCount (X_new2); unsigned N2 = constr_->getConditionalCount (X_new2);
unsigned N = N1 + N2; unsigned N = N1 + N2;
unsigned R = formulas_[fIdx].range(); unsigned R = args_[fIdx].range();
unsigned H1 = HistogramSet::nrHistograms (N1, R); unsigned H1 = HistogramSet::nrHistograms (N1, R);
unsigned H2 = HistogramSet::nrHistograms (N2, R); unsigned H2 = HistogramSet::nrHistograms (N2, R);
unsigned H = ranges_[fIdx];
vector<Histogram> histograms = HistogramSet::getHistograms (N, R); vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R); vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
@ -320,48 +280,11 @@ Parfactor::expandPotential (
} }
} }
unsigned size = (params_.size() / H) * H1 * H2; expandPotential (fIdx, H1 * H2, sumIndexes);
Params copy = params_;
params_.clear();
params_.reserve (size);
unsigned prod = 1; args_.insert (args_.begin() + fIdx + 1, args_[fIdx]);
vector<unsigned> offsets_ (ranges_.size()); args_[fIdx].rename (X, X_new1);
for (int i = ranges_.size() - 1; i >= 0; i--) { args_[fIdx + 1].rename (X, X_new2);
offsets_[i] = prod;
prod *= ranges_[i];
}
unsigned index = 0;
ranges_[fIdx] = H1 * H2;
vector<unsigned> indices (ranges_.size(), 0);
for (unsigned k = 0; k < size; k++) {
params_.push_back (copy[index]);
for (int i = ranges_.size() - 1; i >= 0; i--) {
indices[i] ++;
if (i == fIdx) {
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
index += diff * offsets_[i];
} else {
index += offsets_[i];
}
if (indices[i] != ranges_[i]) {
break;
} else {
if (i == fIdx) {
int diff = sumIndexes[0] - sumIndexes[indices[i]];
index += diff * offsets_[i];
} else {
index -= offsets_[i] * ranges_[i];
}
indices[i] = 0;
}
}
}
formulas_.insert (formulas_.begin() + fIdx + 1, formulas_[fIdx]);
formulas_[fIdx].rename (X, X_new1);
formulas_[fIdx + 1].rename (X, X_new2);
ranges_.insert (ranges_.begin() + fIdx + 1, H2); ranges_.insert (ranges_.begin() + fIdx + 1, H2);
ranges_[fIdx] = H1; ranges_[fIdx] = H1;
} }
@ -371,13 +294,12 @@ Parfactor::expandPotential (
void void
Parfactor::fullExpand (LogVar X) Parfactor::fullExpand (LogVar X)
{ {
int fIdx = indexOfFormulaWithLogVar (X); int fIdx = indexOfLogVar (X);
assert (fIdx != -1); assert (fIdx != -1);
assert (formulas_[fIdx].isCounting()); assert (args_[fIdx].isCounting());
unsigned N = constr_->getConditionalCount (X); unsigned N = constr_->getConditionalCount (X);
unsigned R = formulas_[fIdx].range(); unsigned R = args_[fIdx].range();
unsigned H = ranges_[fIdx];
vector<Histogram> originHists = HistogramSet::getHistograms (N, R); vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R); vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
@ -400,54 +322,17 @@ Parfactor::fullExpand (LogVar X)
++ indexer; ++ indexer;
} }
unsigned size = (params_.size() / H) * std::pow (R, N); expandPotential (fIdx, std::pow (R, N), sumIndexes);
Params copy = params_;
params_.clear();
params_.reserve (size);
unsigned prod = 1; ProbFormula f = args_[fIdx];
vector<unsigned> offsets_ (ranges_.size()); args_.erase (args_.begin() + fIdx);
for (int i = ranges_.size() - 1; i >= 0; i--) {
offsets_[i] = prod;
prod *= ranges_[i];
}
unsigned index = 0;
ranges_[fIdx] = std::pow (R, N);
vector<unsigned> indices (ranges_.size(), 0);
for (unsigned k = 0; k < size; k++) {
params_.push_back (copy[index]);
for (int i = ranges_.size() - 1; i >= 0; i--) {
indices[i] ++;
if (i == fIdx) {
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
index += diff * offsets_[i];
} else {
index += offsets_[i];
}
if (indices[i] != ranges_[i]) {
break;
} else {
if (i == fIdx) {
int diff = sumIndexes[0] - sumIndexes[indices[i]];
index += diff * offsets_[i];
} else {
index -= offsets_[i] * ranges_[i];
}
indices[i] = 0;
}
}
}
ProbFormula f = formulas_[fIdx];
formulas_.erase (formulas_.begin() + fIdx);
ranges_.erase (ranges_.begin() + fIdx); ranges_.erase (ranges_.begin() + fIdx);
LogVars newLvs = constr_->expand (X); LogVars newLvs = constr_->expand (X);
assert (newLvs.size() == N); assert (newLvs.size() == N);
for (unsigned i = 0 ; i < N; i++) { for (unsigned i = 0 ; i < N; i++) {
ProbFormula newFormula (f.functor(), f.logVars(), f.range()); ProbFormula newFormula (f.functor(), f.logVars(), f.range());
newFormula.rename (X, newLvs[i]); newFormula.rename (X, newLvs[i]);
formulas_.insert (formulas_.begin() + fIdx + i, newFormula); args_.insert (args_.begin() + fIdx + i, newFormula);
ranges_.insert (ranges_.begin() + fIdx + i, R); ranges_.insert (ranges_.begin() + fIdx + i, R);
} }
} }
@ -459,117 +344,43 @@ Parfactor::reorderAccordingGrounds (const Grounds& grounds)
{ {
ProbFormulas newFormulas; ProbFormulas newFormulas;
for (unsigned i = 0; i < grounds.size(); i++) { for (unsigned i = 0; i < grounds.size(); i++) {
for (unsigned j = 0; j < formulas_.size(); j++) { for (unsigned j = 0; j < args_.size(); j++) {
if (grounds[i].functor() == formulas_[j].functor() && if (grounds[i].functor() == args_[j].functor() &&
grounds[i].arity() == formulas_[j].arity()) { grounds[i].arity() == args_[j].arity()) {
constr_->moveToTop (formulas_[j].logVars()); constr_->moveToTop (args_[j].logVars());
if (constr_->containsTuple (grounds[i].args())) { if (constr_->containsTuple (grounds[i].args())) {
newFormulas.push_back (formulas_[j]); newFormulas.push_back (args_[j]);
break; break;
} }
} }
} }
assert (newFormulas.size() == i + 1); assert (newFormulas.size() == i + 1);
} }
reorderFormulas (newFormulas); reorderArguments (newFormulas);
} }
void void
Parfactor::reorderFormulas (const ProbFormulas& newFormulas) Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
{
assert (newFormulas.size() == formulas_.size());
if (newFormulas == formulas_) {
return;
}
Ranges newRanges;
vector<unsigned> positions;
for (unsigned i = 0; i < newFormulas.size(); i++) {
unsigned idx = indexOf (newFormulas[i]);
newRanges.push_back (ranges_[idx]);
positions.push_back (idx);
}
unsigned N = ranges_.size();
Params newParams (params_.size());
for (unsigned i = 0; i < params_.size(); i++) {
unsigned li = i;
// calculate vector index corresponding to linear index
vector<unsigned> vi (N);
for (int k = N-1; k >= 0; k--) {
vi[k] = li % ranges_[k];
li /= ranges_[k];
}
// convert permuted vector index to corresponding linear index
unsigned prod = 1;
unsigned new_li = 0;
for (int k = N - 1; k >= 0; k--) {
new_li += vi[positions[k]] * prod;
prod *= ranges_[positions[k]];
}
newParams[new_li] = params_[i];
}
formulas_ = newFormulas;
ranges_ = newRanges;
params_ = newParams;
}
void
Parfactor::absorveEvidence (unsigned fIdx, unsigned evidence)
{ {
int fIdx = indexOf (formula);
assert (fIdx != -1);
LogVarSet excl = exclusiveLogVars (fIdx); LogVarSet excl = exclusiveLogVars (fIdx);
assert (fIdx < formulas_.size()); assert (args_[fIdx].isCounting() == false);
assert (evidence < formulas_[fIdx].range());
assert (formulas_[fIdx].isCounting() == false);
assert (constr_->isCountNormalized (excl)); assert (constr_->isCountNormalized (excl));
LogAware::pow (params_, constr_->getConditionalCount (excl));
Util::pow (params_, constr_->getConditionalCount (excl)); TFactor<ProbFormula>::absorveEvidence (formula, evidence);
Params copy = params_;
params_.clear();
params_.reserve (copy.size() / formulas_[fIdx].range());
StatesIndexer indexer (ranges_);
for (unsigned i = 0; i < evidence; i++) {
indexer.increment (fIdx);
}
while (indexer.valid()) {
params_.push_back (copy[indexer]);
indexer.incrementExcluding (fIdx);
}
formulas_.erase (formulas_.begin() + fIdx);
ranges_.erase (ranges_.begin() + fIdx);
constr_->remove (excl); constr_->remove (excl);
} }
void
Parfactor::normalize (void)
{
Util::normalize (params_);
}
void
Parfactor::setFormulaGroup (const ProbFormula& f, int group)
{
assert (indexOf (f) != -1);
formulas_[indexOf (f)].setGroup (group);
}
void void
Parfactor::setNewGroups (void) Parfactor::setNewGroups (void)
{ {
for (unsigned i = 0; i < formulas_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
formulas_[i].setGroup (ProbFormula::getNewGroup()); args_[i].setGroup (ProbFormula::getNewGroup());
} }
} }
@ -578,14 +389,14 @@ Parfactor::setNewGroups (void)
void void
Parfactor::applySubstitution (const Substitution& theta) Parfactor::applySubstitution (const Substitution& theta)
{ {
for (unsigned i = 0; i < formulas_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
LogVars& lvs = formulas_[i].logVars(); LogVars& lvs = args_[i].logVars();
for (unsigned j = 0; j < lvs.size(); j++) { for (unsigned j = 0; j < lvs.size(); j++) {
lvs[j] = theta.newNameFor (lvs[j]); lvs[j] = theta.newNameFor (lvs[j]);
} }
if (formulas_[i].isCounting()) { if (args_[i].isCounting()) {
LogVar clv = formulas_[i].countedLogVar(); LogVar clv = args_[i].countedLogVar();
formulas_[i].setCountedLogVar (theta.newNameFor (clv)); args_[i].setCountedLogVar (theta.newNameFor (clv));
} }
} }
constr_->applySubstitution (theta); constr_->applySubstitution (theta);
@ -593,19 +404,29 @@ Parfactor::applySubstitution (const Substitution& theta)
bool int
Parfactor::containsGround (const Ground& ground) const Parfactor::findGroup (const Ground& ground) const
{ {
for (unsigned i = 0; i < formulas_.size(); i++) { int group = -1;
if (formulas_[i].functor() == ground.functor() && for (unsigned i = 0; i < args_.size(); i++) {
formulas_[i].arity() == ground.arity()) { if (args_[i].functor() == ground.functor() &&
constr_->moveToTop (formulas_[i].logVars()); args_[i].arity() == ground.arity()) {
constr_->moveToTop (args_[i].logVars());
if (constr_->containsTuple (ground.args())) { if (constr_->containsTuple (ground.args())) {
return true; group = args_[i].group();
break;
} }
} }
} }
return false; return group;
}
bool
Parfactor::containsGround (const Ground& ground) const
{
return findGroup (ground) != -1;
} }
@ -613,8 +434,8 @@ Parfactor::containsGround (const Ground& ground) const
bool bool
Parfactor::containsGroup (unsigned group) const Parfactor::containsGroup (unsigned group) const
{ {
for (unsigned i = 0; i < formulas_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
if (formulas_[i].group() == group) { if (args_[i].group() == group) {
return true; return true;
} }
} }
@ -623,30 +444,12 @@ Parfactor::containsGroup (unsigned group) const
const ProbFormula&
Parfactor::formula (unsigned fIdx) const
{
assert (fIdx < formulas_.size());
return formulas_[fIdx];
}
unsigned
Parfactor::range (unsigned fIdx) const
{
assert (fIdx < ranges_.size());
return ranges_[fIdx];
}
unsigned unsigned
Parfactor::nrFormulas (LogVar X) const Parfactor::nrFormulas (LogVar X) const
{ {
unsigned count = 0; unsigned count = 0;
for (unsigned i = 0; i < formulas_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
if (formulas_[i].contains (X)) { if (args_[i].contains (X)) {
count ++; count ++;
} }
} }
@ -656,27 +459,12 @@ Parfactor::nrFormulas (LogVar X) const
int int
Parfactor::indexOf (const ProbFormula& f) const Parfactor::indexOfLogVar (LogVar X) const
{
int idx = -1;
for (unsigned i = 0; i < formulas_.size(); i++) {
if (f == formulas_[i]) {
idx = i;
break;
}
}
return idx;
}
int
Parfactor::indexOfFormulaWithLogVar (LogVar X) const
{ {
int idx = -1; int idx = -1;
assert (nrFormulas (X) == 1); assert (nrFormulas (X) == 1);
for (unsigned i = 0; i < formulas_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
if (formulas_[i].contains (X)) { if (args_[i].contains (X)) {
idx = i; idx = i;
break; break;
} }
@ -687,11 +475,11 @@ Parfactor::indexOfFormulaWithLogVar (LogVar X) const
int int
Parfactor::indexOfFormulaWithGroup (unsigned group) const Parfactor::indexOfGroup (unsigned group) const
{ {
int pos = -1; int pos = -1;
for (unsigned i = 0; i < formulas_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
if (formulas_[i].group() == group) { if (args_[i].group() == group) {
pos = i; pos = i;
break; break;
} }
@ -704,9 +492,9 @@ Parfactor::indexOfFormulaWithGroup (unsigned group) const
vector<unsigned> vector<unsigned>
Parfactor::getAllGroups (void) const Parfactor::getAllGroups (void) const
{ {
vector<unsigned> groups (formulas_.size()); vector<unsigned> groups (args_.size());
for (unsigned i = 0; i < formulas_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
groups[i] = formulas_[i].group(); groups[i] = args_[i].group();
} }
return groups; return groups;
} }
@ -714,13 +502,13 @@ Parfactor::getAllGroups (void) const
string string
Parfactor::getHeaderString (void) const Parfactor::getLabel (void) const
{ {
stringstream ss; stringstream ss;
ss << "phi(" ; ss << "phi(" ;
for (unsigned i = 0; i < formulas_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ; if (i != 0) ss << "," ;
ss << formulas_[i]; ss << args_[i];
} }
ss << ")" ; ss << ")" ;
ConstraintTree copy (*constr_); ConstraintTree copy (*constr_);
@ -735,32 +523,37 @@ void
Parfactor::print (bool printParams) const Parfactor::print (bool printParams) const
{ {
cout << "Formulas: " ; cout << "Formulas: " ;
for (unsigned i = 0; i < formulas_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) cout << ", " ; if (i != 0) cout << ", " ;
cout << formulas_[i]; cout << args_[i];
} }
cout << endl; cout << endl;
vector<string> groups; if (args_[0].group() != Util::maxUnsigned()) {
for (unsigned i = 0; i < formulas_.size(); i++) { vector<string> groups;
groups.push_back (string ("g") + Util::toString (formulas_[i].group())); for (unsigned i = 0; i < args_.size(); i++) {
groups.push_back (string ("g") + Util::toString (args_[i].group()));
}
cout << "Groups: " << groups << endl;
} }
cout << "Groups: " << groups << endl; cout << "LogVars: " << constr_->logVarSet() << endl;
cout << "LogVars: " << constr_->logVars() << endl;
cout << "Ranges: " << ranges_ << endl; cout << "Ranges: " << ranges_ << endl;
if (printParams == false) { if (printParams == false) {
cout << "Params: " << params_ << endl; cout << "Params: " << params_ << endl;
} }
cout << "Tuples: " << constr_->tupleSet() << endl; ConstraintTree copy (*constr_);
copy.moveToTop (copy.logVarSet().elements());
cout << "Tuples: " << copy.tupleSet() << endl;
if (printParams) { if (printParams) {
vector<string> jointStrings; vector<string> jointStrings;
StatesIndexer indexer (ranges_); StatesIndexer indexer (ranges_);
while (indexer.valid()) { while (indexer.valid()) {
stringstream ss; stringstream ss;
for (unsigned i = 0; i < formulas_.size(); i++) { for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << ", " ; if (i != 0) ss << ", " ;
if (formulas_[i].isCounting()) { if (args_[i].isCounting()) {
unsigned N = constr_->getConditionalCount (formulas_[i].countedLogVar()); unsigned N = constr_->getConditionalCount (
HistogramSet hs (N, formulas_[i].range()); args_[i].countedLogVar());
HistogramSet hs (N, args_[i].range());
unsigned c = 0; unsigned c = 0;
while (c < indexer[i]) { while (c < indexer[i]) {
hs.nextHistogram(); hs.nextHistogram();
@ -779,22 +572,56 @@ Parfactor::print (bool printParams) const
cout << " = " << params_[i] << endl; cout << " = " << params_[i] << endl;
} }
} }
cout << endl;
} }
void void
Parfactor::insertDimension (unsigned range) Parfactor::expandPotential (
int fIdx,
unsigned newRange,
const vector<unsigned>& sumIndexes)
{ {
unsigned size = (params_.size() / ranges_[fIdx]) * newRange;
Params copy = params_; Params copy = params_;
params_.clear(); params_.clear();
params_.reserve (copy.size() * range); params_.reserve (size);
for (unsigned i = 0; i < copy.size(); i++) {
for (unsigned reps = 0; reps < range; reps++) { unsigned prod = 1;
params_.push_back (copy[i]); vector<unsigned> offsets_ (ranges_.size());
for (int i = ranges_.size() - 1; i >= 0; i--) {
offsets_[i] = prod;
prod *= ranges_[i];
}
unsigned index = 0;
ranges_[fIdx] = newRange;
vector<unsigned> indices (ranges_.size(), 0);
for (unsigned k = 0; k < size; k++) {
params_.push_back (copy[index]);
for (int i = ranges_.size() - 1; i >= 0; i--) {
indices[i] ++;
if (i == fIdx) {
assert (indices[i] - 1 < sumIndexes.size());
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
index += diff * offsets_[i];
} else {
index += offsets_[i];
}
if (indices[i] != ranges_[i]) {
break;
} else {
if (i == fIdx) {
int diff = sumIndexes[0] - sumIndexes[indices[i]];
index += diff * offsets_[i];
} else {
index -= offsets_[i] * ranges_[i];
}
indices[i] = 0;
}
} }
} }
ranges_.push_back (range);
} }
@ -803,29 +630,27 @@ void
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2) Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
{ {
LogVars X_1, X_2; LogVars X_1, X_2;
const ProbFormulas& formulas1 = g1->formulas(); const ProbFormulas& formulas1 = g1->arguments();
const ProbFormulas& formulas2 = g2->formulas(); const ProbFormulas& formulas2 = g2->arguments();
for (unsigned i = 0; i < formulas1.size(); i++) { for (unsigned i = 0; i < formulas1.size(); i++) {
for (unsigned j = 0; j < formulas2.size(); j++) { for (unsigned j = 0; j < formulas2.size(); j++) {
if (formulas1[i].group() == formulas2[j].group()) { if (formulas1[i].group() == formulas2[j].group()) {
X_1.insert (X_1.end(), Util::addToVector (X_1, formulas1[i].logVars());
formulas1[i].logVars().begin(), Util::addToVector (X_2, formulas2[j].logVars());
formulas1[i].logVars().end());
X_2.insert (X_2.end(),
formulas2[j].logVars().begin(),
formulas2[j].logVars().end());
} }
} }
} }
align (g1, X_1, g2, X_2);
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1); LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2); LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
assert (g1->constr()->isCountNormalized (Y_1)); assert (g1->constr()->isCountNormalized (Y_1));
assert (g2->constr()->isCountNormalized (Y_2)); assert (g2->constr()->isCountNormalized (Y_2));
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1); unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2); unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
Util::pow (g1->params(), 1.0 / condCount2); LogAware::pow (g1->params(), 1.0 / condCount2);
Util::pow (g2->params(), 1.0 / condCount1); LogAware::pow (g2->params(), 1.0 / condCount1);
// this must be done in the end or else X_1 and X_2
// will refer the old log var names in the code above
align (g1, X_1, g2, X_2);
} }
@ -838,7 +663,6 @@ Parfactor::align (
LogVar freeLogVar = 0; LogVar freeLogVar = 0;
Substitution theta1; Substitution theta1;
Substitution theta2; Substitution theta2;
const LogVarSet& allLvs1 = g1->logVarSet(); const LogVarSet& allLvs1 = g1->logVarSet();
for (unsigned i = 0; i < allLvs1.size(); i++) { for (unsigned i = 0; i < allLvs1.size(); i++) {
theta1.add (allLvs1[i], freeLogVar); theta1.add (allLvs1[i], freeLogVar);
@ -850,7 +674,7 @@ Parfactor::align (
theta2.add (allLvs2[i], freeLogVar); theta2.add (allLvs2[i], freeLogVar);
++ freeLogVar; ++ freeLogVar;
} }
assert (alignLvs1.size() == alignLvs2.size()); assert (alignLvs1.size() == alignLvs2.size());
for (unsigned i = 0; i < alignLvs1.size(); i++) { for (unsigned i = 0; i < alignLvs1.size(); i++) {
theta1.rename (alignLvs1[i], theta2.newNameFor (alignLvs2[i])); theta1.rename (alignLvs1[i], theta2.newNameFor (alignLvs2[i]));

View File

@ -9,8 +9,9 @@
#include "LiftedUtils.h" #include "LiftedUtils.h"
#include "Horus.h" #include "Horus.h"
#include "Factor.h"
class Parfactor class Parfactor : public TFactor<ProbFormula>
{ {
public: public:
Parfactor ( Parfactor (
@ -18,27 +19,15 @@ class Parfactor
const Params&, const Params&,
const Tuples&, const Tuples&,
unsigned); unsigned);
Parfactor (const Parfactor*, const Tuple&); Parfactor (const Parfactor*, const Tuple&);
Parfactor (const Parfactor*, ConstraintTree*); Parfactor (const Parfactor*, ConstraintTree*);
Parfactor (const Parfactor&); Parfactor (const Parfactor&);
~Parfactor (void); ~Parfactor (void);
ProbFormulas& formulas (void) { return formulas_; }
const ProbFormulas& formulas (void) const { return formulas_; }
unsigned nrFormulas (void) const { return formulas_.size(); }
Params& params (void) { return params_; }
const Params& params (void) const { return params_; }
unsigned size (void) const { return params_.size(); }
const Ranges& ranges (void) const { return ranges_; }
unsigned distId (void) const { return distId_; }
ConstraintTree* constr (void) { return constr_; } ConstraintTree* constr (void) { return constr_; }
const ConstraintTree* constr (void) const { return constr_; } const ConstraintTree* constr (void) const { return constr_; }
@ -57,64 +46,52 @@ class Parfactor
void setConstraintTree (ConstraintTree*); void setConstraintTree (ConstraintTree*);
void sumOut (unsigned); void sumOut (unsigned fIdx);
void multiply (Parfactor&); void multiply (Parfactor&);
void countConvert (LogVar); void countConvert (LogVar);
void expandPotential (LogVar, LogVar, LogVar); void expand (LogVar, LogVar, LogVar);
void fullExpand (LogVar); void fullExpand (LogVar);
void reorderAccordingGrounds (const Grounds&); void reorderAccordingGrounds (const Grounds&);
void reorderFormulas (const ProbFormulas&); void absorveEvidence (const ProbFormula&, unsigned);
void absorveEvidence (unsigned, unsigned);
void normalize (void);
void setFormulaGroup (const ProbFormula&, int);
void setNewGroups (void); void setNewGroups (void);
void applySubstitution (const Substitution&); void applySubstitution (const Substitution&);
int findGroup (const Ground&) const;
bool containsGround (const Ground&) const; bool containsGround (const Ground&) const;
bool containsGroup (unsigned) const; bool containsGroup (unsigned) const;
const ProbFormula& formula (unsigned) const;
unsigned range (unsigned) const;
unsigned nrFormulas (LogVar) const; unsigned nrFormulas (LogVar) const;
int indexOf (const ProbFormula&) const; int indexOfLogVar (LogVar) const;
int indexOfFormulaWithLogVar (LogVar) const; int indexOfGroup (unsigned) const;
int indexOfFormulaWithGroup (unsigned) const;
vector<unsigned> getAllGroups (void) const; vector<unsigned> getAllGroups (void) const;
void print (bool = false) const; void print (bool = false) const;
string getHeaderString (void) const; string getLabel (void) const;
private: private:
void expandPotential (int fIdx, unsigned newRange,
const vector<unsigned>& sumIndexes);
static void alignAndExponentiate (Parfactor*, Parfactor*); static void alignAndExponentiate (Parfactor*, Parfactor*);
static void align ( static void align (
Parfactor*, const LogVars&, Parfactor*, const LogVars&); Parfactor*, const LogVars&, Parfactor*, const LogVars&);
void insertDimension (unsigned);
ProbFormulas formulas_; ConstraintTree* constr_;
Ranges ranges_;
Params params_;
unsigned distId_;
ConstraintTree* constr_;
}; };

View File

@ -3,9 +3,32 @@
#include "ParfactorList.h" #include "ParfactorList.h"
ParfactorList::ParfactorList (Parfactors& pfs) ParfactorList::ParfactorList (const ParfactorList& pfList)
{ {
pfList_.insert (pfList_.end(), pfs.begin(), pfs.end()); ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
addShattered (new Parfactor (**it));
++ it;
}
}
ParfactorList::ParfactorList (const Parfactors& pfs)
{
add (pfs);
}
ParfactorList::~ParfactorList (void)
{
ParfactorList::const_iterator it = pfList_.begin();
while (it != pfList_.end()) {
delete *it;
++ it;
}
} }
@ -14,17 +37,17 @@ void
ParfactorList::add (Parfactor* pf) ParfactorList::add (Parfactor* pf)
{ {
pf->setNewGroups(); pf->setNewGroups();
pfList_.push_back (pf); addToShatteredList (pf);
} }
void void
ParfactorList::add (Parfactors& pfs) ParfactorList::add (const Parfactors& pfs)
{ {
for (unsigned i = 0; i < pfs.size(); i++) { for (unsigned i = 0; i < pfs.size(); i++) {
pfs[i]->setNewGroups(); pfs[i]->setNewGroups();
pfList_.push_back (pfs[i]); addToShatteredList (pfs[i]);
} }
} }
@ -33,7 +56,20 @@ ParfactorList::add (Parfactors& pfs)
void void
ParfactorList::addShattered (Parfactor* pf) ParfactorList::addShattered (Parfactor* pf)
{ {
assert (isAllShattered());
pfList_.push_back (pf); pfList_.push_back (pf);
assert (isAllShattered());
}
list<Parfactor*>::iterator
ParfactorList::insertShattered (
list<Parfactor*>::iterator it,
Parfactor* pf)
{
return pfList_.insert (it, pf);
assert (isAllShattered());
} }
@ -47,7 +83,7 @@ ParfactorList::remove (list<Parfactor*>::iterator it)
list<Parfactor*>::iterator list<Parfactor*>::iterator
ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it) ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
{ {
delete *it; delete *it;
return pfList_.erase (it); return pfList_.erase (it);
@ -55,58 +91,21 @@ ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it)
void bool
ParfactorList::shatter (void) ParfactorList::isAllShattered (void) const
{ {
list<Parfactor*> tempList; if (pfList_.size() <= 1) {
Parfactors newPfs; return true;
newPfs.insert (newPfs.end(), pfList_.begin(), pfList_.end()); }
while (newPfs.empty() == false) { vector<Parfactor*> pfs (pfList_.begin(), pfList_.end());
tempList.insert (tempList.end(), newPfs.begin(), newPfs.end()); for (unsigned i = 0; i < pfs.size() - 1; i++) {
newPfs.clear(); for (unsigned j = i + 1; j < pfs.size(); j++) {
list<Parfactor*>::iterator iter1 = tempList.begin(); if (isShattered (pfs[i], pfs[j]) == false) {
while (tempList.size() > 1 && iter1 != -- tempList.end()) { return false;
list<Parfactor*>::iterator iter2 = iter1;
++ iter2;
bool incIter1 = true;
while (iter2 != tempList.end()) {
assert (iter1 != iter2);
std::pair<Parfactors, Parfactors> res = shatter (
(*iter1)->formulas(), *iter1, (*iter2)->formulas(), *iter2);
bool incIter2 = true;
if (res.second.empty() == false) {
// cout << "second unshattered" << endl;
delete *iter2;
iter2 = tempList.erase (iter2);
incIter2 = false;
newPfs.insert (
newPfs.begin(), res.second.begin(), res.second.end());
}
if (res.first.empty() == false) {
// cout << "first unshattered" << endl;
delete *iter1;
iter1 = tempList.erase (iter1);
newPfs.insert (
newPfs.begin(), res.first.begin(), res.first.end());
incIter1 = false;
break;
}
if (incIter2) {
++ iter2;
}
}
if (incIter1) {
++ iter1;
} }
} }
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
// cout << "||||||||||||| SHATTERING ITERATION ||||||||||||||" << endl;
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
// printParfactors (newPfs);
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
} }
pfList_.clear(); return true;
pfList_.insert (pfList_.end(), tempList.begin(), tempList.end());
} }
@ -117,25 +116,88 @@ ParfactorList::print (void) const
list<Parfactor*>::const_iterator it; list<Parfactor*>::const_iterator it;
for (it = pfList_.begin(); it != pfList_.end(); ++it) { for (it = pfList_.begin(); it != pfList_.end(); ++it) {
(*it)->print(); (*it)->print();
cout << endl;
} }
} }
std::pair<Parfactors, Parfactors> bool
ParfactorList::shatter ( ParfactorList::isShattered (
ProbFormulas& formulas1, const Parfactor* g1,
Parfactor* g1, const Parfactor* g2) const
ProbFormulas& formulas2,
Parfactor* g2)
{ {
assert (g1 != g2);
const ProbFormulas& fms1 = g1->arguments();
const ProbFormulas& fms2 = g2->arguments();
for (unsigned i = 0; i < fms1.size(); i++) {
for (unsigned j = 0; j < fms2.size(); j++) {
if (fms1[i].group() == fms2[j].group()) {
if (identical (
fms1[i], *(g1->constr()),
fms2[j], *(g2->constr())) == false) {
return false;
}
} else {
if (disjoint (
fms1[i], *(g1->constr()),
fms2[j], *(g2->constr())) == false) {
return false;
}
}
}
}
return true;
}
void
ParfactorList::addToShatteredList (Parfactor* g)
{
queue<Parfactor*> residuals;
residuals.push (g);
while (residuals.empty() == false) {
Parfactor* pf = residuals.front();
bool pfSplitted = false;
list<Parfactor*>::iterator pfIter;
pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
std::pair<Parfactors, Parfactors> shattRes;
shattRes = shatter (*pfIter, pf);
if (shattRes.first.empty() == false) {
pfIter = removeAndDelete (pfIter);
Util::addToQueue (residuals, shattRes.first);
} else {
++ pfIter;
}
if (shattRes.second.empty() == false) {
delete pf;
Util::addToQueue (residuals, shattRes.second);
pfSplitted = true;
break;
}
}
residuals.pop();
if (pfSplitted == false) {
addShattered (pf);
}
}
assert (isAllShattered());
}
std::pair<Parfactors, Parfactors>
ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
{
ProbFormulas& formulas1 = g1->arguments();
ProbFormulas& formulas2 = g2->arguments();
assert (g1 != 0 && g2 != 0 && g1 != g2); assert (g1 != 0 && g2 != 0 && g1 != g2);
for (unsigned i = 0; i < formulas1.size(); i++) { for (unsigned i = 0; i < formulas1.size(); i++) {
for (unsigned j = 0; j < formulas2.size(); j++) { for (unsigned j = 0; j < formulas2.size(); j++) {
if (formulas1[i].sameSkeletonAs (formulas2[j])) { if (formulas1[i].sameSkeletonAs (formulas2[j])) {
std::pair<Parfactors, Parfactors> res std::pair<Parfactors, Parfactors> res;
= shatter (formulas1[i], g1, formulas2[j], g2); res = shatter (i, g1, j, g2);
if (res.first.empty() == false || if (res.first.empty() == false ||
res.second.empty() == false) { res.second.empty() == false) {
return res; return res;
@ -150,21 +212,22 @@ ParfactorList::shatter (
std::pair<Parfactors, Parfactors> std::pair<Parfactors, Parfactors>
ParfactorList::shatter ( ParfactorList::shatter (
ProbFormula& f1, unsigned fIdx1, Parfactor* g1,
Parfactor* g1, unsigned fIdx2, Parfactor* g2)
ProbFormula& f2,
Parfactor* g2)
{ {
ProbFormula& f1 = g1->argument (fIdx1);
ProbFormula& f2 = g2->argument (fIdx2);
// cout << endl; // cout << endl;
// cout << "-------------------------------------------------" << endl; // Util::printDashedLine();
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl; // cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
// g1->print(); // g1->print();
// cout << "-> WITH" << endl; // cout << "-> WITH" << endl;
// g2->print(); // g2->print();
// cout << "-> ON: " << f1.toString (g1->constr()) << endl; // cout << "-> ON: " << f1 << "|" ;
// cout << "-> ON: " << f2.toString (g2->constr()) << endl; // cout << g1->constr()->tupleSet (f1.logVars()) << endl;
// cout << "-------------------------------------------------" << endl; // cout << "-> ON: " << f2 << "|" ;
// cout << g2->constr()->tupleSet (f2.logVars()) << endl;
// Util::printDashedLine();
if (f1.isAtom()) { if (f1.isAtom()) {
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group(); unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
f1.setGroup (group); f1.setGroup (group);
@ -174,7 +237,7 @@ ParfactorList::shatter (
assert (g1->constr()->empty() == false); assert (g1->constr()->empty() == false);
assert (g2->constr()->empty() == false); assert (g2->constr()->empty() == false);
if (f1.group() == f2.group()) { if (f1.group() == f2.group()) {
// assert (identical (f1, g1->constr(), f2, g2->constr())); assert (identical (f1, *(g1->constr()), f2, *(g2->constr())));
return { }; return { };
} }
@ -201,21 +264,24 @@ ParfactorList::shatter (
assert (commCt1->tupleSet (f1.arity()) == assert (commCt1->tupleSet (f1.arity()) ==
commCt2->tupleSet (f2.arity())); commCt2->tupleSet (f2.arity()));
// stringstream ss1; ss1 << "" << count << "_A.dot" ; // unsigned static count = 0; count ++;
// stringstream ss2; ss2 << "" << count << "_B.dot" ; // stringstream ss1; ss1 << "" << count << "_A.dot" ;
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ; // stringstream ss2; ss2 << "" << count << "_B.dot" ;
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ; // stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ; // stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ; // stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
// ct1->exportToGraphViz (ss1.str().c_str(), true); // stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
// ct2->exportToGraphViz (ss2.str().c_str(), true); // g1->constr()->exportToGraphViz (ss1.str().c_str(), true);
// commCt1->exportToGraphViz (ss3.str().c_str(), true); // g2->constr()->exportToGraphViz (ss2.str().c_str(), true);
// exclCt1->exportToGraphViz (ss4.str().c_str(), true); // commCt1->exportToGraphViz (ss3.str().c_str(), true);
// commCt2->exportToGraphViz (ss5.str().c_str(), true); // exclCt1->exportToGraphViz (ss4.str().c_str(), true);
// exclCt2->exportToGraphViz (ss6.str().c_str(), true); // commCt2->exportToGraphViz (ss5.str().c_str(), true);
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
if (exclCt1->empty() && exclCt2->empty()) { if (exclCt1->empty() && exclCt2->empty()) {
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group(); unsigned group = (f1.group() < f2.group())
? f1.group()
: f2.group();
// identical // identical
f1.setGroup (group); f1.setGroup (group);
f2.setGroup (group); f2.setGroup (group);
@ -235,8 +301,8 @@ ParfactorList::shatter (
} else { } else {
group = ProbFormula::getNewGroup(); group = ProbFormula::getNewGroup();
} }
Parfactors res1 = shatter (g1, f1, commCt1, exclCt1, group); Parfactors res1 = shatter (g1, fIdx1, commCt1, exclCt1, group);
Parfactors res2 = shatter (g2, f2, commCt2, exclCt2, group); Parfactors res2 = shatter (g2, fIdx2, commCt2, exclCt2, group);
return make_pair (res1, res2); return make_pair (res1, res2);
} }
@ -245,11 +311,19 @@ ParfactorList::shatter (
Parfactors Parfactors
ParfactorList::shatter ( ParfactorList::shatter (
Parfactor* g, Parfactor* g,
const ProbFormula& f, unsigned fIdx,
ConstraintTree* commCt, ConstraintTree* commCt,
ConstraintTree* exclCt, ConstraintTree* exclCt,
unsigned commGroup) unsigned commGroup)
{ {
ProbFormula& f = g->argument (fIdx);
if (exclCt->empty()) {
delete commCt;
delete exclCt;
f.setGroup (commGroup);
return { };
}
Parfactors result; Parfactors result;
if (f.isCounting()) { if (f.isCounting()) {
LogVar X_new1 = g->constr()->logVarSet().back() + 1; LogVar X_new1 = g->constr()->logVarSet().back() + 1;
@ -259,7 +333,7 @@ ParfactorList::shatter (
for (unsigned i = 0; i < cts.size(); i++) { for (unsigned i = 0; i < cts.size(); i++) {
Parfactor* newPf = new Parfactor (g, cts[i]); Parfactor* newPf = new Parfactor (g, cts[i]);
if (cts[i]->nrLogVars() == g->constr()->nrLogVars() + 1) { if (cts[i]->nrLogVars() == g->constr()->nrLogVars() + 1) {
newPf->expandPotential (f.countedLogVar(), X_new1, X_new2); newPf->expand (f.countedLogVar(), X_new1, X_new2);
assert (g->constr()->getConditionalCount (f.countedLogVar()) == assert (g->constr()->getConditionalCount (f.countedLogVar()) ==
cts[i]->getConditionalCount (X_new1) + cts[i]->getConditionalCount (X_new1) +
cts[i]->getConditionalCount (X_new2)); cts[i]->getConditionalCount (X_new2));
@ -270,20 +344,16 @@ ParfactorList::shatter (
newPf->setNewGroups(); newPf->setNewGroups();
result.push_back (newPf); result.push_back (newPf);
} }
delete commCt;
delete exclCt;
} else { } else {
if (exclCt->empty()) { Parfactor* newPf = new Parfactor (g, commCt);
delete commCt; newPf->setNewGroups();
delete exclCt; newPf->argument (fIdx).setGroup (commGroup);
g->setFormulaGroup (f, commGroup); result.push_back (newPf);
} else { newPf = new Parfactor (g, exclCt);
Parfactor* newPf = new Parfactor (g, commCt); newPf->setNewGroups();
newPf->setNewGroups(); result.push_back (newPf);
newPf->setFormulaGroup (f, commGroup);
result.push_back (newPf);
newPf = new Parfactor (g, exclCt);
newPf->setNewGroups();
result.push_back (newPf);
}
} }
return result; return result;
} }
@ -296,7 +366,7 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
unsigned newGroup = ProbFormula::getNewGroup(); unsigned newGroup = ProbFormula::getNewGroup();
for (ParfactorList::iterator it = pfList_.begin(); for (ParfactorList::iterator it = pfList_.begin();
it != pfList_.end(); it++) { it != pfList_.end(); it++) {
ProbFormulas& formulas = (*it)->formulas(); ProbFormulas& formulas = (*it)->arguments();
for (unsigned i = 0; i < formulas.size(); i++) { for (unsigned i = 0; i < formulas.size(); i++) {
if (formulas[i].group() == group1 || if (formulas[i].group() == group1 ||
formulas[i].group() == group2) { formulas[i].group() == group2) {
@ -306,3 +376,52 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
} }
} }
bool
ParfactorList::proper (
const ProbFormula& f1, ConstraintTree c1,
const ProbFormula& f2, ConstraintTree c2) const
{
return disjoint (f1, c1, f2, c2)
|| identical (f1, c1, f2, c2);
}
bool
ParfactorList::identical (
const ProbFormula& f1, ConstraintTree c1,
const ProbFormula& f2, ConstraintTree c2) const
{
if (f1.sameSkeletonAs (f2) == false) {
return false;
}
if (f1.isAtom()) {
return true;
}
c1.moveToTop (f1.logVars());
c2.moveToTop (f2.logVars());
return ConstraintTree::identical (
&c1, &c2, f1.logVars().size());
}
bool
ParfactorList::disjoint (
const ProbFormula& f1, ConstraintTree c1,
const ProbFormula& f2, ConstraintTree c2) const
{
if (f1.sameSkeletonAs (f2) == false) {
return true;
}
if (f1.isAtom()) {
return true;
}
c1.moveToTop (f1.logVars());
c2.moveToTop (f2.logVars());
return ConstraintTree::overlap (
&c1, &c2, f1.arity()) == false;
}

View File

@ -2,6 +2,7 @@
#define HORUS_PARFACTORLIST_H #define HORUS_PARFACTORLIST_H
#include <list> #include <list>
#include <queue>
#include "Parfactor.h" #include "Parfactor.h"
#include "ProbFormula.h" #include "ProbFormula.h"
@ -14,56 +15,82 @@ class ParfactorList
{ {
public: public:
ParfactorList (void) { } ParfactorList (void) { }
ParfactorList (Parfactors&);
list<Parfactor*>& getParfactors (void) { return pfList_; }
const list<Parfactor*>& getParfactors (void) const { return pfList_; }
void add (Parfactor* pf); ParfactorList (const ParfactorList&);
void add (Parfactors& pfs);
void addShattered (Parfactor* pf);
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
list<Parfactor*>::iterator deleteAndRemove (list<Parfactor*>::iterator);
void clear (void) { pfList_.clear(); } ParfactorList (const Parfactors&);
unsigned size (void) const { return pfList_.size(); }
void shatter (void); ~ParfactorList (void);
const list<Parfactor*>& parfactors (void) const { return pfList_; }
void clear (void) { pfList_.clear(); }
unsigned size (void) const { return pfList_.size(); }
typedef std::list<Parfactor*>::iterator iterator; typedef std::list<Parfactor*>::iterator iterator;
iterator begin (void) { return pfList_.begin(); } iterator begin (void) { return pfList_.begin(); }
iterator end (void) { return pfList_.end(); }
iterator end (void) { return pfList_.end(); }
typedef std::list<Parfactor*>::const_iterator const_iterator; typedef std::list<Parfactor*>::const_iterator const_iterator;
const_iterator begin (void) const { return pfList_.begin(); } const_iterator begin (void) const { return pfList_.begin(); }
const_iterator end (void) const { return pfList_.end(); }
const_iterator end (void) const { return pfList_.end(); }
void add (Parfactor* pf);
void add (const Parfactors& pfs);
void addShattered (Parfactor* pf);
list<Parfactor*>::iterator insertShattered (
list<Parfactor*>::iterator, Parfactor*);
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
list<Parfactor*>::iterator removeAndDelete (list<Parfactor*>::iterator);
bool isAllShattered (void) const;
void print (void) const; void print (void) const;
private: private:
bool isShattered (const Parfactor*, const Parfactor*) const;
static std::pair<Parfactors, Parfactors> shatter ( void addToShatteredList (Parfactor*);
ProbFormulas&,
Parfactor*, std::pair<Parfactors, Parfactors> shatter (
ProbFormulas&, Parfactor*, Parfactor*);
Parfactor*);
static std::pair<Parfactors, Parfactors> shatter ( std::pair<Parfactors, Parfactors> shatter (
ProbFormula&, unsigned, Parfactor*, unsigned, Parfactor*);
Parfactor*,
ProbFormula&,
Parfactor*);
static Parfactors shatter ( Parfactors shatter (
Parfactor*, Parfactor*,
const ProbFormula&, unsigned,
ConstraintTree*, ConstraintTree*,
ConstraintTree*, ConstraintTree*,
unsigned); unsigned);
void unifyGroups (unsigned group1, unsigned group2); void unifyGroups (unsigned group1, unsigned group2);
list<Parfactor*> pfList_; bool proper (
const ProbFormula&, ConstraintTree,
const ProbFormula&, ConstraintTree) const;
bool identical (
const ProbFormula&, ConstraintTree,
const ProbFormula&, ConstraintTree) const;
bool disjoint (
const ProbFormula&, ConstraintTree,
const ProbFormula&, ConstraintTree) const;
list<Parfactor*> pfList_;
}; };
#endif // HORUS_PARFACTORLIST_H #endif // HORUS_PARFACTORLIST_H

View File

@ -16,8 +16,7 @@ ProbFormula::sameSkeletonAs (const ProbFormula& f) const
bool bool
ProbFormula::contains (LogVar lv) const ProbFormula::contains (LogVar lv) const
{ {
return std::find (logVars_.begin(), logVars_.end(), lv) != return Util::contains (logVars_, lv);
logVars_.end();
} }
@ -77,16 +76,15 @@ ProbFormula::rename (LogVar oldName, LogVar newName)
} }
bool operator== (const ProbFormula& f1, const ProbFormula& f2)
bool
ProbFormula::operator== (const ProbFormula& f) const
{ {
return functor_ == f.functor_ && logVars_ == f.logVars_ ; return f1.group_ == f2.group_;
//return functor_ == f.functor_ && logVars_ == f.logVars_ ;
} }
ostream& operator<< (ostream &os, const ProbFormula& f) std::ostream& operator<< (ostream &os, const ProbFormula& f)
{ {
os << f.functor_; os << f.functor_;
if (f.isAtom() == false) { if (f.isAtom() == false) {
@ -113,3 +111,13 @@ ProbFormula::getNewGroup (void)
return freeGroup_; return freeGroup_;
} }
ostream& operator<< (ostream &os, const ObservedFormula& of)
{
os << of.functor_ << "/" << of.arity_;
os << "|" << of.constr_.tupleSet();
os << " [evidence=" << of.evidence_ << "]";
return os;
}

View File

@ -8,14 +8,16 @@
#include "Horus.h" #include "Horus.h"
class ProbFormula class ProbFormula
{ {
public: public:
ProbFormula (Symbol f, const LogVars& lvs, unsigned range) ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
: functor_(f), logVars_(lvs), range_(range), : functor_(f), logVars_(lvs), range_(range),
countedLogVar_() { } countedLogVar_(), group_(Util::maxUnsigned()) { }
ProbFormula (Symbol f, unsigned r) : functor_(f), range_(r) { } ProbFormula (Symbol f, unsigned r)
: functor_(f), range_(r), group_(Util::maxUnsigned()) { }
Symbol functor (void) const { return functor_; } Symbol functor (void) const { return functor_; }
@ -29,9 +31,9 @@ class ProbFormula
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); } LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
unsigned group (void) const { return groupId_; } unsigned group (void) const { return group_; }
void setGroup (unsigned g) { groupId_ = g; } void setGroup (unsigned g) { group_ = g; }
bool sameSkeletonAs (const ProbFormula&) const; bool sameSkeletonAs (const ProbFormula&) const;
@ -49,23 +51,58 @@ class ProbFormula
void rename (LogVar, LogVar); void rename (LogVar, LogVar);
bool operator== (const ProbFormula& f) const;
friend ostream& operator<< (ostream &out, const ProbFormula& f);
static unsigned getNewGroup (void); static unsigned getNewGroup (void);
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
private: private:
Symbol functor_; Symbol functor_;
LogVars logVars_; LogVars logVars_;
unsigned range_; unsigned range_;
LogVar countedLogVar_; LogVar countedLogVar_;
unsigned groupId_; unsigned group_;
static int freeGroup_; static int freeGroup_;
}; };
typedef vector<ProbFormula> ProbFormulas; typedef vector<ProbFormula> ProbFormulas;
class ObservedFormula
{
public:
ObservedFormula (Symbol f, unsigned a, unsigned ev)
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
{
constr_.addTuple (tuple);
}
Symbol functor (void) const { return functor_; }
unsigned arity (void) const { return arity_; }
unsigned evidence (void) const { return evidence_; }
ConstraintTree& constr (void) { return constr_; }
bool isAtom (void) const { return arity_ == 0; }
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
private:
Symbol functor_;
unsigned arity_;
unsigned evidence_;
ConstraintTree constr_;
};
typedef vector<ObservedFormula> ObservedFormulas;
#endif // HORUS_PROBFORMULA_H #endif // HORUS_PROBFORMULA_H

View File

@ -3,51 +3,35 @@
void void
Solver::printAllPosterioris (void) Solver::printAnswer (const VarIds& vids)
{ {
const VarNodes& vars = gm_->getVariableNodes(); Vars unobservedVars;
for (unsigned i = 0; i < vars.size(); i++) { VarIds unobservedVids;
printPosterioriOf (vars[i]->varId());
}
}
void
Solver::printPosterioriOf (VarId vid)
{
VarNode* var = gm_->getVariableNode (vid);
const Params& posterioriDist = getPosterioriOf (vid);
const States& states = var->states();
for (unsigned i = 0; i < states.size(); i++) {
cout << "P(" << var->label() << "=" << states[i] << ") = " ;
cout << setprecision (PRECISION) << posterioriDist[i];
cout << endl;
}
cout << endl;
}
void
Solver::printJointDistributionOf (const VarIds& vids)
{
VarNodes vars;
VarIds vidsWithoutEvidence;
for (unsigned i = 0; i < vids.size(); i++) { for (unsigned i = 0; i < vids.size(); i++) {
VarNode* var = gm_->getVariableNode (vids[i]); VarNode* vn = fg.getVarNode (vids[i]);
if (var->hasEvidence() == false) { if (vn->hasEvidence() == false) {
vars.push_back (var); unobservedVars.push_back (vn);
vidsWithoutEvidence.push_back (vids[i]); unobservedVids.push_back (vids[i]);
} }
} }
const Params& jointDist = getJointDistributionOf (vidsWithoutEvidence); Params res = solveQuery (unobservedVids);
vector<string> jointStrings = Util::getJointStateStrings (vars); vector<string> stateLines = Util::getStateLines (unobservedVars);
for (unsigned i = 0; i < jointDist.size(); i++) { for (unsigned i = 0; i < res.size(); i++) {
cout << "P(" << jointStrings[i] << ") = " ; cout << "P(" << stateLines[i] << ") = " ;
cout << setprecision (PRECISION) << jointDist[i]; cout << std::setprecision (Constants::PRECISION) << res[i];
cout << endl; cout << endl;
} }
cout << endl; cout << endl;
} }
void
Solver::printAllPosterioris (void)
{
const VarNodes& vars = fg.varNodes();
for (unsigned i = 0; i < vars.size(); i++) {
printAnswer ({vars[i]->varId()});
}
}

View File

@ -3,29 +3,27 @@
#include <iomanip> #include <iomanip>
#include "GraphicalModel.h" #include "Var.h"
#include "VarNode.h" #include "FactorGraph.h"
using namespace std; using namespace std;
class Solver class Solver
{ {
public: public:
Solver (const GraphicalModel* gm) Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
{
gm_ = gm; virtual ~Solver() { } // ensure that subclass destructor is called
}
virtual ~Solver() {} // to ensure that subclass destructor is called virtual Params solveQuery (VarIds queryVids) = 0;
virtual void runSolver (void) = 0;
virtual Params getPosterioriOf (VarId) = 0; void printAnswer (const VarIds& vids);
virtual Params getJointDistributionOf (const VarIds&) = 0;
void printAllPosterioris (void); void printAllPosterioris (void);
void printPosterioriOf (VarId vid);
void printJointDistributionOf (const VarIds& vids);
private: protected:
const GraphicalModel* gm_; const FactorGraph& fg;
}; };
#endif // HORUS_SOLVER_H #endif // HORUS_SOLVER_H

View File

@ -0,0 +1,4 @@
TODO
- add a way to calculate combinations and factorials with large numbers
- refactor sumOut in parfactor -> is really ugly code
- Indexer: start receiving ranges as constant reference

View File

@ -1,21 +1,22 @@
#include <limits>
#include <sstream> #include <sstream>
#include <fstream>
#include "Util.h" #include "Util.h"
#include "Indexer.h" #include "Indexer.h"
#include "GraphicalModel.h"
namespace Globals { namespace Globals {
bool logDomain = false; bool logDomain = false;
//InfAlgs infAlgorithm = InfAlgorithms::VE;
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
}; };
namespace InfAlgorithms {
//InfAlgs infAlgorithm = InfAlgorithms::VE;
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
//InfAlgs infAlgorithm = InfAlgorithms::CBP;
}
namespace BpOptions { namespace BpOptions {
@ -23,13 +24,11 @@ Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM; //Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
//Schedule schedule = BpOptions::Schedule::PARALLEL; //Schedule schedule = BpOptions::Schedule::PARALLEL;
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL; //Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
double accuracy = 0.0001; double accuracy = 0.0001;
unsigned maxIter = 1000; unsigned maxIter = 1000;
} }
unordered_map<VarId,VariableInfo> GraphicalModel::varsInfo_;
unordered_map<unsigned,Distribution*> GraphicalModel::distsInfo_;
vector<NetInfo> Statistics::netInfo_; vector<NetInfo> Statistics::netInfo_;
vector<CompressInfo> Statistics::compressInfo_; vector<CompressInfo> Statistics::compressInfo_;
@ -58,76 +57,6 @@ fromLog (Params& v)
void
normalize (Params& v)
{
double sum;
if (Globals::logDomain) {
sum = addIdenty();
for (unsigned i = 0; i < v.size(); i++) {
logSum (sum, v[i]);
}
assert (sum != -numeric_limits<double>::infinity());
for (unsigned i = 0; i < v.size(); i++) {
v[i] -= sum;
}
} else {
sum = 0.0;
for (unsigned i = 0; i < v.size(); i++) {
sum += v[i];
}
assert (sum != 0.0);
for (unsigned i = 0; i < v.size(); i++) {
v[i] /= sum;
}
}
}
void
pow (Params& v, double expoent)
{
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
}
void
pow (Params& v, unsigned expoent)
{
if (expoent == 1) {
return;
}
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
}
double
pow (double p, unsigned expoent)
{
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
}
double double
factorial (double num) factorial (double num)
{ {
@ -153,6 +82,151 @@ nrCombinations (unsigned n, unsigned r)
unsigned
expectedSize (const Ranges& ranges)
{
unsigned prod = 1;
for (unsigned i = 0; i < ranges.size(); i++) {
prod *= ranges[i];
}
return prod;
}
unsigned
getNumberOfDigits (int number)
{
unsigned count = 1;
while (number >= 10) {
number /= 10;
count ++;
}
return count;
}
bool
isInteger (const string& s)
{
stringstream ss1 (s);
stringstream ss2;
int integer;
ss1 >> integer;
ss2 << integer;
return (ss1.str() == ss2.str());
}
string
parametersToString (const Params& v, unsigned precision)
{
stringstream ss;
ss.precision (precision);
ss << "[" ;
for (unsigned i = 0; i < v.size(); i++) {
if (i != 0) ss << ", " ;
ss << v[i];
}
ss << "]" ;
return ss.str();
}
vector<string>
getStateLines (const Vars& vars)
{
StatesIndexer idx (vars);
vector<string> jointStrings;
while (idx.valid()) {
stringstream ss;
for (unsigned i = 0; i < vars.size(); i++) {
if (i != 0) ss << ", " ;
ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])];
}
jointStrings.push_back (ss.str());
++ idx;
}
return jointStrings;
}
void
printHeader (string header, std::ostream& os)
{
printAsteriskLine (os);
os << header << endl;
printAsteriskLine (os);
}
void
printSubHeader (string header, std::ostream& os)
{
printDashedLine (os);
os << header << endl;
printDashedLine (os);
}
void
printAsteriskLine (std::ostream& os)
{
os << "********************************" ;
os << "********************************" ;
os << endl;
}
void
printDashedLine (std::ostream& os)
{
os << "--------------------------------" ;
os << "--------------------------------" ;
os << endl;
}
}
namespace LogAware {
void
normalize (Params& v)
{
double sum;
if (Globals::logDomain) {
sum = LogAware::addIdenty();
for (unsigned i = 0; i < v.size(); i++) {
sum = Util::logSum (sum, v[i]);
}
assert (sum != -numeric_limits<double>::infinity());
for (unsigned i = 0; i < v.size(); i++) {
v[i] -= sum;
}
} else {
sum = 0.0;
for (unsigned i = 0; i < v.size(); i++) {
sum += v[i];
}
assert (sum != 0.0);
for (unsigned i = 0; i < v.size(); i++) {
v[i] /= sum;
}
}
}
double double
getL1Distance (const Params& v1, const Params& v2) getL1Distance (const Params& v1, const Params& v2)
{ {
@ -196,67 +270,57 @@ getMaxNorm (const Params& v1, const Params& v2)
} }
double
pow (double p, unsigned expoent)
{
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
}
unsigned
getNumberOfDigits (int number) {
unsigned count = 1; double
while (number >= 10) { pow (double p, double expoent)
number /= 10; {
count ++; // assumes that `expoent' is never in log domain
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
}
void
pow (Params& v, unsigned expoent)
{
if (expoent == 1) {
return;
} }
return count; if (Globals::logDomain) {
} for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
bool for (unsigned i = 0; i < v.size(); i++) {
isInteger (const string& s) v[i] = std::pow (v[i], expoent);
{
stringstream ss1 (s);
stringstream ss2;
int integer;
ss1 >> integer;
ss2 << integer;
return (ss1.str() == ss2.str());
}
string
parametersToString (const Params& v, unsigned precision)
{
stringstream ss;
ss.precision (precision);
ss << "[" ;
for (unsigned i = 0; i < v.size(); i++) {
if (i != 0) ss << ", " ;
ss << v[i];
}
ss << "]" ;
return ss.str();
}
vector<string>
getJointStateStrings (const VarNodes& vars)
{
StatesIndexer idx (vars);
vector<string> jointStrings;
while (idx.valid()) {
stringstream ss;
for (unsigned i = 0; i < vars.size(); i++) {
if (i != 0) ss << ", " ;
ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])];
} }
jointStrings.push_back (ss.str());
++ idx;
} }
return jointStrings;
} }
void
pow (Params& v, double expoent)
{
// assumes that `expoent' is never in log domain
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
}
} }
@ -286,8 +350,11 @@ Statistics::getPrimaryNetworksCounting (void)
void void
Statistics::updateStatistics (unsigned size, bool loopy, Statistics::updateStatistics (
unsigned nIters, double time) unsigned size,
bool loopy,
unsigned nIters,
double time)
{ {
netInfo_.push_back (NetInfo (size, loopy, nIters, time)); netInfo_.push_back (NetInfo (size, loopy, nIters, time));
} }
@ -303,12 +370,12 @@ Statistics::printStatistics (void)
void void
Statistics::writeStatisticsToFile (const char* fileName) Statistics::writeStatistics (const char* fileName)
{ {
ofstream out (fileName); ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "error: cannot open file to write at " ; cerr << "error: cannot open file to write at " ;
cerr << "Statistics::writeStatisticsToFile()" << endl; cerr << "Statistics::writeStats()" << endl;
abort(); abort();
} }
out << getStatisticString(); out << getStatisticString();
@ -318,13 +385,14 @@ Statistics::writeStatisticsToFile (const char* fileName)
void void
Statistics::updateCompressingStatistics (unsigned nGroundVars, Statistics::updateCompressingStatistics (
unsigned nGroundFactors, unsigned nrGroundVars,
unsigned nClusterVars, unsigned nrGroundFactors,
unsigned nClusterFactors, unsigned nrClusterVars,
unsigned nWithoutNeighs) { unsigned nrClusterFactors,
compressInfo_.push_back (CompressInfo (nGroundVars, nGroundFactors, unsigned nrNeighborless) {
nClusterVars, nClusterFactors, nWithoutNeighs)); compressInfo_.push_back (CompressInfo (nrGroundVars, nrGroundFactors,
nrClusterVars, nrClusterFactors, nrNeighborless));
} }
@ -334,26 +402,30 @@ Statistics::getStatisticString (void)
{ {
stringstream ss2, ss3, ss4, ss1; stringstream ss2, ss3, ss4, ss1;
ss1 << "running mode: " ; ss1 << "running mode: " ;
switch (InfAlgorithms::infAlgorithm) { switch (Globals::infAlgorithm) {
case InfAlgorithms::VE: ss1 << "ve" << endl; break; case InfAlgorithms::VE: ss1 << "ve" << endl; break;
case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break; case InfAlgorithms::BP: ss1 << "bp" << endl; break;
case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break; case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
} }
ss1 << "message schedule: " ; ss1 << "message schedule: " ;
switch (BpOptions::schedule) { switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_FIXED: ss1 << "sequential fixed" << endl; break; case BpOptions::Schedule::SEQ_FIXED:
case BpOptions::Schedule::SEQ_RANDOM: ss1 << "sequential random" << endl; break; ss1 << "sequential fixed" << endl;
case BpOptions::Schedule::PARALLEL: ss1 << "parallel" << endl; break; break;
case BpOptions::Schedule::MAX_RESIDUAL: ss1 << "max residual" << endl; break; case BpOptions::Schedule::SEQ_RANDOM:
ss1 << "sequential random" << endl;
break;
case BpOptions::Schedule::PARALLEL:
ss1 << "parallel" << endl;
break;
case BpOptions::Schedule::MAX_RESIDUAL:
ss1 << "max residual" << endl;
break;
} }
ss1 << "max iterations: " << BpOptions::maxIter << endl; ss1 << "max iterations: " << BpOptions::maxIter << endl;
ss1 << "accuracy " << BpOptions::accuracy << endl; ss1 << "accuracy " << BpOptions::accuracy << endl;
ss1 << endl << endl; ss1 << endl << endl;
Util::printSubHeader ("Network information", ss2);
ss2 << "---------------------------------------------------" << endl;
ss2 << " Network information" << endl;
ss2 << "---------------------------------------------------" << endl;
ss2 << left; ss2 << left;
ss2 << setw (15) << "Network Size" ; ss2 << setw (15) << "Network Size" ;
ss2 << setw (9) << "Loopy" ; ss2 << setw (9) << "Loopy" ;
@ -387,24 +459,22 @@ Statistics::getStatisticString (void)
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0; unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
if (compressInfo_.size() > 0) { if (compressInfo_.size() > 0) {
ss3 << "---------------------------------------------------" << endl; Util::printSubHeader ("Compress information", ss3);
ss3 << " Compression information" << endl;
ss3 << "---------------------------------------------------" << endl;
ss3 << left; ss3 << left;
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl; ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
ss3 << "Vars Vars Factors Factors Vars" << endl; ss3 << "Vars Vars Factors Factors Vars" << endl;
for (unsigned i = 0; i < compressInfo_.size(); i++) { for (unsigned i = 0; i < compressInfo_.size(); i++) {
ss3 << setw (9) << compressInfo_[i].nGroundVars; ss3 << setw (9) << compressInfo_[i].nrGroundVars;
ss3 << setw (10) << compressInfo_[i].nClusterVars; ss3 << setw (10) << compressInfo_[i].nrClusterVars;
ss3 << setw (10) << compressInfo_[i].nGroundFactors; ss3 << setw (10) << compressInfo_[i].nrGroundFactors;
ss3 << setw (10) << compressInfo_[i].nClusterFactors; ss3 << setw (10) << compressInfo_[i].nrClusterFactors;
ss3 << setw (10) << compressInfo_[i].nWithoutNeighs; ss3 << setw (10) << compressInfo_[i].nrNeighborless;
ss3 << endl; ss3 << endl;
c1 += compressInfo_[i].nGroundVars - compressInfo_[i].nWithoutNeighs; c1 += compressInfo_[i].nrGroundVars - compressInfo_[i].nrNeighborless;
c2 += compressInfo_[i].nClusterVars; c2 += compressInfo_[i].nrClusterVars;
c3 += compressInfo_[i].nGroundFactors - compressInfo_[i].nWithoutNeighs; c3 += compressInfo_[i].nrGroundFactors - compressInfo_[i].nrNeighborless;
c4 += compressInfo_[i].nClusterFactors; c4 += compressInfo_[i].nrClusterFactors;
if (compressInfo_[i].nWithoutNeighs != 0) { if (compressInfo_[i].nrNeighborless != 0) {
c2 --; c2 --;
c4 --; c4 --;
} }

View File

@ -1,53 +1,141 @@
#ifndef HORUS_UTIL_H #ifndef HORUS_UTIL_H
#define HORUS_UTIL_H #define HORUS_UTIL_H
#include <cmath>
#include <cassert>
#include <limits>
#include <algorithm>
#include <vector> #include <vector>
#include <set>
#include <queue>
#include <unordered_map>
#include <sstream>
#include <iostream>
#include "Horus.h" #include "Horus.h"
using namespace std; using namespace std;
namespace Util { namespace Util {
void toLog (Params&); template <typename T> void addToVector (vector<T>&, const vector<T>&);
void fromLog (Params&);
void normalize (Params&); template <typename T> void addToSet (set<T>&, const vector<T>&);
void logSum (double&, double);
void multiply (Params&, const Params&); template <typename T> void addToQueue (queue<T>&, const vector<T>&);
void multiply (Params&, const Params&, unsigned);
void add (Params&, const Params&); template <typename T> bool contains (const vector<T>&, const T&);
void add (Params&, const Params&, unsigned);
void pow (Params&, double); template <typename T> bool contains (const set<T>&, const T&);
void pow (Params&, unsigned);
double pow (double, unsigned); template <typename K, typename V> bool contains (
double factorial (double); const unordered_map<K, V>&, const K&);
unsigned nrCombinations (unsigned, unsigned);
double getL1Distance (const Params&, const Params&); template <typename T> std::string toString (const T&);
double getMaxNorm (const Params&, const Params&);
unsigned getNumberOfDigits (int); void toLog (Params&);
bool isInteger (const string&);
string parametersToString (const Params&, unsigned = PRECISION); void fromLog (Params&);
vector<string> getJointStateStrings (const VarNodes&);
double tl (double); double logSum (double, double);
double fl (double);
double multIdenty(); void multiply (Params&, const Params&);
double addIdenty();
double withEvidence(); void multiply (Params&, const Params&, unsigned);
double noEvidence();
double one(); void add (Params&, const Params&);
double zero();
void add (Params&, const Params&, unsigned);
double factorial (double);
unsigned nrCombinations (unsigned, unsigned);
unsigned expectedSize (const Ranges&);
unsigned getNumberOfDigits (int);
bool isInteger (const string&);
string parametersToString (const Params&, unsigned = Constants::PRECISION);
vector<string> getStateLines (const Vars&);
void printHeader (string, std::ostream& os = std::cout);
void printSubHeader (string, std::ostream& os = std::cout);
void printAsteriskLine (std::ostream& os = std::cout);
void printDashedLine (std::ostream& os = std::cout);
unsigned maxUnsigned (void);
};
template <class T>
std::string toString (const T& t) template <typename T> void
Util::addToVector (vector<T>& v, const vector<T>& elements)
{
v.insert (v.end(), elements.begin(), elements.end());
}
template <typename T> void
Util::addToSet (set<T>& s, const vector<T>& elements)
{
s.insert (elements.begin(), elements.end());
}
template <typename T> void
Util::addToQueue (queue<T>& q, const vector<T>& elements)
{
for (unsigned i = 0; i < elements.size(); i++) {
q.push (elements[i]);
}
}
template <typename T> bool
Util::contains (const vector<T>& v, const T& e)
{
return std::find (v.begin(), v.end(), e) != v.end();
}
template <typename T> bool
Util::contains (const set<T>& s, const T& e)
{
return s.find (e) != s.end();
}
template <typename K, typename V> bool
Util::contains (const unordered_map<K, V>& m, const K& k)
{
return m.find (k) != m.end();
}
template <typename T> std::string
Util::toString (const T& t)
{ {
std::stringstream ss; std::stringstream ss;
ss << t; ss << t;
return ss.str(); return ss.str();
} }
};
template <typename T> template <typename T>
@ -62,28 +150,31 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
} }
namespace {
const double INF = -numeric_limits<double>::infinity();
};
inline void inline double
Util::logSum (double& x, double y) Util::logSum (double x, double y)
{ {
x = log (exp (x) + exp (y)); return; return log (exp (x) + exp (y));
assert (isfinite (x) && isfinite (y)); assert (isfinite (x) && isfinite (y));
// If one value is much smaller than the other, keep the larger value. // If one value is much smaller than the other, keep the larger value.
if (x < (y - log (1e200))) { if (x < (y - log (1e200))) {
x = y; return y;
return;
} }
if (y < (x - log (1e200))) { if (y < (x - log (1e200))) {
return; return x;
} }
double diff = x - y; double diff = x - y;
assert (isfinite (diff) && isfinite (x) && isfinite (y)); assert (isfinite (diff) && isfinite (x) && isfinite (y));
if (!isfinite (exp (diff))) { // difference is too large if (!isfinite (exp (diff))) {
x = x > y ? x : y; // difference is too large
} else { // otherwise return the sum. return x > y ? x : y;
x = y + log (static_cast<double>(1.0) + exp (diff));
} }
// otherwise return the sum.
return y + log (static_cast<double>(1.0) + exp (diff));
} }
@ -140,52 +231,87 @@ Util::add (Params& v1, const Params& v2, unsigned repetitions)
inline double inline unsigned
Util::tl (double v) Util::maxUnsigned (void)
{ {
return Globals::logDomain ? log(v) : v; return numeric_limits<unsigned>::max();
} }
inline double
Util::fl (double v)
{ namespace LogAware {
return Globals::logDomain ? exp(v) : v;
}
inline double inline double
Util::multIdenty() { one()
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
Util::addIdenty()
{
return Globals::logDomain ? INF : 0.0;
}
inline double
Util::withEvidence()
{ {
return Globals::logDomain ? 0.0 : 1.0; return Globals::logDomain ? 0.0 : 1.0;
} }
inline double
Util::noEvidence() {
return Globals::logDomain ? INF : 0.0;
}
inline double inline double
Util::one() zero() {
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
Util::zero() {
return Globals::logDomain ? INF : 0.0 ; return Globals::logDomain ? INF : 0.0 ;
} }
inline double
addIdenty()
{
return Globals::logDomain ? INF : 0.0;
}
inline double
multIdenty()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
withEvidence()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
noEvidence() {
return Globals::logDomain ? INF : 0.0;
}
inline double
tl (double v)
{
return Globals::logDomain ? log (v) : v;
}
inline double
fl (double v)
{
return Globals::logDomain ? exp (v) : v;
}
void normalize (Params&);
double getL1Distance (const Params&, const Params&);
double getMaxNorm (const Params&, const Params&);
double pow (double, unsigned);
double pow (double, double);
void pow (Params&, unsigned);
void pow (Params&, double);
};
struct NetInfo struct NetInfo
{ {
NetInfo (unsigned size, bool loopy, unsigned nIters, double time) NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
@ -206,17 +332,17 @@ struct CompressInfo
{ {
CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e) CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e)
{ {
nGroundVars = a; nrGroundVars = a;
nGroundFactors = b; nrGroundFactors = b;
nClusterVars = c; nrClusterVars = c;
nClusterFactors = d; nrClusterFactors = d;
nWithoutNeighs = e; nrNeighborless = e;
} }
unsigned nGroundVars; unsigned nrGroundVars;
unsigned nGroundFactors; unsigned nrGroundFactors;
unsigned nClusterVars; unsigned nrClusterVars;
unsigned nClusterFactors; unsigned nrClusterFactors;
unsigned nWithoutNeighs; unsigned nrNeighborless;
}; };
@ -224,11 +350,17 @@ class Statistics
{ {
public: public:
static unsigned getSolvedNetworksCounting (void); static unsigned getSolvedNetworksCounting (void);
static void incrementPrimaryNetworksCounting (void); static void incrementPrimaryNetworksCounting (void);
static unsigned getPrimaryNetworksCounting (void); static unsigned getPrimaryNetworksCounting (void);
static void updateStatistics (unsigned, bool, unsigned, double); static void updateStatistics (unsigned, bool, unsigned, double);
static void printStatistics (void); static void printStatistics (void);
static void writeStatisticsToFile (const char*);
static void writeStatistics (const char*);
static void updateCompressingStatistics ( static void updateCompressingStatistics (
unsigned, unsigned, unsigned, unsigned, unsigned); unsigned, unsigned, unsigned, unsigned, unsigned);

View File

@ -0,0 +1,102 @@
#include <algorithm>
#include <sstream>
#include "Var.h"
using namespace std;
unordered_map<VarId, VarInfo> Var::varsInfo_;
Var::Var (const Var* v)
{
varId_ = v->varId();
range_ = v->range();
evidence_ = v->getEvidence();
index_ = std::numeric_limits<unsigned>::max();
}
Var::Var (VarId varId, unsigned range, int evidence)
{
assert (range != 0);
assert (evidence < (int) range);
varId_ = varId;
range_ = range;
evidence_ = evidence;
index_ = std::numeric_limits<unsigned>::max();
}
bool
Var::isValidState (int stateIndex)
{
return stateIndex >= 0 && stateIndex < (int) range_;
}
bool
Var::isValidState (const string& stateName)
{
States states = Var::getVarInfo (varId_).states;
return Util::contains (states, stateName);
}
void
Var::setEvidence (int ev)
{
assert (ev < (int) range_);
evidence_ = ev;
}
void
Var::setEvidence (const string& ev)
{
States states = Var::getVarInfo (varId_).states;
for (unsigned i = 0; i < states.size(); i++) {
if (states[i] == ev) {
evidence_ = i;
return;
}
}
assert (false);
}
string
Var::label (void) const
{
if (Var::varsHaveInfo()) {
return Var::getVarInfo (varId_).label;
}
stringstream ss;
ss << "x" << varId_;
return ss.str();
}
States
Var::states (void) const
{
if (Var::varsHaveInfo()) {
return Var::getVarInfo (varId_).states;
}
States states;
for (unsigned i = 0; i < range_; i++) {
stringstream ss;
ss << i ;
states.push_back (ss.str());
}
return states;
}

View File

@ -0,0 +1,108 @@
#ifndef HORUS_Var_H
#define HORUS_Var_H
#include <cassert>
#include <iostream>
#include "Util.h"
#include "Horus.h"
using namespace std;
struct VarInfo
{
VarInfo (string l, const States& sts) : label(l), states(sts) { }
string label;
States states;
};
class Var
{
public:
Var (const Var*);
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
virtual ~Var (void) { };
unsigned varId (void) const { return varId_; }
unsigned range (void) const { return range_; }
int getEvidence (void) const { return evidence_; }
unsigned getIndex (void) const { return index_; }
void setIndex (unsigned idx) { index_ = idx; }
operator unsigned () const { return index_; }
bool hasEvidence (void) const
{
return evidence_ != Constants::NO_EVIDENCE;
}
bool operator== (const Var& var) const
{
assert (!(varId_ == var.varId() && range_ != var.range()));
return varId_ == var.varId();
}
bool operator!= (const Var& var) const
{
assert (!(varId_ == var.varId() && range_ != var.range()));
return varId_ != var.varId();
}
bool isValidState (int);
bool isValidState (const string&);
void setEvidence (int);
void setEvidence (const string&);
string label (void) const;
States states (void) const;
static void addVarInfo (
VarId vid, string label, const States& states)
{
assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
}
static VarInfo getVarInfo (VarId vid)
{
assert (Util::contains (varsInfo_, vid));
return varsInfo_.find (vid)->second;
}
static bool varsHaveInfo (void)
{
return varsInfo_.size() != 0;
}
static void clearVarsInfo (void)
{
varsInfo_.clear();
}
private:
VarId varId_;
unsigned range_;
int evidence_;
unsigned index_;
static unordered_map<VarId, VarInfo> varsInfo_;
};
#endif // BP_Var_H

View File

@ -6,61 +6,27 @@
#include "Util.h" #include "Util.h"
VarElimSolver::VarElimSolver (const BayesNet& bn) : Solver (&bn)
{
bayesNet_ = &bn;
factorGraph_ = new FactorGraph (bn);
}
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (&fg)
{
bayesNet_ = 0;
factorGraph_ = &fg;
}
VarElimSolver::~VarElimSolver (void) VarElimSolver::~VarElimSolver (void)
{ {
if (bayesNet_) { delete factorList_.back();
delete factorGraph_;
}
} }
Params Params
VarElimSolver::getPosterioriOf (VarId vid) VarElimSolver::solveQuery (VarIds queryVids)
{
assert (factorGraph_->getFgVarNode (vid));
FgVarNode* vn = factorGraph_->getFgVarNode (vid);
if (vn->hasEvidence()) {
Params params (vn->nrStates(), 0.0);
params[vn->getEvidence()] = 1.0;
return params;
}
return getJointDistributionOf (VarIds() = {vid});
}
Params
VarElimSolver::getJointDistributionOf (const VarIds& vids)
{ {
factorList_.clear(); factorList_.clear();
varFactors_.clear(); varFactors_.clear();
elimOrder_.clear(); elimOrder_.clear();
createFactorList(); createFactorList();
introduceEvidence(); absorveEvidence();
chooseEliminationOrder (vids); findEliminationOrder (queryVids);
processFactorList (vids); processFactorList (queryVids);
Params params = factorList_.back()->getParameters(); Params params = factorList_.back()->params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::fromLog (params); Util::fromLog (params);
} }
delete factorList_.back();
return params; return params;
} }
@ -69,11 +35,11 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
void void
VarElimSolver::createFactorList (void) VarElimSolver::createFactorList (void)
{ {
const FgFacSet& factorNodes = factorGraph_->getFactorNodes(); const FacNodes& facNodes = fg.facNodes();
factorList_.reserve (factorNodes.size() * 2); factorList_.reserve (facNodes.size() * 2);
for (unsigned i = 0; i < factorNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
factorList_.push_back (new Factor (*factorNodes[i]->factor())); factorList_.push_back (new Factor (facNodes[i]->factor()));
const FgVarSet& neighs = factorNodes[i]->neighbors(); const VarNodes& neighs = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) { for (unsigned j = 0; j < neighs.size(); j++) {
unordered_map<VarId,vector<unsigned> >::iterator it unordered_map<VarId,vector<unsigned> >::iterator it
= varFactors_.find (neighs[j]->varId()); = varFactors_.find (neighs[j]->varId());
@ -89,16 +55,16 @@ VarElimSolver::createFactorList (void)
void void
VarElimSolver::introduceEvidence (void) VarElimSolver::absorveEvidence (void)
{ {
const FgVarSet& varNodes = factorGraph_->getVarNodes(); const VarNodes& varNodes = fg.varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) { for (unsigned i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) { if (varNodes[i]->hasEvidence()) {
const vector<unsigned>& idxs = const vector<unsigned>& idxs =
varFactors_.find (varNodes[i]->varId())->second; varFactors_.find (varNodes[i]->varId())->second;
for (unsigned j = 0; j < idxs.size(); j++) { for (unsigned j = 0; j < idxs.size(); j++) {
Factor* factor = factorList_[idxs[j]]; Factor* factor = factorList_[idxs[j]];
if (factor->nrVariables() == 1) { if (factor->nrArguments() == 1) {
factorList_[idxs[j]] = 0; factorList_[idxs[j]] = 0;
} else { } else {
factorList_[idxs[j]]->absorveEvidence ( factorList_[idxs[j]]->absorveEvidence (
@ -112,21 +78,9 @@ VarElimSolver::introduceEvidence (void)
void void
VarElimSolver::chooseEliminationOrder (const VarIds& vids) VarElimSolver::findEliminationOrder (const VarIds& vids)
{ {
if (bayesNet_) { elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
ElimGraph graph (*bayesNet_);
elimOrder_ = graph.getEliminatingOrder (vids);
} else {
const FgVarSet& varNodes = factorGraph_->getVarNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
VarId vid = varNodes[i]->varId();
if (std::find (vids.begin(), vids.end(), vid) == vids.end()
&& !varNodes[i]->hasEvidence()) {
elimOrder_.push_back (vid);
}
}
}
} }
@ -149,12 +103,12 @@ VarElimSolver::processFactorList (const VarIds& vids)
VarIds unobservedVids; VarIds unobservedVids;
for (unsigned i = 0; i < vids.size(); i++) { for (unsigned i = 0; i < vids.size(); i++) {
if (factorGraph_->getFgVarNode (vids[i])->hasEvidence() == false) { if (fg.getVarNode (vids[i])->hasEvidence() == false) {
unobservedVids.push_back (vids[i]); unobservedVids.push_back (vids[i]);
} }
} }
finalFactor->reorderVariables (unobservedVids); finalFactor->reorderArguments (unobservedVids);
finalFactor->normalize(); finalFactor->normalize();
factorList_.push_back (finalFactor); factorList_.push_back (finalFactor);
} }
@ -165,13 +119,12 @@ void
VarElimSolver::eliminate (VarId elimVar) VarElimSolver::eliminate (VarId elimVar)
{ {
Factor* result = 0; Factor* result = 0;
FgVarNode* vn = factorGraph_->getFgVarNode (elimVar);
vector<unsigned>& idxs = varFactors_.find (elimVar)->second; vector<unsigned>& idxs = varFactors_.find (elimVar)->second;
for (unsigned i = 0; i < idxs.size(); i++) { for (unsigned i = 0; i < idxs.size(); i++) {
unsigned idx = idxs[i]; unsigned idx = idxs[i];
if (factorList_[idx]) { if (factorList_[idx]) {
if (result == 0) { if (result == 0) {
result = new Factor(*factorList_[idx]); result = new Factor (*factorList_[idx]);
} else { } else {
result->multiply (*factorList_[idx]); result->multiply (*factorList_[idx]);
} }
@ -179,10 +132,10 @@ VarElimSolver::eliminate (VarId elimVar)
factorList_[idx] = 0; factorList_[idx] = 0;
} }
} }
if (result != 0 && result->nrVariables() != 1) { if (result != 0 && result->nrArguments() != 1) {
result->sumOut (vn->varId()); result->sumOut (elimVar);
factorList_.push_back (result); factorList_.push_back (result);
const VarIds& resultVarIds = result->getVarIds(); const VarIds& resultVarIds = result->arguments();
for (unsigned i = 0; i < resultVarIds.size(); i++) { for (unsigned i = 0; i < resultVarIds.size(); i++) {
vector<unsigned>& idxs = vector<unsigned>& idxs =
varFactors_.find (resultVarIds[i])->second; varFactors_.find (resultVarIds[i])->second;
@ -199,7 +152,6 @@ VarElimSolver::printActiveFactors (void)
for (unsigned i = 0; i < factorList_.size(); i++) { for (unsigned i = 0; i < factorList_.size(); i++) {
if (factorList_[i] != 0) { if (factorList_[i] != 0) {
factorList_[i]->print(); factorList_[i]->print();
cout << endl;
} }
} }
} }

View File

@ -5,7 +5,6 @@
#include "Solver.h" #include "Solver.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "BayesNet.h"
#include "Horus.h" #include "Horus.h"
@ -15,23 +14,25 @@ using namespace std;
class VarElimSolver : public Solver class VarElimSolver : public Solver
{ {
public: public:
VarElimSolver (const BayesNet&); VarElimSolver (const FactorGraph& fg) : Solver (fg) { }
VarElimSolver (const FactorGraph&);
~VarElimSolver (void); ~VarElimSolver (void);
void runSolver (void) { }
Params getPosterioriOf (VarId); Params solveQuery (VarIds);
Params getJointDistributionOf (const VarIds&);
private: private:
void createFactorList (void); void createFactorList (void);
void introduceEvidence (void);
void chooseEliminationOrder (const VarIds&); void absorveEvidence (void);
void findEliminationOrder (const VarIds&);
void processFactorList (const VarIds&); void processFactorList (const VarIds&);
void eliminate (VarId); void eliminate (VarId);
void printActiveFactors (void); void printActiveFactors (void);
const BayesNet* bayesNet_;
const FactorGraph* factorGraph_;
vector<Factor*> factorList_; vector<Factor*> factorList_;
VarIds elimOrder_; VarIds elimOrder_;
unordered_map<VarId, vector<unsigned>> varFactors_; unordered_map<VarId, vector<unsigned>> varFactors_;

View File

@ -1,100 +0,0 @@
#include <algorithm>
#include <sstream>
#include "VarNode.h"
#include "GraphicalModel.h"
using namespace std;
VarNode::VarNode (const VarNode* v)
{
varId_ = v->varId();
nrStates_ = v->nrStates();
evidence_ = v->getEvidence();
index_ = std::numeric_limits<unsigned>::max();
}
VarNode::VarNode (VarId varId, unsigned nrStates, int evidence)
{
assert (nrStates != 0);
assert (evidence < (int) nrStates);
varId_ = varId;
nrStates_ = nrStates;
evidence_ = evidence;
index_ = std::numeric_limits<unsigned>::max();
}
bool
VarNode::isValidState (int stateIndex)
{
return stateIndex >= 0 && stateIndex < (int) nrStates_;
}
bool
VarNode::isValidState (const string& stateName)
{
States states = GraphicalModel::getVariableInformation (varId_).states;
return find (states.begin(), states.end(), stateName) != states.end();
}
void
VarNode::setEvidence (int ev)
{
assert (ev < (int) nrStates_);
evidence_ = ev;
}
void
VarNode::setEvidence (const string& ev)
{
States states = GraphicalModel::getVariableInformation (varId_).states;
for (unsigned i = 0; i < states.size(); i++) {
if (states[i] == ev) {
evidence_ = i;
return;
}
}
assert (false);
}
string
VarNode::label (void) const
{
if (GraphicalModel::variablesHaveInformation()) {
return GraphicalModel::getVariableInformation (varId_).label;
}
stringstream ss;
ss << "x" << varId_;
return ss.str();
}
States
VarNode::states (void) const
{
if (GraphicalModel::variablesHaveInformation()) {
return GraphicalModel::getVariableInformation (varId_).states;
}
States states;
for (unsigned i = 0; i < nrStates_; i++) {
stringstream ss;
ss << i ;
states.push_back (ss.str());
}
return states;
}

View File

@ -1,54 +0,0 @@
#ifndef HORUS_VARNODE_H
#define HORUS_VARNODE_H
#include "Horus.h"
using namespace std;
class VarNode
{
public:
VarNode (const VarNode*);
VarNode (VarId, unsigned, int = NO_EVIDENCE);
virtual ~VarNode (void) {};
bool isValidState (int);
bool isValidState (const string&);
void setEvidence (int);
void setEvidence (const string&);
string label (void) const;
States states (void) const;
unsigned varId (void) const { return varId_; }
unsigned nrStates (void) const { return nrStates_; }
bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; }
int getEvidence (void) const { return evidence_; }
unsigned getIndex (void) const { return index_; }
void setIndex (unsigned idx) { index_ = idx; }
operator unsigned () const { return index_; }
bool operator== (const VarNode& var) const
{
cout << "equal operator called" << endl;
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
return varId_ == var.varId();
}
bool operator!= (const VarNode& var) const
{
cout << "diff operator called" << endl;
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
return varId_ != var.varId();
}
private:
VarId varId_;
unsigned nrStates_;
int evidence_;
unsigned index_;
};
#endif // BP_VARNODE_H

View File

@ -0,0 +1,35 @@
if [ $1 ] && [ $1 == "clear" ]; then
rm *~
rm -f school/*.log school/*~
rm -f city/*.log city/*~
rm -f workshop_attrs/*.log workshop_attrs/*~
fi
function run_solver
{
constraint=$1
solver_flag=true
if [ -n "$2" ]; then
if [ $SOLVER = hve ]; then
extra_flag=clpbn_horus:set_horus_flag\(elim_heuristic,$2\)
elif [ $SOLVER = bp ]; then
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
elif [ $SOLVER = cbp ]; then
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
else
echo "unknow flag $2"
fi
fi
/usr/bin/time -o $LOG_FILE -a -f "real:%E\tuser:%U\tsys:%S" \
$YAP << EOF >> $LOG_FILE 2>> ignore.$LOG_FILE
[$NETWORK].
[$constraint].
clpbn_horus:set_solver($SOLVER).
clpbn_horus:set_horus_flag(use_logarithms, true).
$solver_flag.
$QUERY.
open("$LOG_FILE", 'append', S), format(S, '$constraint: ~15+ ', []), close(S).
EOF
}

View File

@ -1,50 +0,0 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_bnbp
YAP=~/bin/town_bnbp
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=bnbp.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
run_solver town_5000 $1 town_5000 $3 $4 $5
run_solver town_10000 $1 town_10000 $3 $4 $5
run_solver town_50000 $1 town_50000 $3 $4 $5
run_solver town_100000 $1 town_100000 $3 $4 $5
run_solver town_500000 $1 town_500000 $3 $4 $5
run_solver town_1000000 $1 town_1000000 $3 $4 $5
}
run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed

View File

@ -0,0 +1,17 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="bp"
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed

View File

@ -1,54 +1,17 @@
#!/bin/bash #!/bin/bash
cp ~/bin/yap ~/bin/town_cbp source city.sh
YAP=~/bin/town_cbp source ../benchs.sh
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log SOLVER="cbp"
OUT_FILE_NAME=cbp.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
YAP=~/bin/$SHORTNAME-$SOLVER
function run_solver LOG_FILE=$SOLVER.log
{ #LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
function run_all_graphs run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
run_solver town_5000 $1 town_5000 $3 $4 $5
run_solver town_10000 $1 town_10000 $3 $4 $5
run_solver town_50000 $1 town_50000 $3 $4 $5
run_solver town_100000 $1 town_100000 $3 $4 $5
run_solver town_500000 $1 town_500000 $3 $4 $5
run_solver town_1000000 $1 town_1000000 $3 $4 $5
run_solver town_2500000 $1 town_2500000 $3 $4 $5
run_solver town_5000000 $1 town_5000000 $3 $4 $5
run_solver town_7500000 $1 town_7500000 $3 $4 $5
run_solver town_10000000 $1 town_10000000 $3 $4 $5
}
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed

View File

@ -0,0 +1,25 @@
#!/bin/bash
NETWORK="'../../examples/city'"
SHORTNAME="city"
QUERY="is_joe_guilty(X)"
function run_all_graphs
{
cp ~/bin/yap $YAP
echo -n "**********************************" >> $LOG_FILE
echo "**********************************" >> $LOG_FILE
echo "results for solver $1" >> $LOG_FILE
echo -n "**********************************" >> $LOG_FILE
echo "**********************************" >> $LOG_FILE
run_solver city_5 $2
#run_solver city_1000 $2
#run_solver city_5000 $2
#run_solver city_10000 $2
#run_solver city_50000 $2
#run_solver city_100000 $2
#run_solver city_500000 $2
#run_solver city_1000000 $2
}

View File

@ -0,0 +1,37 @@
#!/home/tiago/bin/yap -L --
:- initialization(main).
main :-
unix(argv([H])),
generate_town(H).
generate_town(N) :-
atomic_concat(['city_', N, '.yap'], FileName),
open(FileName, 'write', S),
atom_number(N, N2),
generate_people(S, N2, 4),
write(S, '\n'),
generate_query(S, N2, 4),
write(S, '\n'),
close(S).
generate_people(S, N, Counting) :-
Counting > N, !.
generate_people(S, N, Counting) :-
format(S, 'people(p~w, nyc).~n', [Counting]),
Counting1 is Counting + 1,
generate_people(S, N, Counting1).
generate_query(S, N, Counting) :-
Counting > N, !.
generate_query(S, N, Counting) :- !,
format(S, 'ev(descn(p~w, t)).~n', [Counting]),
Counting1 is Counting + 1,
generate_query(S, N, Counting1).

View File

@ -1,50 +0,0 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_fgbp
YAP=~/bin/town_fgbp
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=fb_bp.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
#run_solver town_5000 $1 town_5000 $3 $4 $5
#run_solver town_10000 $1 town_10000 $3 $4 $5
#run_solver town_50000 $1 town_50000 $3 $4 $5
#run_solver town_100000 $1 town_100000 $3 $4 $5
#run_solver town_500000 $1 town_500000 $3 $4 $5
#run_solver town_1000000 $1 town_1000000 $3 $4 $5
}
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed

View File

@ -0,0 +1,17 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="fove"
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILEE
run_all_graphs "fove "

View File

@ -1,50 +0,0 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_gibbs
YAP=~/bin/town_gibbs
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=gibbs.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
run_solver town_5000 $1 town_5000 $3 $4 $5
run_solver town_10000 $1 town_10000 $3 $4 $5
run_solver town_50000 $1 town_50000 $3 $4 $5
run_solver town_100000 $1 town_100000 $3 $4 $5
run_solver town_500000 $1 town_500000 $3 $4 $5
run_solver town_1000000 $1 town_1000000 $3 $4 $5
}
run_all_graphs gibbs "gibbs "

View File

@ -0,0 +1,17 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="hve"
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors

View File

@ -1,50 +0,0 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_jt
YAP=~/bin/town_jt
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=jt.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
run_solver town_5000 $1 town_5000 $3 $4 $5
run_solver town_10000 $1 town_10000 $3 $4 $5
run_solver town_50000 $1 town_50000 $3 $4 $5
run_solver town_100000 $1 town_100000 $3 $4 $5
run_solver town_500000 $1 town_500000 $3 $4 $5
run_solver town_1000000 $1 town_1000000 $3 $4 $5
}
run_all_graphs jt "jt "

View File

@ -1,65 +0,0 @@
conservative_city(City, Cons) :-
cons_table(City, ConsDist),
{ Cons = conservative_city(City) with p([y,n], ConsDist) }.
gender(X, Gender) :-
gender_table(X, GenderDist),
{ Gender = gender(X) with p([m,f], GenderDist) }.
hair_color(X, Color) :-
lives(X, City),
conservative_city(City, Cons),
hair_color_table(X,ColorTable),
{ Color = hair_color(X) with
p([t,f], ColorTable,[Cons]) }.
car_color(X, Color) :-
hair_color(X, HColor),
car_color_table(X,CColorTable),
{ Color = car_color(X) with
p([t,f], CColorTable,[HColor]) }.
height(X, Height) :-
gender(X, Gender),
height_table(X,HeightTable),
{ Height = height(X) with
p([t,f], HeightTable,[Gender]) }.
shoe_size(X, Shoesize) :-
height(X, Height),
shoe_size_table(X,ShoesizeTable),
{ Shoesize = shoe_size(X) with
p([t,f], ShoesizeTable,[Height]) }.
guilty(X, Guilt) :-
guilty_table(X, GuiltDist),
{ Guilt = guilty(X) with p([y,n], GuiltDist) }.
descn(X, Descn) :-
car_color(X, Car),
hair_color(X, Hair),
height(X, Height),
guilty(X, Guilt),
descn_table(X, DescTable),
{ Descn = descn(X) with
p([t,f], DescTable,[Car,Hair,Height,Guilt]) }.
witness(City, Witness) :-
descn(joe, DescnJ),
descn(p2, Descn2),
wit_table(WitTable),
{ Witness = witness(City) with
p([t,f], WitTable,[DescnJ, Descn2]) }.
:- ensure_loaded(tables).

View File

@ -1,46 +0,0 @@
cons_table(amsterdam, [0.2, 0.8]) :- !.
cons_table(_, [0.8, 0.2]).
gender_table(_, [0.55, 0.44]).
hair_color_table(_,
/* conservative_city */
/* y n */
[ 0.05, 0.1,
0.95, 0.9 ]).
car_color_table(_,
/* t f */
[ 0.9, 0.2,
0.1, 0.8 ]).
height_table(_,
/* m f */
[ 0.6, 0.4,
0.4, 0.6 ]).
shoe_size_table(_,
/* t f */
[ 0.9, 0.1,
0.1, 0.9 ]).
guilty_table(_, [0.23, 0.77]).
descn_table(_,
/* color, hair, height, guilt */
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
[ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.33, 0.91, 0.95, 0.92,
0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.61, 0.28, 0.77, 0.09, 0.05, 0.08]).
wit_table([0.2, 0.45, 0.24, 0.34,
0.8, 0.55, 0.76, 0.66]).

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More