Merge branch 'master' of git://yap.git.sourceforge.net/gitroot/yap/yap-6.3
This commit is contained in:
commit
b66b261972
@ -490,7 +490,6 @@ Yap_HasOp(Atom a)
|
||||
OpEntry *
|
||||
Yap_OpPropForModule(Atom a, Term mod)
|
||||
{ /* look property list of atom a for kind */
|
||||
CACHE_REGS
|
||||
AtomEntry *ae = RepAtom(a);
|
||||
PropEntry *pp;
|
||||
OpEntry *info = NULL;
|
||||
@ -767,6 +766,7 @@ ExpandPredHash(void)
|
||||
Prop
|
||||
Yap_NewPredPropByFunctor(FunctorEntry *fe, Term cur_mod)
|
||||
{
|
||||
CACHE_REGS
|
||||
PredEntry *p = (PredEntry *) Yap_AllocAtomSpace(sizeof(*p));
|
||||
|
||||
if (p == NULL) {
|
||||
@ -902,6 +902,7 @@ Yap_NewThreadPred(PredEntry *ap USES_REGS)
|
||||
Prop
|
||||
Yap_NewPredPropByAtom(AtomEntry *ae, Term cur_mod)
|
||||
{
|
||||
CACHE_REGS
|
||||
Prop p0;
|
||||
PredEntry *p = (PredEntry *) Yap_AllocAtomSpace(sizeof(*p));
|
||||
|
||||
|
@ -2053,6 +2053,7 @@ a_try(op_numbers opcode, CELL lab, CELL opr, int nofalts, int hascut, yamop *cod
|
||||
yamop *newcp;
|
||||
/* emit a special instruction and then a label for backpatching */
|
||||
if (pass_no) {
|
||||
CACHE_REGS
|
||||
UInt size = (UInt)NEXTOP((yamop *)NULL,OtaLl);
|
||||
if ((newcp = (yamop *)Yap_AllocCodeSpace(size)) == NULL) {
|
||||
/* OOOPS, got in trouble, must do a longjmp and recover space */
|
||||
|
@ -561,6 +561,7 @@ X_API YAP_tag_t STD_PROTO(YAP_TagOfTerm,(Term));
|
||||
X_API size_t STD_PROTO(YAP_ExportTerm,(Term, char *, size_t));
|
||||
X_API size_t STD_PROTO(YAP_SizeOfExportedTerm,(char *));
|
||||
X_API Term STD_PROTO(YAP_ImportTerm,(char *));
|
||||
X_API int STD_PROTO(YAP_RequiresExtraStack,(size_t));
|
||||
|
||||
static UInt
|
||||
current_arity(void)
|
||||
@ -2705,7 +2706,6 @@ YAP_InitConsult(int mode, char *filename)
|
||||
X_API IOSTREAM *
|
||||
YAP_TermToStream(Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
IOSTREAM *s;
|
||||
BACKUP_MACHINE_REGS();
|
||||
|
||||
@ -2937,7 +2937,13 @@ YAP_Init(YAP_init_args *yap_init)
|
||||
int restore_result;
|
||||
int do_bootstrap = (yap_init->YapPrologBootFile != NULL);
|
||||
CELL Trail = 0, Stack = 0, Heap = 0, Atts = 0;
|
||||
static char boot_file[256];
|
||||
char boot_file[256];
|
||||
static int initialised = FALSE;
|
||||
|
||||
/* ignore repeated calls to YAP_Init */
|
||||
if (initialised)
|
||||
return YAP_BOOT_DONE_BEFOREHAND;
|
||||
initialised = TRUE;
|
||||
|
||||
Yap_InitPageSize(); /* init memory page size, required by later functions */
|
||||
#if defined(YAPOR_COPY) || defined(YAPOR_COW) || defined(YAPOR_SBA)
|
||||
@ -3612,9 +3618,22 @@ YAP_ListToFloats(Term t, double *dblp, size_t sz)
|
||||
if (!IsPairTerm(t))
|
||||
return -1;
|
||||
hd = HeadOfTerm(t);
|
||||
if (!IsFloatTerm(hd))
|
||||
return -1;
|
||||
dblp[i++] = FloatOfTerm(hd);
|
||||
if (IsFloatTerm(hd)) {
|
||||
dblp[i++] = FloatOfTerm(hd);
|
||||
} else {
|
||||
extern double Yap_gmp_to_float(Term hd);
|
||||
|
||||
if (IsIntTerm(hd))
|
||||
dblp[i++] = IntOfTerm(hd);
|
||||
else if (IsLongIntTerm(hd))
|
||||
dblp[i++] = LongIntOfTerm(hd);
|
||||
#if USE_GMP
|
||||
else if (IsBigIntTerm(hd))
|
||||
dblp[i++] = Yap_gmp_to_float(hd);
|
||||
#endif
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
if (i == sz)
|
||||
return sz;
|
||||
t = TailOfTerm(t);
|
||||
@ -4108,3 +4127,24 @@ YAP_ImportTerm(char * buf) {
|
||||
return Yap_ImportTerm(buf);
|
||||
}
|
||||
|
||||
X_API int
|
||||
YAP_RequiresExtraStack(size_t sz) {
|
||||
CACHE_REGS
|
||||
|
||||
if (sz < 16*1024)
|
||||
sz = 16*1024;
|
||||
if (H <= ASP-sz) {
|
||||
return FALSE;
|
||||
}
|
||||
BACKUP_H();
|
||||
while (H > ASP-sz) {
|
||||
CACHE_REGS
|
||||
RECOVER_H();
|
||||
if (!dogc( 0, NULL PASS_REGS )) {
|
||||
return -1;
|
||||
}
|
||||
BACKUP_H();
|
||||
}
|
||||
RECOVER_H();
|
||||
return TRUE;
|
||||
}
|
||||
|
@ -5107,6 +5107,8 @@ p_continue_static_clause( USES_REGS1 )
|
||||
static void
|
||||
add_code_in_lu_index(LogUpdIndex *cl, PredEntry *pp)
|
||||
{
|
||||
CACHE_REGS
|
||||
|
||||
char *code_end = (char *)cl + cl->ClSize;
|
||||
Yap_inform_profiler_of_clause(cl, code_end, pp, GPROF_LU_INDEX);
|
||||
cl = cl->ChildIndex;
|
||||
@ -5119,6 +5121,7 @@ add_code_in_lu_index(LogUpdIndex *cl, PredEntry *pp)
|
||||
static void
|
||||
add_code_in_static_index(StaticIndex *cl, PredEntry *pp)
|
||||
{
|
||||
CACHE_REGS
|
||||
char *code_end = (char *)cl + cl->ClSize;
|
||||
Yap_inform_profiler_of_clause(cl, code_end, pp, GPROF_STATIC_INDEX);
|
||||
cl = cl->ChildIndex;
|
||||
@ -5131,6 +5134,7 @@ add_code_in_static_index(StaticIndex *cl, PredEntry *pp)
|
||||
|
||||
static void
|
||||
add_code_in_pred(PredEntry *pp) {
|
||||
CACHE_REGS
|
||||
yamop *clcode;
|
||||
|
||||
PELOCK(49,pp);
|
||||
@ -5202,6 +5206,7 @@ add_code_in_pred(PredEntry *pp) {
|
||||
|
||||
void
|
||||
Yap_dump_code_area_for_profiler(void) {
|
||||
CACHE_REGS
|
||||
ModEntry *me = CurrentModules;
|
||||
|
||||
while (me) {
|
||||
|
@ -1887,6 +1887,7 @@ Yap_new_ludbe(Term t, PredEntry *pe, UInt nargs)
|
||||
static LogUpdClause *
|
||||
record_lu(PredEntry *pe, Term t, int position)
|
||||
{
|
||||
CACHE_REGS
|
||||
LogUpdClause *cl;
|
||||
|
||||
if ((cl = new_lu_db_entry(t, pe)) == NULL) {
|
||||
|
1
C/exec.c
1
C/exec.c
@ -1066,6 +1066,7 @@ do_goal(Term t, yamop *CodeAdr, int arity, CELL *pt, int top USES_REGS)
|
||||
S = CellPtr (RepPredProp (PredPropByFunc (Yap_MkFunctor(AtomCall, 1),0))); /* A1 mishaps */
|
||||
|
||||
out = exec_absmi(top PASS_REGS);
|
||||
Yap_flush();
|
||||
// if (out) {
|
||||
// out = Yap_GetFromSlot(sl);
|
||||
// }
|
||||
|
27
C/gprof.c
27
C/gprof.c
@ -168,6 +168,7 @@ RBfree(rb_red_blk_node *ptr)
|
||||
|
||||
static rb_red_blk_node *
|
||||
RBTreeCreate(void) {
|
||||
CACHE_REGS
|
||||
rb_red_blk_node* temp;
|
||||
|
||||
/* see the comment in the rb_red_blk_tree structure in red_black_tree.h */
|
||||
@ -210,6 +211,7 @@ RBTreeCreate(void) {
|
||||
|
||||
static void
|
||||
LeftRotate(rb_red_blk_node* x) {
|
||||
CACHE_REGS
|
||||
rb_red_blk_node* y;
|
||||
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
||||
|
||||
@ -266,6 +268,7 @@ LeftRotate(rb_red_blk_node* x) {
|
||||
|
||||
static void
|
||||
RightRotate(rb_red_blk_node* y) {
|
||||
CACHE_REGS
|
||||
rb_red_blk_node* x;
|
||||
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
||||
|
||||
@ -318,6 +321,7 @@ RightRotate(rb_red_blk_node* y) {
|
||||
|
||||
static void
|
||||
TreeInsertHelp(rb_red_blk_node* z) {
|
||||
CACHE_REGS
|
||||
/* This function should only be called by InsertRBTree (see above) */
|
||||
rb_red_blk_node* x;
|
||||
rb_red_blk_node* y;
|
||||
@ -369,6 +373,7 @@ TreeInsertHelp(rb_red_blk_node* z) {
|
||||
|
||||
static rb_red_blk_node *
|
||||
RBTreeInsert(yamop *key, yamop *lim) {
|
||||
CACHE_REGS
|
||||
rb_red_blk_node * y;
|
||||
rb_red_blk_node * x;
|
||||
rb_red_blk_node * newNode;
|
||||
@ -440,6 +445,7 @@ RBTreeInsert(yamop *key, yamop *lim) {
|
||||
|
||||
static rb_red_blk_node*
|
||||
RBExactQuery(yamop* q) {
|
||||
CACHE_REGS
|
||||
rb_red_blk_node* x;
|
||||
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
||||
|
||||
@ -460,6 +466,7 @@ RBExactQuery(yamop* q) {
|
||||
|
||||
static rb_red_blk_node*
|
||||
RBLookup(yamop *entry) {
|
||||
CACHE_REGS
|
||||
rb_red_blk_node *current;
|
||||
|
||||
if (!LOCAL_ProfilerRoot)
|
||||
@ -495,6 +502,7 @@ RBLookup(yamop *entry) {
|
||||
/***********************************************************************/
|
||||
|
||||
static void RBDeleteFixUp(rb_red_blk_node* x) {
|
||||
CACHE_REGS
|
||||
rb_red_blk_node* root=LOCAL_ProfilerRoot->left;
|
||||
rb_red_blk_node *w;
|
||||
|
||||
@ -574,6 +582,7 @@ static void RBDeleteFixUp(rb_red_blk_node* x) {
|
||||
|
||||
static rb_red_blk_node*
|
||||
TreeSuccessor(rb_red_blk_node* x) {
|
||||
CACHE_REGS
|
||||
rb_red_blk_node* y;
|
||||
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
||||
rb_red_blk_node* root=LOCAL_ProfilerRoot;
|
||||
@ -612,6 +621,7 @@ TreeSuccessor(rb_red_blk_node* x) {
|
||||
|
||||
static void
|
||||
RBDelete(rb_red_blk_node* z){
|
||||
CACHE_REGS
|
||||
rb_red_blk_node* y;
|
||||
rb_red_blk_node* x;
|
||||
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
||||
@ -664,7 +674,8 @@ RBDelete(rb_red_blk_node* z){
|
||||
|
||||
char *set_profile_dir(char *);
|
||||
char *set_profile_dir(char *name){
|
||||
int size=0;
|
||||
CACHE_REGS
|
||||
int size=0;
|
||||
|
||||
if (name!=NULL) {
|
||||
size=strlen(name)+1;
|
||||
@ -687,8 +698,9 @@ return LOCAL_DIRNAME;
|
||||
|
||||
char *profile_names(int);
|
||||
char *profile_names(int k) {
|
||||
static char *FNAME=NULL;
|
||||
int size=200;
|
||||
CACHE_REGS
|
||||
static char *FNAME=NULL;
|
||||
int size=200;
|
||||
|
||||
if (LOCAL_DIRNAME==NULL) set_profile_dir(NULL);
|
||||
size=strlen(LOCAL_DIRNAME)+40;
|
||||
@ -709,6 +721,7 @@ int size=200;
|
||||
|
||||
void del_profile_files(void);
|
||||
void del_profile_files() {
|
||||
CACHE_REGS
|
||||
if (LOCAL_DIRNAME!=NULL) {
|
||||
remove(profile_names(PROFPREDS_FILE));
|
||||
remove(profile_names(PROFILING_FILE));
|
||||
@ -717,6 +730,7 @@ void del_profile_files() {
|
||||
|
||||
void
|
||||
Yap_inform_profiler_of_clause__(void *code_start, void *code_end, PredEntry *pe,gprof_info index_code) {
|
||||
CACHE_REGS
|
||||
buf_ptr b;
|
||||
buf_extra e;
|
||||
LOCAL_ProfOn = TRUE;
|
||||
@ -742,6 +756,7 @@ static Int profend( USES_REGS1 );
|
||||
|
||||
static void
|
||||
clean_tree(rb_red_blk_node* node) {
|
||||
CACHE_REGS
|
||||
if (node == LOCAL_ProfilerNil)
|
||||
return;
|
||||
clean_tree(node->left);
|
||||
@ -751,6 +766,7 @@ clean_tree(rb_red_blk_node* node) {
|
||||
|
||||
static void
|
||||
reset_tree(void) {
|
||||
CACHE_REGS
|
||||
clean_tree(LOCAL_ProfilerRoot);
|
||||
Yap_FreeCodeSpace((char *)LOCAL_ProfilerNil);
|
||||
LOCAL_ProfilerNil = LOCAL_ProfilerRoot = NULL;
|
||||
@ -760,6 +776,7 @@ reset_tree(void) {
|
||||
static int
|
||||
InitProfTree(void)
|
||||
{
|
||||
CACHE_REGS
|
||||
if (LOCAL_ProfilerRoot)
|
||||
reset_tree();
|
||||
while (!(LOCAL_ProfilerRoot = RBTreeCreate())) {
|
||||
@ -773,6 +790,7 @@ InitProfTree(void)
|
||||
|
||||
static void RemoveCode(CODEADDR clau)
|
||||
{
|
||||
CACHE_REGS
|
||||
rb_red_blk_node* x, *node;
|
||||
PredEntry *pp;
|
||||
UInt count;
|
||||
@ -958,6 +976,7 @@ prof_alrm(int signo, siginfo_t *si, void *scv)
|
||||
void
|
||||
Yap_InformOfRemoval(void *clau)
|
||||
{
|
||||
CACHE_REGS
|
||||
LOCAL_ProfOn = TRUE;
|
||||
if (LOCAL_FPreds != NULL) {
|
||||
/* just store info about what is going on */
|
||||
@ -1048,6 +1067,7 @@ static Int profinit( USES_REGS1 )
|
||||
|
||||
static Int start_profilers(int msec)
|
||||
{
|
||||
CACHE_REGS
|
||||
struct itimerval t;
|
||||
struct sigaction sa;
|
||||
|
||||
@ -1157,6 +1177,7 @@ static Int profres0( USES_REGS1 ) {
|
||||
void
|
||||
Yap_InitLowProf(void)
|
||||
{
|
||||
CACHE_REGS
|
||||
#if LOW_PROF
|
||||
LOCAL_ProfCalls = 0;
|
||||
LOCAL_ProfilerOn = FALSE;
|
||||
|
5
C/grow.c
5
C/grow.c
@ -718,6 +718,11 @@ AdjustScannerStacks(TokEntry **tksp, VarEntry **vep USES_REGS)
|
||||
TokEntry *tktmp;
|
||||
|
||||
switch (tks->Tok) {
|
||||
case Number_tok:
|
||||
if (IsApplTerm(tks->TokInfo)) {
|
||||
tks->TokInfo = AdjustAppl(tks->TokInfo PASS_REGS);
|
||||
}
|
||||
break;
|
||||
case Var_tok:
|
||||
case String_tok:
|
||||
if (IsOldTrail(tks->TokInfo))
|
||||
|
@ -1888,6 +1888,7 @@ emit_single_switch_case(ClauseDef *min, struct intermediates *cint, int first, i
|
||||
static UInt
|
||||
suspend_indexing(ClauseDef *min, ClauseDef *max, PredEntry *ap, struct intermediates *cint)
|
||||
{
|
||||
CACHE_REGS
|
||||
UInt tcls = ap->cs.p_code.NOfClauses;
|
||||
UInt cls = (max-min)+1;
|
||||
|
||||
|
12
C/qlyr.c
12
C/qlyr.c
@ -993,6 +993,15 @@ p_read_module_preds( USES_REGS1 )
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static void
|
||||
ReInitCatch(void)
|
||||
{
|
||||
Term t = Yap_MkNewApplTerm(PredHandleThrow->FunctorOfPred, PredHandleThrow->ArityOfPE);
|
||||
YAP_RunGoalOnce(t);
|
||||
}
|
||||
|
||||
|
||||
|
||||
static Int
|
||||
p_read_program( USES_REGS1 )
|
||||
{
|
||||
@ -1016,7 +1025,7 @@ p_read_program( USES_REGS1 )
|
||||
Sclose( stream );
|
||||
/* back to the top level we go */
|
||||
Yap_CloseSlots(PASS_REGS1);
|
||||
|
||||
ReInitCatch();
|
||||
Yap_RestartYap( 3 );
|
||||
return TRUE;
|
||||
}
|
||||
@ -1030,6 +1039,7 @@ Yap_Restore(char *s, char *lib_dir)
|
||||
return -1;
|
||||
read_module(stream);
|
||||
Sclose( stream );
|
||||
ReInitCatch();
|
||||
return DO_ONLY_CODE;
|
||||
}
|
||||
|
||||
|
@ -1619,6 +1619,7 @@ InteractSIGINT(int ch) {
|
||||
static int
|
||||
ProcessSIGINT(void)
|
||||
{
|
||||
CACHE_REGS
|
||||
int ch, out;
|
||||
|
||||
LOCAL_PrologMode |= AsyncIntMode;
|
||||
|
12
C/tracer.c
12
C/tracer.c
@ -52,17 +52,7 @@ send_tracer_message(char *start, char *name, Int arity, char *mname, CELL *args)
|
||||
if (args) {
|
||||
for (i= 0; i < arity; i++) {
|
||||
if (i > 0) fprintf(GLOBAL_stderr, ",");
|
||||
#if DEBUG
|
||||
#if COROUTINING
|
||||
Yap_Portray_delays = TRUE;
|
||||
#endif
|
||||
#endif
|
||||
Yap_plwrite(args[i], NULL, 15, Handle_vars_f, 1200);
|
||||
#if DEBUG
|
||||
#if COROUTINING
|
||||
Yap_Portray_delays = FALSE;
|
||||
#endif
|
||||
#endif
|
||||
Yap_plwrite(args[i], NULL, 15, Handle_vars_f|AttVar_Portray_f, 1200);
|
||||
}
|
||||
if (arity) {
|
||||
fprintf(GLOBAL_stderr, ")");
|
||||
|
@ -4255,7 +4255,7 @@ p_is_list_or_partial_list( USES_REGS1 )
|
||||
}
|
||||
|
||||
static Term
|
||||
numbervar(Int id)
|
||||
numbervar(Int id USES_REGS)
|
||||
{
|
||||
Term ts[1];
|
||||
ts[0] = MkIntegerTerm(id);
|
||||
@ -4263,7 +4263,7 @@ numbervar(Int id)
|
||||
}
|
||||
|
||||
static Term
|
||||
numbervar_singleton(void)
|
||||
numbervar_singleton(USES_REGS1)
|
||||
{
|
||||
Term ts[1];
|
||||
ts[0] = MkIntegerTerm(-1);
|
||||
@ -4356,9 +4356,9 @@ static Int numbervars_in_complex_term(register CELL *pt0, register CELL *pt0_end
|
||||
derefa_body(d0, ptd0, vars_in_term_unk, vars_in_term_nvar);
|
||||
/* do or pt2 are unbound */
|
||||
if (singles)
|
||||
*ptd0 = numbervar_singleton();
|
||||
*ptd0 = numbervar_singleton( PASS_REGS1 );
|
||||
else
|
||||
*ptd0 = numbervar(numbv++);
|
||||
*ptd0 = numbervar(numbv++ PASS_REGS);
|
||||
/* leave an empty slot to fill in later */
|
||||
if (H+1024 > ASP) {
|
||||
goto global_overflow;
|
||||
@ -4450,10 +4450,10 @@ Yap_NumberVars( Term inp, Int numbv, int handle_singles ) /* numbervariables in
|
||||
CELL *ptd0 = VarOfTerm(t);
|
||||
TrailTerm(TR++) = (CELL)ptd0;
|
||||
if (handle_singles) {
|
||||
*ptd0 = numbervar_singleton();
|
||||
*ptd0 = numbervar_singleton( PASS_REGS1 );
|
||||
return numbv;
|
||||
} else {
|
||||
*ptd0 = numbervar(numbv);
|
||||
*ptd0 = numbervar(numbv PASS_REGS);
|
||||
return numbv+1;
|
||||
}
|
||||
} else if (IsPrimitiveTerm(t)) {
|
||||
|
164
C/write.c
164
C/write.c
@ -66,7 +66,7 @@ typedef struct rewind_term {
|
||||
|
||||
typedef struct write_globs {
|
||||
void *stream;
|
||||
int Quote_illegal, Ignore_ops, Handle_vars, Use_portray;
|
||||
int Quote_illegal, Ignore_ops, Handle_vars, Use_portray, Portray_delays;
|
||||
int Keep_terms;
|
||||
int Write_Loops;
|
||||
int Write_strings;
|
||||
@ -90,28 +90,50 @@ STATIC_PROTO(void writeTerm, (Term, int, int, int, struct write_globs *, struct
|
||||
|
||||
#define wrputc(X,WF) Sputcode(X,WF) /* writes a character */
|
||||
|
||||
/*
|
||||
protect bracket from merging with previoous character.
|
||||
avoid stuff like not (2,3) -> not(2,3) or
|
||||
*/
|
||||
static void
|
||||
protect_open_number(struct write_globs *wglb, int minus_required)
|
||||
wropen_bracket(struct write_globs *wglb, int protect)
|
||||
{
|
||||
wrf stream = wglb->stream;
|
||||
|
||||
if (lastw == symbol && last_minus && !minus_required) {
|
||||
if (!wglb->Ignore_ops) {
|
||||
/* protect against collating - with number, and getting - 1 ^2 as (-(1))^2 */
|
||||
wrputc(' ', wglb->stream);
|
||||
}
|
||||
wrputc('(', wglb->stream);
|
||||
} else if (lastw == alphanum) {
|
||||
wrputc(' ', stream);
|
||||
}
|
||||
if (lastw != separator && protect)
|
||||
wrputc(' ', stream);
|
||||
wrputc('(', stream);
|
||||
lastw = separator;
|
||||
}
|
||||
|
||||
static void
|
||||
protect_close_number(struct write_globs *wglb, int minus_required)
|
||||
wrclose_bracket(struct write_globs *wglb, int protect)
|
||||
{
|
||||
if (lastw == symbol && last_minus && !minus_required) {
|
||||
wrputc(')', wglb->stream);
|
||||
lastw = separator;
|
||||
wrf stream = wglb->stream;
|
||||
|
||||
wrputc(')', stream);
|
||||
lastw = separator;
|
||||
}
|
||||
|
||||
static int
|
||||
protect_open_number(struct write_globs *wglb, int lm, int minus_required)
|
||||
{
|
||||
wrf stream = wglb->stream;
|
||||
|
||||
if (lastw == symbol && lm && !minus_required) {
|
||||
wropen_bracket(wglb, TRUE);
|
||||
return TRUE;
|
||||
} else if (lastw == alphanum ||
|
||||
(lastw == symbol && minus_required)) {
|
||||
wrputc(' ', stream);
|
||||
}
|
||||
return FALSE;
|
||||
}
|
||||
|
||||
static void
|
||||
protect_close_number(struct write_globs *wglb, int used_bracket)
|
||||
{
|
||||
if (used_bracket) {
|
||||
wrclose_bracket(wglb, TRUE);
|
||||
} else {
|
||||
lastw = alphanum;
|
||||
}
|
||||
@ -125,8 +147,9 @@ wrputn(Int n, struct write_globs *wglb) /* writes an integer */
|
||||
wrf stream = wglb->stream;
|
||||
char s[256], *s1=s; /* that should be enough for most integers */
|
||||
int has_minus = (n < 0);
|
||||
int ob;
|
||||
|
||||
protect_open_number(wglb, has_minus);
|
||||
ob = protect_open_number(wglb, last_minus, has_minus);
|
||||
#if HAVE_SNPRINTF
|
||||
snprintf(s, 256, Int_FORMAT, n);
|
||||
#else
|
||||
@ -134,7 +157,7 @@ wrputn(Int n, struct write_globs *wglb) /* writes an integer */
|
||||
#endif
|
||||
while (*s1)
|
||||
wrputc(*s1++, stream);
|
||||
protect_close_number(wglb, has_minus);
|
||||
protect_close_number(wglb, ob);
|
||||
}
|
||||
|
||||
#define wrputs(s, stream) Sfputs(s, stream)
|
||||
@ -190,9 +213,10 @@ static void
|
||||
write_mpint(MP_INT *big, struct write_globs *wglb) {
|
||||
char *s;
|
||||
int has_minus = mpz_sgn(big);
|
||||
int ob;
|
||||
|
||||
s = ensure_space(3+mpz_sizeinbase(big, 10));
|
||||
protect_open_number(wglb, has_minus);
|
||||
ob = protect_open_number(wglb, last_minus, has_minus);
|
||||
if (!s) {
|
||||
s = mpz_get_str(NULL, 10, big);
|
||||
if (!s)
|
||||
@ -203,7 +227,7 @@ write_mpint(MP_INT *big, struct write_globs *wglb) {
|
||||
mpz_get_str(s, 10, big);
|
||||
wrputs(s,wglb->stream);
|
||||
}
|
||||
protect_close_number(wglb, has_minus);
|
||||
protect_close_number(wglb, ob);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -271,6 +295,8 @@ wrputf(Float f, struct write_globs *wglb) /* writes a float */
|
||||
char s[256];
|
||||
wrf stream = wglb->stream;
|
||||
int sgn;
|
||||
int ob;
|
||||
|
||||
|
||||
#if HAVE_ISNAN || defined(__WIN32)
|
||||
if (isnan(f)) {
|
||||
@ -291,7 +317,7 @@ wrputf(Float f, struct write_globs *wglb) /* writes a float */
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
protect_open_number(wglb, sgn);
|
||||
ob = protect_open_number(wglb, last_minus, sgn);
|
||||
#if THREADS
|
||||
/* old style writing */
|
||||
int found_dot = FALSE, found_exp = FALSE;
|
||||
@ -343,7 +369,7 @@ wrputf(Float f, struct write_globs *wglb) /* writes a float */
|
||||
if (!buf) return;
|
||||
wrputs(buf, stream);
|
||||
#endif
|
||||
protect_close_number(wglb, sgn);
|
||||
protect_close_number(wglb, ob);
|
||||
}
|
||||
|
||||
/* writes a data base reference */
|
||||
@ -423,7 +449,7 @@ AtomIsSymbols(unsigned char *s) /* Is this atom just formed by symbols ? */
|
||||
return(separator);
|
||||
while ((ch = *s++) != '\0') {
|
||||
if (Yap_chtype[ch] != SY)
|
||||
return(alphanum);
|
||||
return alphanum;
|
||||
}
|
||||
return symbol;
|
||||
}
|
||||
@ -669,15 +695,14 @@ write_var(CELL *t, struct write_globs *wglb, struct rewind_term *rwt)
|
||||
/* make sure we don't get no creepy spaces where they shouldn't be */
|
||||
lastw = separator;
|
||||
if (IsAttVar(t)) {
|
||||
#if defined(COROUTINING) && defined(DEBUG)
|
||||
Int vcount = (t-H0);
|
||||
if (Yap_Portray_delays) {
|
||||
if (wglb->Portray_delays) {
|
||||
exts ext = ExtFromCell(t);
|
||||
struct rewind_term nrwt;
|
||||
nrwt.parent = rwt;
|
||||
nrwt.u.s.ptr = 0;
|
||||
|
||||
Yap_Portray_delays = FALSE;
|
||||
wglb->Portray_delays = FALSE;
|
||||
if (ext == attvars_ext) {
|
||||
attvar_record *attv = RepAttVar(t);
|
||||
CELL *l = &attv->Value; /* dirty low-level hack, check atts.h */
|
||||
@ -691,14 +716,13 @@ write_var(CELL *t, struct write_globs *wglb, struct rewind_term *rwt)
|
||||
l += 2;
|
||||
writeTerm(from_pointer(l, &nrwt, wglb), 999, 1, FALSE, wglb, &nrwt);
|
||||
restore_from_write(&nrwt, wglb);
|
||||
wrputc(')', wglb->stream);
|
||||
wrclose_bracket(wglb, TRUE);
|
||||
}
|
||||
Yap_Portray_delays = TRUE;
|
||||
wglb->Portray_delays = TRUE;
|
||||
return;
|
||||
}
|
||||
wrputc('D', wglb->stream);
|
||||
wrputn(vcount,wglb);
|
||||
#endif
|
||||
} else {
|
||||
wrputn(((Int) (t- H0)),wglb);
|
||||
}
|
||||
@ -790,6 +814,7 @@ write_list(Term t, int direction, int depth, struct write_globs *wglb, struct re
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void
|
||||
writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, struct rewind_term *rwt)
|
||||
/* term to write */
|
||||
@ -823,8 +848,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
||||
wrputs(",",wglb->stream);
|
||||
writeTerm(from_pointer(RepPair(t)+1, &nrwt, wglb), 999, depth + 1, FALSE, wglb, &nrwt);
|
||||
restore_from_write(&nrwt, wglb);
|
||||
wrputc(')', wglb->stream);
|
||||
lastw = separator;
|
||||
wrclose_bracket(wglb, TRUE);
|
||||
return;
|
||||
}
|
||||
if (wglb->Use_portray) {
|
||||
@ -886,7 +910,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
||||
int argno = 1;
|
||||
CELL *p = ArgsOfSFTerm(t);
|
||||
putAtom(atom, wglb->Quote_illegal, wglb);
|
||||
wrputc('(', wglb->stream);
|
||||
wropen_bracket(wglb, FALSE);
|
||||
lastw = separator;
|
||||
while (*p) {
|
||||
Int sl = 0;
|
||||
@ -904,8 +928,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
||||
wrputc(',', wglb->stream);
|
||||
argno++;
|
||||
}
|
||||
wrputc(')', wglb->stream);
|
||||
lastw = separator;
|
||||
wrclose_bracket(wglb, TRUE);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
@ -934,28 +957,22 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
||||
!IsVarTerm(tright) && IsAtomTerm(tright) &&
|
||||
Yap_IsOp(AtomOfTerm(tright));
|
||||
if (op > p) {
|
||||
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
|
||||
if (lastw != separator && !rinfixarg)
|
||||
wrputc(' ', wglb->stream);
|
||||
wrputc('(', wglb->stream);
|
||||
lastw = separator;
|
||||
wropen_bracket(wglb, TRUE);
|
||||
}
|
||||
putAtom(atom, wglb->Quote_illegal, wglb);
|
||||
if (bracket_right) {
|
||||
wrputc('(', wglb->stream);
|
||||
lastw = separator;
|
||||
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
|
||||
wropen_bracket(wglb, TRUE);
|
||||
} else if (atom == AtomMinus) {
|
||||
last_minus = TRUE;
|
||||
}
|
||||
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), rp, depth + 1, FALSE, wglb, &nrwt);
|
||||
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), rp, depth + 1, TRUE, wglb, &nrwt);
|
||||
restore_from_write(&nrwt, wglb);
|
||||
if (bracket_right) {
|
||||
wrputc(')', wglb->stream);
|
||||
lastw = separator;
|
||||
wrclose_bracket(wglb, TRUE);
|
||||
}
|
||||
if (op > p) {
|
||||
wrputc(')', wglb->stream);
|
||||
lastw = separator;
|
||||
wrclose_bracket(wglb, TRUE);
|
||||
}
|
||||
} else if (!wglb->Ignore_ops &&
|
||||
Arity == 1 &&
|
||||
@ -963,29 +980,24 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
||||
Term tleft = ArgOfTerm(1, t);
|
||||
|
||||
int bracket_left =
|
||||
!IsVarTerm(tleft) && IsAtomTerm(tleft) &&
|
||||
!IsVarTerm(tleft) &&
|
||||
IsAtomTerm(tleft) &&
|
||||
Yap_IsOp(AtomOfTerm(tleft));
|
||||
if (op > p) {
|
||||
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
|
||||
if (lastw != separator && !rinfixarg)
|
||||
wrputc(' ', wglb->stream);
|
||||
wrputc('(', wglb->stream);
|
||||
lastw = separator;
|
||||
wropen_bracket(wglb, TRUE);
|
||||
}
|
||||
if (bracket_left) {
|
||||
wrputc('(', wglb->stream);
|
||||
lastw = separator;
|
||||
wropen_bracket(wglb, TRUE);
|
||||
}
|
||||
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), lp, depth + 1, rinfixarg, wglb, &nrwt);
|
||||
restore_from_write(&nrwt, wglb);
|
||||
if (bracket_left) {
|
||||
wrputc(')', wglb->stream);
|
||||
lastw = separator;
|
||||
wrclose_bracket(wglb, TRUE);
|
||||
}
|
||||
putAtom(atom, wglb->Quote_illegal, wglb);
|
||||
if (op > p) {
|
||||
wrputc(')', wglb->stream);
|
||||
lastw = separator;
|
||||
wrclose_bracket(wglb, TRUE);
|
||||
}
|
||||
} else if (!wglb->Ignore_ops &&
|
||||
Arity == 2 && Yap_IsInfixOp(atom, &op, &lp,
|
||||
@ -1001,41 +1013,36 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
||||
|
||||
if (op > p) {
|
||||
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
|
||||
if (lastw != separator && !rinfixarg)
|
||||
wrputc(' ', wglb->stream);
|
||||
wrputc('(', wglb->stream);
|
||||
wropen_bracket(wglb, TRUE);
|
||||
lastw = separator;
|
||||
}
|
||||
if (bracket_left) {
|
||||
wrputc('(', wglb->stream);
|
||||
lastw = separator;
|
||||
wropen_bracket(wglb, TRUE);
|
||||
}
|
||||
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), lp, depth + 1, rinfixarg, wglb, &nrwt);
|
||||
t = AbsAppl(restore_from_write(&nrwt, wglb)-1);
|
||||
if (bracket_left) {
|
||||
wrputc(')', wglb->stream);
|
||||
lastw = separator;
|
||||
wrclose_bracket(wglb, TRUE);
|
||||
}
|
||||
/* avoid quoting commas */
|
||||
if (strcmp(RepAtom(atom)->StrOfAE,","))
|
||||
putAtom(atom, wglb->Quote_illegal, wglb);
|
||||
else {
|
||||
/* avoid quoting commas and bars */
|
||||
if (!strcmp(RepAtom(atom)->StrOfAE,",")) {
|
||||
wrputc(',', wglb->stream);
|
||||
lastw = separator;
|
||||
}
|
||||
if (bracket_right) {
|
||||
wrputc('(', wglb->stream);
|
||||
} else if (!strcmp(RepAtom(atom)->StrOfAE,"|")) {
|
||||
wrputc('|', wglb->stream);
|
||||
lastw = separator;
|
||||
} else
|
||||
putAtom(atom, wglb->Quote_illegal, wglb);
|
||||
if (bracket_right) {
|
||||
wropen_bracket(wglb, TRUE);
|
||||
}
|
||||
writeTerm(from_pointer(RepAppl(t)+2, &nrwt, wglb), rp, depth + 1, TRUE, wglb, &nrwt);
|
||||
restore_from_write(&nrwt, wglb);
|
||||
if (bracket_right) {
|
||||
wrputc(')', wglb->stream);
|
||||
lastw = separator;
|
||||
wrclose_bracket(wglb, TRUE);
|
||||
}
|
||||
if (op > p) {
|
||||
wrputc(')', wglb->stream);
|
||||
lastw = separator;
|
||||
wrclose_bracket(wglb, TRUE);
|
||||
}
|
||||
} else if (wglb->Handle_vars && functor == LOCAL_FunctorVar) {
|
||||
Term ti = ArgOfTerm(1, t);
|
||||
@ -1068,8 +1075,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
||||
lastw = separator;
|
||||
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), 999, depth + 1, FALSE, wglb, &nrwt);
|
||||
restore_from_write(&nrwt, wglb);
|
||||
wrputc(')', wglb->stream);
|
||||
lastw = separator;
|
||||
wrclose_bracket(wglb, TRUE);
|
||||
}
|
||||
} else if (!wglb->Ignore_ops && functor == FunctorBraces) {
|
||||
wrputc('{', wglb->stream);
|
||||
@ -1098,7 +1104,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
||||
} else {
|
||||
putAtom(atom, wglb->Quote_illegal, wglb);
|
||||
lastw = separator;
|
||||
wrputc('(', wglb->stream);
|
||||
wropen_bracket(wglb, FALSE);
|
||||
for (op = 1; op <= Arity; ++op) {
|
||||
if (op == wglb->MaxArgs) {
|
||||
wrputc('.', wglb->stream);
|
||||
@ -1113,8 +1119,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
||||
lastw = separator;
|
||||
}
|
||||
}
|
||||
wrputc(')', wglb->stream);
|
||||
lastw = separator;
|
||||
wrclose_bracket(wglb, TRUE);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1138,6 +1143,7 @@ Yap_plwrite(Term t, void *mywrite, int max_depth, int flags, int priority)
|
||||
wglb.Quote_illegal = flags & Quote_illegal_f;
|
||||
wglb.Handle_vars = flags & Handle_vars_f;
|
||||
wglb.Use_portray = flags & Use_portray_f;
|
||||
wglb.Portray_delays = flags & AttVar_Portray_f;
|
||||
wglb.MaxDepth = max_depth;
|
||||
wglb.MaxArgs = max_depth;
|
||||
/* notice: we must have ASP well set when using portray, otherwise
|
||||
|
@ -498,6 +498,7 @@ void STD_PROTO(Yap_init_optyap_preds,(void));
|
||||
|
||||
/* pl-file.c */
|
||||
struct PL_local_data *Yap_InitThreadIO(int wid);
|
||||
void Yap_flush(void);
|
||||
|
||||
static inline
|
||||
yamop *
|
||||
|
@ -159,6 +159,9 @@ typedef enum
|
||||
#ifdef HAVE_LOCALE_H
|
||||
#include <locale.h>
|
||||
#endif
|
||||
#ifdef HAVE_LIMITS_H /* get MAXPATHLEN */
|
||||
#include <limits.h>
|
||||
#endif
|
||||
#include <setjmp.h>
|
||||
#include <assert.h>
|
||||
#if HAVE_SYS_PARAM_H
|
||||
|
@ -323,12 +323,6 @@ int STD_PROTO(Yap_growtrail_in_parser, (tr_fr_ptr *, TokEntry **, VarEntry **)
|
||||
extern int errno;
|
||||
#endif
|
||||
|
||||
#ifdef DEBUG
|
||||
#if COROUTINING
|
||||
extern int Yap_Portray_delays;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
EXTERN inline UInt STD_PROTO(HashFunction, (unsigned char *));
|
||||
EXTERN inline UInt STD_PROTO(WideHashFunction, (wchar_t *));
|
||||
|
||||
|
@ -710,6 +710,7 @@ all: startup.yss
|
||||
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE))
|
||||
@ENABLE_CPLINT@ (cd packages/cplint/slipcase; $(MAKE))
|
||||
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE))
|
||||
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE))
|
||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE))
|
||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE))
|
||||
@ENABLE_JPL@ @INSTALL_DLLS@ (cd packages/jpl; $(MAKE))
|
||||
@ -788,6 +789,7 @@ install_unix: startup.yss libYap.a
|
||||
@ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE) install)
|
||||
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) install)
|
||||
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) install)
|
||||
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE) install)
|
||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) install)
|
||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) install)
|
||||
|
||||
@ -840,6 +842,7 @@ install_win32: startup.yss @ENABLE_WINCONSOLE@ pl-yap@EXEC_SUFFIX@
|
||||
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) install)
|
||||
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) install)
|
||||
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) install)
|
||||
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE) install)
|
||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) install)
|
||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) install)
|
||||
|
||||
@ -904,6 +907,7 @@ clean: clean_docs
|
||||
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) clean)
|
||||
@ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE) clean)
|
||||
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) clean)
|
||||
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE) clean)
|
||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) clean)
|
||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) clean)
|
||||
@ENABLE_JPL@ @INSTALL_DLLS@ (cd packages/jpl; $(MAKE) clean)
|
||||
|
@ -257,6 +257,10 @@
|
||||
#undef HAVE_WAITPID
|
||||
#undef HAVE_MPZ_XOR
|
||||
|
||||
#if HAVE_GETHOSTNAME==1
|
||||
#define HAS_GETHOSTNAME 1
|
||||
#endif
|
||||
|
||||
#undef HAVE_SIGINFO
|
||||
#undef HAVE_SIGSEGV
|
||||
#undef HAVE_SIGPROF
|
||||
|
22
configure
vendored
22
configure
vendored
@ -625,6 +625,7 @@ ENABLE_REAL
|
||||
ENABLE_MINISAT
|
||||
CUDD_CPPFLAGS
|
||||
CUDD_LDFLAGS
|
||||
ENABLE_BDDLIB
|
||||
ENABLE_CUDD
|
||||
EXTRA_INCLUDES_FOR_WIN32
|
||||
ENABLE_WINCONSOLE
|
||||
@ -788,6 +789,7 @@ enable_depth_limit
|
||||
enable_wam_profile
|
||||
enable_low_level_tracer
|
||||
enable_threads
|
||||
enable_bddlib
|
||||
enable_pthread_locking
|
||||
enable_max_performance
|
||||
enable_max_memory
|
||||
@ -1462,6 +1464,7 @@ Optional Features:
|
||||
--enable-wam-profile support low level profiling of abstract machine
|
||||
--enable-low-level-tracer support support for procedure-call tracing
|
||||
--enable-threads support system threads
|
||||
--enable-bddlib dynamic bdd library
|
||||
--enable-pthread-locking use pthread locking primitives for internal locking (requires threads)
|
||||
--enable-max-performance try using the best flags for specific architecture
|
||||
--enable-max-memory try using the best flags for using the memory to the most
|
||||
@ -4486,6 +4489,13 @@ else
|
||||
threads=no
|
||||
fi
|
||||
|
||||
# Check whether --enable-bddlib was given.
|
||||
if test "${enable_bddlib+set}" = set; then :
|
||||
enableval=$enable_bddlib; dynamic_bdd="$enableval"
|
||||
else
|
||||
dynamic_bdd=no
|
||||
fi
|
||||
|
||||
# Check whether --enable-pthread-locking was given.
|
||||
if test "${enable_pthread_locking+set}" = set; then :
|
||||
enableval=$enable_pthread_locking; pthreadlocking="$enableval"
|
||||
@ -5000,7 +5010,14 @@ fi
|
||||
if test "$yap_cv_cudd" = no
|
||||
then
|
||||
ENABLE_CUDD="@# "
|
||||
ENABLE_BDDLIB="@# "
|
||||
else
|
||||
if test "$dynamic_bdd" = yes
|
||||
then
|
||||
ENABLE_BDDLIB=""
|
||||
else
|
||||
ENABLE_BDDLIB="@# "
|
||||
fi
|
||||
ENABLE_CUDD=""
|
||||
fi
|
||||
|
||||
@ -9220,6 +9237,7 @@ fi
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for gcc threaded code" >&5
|
||||
@ -10454,6 +10472,7 @@ mkdir -p LGPL/clp
|
||||
mkdir -p LGPL/swi_console
|
||||
mkdir -p GPL
|
||||
mkdir -p packages/
|
||||
mkdir -p packages/bdd
|
||||
mkdir -p packages/clib
|
||||
mkdir -p packages/clib/sha1
|
||||
mkdir -p packages/clib/maildrop
|
||||
@ -10616,6 +10635,8 @@ ac_config_files="$ac_config_files packages/zlib/Makefile"
|
||||
fi
|
||||
|
||||
if test "$ENABLE_CUDD" = ""; then
|
||||
ac_config_files="$ac_config_files packages/bdd/Makefile"
|
||||
|
||||
ac_config_files="$ac_config_files packages/ProbLog/simplecudd/Makefile"
|
||||
|
||||
ac_config_files="$ac_config_files packages/ProbLog/simplecudd_lfi/Makefile"
|
||||
@ -11398,6 +11419,7 @@ do
|
||||
"packages/semweb/Makefile") CONFIG_FILES="$CONFIG_FILES packages/semweb/Makefile" ;;
|
||||
"packages/sgml/Makefile") CONFIG_FILES="$CONFIG_FILES packages/sgml/Makefile" ;;
|
||||
"packages/zlib/Makefile") CONFIG_FILES="$CONFIG_FILES packages/zlib/Makefile" ;;
|
||||
"packages/bdd/Makefile") CONFIG_FILES="$CONFIG_FILES packages/bdd/Makefile" ;;
|
||||
"packages/ProbLog/simplecudd/Makefile") CONFIG_FILES="$CONFIG_FILES packages/ProbLog/simplecudd/Makefile" ;;
|
||||
"packages/ProbLog/simplecudd_lfi/Makefile") CONFIG_FILES="$CONFIG_FILES packages/ProbLog/simplecudd_lfi/Makefile" ;;
|
||||
"packages/swi-minisat2/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/Makefile" ;;
|
||||
|
13
configure.in
13
configure.in
@ -158,6 +158,9 @@ AC_ARG_ENABLE(low-level-tracer,
|
||||
AC_ARG_ENABLE(threads,
|
||||
[ --enable-threads support system threads ],
|
||||
threads="$enableval", threads=no)
|
||||
AC_ARG_ENABLE(bddlib,
|
||||
[ --enable-bddlib dynamic bdd library ],
|
||||
dynamic_bdd="$enableval", dynamic_bdd=no)
|
||||
AC_ARG_ENABLE(pthread-locking,
|
||||
[ --enable-pthread-locking use pthread locking primitives for internal locking (requires threads) ],
|
||||
pthreadlocking="$enableval", pthreadlocking=no)
|
||||
@ -510,7 +513,14 @@ fi
|
||||
if test "$yap_cv_cudd" = no
|
||||
then
|
||||
ENABLE_CUDD="@# "
|
||||
ENABLE_BDDLIB="@# "
|
||||
else
|
||||
if test "$dynamic_bdd" = yes
|
||||
then
|
||||
ENABLE_BDDLIB=""
|
||||
else
|
||||
ENABLE_BDDLIB="@# "
|
||||
fi
|
||||
ENABLE_CUDD=""
|
||||
fi
|
||||
|
||||
@ -1789,6 +1799,7 @@ AC_SUBST(ENABLE_WINCONSOLE)
|
||||
AC_SUBST(EXTRA_INCLUDES_FOR_WIN32)
|
||||
|
||||
AC_SUBST(ENABLE_CUDD)
|
||||
AC_SUBST(ENABLE_BDDLIB)
|
||||
AC_SUBST(CUDD_LDFLAGS)
|
||||
AC_SUBST(CUDD_CPPFLAGS)
|
||||
AC_SUBST(ENABLE_MINISAT)
|
||||
@ -2269,6 +2280,7 @@ mkdir -p LGPL/clp
|
||||
mkdir -p LGPL/swi_console
|
||||
mkdir -p GPL
|
||||
mkdir -p packages/
|
||||
mkdir -p packages/bdd
|
||||
mkdir -p packages/clib
|
||||
mkdir -p packages/clib/sha1
|
||||
mkdir -p packages/clib/maildrop
|
||||
@ -2392,6 +2404,7 @@ AC_CONFIG_FILES([packages/zlib/Makefile])
|
||||
fi
|
||||
|
||||
if test "$ENABLE_CUDD" = ""; then
|
||||
AC_CONFIG_FILES([packages/bdd/Makefile])
|
||||
AC_CONFIG_FILES([packages/ProbLog/simplecudd/Makefile])
|
||||
AC_CONFIG_FILES([packages/ProbLog/simplecudd_lfi/Makefile])
|
||||
fi
|
||||
|
@ -75,6 +75,9 @@
|
||||
#if HAVE_STRING_H
|
||||
#include <string.h>
|
||||
#endif
|
||||
#if HAVE_IEEEFP_H
|
||||
#include <ieeefp.h>
|
||||
#endif
|
||||
|
||||
static void PROTO(do_top_goal,(YAP_Term));
|
||||
static void PROTO(exec_top_level,(int, YAP_init_args *));
|
||||
|
22
docs/yap.tex
22
docs/yap.tex
@ -12635,6 +12635,13 @@ The path @var{Path} is a path starting at vertex @var{Vertex} in graph
|
||||
The path @var{Path} is a path starting at vertex @var{Vertex} in graph
|
||||
@var{Graph}.
|
||||
|
||||
@item dgraph_leaves(+@var{Graph}, ?@var{Vertices})
|
||||
@findex dgraph_leaves/2
|
||||
@snindex dgraph_leaves/2
|
||||
@cnindex dgraph_leaves/2
|
||||
The vertices @var{Vertices} have no outgoing edge in graph
|
||||
@var{Graph}.
|
||||
|
||||
@end table
|
||||
|
||||
@node UnDGraphs, Lambda , DGraphs, Library
|
||||
@ -16464,6 +16471,21 @@ then allow one to construct functors, and to obtain their name and arity.
|
||||
Note that the functor is essentially a pair formed by an atom, and
|
||||
arity.
|
||||
|
||||
Constructing terms in the stack may lead to overflow. The routine
|
||||
@example
|
||||
int YAP_RequiresExtraStack(size_t @var{min})
|
||||
@end example
|
||||
verifies whether you have at least @var{min} cells free in the stack,
|
||||
and it returns true if it has to ensure enough memory by calling the
|
||||
garbage collector and or stack shifter. The routine returns false if no
|
||||
memory is needed, and a negative number if it cannot provide enough
|
||||
memory.
|
||||
|
||||
You can set @var{min} to zero if you do not know how much room you need
|
||||
but you do know you do not need a big chunk at a single go. Usually, the routine
|
||||
would usually be called together with a long-jump to restart the
|
||||
code. Slots can also be used if there is small state.
|
||||
|
||||
@node Unifying Terms, Manipulating Strings, Manipulating Terms, C-Interface
|
||||
@section Unification
|
||||
|
||||
|
@ -599,6 +599,8 @@ extern X_API size_t PROTO(YAP_SizeOfExportedTerm,(char *));
|
||||
|
||||
extern X_API YAP_Term PROTO(YAP_ImportTerm,(char *));
|
||||
|
||||
extern X_API int PROTO(YAP_RequiresExtraStack,(size_t));
|
||||
|
||||
#define YAP_InitCPred(N,A,F) YAP_UserCPredicate(N,F,A)
|
||||
|
||||
__END_DECLS
|
||||
|
@ -115,6 +115,7 @@ typedef enum {
|
||||
#define YAP_BOOT_FROM_SAVED_CODE 1
|
||||
#define YAP_BOOT_FROM_SAVED_STACKS 2
|
||||
#define YAP_FULL_BOOT_FROM_PROLOG 4
|
||||
#define YAP_BOOT_DONE_BEFOREHAND 8
|
||||
#define YAP_BOOT_ERROR -1
|
||||
|
||||
#define YAP_WRITE_QUOTED 1
|
||||
|
@ -32,6 +32,7 @@
|
||||
dgraph_min_paths/3,
|
||||
dgraph_isomorphic/4,
|
||||
dgraph_path/3,
|
||||
dgraph_leaves/2,
|
||||
dgraph_reachable/3
|
||||
]).
|
||||
|
||||
@ -414,3 +415,13 @@ reachable([V|Vertices], Done0, DoneF, G, [V|EdgesF], Edges0) :-
|
||||
rb_insert(Done0, V, [], Done1),
|
||||
reachable(Kids, Done1, DoneI, G, EdgesF, EdgesI),
|
||||
reachable(Vertices, DoneI, DoneF, G, EdgesI, Edges0).
|
||||
|
||||
dgraph_leaves(Graph, Vertices) :-
|
||||
rb_visit(Graph, Pairs),
|
||||
vertices_without_children(Pairs, Vertices).
|
||||
|
||||
vertices_without_children([], []).
|
||||
vertices_without_children((V-[]).Pairs, V.Vertices) :-
|
||||
vertices_without_children(Pairs, Vertices).
|
||||
vertices_without_children(_V-[_|_].Pairs, Vertices) :-
|
||||
vertices_without_children(Pairs, Vertices).
|
||||
|
@ -25,7 +25,10 @@
|
||||
|
||||
:- ensure_loaded(library(lists)).
|
||||
|
||||
:- load_foreign_files([matlab], ['eng','mx','ut'], init_matlab).
|
||||
tell_warning :-
|
||||
print_message(warning,functionality(matlab)).
|
||||
|
||||
:- ( catch(load_foreign_files([matlab], ['eng','mx','ut'], init_matlab),_,fail) -> true ; tell_warning).
|
||||
|
||||
matlab_eval_sequence(S) :-
|
||||
atomic_concat(S,S1),
|
||||
|
@ -4671,6 +4671,11 @@ EndPredDefs
|
||||
|
||||
#if __YAP_PROLOG__
|
||||
|
||||
void Yap_flush(void)
|
||||
{
|
||||
flush_output(0);
|
||||
}
|
||||
|
||||
void *
|
||||
Yap_GetStreamHandle(Atom at)
|
||||
{ GET_LD
|
||||
|
@ -16,6 +16,11 @@ YAPLIBDIR=@libdir@/Yap
|
||||
#
|
||||
SHAREDIR=$(ROOTDIR)/share/Yap
|
||||
#
|
||||
# where YAP should store documentation
|
||||
#
|
||||
DOCDIR=$(ROOTDIR)/share/doc/Yap
|
||||
EXDIR=$(DOCDIR)/examples/CLPBN
|
||||
#
|
||||
#
|
||||
# You shouldn't need to change what follows.
|
||||
#
|
||||
@ -35,6 +40,7 @@ CLPBN_EXDIR = $(srcdir)/examples
|
||||
|
||||
CLPBN_PROGRAMS= \
|
||||
$(CLPBN_SRCDIR)/aggregates.yap \
|
||||
$(CLPBN_SRCDIR)/bdd.yap \
|
||||
$(CLPBN_SRCDIR)/bnt.yap \
|
||||
$(CLPBN_SRCDIR)/bp.yap \
|
||||
$(CLPBN_SRCDIR)/connected.yap \
|
||||
@ -48,6 +54,7 @@ CLPBN_PROGRAMS= \
|
||||
$(CLPBN_SRCDIR)/graphviz.yap \
|
||||
$(CLPBN_SRCDIR)/ground_factors.yap \
|
||||
$(CLPBN_SRCDIR)/hmm.yap \
|
||||
$(CLPBN_SRCDIR)/horus.yap \
|
||||
$(CLPBN_SRCDIR)/jt.yap \
|
||||
$(CLPBN_SRCDIR)/matrix_cpt_utils.yap \
|
||||
$(CLPBN_SRCDIR)/pgrammar.yap \
|
||||
@ -72,6 +79,8 @@ CLPBN_SCHOOL_EXAMPLES= \
|
||||
$(CLPBN_EXDIR)/School/parschema.yap \
|
||||
$(CLPBN_EXDIR)/School/school_128.yap \
|
||||
$(CLPBN_EXDIR)/School/school_32.yap \
|
||||
$(CLPBN_EXDIR)/School/sch32.yap \
|
||||
$(CLPBN_EXDIR)/School/school32_data.yap \
|
||||
$(CLPBN_EXDIR)/School/school_64.yap \
|
||||
$(CLPBN_EXDIR)/School/tables.yap
|
||||
|
||||
@ -92,12 +101,13 @@ CLPBN_EXAMPLES= \
|
||||
install: $(CLBN_TOP) $(CLBN_PROGRAMS) $(CLPBN_PROGRAMS)
|
||||
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn
|
||||
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/learning
|
||||
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/examples/School
|
||||
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/examples/HMMer
|
||||
mkdir -p $(DESTDIR)$(EXDIR)
|
||||
mkdir -p $(DESTDIR)$(EXDIR)/School
|
||||
mkdir -p $(DESTDIR)$(EXDIR)/HMMer
|
||||
for h in $(CLPBN_TOP); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR); done
|
||||
for h in $(CLPBN_PROGRAMS); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn; done
|
||||
for h in $(CLPBN_LEARNING_PROGRAMS); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/learning; done
|
||||
for h in $(CLPBN_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/examples; done
|
||||
for h in $(CLPBN_SCHOOL_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/examples/School; done
|
||||
for h in $(CLPBN_HMMER_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/examples/HMMer; done
|
||||
for h in $(CLPBN_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(EXDIR); done
|
||||
for h in $(CLPBN_SCHOOL_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(EXDIR)/School; done
|
||||
for h in $(CLPBN_HMMER_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(EXDIR)/HMMer; done
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
clpbn_init_graph/1,
|
||||
probability/2,
|
||||
conditional_probability/3,
|
||||
use_parfactors/1,
|
||||
op( 500, xfy, with)]).
|
||||
|
||||
:- use_module(library(atts)).
|
||||
@ -43,6 +44,7 @@
|
||||
check_if_bp_done/1,
|
||||
init_bp_solver/4,
|
||||
run_bp_solver/3,
|
||||
call_bp_ground/6,
|
||||
finalize_bp_solver/1
|
||||
]).
|
||||
|
||||
@ -61,11 +63,17 @@
|
||||
run_jt_solver/3
|
||||
]).
|
||||
|
||||
:- use_module('clpbn/bnt',
|
||||
[do_bnt/3,
|
||||
check_if_bnt_done/1
|
||||
:- use_module('clpbn/bdd',
|
||||
[bdd/3,
|
||||
init_bdd_solver/4,
|
||||
run_bdd_solver/3
|
||||
]).
|
||||
|
||||
%% :- use_module('clpbn/bnt',
|
||||
%% [do_bnt/3,
|
||||
%% check_if_bnt_done/1
|
||||
%% ]).
|
||||
|
||||
:- use_module('clpbn/gibbs',
|
||||
[gibbs/3,
|
||||
check_if_gibbs_done/1,
|
||||
@ -111,7 +119,7 @@
|
||||
[clpbn2gviz/4]).
|
||||
|
||||
:- use_module(clpbn/ground_factors,
|
||||
[generate_bn/2]).
|
||||
[generate_network/5]).
|
||||
|
||||
|
||||
:- dynamic solver/1,output/1,use/1,suppress_attribute_display/1, parameter_softening/1, em_solver/1, use_parfactors/1.
|
||||
@ -223,9 +231,17 @@ clpbn_marginalise(V, Dist) :-
|
||||
% called by top-level
|
||||
% or by call_residue/2
|
||||
%
|
||||
project_attributes(GVars, AVars0) :-
|
||||
project_attributes(GVars, _AVars0) :-
|
||||
use_parfactors(on),
|
||||
clpbn_flag(solver, Solver), Solver \= fove, !,
|
||||
generate_network(GVars, GKeys, Keys, Factors, Evidence),
|
||||
(ground(GVars) ->
|
||||
true
|
||||
;
|
||||
call_ground_solver(Solver, GVars, GKeys, Keys, Factors, Evidence, _Avars0)
|
||||
).
|
||||
project_attributes(GVars, AVars) :-
|
||||
suppress_attribute_display(false),
|
||||
generate_vars(GVars, AVars0, AVars),
|
||||
AVars = [_|_],
|
||||
solver(Solver),
|
||||
( GVars = [_|_] ; Solver = graphs), !,
|
||||
@ -243,11 +259,6 @@ project_attributes(GVars, AVars0) :-
|
||||
).
|
||||
project_attributes(_, _).
|
||||
|
||||
generate_vars(GVars, _, NewAVars) :-
|
||||
use_parfactors(on), !,
|
||||
generate_bn(GVars, NewAVars).
|
||||
generate_vars(_GVars, AVars, AVars).
|
||||
|
||||
clpbn_vars(AVars, DiffVars, AllVars) :-
|
||||
sort_vars_by_key(AVars,SortedAVars,DiffVars),
|
||||
incorporate_evidence(SortedAVars, AllVars).
|
||||
@ -289,6 +300,8 @@ write_out(ve, GVars, AVars, DiffVars) :-
|
||||
ve(GVars, AVars, DiffVars).
|
||||
write_out(jt, GVars, AVars, DiffVars) :-
|
||||
jt(GVars, AVars, DiffVars).
|
||||
write_out(bdd, GVars, AVars, DiffVars) :-
|
||||
bdd(GVars, AVars, DiffVars).
|
||||
write_out(bp, GVars, AVars, DiffVars) :-
|
||||
bp(GVars, AVars, DiffVars).
|
||||
write_out(gibbs, GVars, AVars, DiffVars) :-
|
||||
@ -298,6 +311,11 @@ write_out(bnt, GVars, AVars, DiffVars) :-
|
||||
write_out(fove, GVars, AVars, DiffVars) :-
|
||||
fove(GVars, AVars, DiffVars).
|
||||
|
||||
% call a solver with keys, not actual variables
|
||||
call_ground_solver(bp, GVars, GoalKeys, Keys, Factors, Evidence, Answ) :-
|
||||
call_bp_ground(GVars, GoalKeys, Keys, Factors, Evidence, Answ).
|
||||
|
||||
|
||||
get_bnode(Var, Goal) :-
|
||||
get_atts(Var, [key(Key),dist(Dist,Parents)]),
|
||||
get_dist(Dist,_,Domain,CPT),
|
||||
@ -382,6 +400,9 @@ bind_clpbn(_, Var, _, _, _, _, []) :-
|
||||
bind_clpbn(_, Var, _, _, _, _, []) :-
|
||||
use(jt),
|
||||
check_if_ve_done(Var), !.
|
||||
bind_clpbn(_, Var, _, _, _, _, []) :-
|
||||
use(bdd),
|
||||
check_if_bdd_done(Var), !.
|
||||
bind_clpbn(T, Var, Key0, _, _, _, []) :-
|
||||
get_atts(Var, [key(Key)]), !,
|
||||
(
|
||||
@ -397,11 +418,12 @@ fresh_attvar(Var, NVar) :-
|
||||
|
||||
% I will now allow two CLPBN variables to be bound together.
|
||||
%bind_clpbns(Key, Dist, Parents, Key, Dist, Parents).
|
||||
bind_clpbns(Key, Dist, _Parents, Key1, Dist1, _Parents1) :-
|
||||
bind_clpbns(Key, Dist, Parents, Key1, Dist1, Parents1) :-
|
||||
Key == Key1, !,
|
||||
get_dist(Dist,_Type,_Domain,_Table),
|
||||
get_dist(Dist1,_Type1,_Domain1,_Table1),
|
||||
Dist = Dist1.
|
||||
Dist = Dist1,
|
||||
Parents = Parents1.
|
||||
bind_clpbns(Key, _, _, _, Key1, _, _, _) :-
|
||||
Key\=Key1, !, fail.
|
||||
bind_clpbns(_, _, _, _, _, _, _, _) :-
|
||||
@ -452,6 +474,8 @@ clpbn_init_solver(bp, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
||||
init_bp_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
||||
clpbn_init_solver(jt, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
||||
init_jt_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
||||
clpbn_init_solver(bdd, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
||||
init_bdd_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
||||
clpbn_init_solver(pcg, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
||||
init_pcg_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
||||
|
||||
@ -478,6 +502,9 @@ clpbn_run_solver(bp, LVs, LPs, State) :-
|
||||
clpbn_run_solver(jt, LVs, LPs, State) :-
|
||||
run_jt_solver(LVs, LPs, State).
|
||||
|
||||
clpbn_run_solver(bdd, LVs, LPs, State) :-
|
||||
run_bdd_solver(LVs, LPs, State).
|
||||
|
||||
clpbn_run_solver(pcg, LVs, LPs, State) :-
|
||||
run_pcg_solver(LVs, LPs, State).
|
||||
|
||||
@ -538,4 +565,5 @@ match_probability([p(V0=C)=Prob|_], C, V, Prob) :-
|
||||
match_probability([_|Probs], C, V, Prob) :-
|
||||
match_probability(Probs, C, V, Prob).
|
||||
|
||||
:- use_parfactors(on) -> true ; assert(use_parfactors(off)).
|
||||
|
||||
|
@ -27,7 +27,7 @@
|
||||
|
||||
:- use_module(library('clpbn/dists'),
|
||||
[
|
||||
dist/4,
|
||||
add_dist/6,
|
||||
get_dist_domain_size/2]).
|
||||
|
||||
:- use_module(library('clpbn/matrix_cpt_utils'),
|
||||
@ -44,8 +44,9 @@ check_for_agg_vars([_|Vs0], Vs1) :-
|
||||
% transform aggregate distribution into tree
|
||||
simplify_dist(avg(Domain), V, Key, Parents, Vs0, VsF) :- !,
|
||||
cpt_average([V|Parents], Key, Domain, NewDist, Vs0, VsF),
|
||||
dist(NewDist, Id, Key, ParentsF),
|
||||
clpbn:put_atts(V, [dist(Id,ParentsF)]).
|
||||
NewDist = p(Dom, Tab, Ps),
|
||||
add_dist(Dom, tab, Tab, Ps, Key, Id),
|
||||
clpbn:put_atts(V, [dist(Id,Ps)]).
|
||||
simplify_dist(_, _, _, _, Vs0, Vs0).
|
||||
|
||||
cpt_average(AllVars, Key, Els0, Tab, Vs, NewVs) :-
|
||||
|
802
packages/CLPBN/clpbn/bdd.yap
Normal file
802
packages/CLPBN/clpbn/bdd.yap
Normal file
@ -0,0 +1,802 @@
|
||||
|
||||
/************************************************
|
||||
|
||||
BDDs in CLP(BN)
|
||||
|
||||
A variable is represented by the N possible cases it can take
|
||||
|
||||
V = v(Va, Vb, Vc)
|
||||
|
||||
The generic formula is
|
||||
|
||||
V <- X, Y
|
||||
|
||||
Va <- P*X1*Y1 + Q*X2*Y2 + ...
|
||||
|
||||
|
||||
|
||||
**************************************************/
|
||||
|
||||
:- module(clpbn_bdd,
|
||||
[bdd/3,
|
||||
set_solver_parameter/2,
|
||||
init_bdd_solver/4,
|
||||
run_bdd_solver/3,
|
||||
finalize_bdd_solver/1,
|
||||
check_if_bdd_done/1
|
||||
]).
|
||||
|
||||
|
||||
:- use_module(library('clpbn/dists'),
|
||||
[dist/4,
|
||||
get_dist_domain/2,
|
||||
get_dist_domain_size/2,
|
||||
get_dist_all_sizes/2,
|
||||
get_dist_params/2
|
||||
]).
|
||||
|
||||
|
||||
:- use_module(library('clpbn/display'),
|
||||
[clpbn_bind_vals/3]).
|
||||
|
||||
:- use_module(library('clpbn/aggregates'),
|
||||
[check_for_agg_vars/2]).
|
||||
|
||||
|
||||
:- use_module(library(atts)).
|
||||
|
||||
:- use_module(library(hacks)).
|
||||
|
||||
:- use_module(library(lists)).
|
||||
|
||||
:- use_module(library(dgraphs)).
|
||||
|
||||
:- use_module(library(bdd)).
|
||||
|
||||
:- use_module(library(rbtrees)).
|
||||
|
||||
:- use_module(library(bhash)).
|
||||
|
||||
:- use_module(library(matrix)).
|
||||
|
||||
:- dynamic network_counting/1.
|
||||
|
||||
:- attribute order/1.
|
||||
|
||||
check_if_bdd_done(_Var).
|
||||
|
||||
bdd([[]],_,_) :- !.
|
||||
bdd([QueryVars], AllVars, AllDiffs) :-
|
||||
init_bdd_solver(_, AllVars, _, BayesNet),
|
||||
run_bdd_solver([QueryVars], LPs, BayesNet),
|
||||
finalize_bdd_solver(BayesNet),
|
||||
clpbn_bind_vals([QueryVars], [LPs], AllDiffs).
|
||||
|
||||
init_bdd_solver(_, AllVars0, _, bdd(Term, Leaves, Tops)) :-
|
||||
% check_for_agg_vars(AllVars0, AllVars1),
|
||||
sort_vars(AllVars0, AllVars, Leaves),
|
||||
order_vars(AllVars, 0),
|
||||
rb_new(Vars0),
|
||||
rb_new(Pars0),
|
||||
init_tops(Leaves,Tops),
|
||||
get_vars_info(AllVars, Vars0, _Vars, Pars0, _Pars, Leaves, Tops, Term, []).
|
||||
|
||||
order_vars([], _).
|
||||
order_vars([V|AllVars], I0) :-
|
||||
put_atts(V, [order(I0)]),
|
||||
I is I0+1,
|
||||
order_vars(AllVars, I).
|
||||
|
||||
|
||||
init_tops([],[]).
|
||||
init_tops(_.Leaves,_.Tops) :-
|
||||
init_tops(Leaves,Tops).
|
||||
|
||||
sort_vars(AllVars0, AllVars, Leaves) :-
|
||||
dgraph_new(Graph0),
|
||||
build_graph(AllVars0, Graph0, Graph),
|
||||
dgraph_leaves(Graph, Leaves),
|
||||
dgraph_top_sort(Graph, AllVars).
|
||||
|
||||
build_graph([], Graph, Graph).
|
||||
build_graph(V.AllVars0, Graph0, Graph) :-
|
||||
clpbn:get_atts(V, [dist(_DistId, Parents)]), !,
|
||||
dgraph_add_vertex(Graph0, V, Graph1),
|
||||
add_parents(Parents, V, Graph1, GraphI),
|
||||
build_graph(AllVars0, GraphI, Graph).
|
||||
build_graph(_V.AllVars0, Graph0, Graph) :-
|
||||
build_graph(AllVars0, Graph0, Graph).
|
||||
|
||||
add_parents([], _V, Graph, Graph).
|
||||
add_parents(V0.Parents, V, Graph0, GraphF) :-
|
||||
dgraph_add_edge(Graph0, V0, V, GraphI),
|
||||
add_parents(Parents, V, GraphI, GraphF).
|
||||
|
||||
get_vars_info([], Vs, Vs, Ps, Ps, _, _) --> [].
|
||||
get_vars_info([V|MoreVs], Vs, VsF, Ps, PsF, Lvs, Outs) -->
|
||||
{ clpbn:get_atts(V, [dist(DistId, Parents)]) }, !,
|
||||
%{writeln(v:DistId:Parents)},
|
||||
[DIST],
|
||||
{ get_var_info(V, DistId, Parents, Vs, Vs2, Ps, Ps1, Lvs, Outs, DIST) },
|
||||
get_vars_info(MoreVs, Vs2, VsF, Ps1, PsF, Lvs, Outs).
|
||||
get_vars_info([_|MoreVs], Vs0, VsF, Ps0, PsF, VarsInfo, Lvs, Outs) :-
|
||||
get_vars_info(MoreVs, Vs0, VsF, Ps0, PsF, VarsInfo, Lvs, Outs).
|
||||
|
||||
%
|
||||
% let's have some fun with avg
|
||||
%
|
||||
get_var_info(V, avg(Domain), Parents, Vs, Vs2, Ps, Ps, Lvs, Outs, DIST) :- !,
|
||||
length(Domain, DSize),
|
||||
% run_though_avg(V, DSize, Domain, Parents, Vs, Vs2, Lvs, Outs, DIST).
|
||||
top_down_with_tabling(V, DSize, Domain, Parents, Vs, Vs2, Lvs, Outs, DIST).
|
||||
% bup_avg(V, DSize, Domain, Parents, Vs, Vs2, Lvs, Outs, DIST).
|
||||
% standard random variable
|
||||
get_var_info(V, DistId, Parents0, Vs, Vs2, Ps, Ps1, Lvs, Outs, DIST) :-
|
||||
% clpbn:get_atts(V, [key(K)]), writeln(V:K:DistId:Parents),
|
||||
reorder_vars(Parents0, Parents, Map),
|
||||
check_p(DistId, Map, Parms, _ParmVars, Ps, Ps1),
|
||||
unbound_parms(Parms, ParmVars),
|
||||
check_v(V, DistId, DIST, Vs, Vs1),
|
||||
DIST = info(V, Tree, Ev, Values, Formula, ParmVars, Parms),
|
||||
% get a list of form [[P00,P01], [P10,P11], [P20,P21]]
|
||||
get_parents(Parents, PVars, Vs1, Vs2),
|
||||
cross_product(Values, Ev, PVars, ParmVars, Formula0),
|
||||
% (numbervars(Formula0,0,_),writeln(formula0:Ev:Formula0), fail ; true),
|
||||
get_evidence(V, Tree, Ev, Formula0, Formula, Lvs, Outs).
|
||||
%, (numbervars(Formula,0,_),writeln(formula:Formula), fail ; true)
|
||||
|
||||
%
|
||||
% reorder all variables and make sure we get a
|
||||
% map of how the transfer was done.
|
||||
%
|
||||
% position zero is output
|
||||
%
|
||||
reorder_vars(Vs, OVs, Map) :-
|
||||
add_pos(Vs, 1, PVs),
|
||||
keysort(PVs, SVs),
|
||||
remove_key(SVs, OVs, Map).
|
||||
|
||||
add_pos([], _, []).
|
||||
add_pos([V|Vs], I0, [K-(I0,V)|PVs]) :-
|
||||
get_atts(V,[order(K)]),
|
||||
I is I0+1,
|
||||
add_pos(Vs, I, PVs).
|
||||
|
||||
remove_key([], [], []).
|
||||
remove_key([_-(I,V)|SVs], [V|OVs], [I|Map]) :-
|
||||
remove_key(SVs, OVs, Map).
|
||||
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%
|
||||
% use top-down to generate average
|
||||
%
|
||||
run_though_avg(V, 3, Domain, Parents0, Vs, Vs2, Lvs, Outs, DIST) :-
|
||||
reorder_vars(Parents0, Parents, _Map),
|
||||
check_v(V, avg(Domain,Parents0), DIST, Vs, Vs1),
|
||||
DIST = info(V, Tree, Ev, [V0,V1,V2], Formula, [], []),
|
||||
get_parents(Parents, PVars, Vs1, Vs2),
|
||||
length(Parents, N),
|
||||
generate_3tree(F00, PVars, 0, 0, 0, N, N0, N1, N2, R, (N1+2*N2 =< N/2), (N1+2*(N2+R) =< N/2)),
|
||||
simplify_exp(F00, F0),
|
||||
% generate_3tree(F1, PVars, 0, 0, 0, N, N0, N1, N2, R, ((N1+2*(N2+R) > N/2, N1+2*N2 < (3*N)/2))),
|
||||
generate_3tree(F20, PVars, 0, 0, 0, N, N0, N1, N2, R, (N1+2*(N2+R) >= (3*N)/2), N1+2*N2 >= (3*N)/2),
|
||||
% simplify_exp(F20, F2),
|
||||
F20=F2,
|
||||
Formula0 = [V0=F0*Ev0,V2=F2*Ev2,V1=not(F0+F2)*Ev1],
|
||||
Ev = [Ev0,Ev1,Ev2],
|
||||
get_evidence(V, Tree, Ev, Formula0, Formula, Lvs, Outs).
|
||||
|
||||
generate_3tree(OUT, _, I00, I10, I20, IR0, N0, N1, N2, R, _Exp, ExpF) :-
|
||||
IR is IR0-1,
|
||||
satisf(I00, I10, I20, IR, N0, N1, N2, R, ExpF),
|
||||
!,
|
||||
OUT = 1.
|
||||
generate_3tree(OUT, [[P0,P1,P2]], I00, I10, I20, IR0, N0, N1, N2, R, Exp, _ExpF) :-
|
||||
IR is IR0-1,
|
||||
( satisf(I00+1, I10, I20, IR, N0, N1, N2, R, Exp) ->
|
||||
L0 = [P0|L1]
|
||||
;
|
||||
L0 = L1
|
||||
),
|
||||
( satisf(I00, I10+1, I20, IR, N0, N1, N2, R, Exp) ->
|
||||
L1 = [P1|L2]
|
||||
;
|
||||
L1 = L2
|
||||
),
|
||||
( satisf(I00, I10, I20+1, IR, N0, N1, N2, R, Exp) ->
|
||||
L2 = [P2]
|
||||
;
|
||||
L2 = []
|
||||
),
|
||||
to_disj(L0, OUT).
|
||||
generate_3tree(OUT, [[P0,P1,P2]|Ps], I00, I10, I20, IR0, N0, N1, N2, R, Exp, ExpF) :-
|
||||
IR is IR0-1,
|
||||
( satisf(I00+1, I10, I20, IR, N0, N1, N2, R, Exp) ->
|
||||
I0 is I00+1, generate_3tree(O0, Ps, I0, I10, I20, IR, N0, N1, N2, R, Exp, ExpF)
|
||||
->
|
||||
L0 = [P0*O0|L1]
|
||||
;
|
||||
L0 = L1
|
||||
),
|
||||
( satisf(I00, I10+1, I20, IR0, N0, N1, N2, R, Exp) ->
|
||||
I1 is I10+1, generate_3tree(O1, Ps, I00, I1, I20, IR, N0, N1, N2, R, Exp, ExpF)
|
||||
->
|
||||
L1 = [P1*O1|L2]
|
||||
;
|
||||
L1 = L2
|
||||
),
|
||||
( satisf(I00, I10, I20+1, IR0, N0, N1, N2, R, Exp) ->
|
||||
I2 is I20+1, generate_3tree(O2, Ps, I00, I10, I2, IR, N0, N1, N2, R, Exp, ExpF)
|
||||
->
|
||||
L2 = [P2*O2]
|
||||
;
|
||||
L2 = []
|
||||
),
|
||||
to_disj(L0, OUT).
|
||||
|
||||
|
||||
satisf(I0, I1, I2, IR, N0, N1, N2, R, Exp) :-
|
||||
\+ \+ ( I0 = N0, I1=N1, I2=N2, IR=R, call(Exp) ).
|
||||
|
||||
not_satisf(I0, I1, I2, IR, N0, N1, N2, R, Exp) :-
|
||||
\+ ( I0 = N0, I1=N1, I2=N2, IR=R, call(Exp) ).
|
||||
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%
|
||||
% use top-down to generate average
|
||||
%
|
||||
top_down_with_tabling(V, Size, Domain, Parents0, Vs, Vs2, Lvs, Outs, DIST) :-
|
||||
reorder_vars(Parents0, Parents, _Map),
|
||||
check_v(V, avg(Domain,Parents), DIST, Vs, Vs1),
|
||||
DIST = info(V, Tree, Ev, OVs, Formula, [], []),
|
||||
get_parents(Parents, PVars, Vs1, Vs2),
|
||||
length(Parents, N),
|
||||
Max is (Size-1)*N, % This should be true
|
||||
avg_borders(0, Size, Max, Borders),
|
||||
b_hash_new(H0),
|
||||
avg_trees(0, Max, PVars, Size, F1, 0, Borders, OVs, Ev, H0, H),
|
||||
generate_avg_code(H, Formula, F),
|
||||
% Formula0 = [V0=F0*Ev0,V2=F2*Ev2,V1=not(F0+F2)*Ev1],
|
||||
% Ev = [Ev0,Ev1,Ev2],
|
||||
get_evidence(V, Tree, Ev, F1, F, Lvs, Outs).
|
||||
|
||||
avg_trees(Size, _, _, Size, F0, _, F0, [], [], H, H) :- !.
|
||||
avg_trees(I0, Max, PVars, Size, [V=O*E|F0], Im, [IM|Borders], [V|OVs], [E|Ev], H0, H) :-
|
||||
I is I0+1,
|
||||
avg_tree(PVars, 0, Max, Im, IM, Size, O, H0, HI),
|
||||
Im1 is IM+1,
|
||||
avg_trees(I, Max, PVars, Size, F0, Im1, Borders, OVs, Ev, HI, H).
|
||||
|
||||
avg_tree( _PVars, P, _, Im, IM, _Size, O, H0, H0) :-
|
||||
b_hash_lookup(k(P,Im,IM), O=_Exp, H0), !.
|
||||
avg_tree([], _P, _Max, _Im, _IM, _Size, 1, H, H).
|
||||
avg_tree([Vals|PVars], P, Max, Im, IM, Size, O, H0, HF) :-
|
||||
b_hash_insert(H0, k(P,Im,IM), O=Simp, HI),
|
||||
MaxI is Max-(Size-1),
|
||||
avg_exp(Vals, PVars, 0, P, MaxI, Size, Im, IM, HI, HF, Exp),
|
||||
simplify_exp(Exp, Simp).
|
||||
|
||||
avg_exp([], _, _, _P, _Max, _Size, _Im, _IM, H, H, 0).
|
||||
avg_exp([Val|Vals], PVars, I0, P0, Max, Size, Im, IM, HI, HF, O) :-
|
||||
(Vals = [] -> O=O1 ; O = Val*O1+not(Val)*O2 ),
|
||||
Im1 is max(0, Im-I0),
|
||||
IM1 is IM-I0,
|
||||
( IM1 < 0 -> O1 = 0, H2 = HI; /* we have exceed maximum */
|
||||
Im1 > Max -> O1 = 0, H2 = HI; /* we cannot make to minimum */
|
||||
Im1 = 0, IM1 > Max -> O1 = 1, H2 = HI; /* we cannot exceed maximum */
|
||||
P is P0+1,
|
||||
avg_tree(PVars, P, Max, Im1, IM1, Size, O1, HI, H2)
|
||||
),
|
||||
I is I0+1,
|
||||
avg_exp(Vals, PVars, I, P0, Max, Size, Im, IM, H2, HF, O2).
|
||||
|
||||
generate_avg_code(H, Formula, Formula0) :-
|
||||
b_hash_to_list(H,L),
|
||||
sort(L, S),
|
||||
strip_and_add(S, Formula0, Formula).
|
||||
|
||||
strip_and_add([], F, F).
|
||||
strip_and_add([_-Exp|S], F0, F) :-
|
||||
strip_and_add(S, [Exp|F0], F).
|
||||
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%
|
||||
% use bottom-up dynamic programming to generate average
|
||||
%
|
||||
bup_avg(V, Size, Domain, Parents0, Vs, Vs2, Lvs, Outs, DIST) :-
|
||||
reorder_vars(Parents0, Parents, _),
|
||||
check_v(V, avg(Domain,Parents), DIST, Vs, Vs1),
|
||||
DIST = info(V, Tree, Ev, OVs, Formula, [], []),
|
||||
get_parents(Parents, PVars, Vs1, Vs2),
|
||||
length(Parents, N),
|
||||
Max is (Size-1)*N, % This should be true
|
||||
ArraySize is Max+1,
|
||||
functor(Protected, protected, ArraySize),
|
||||
avg_domains(0, Size, 0, Max, LDomains),
|
||||
Domains =.. [d|LDomains],
|
||||
Reach is (Size-1),
|
||||
generate_sums(PVars, Size, Max, Reach, Protected, Domains, ArraySize, Sums, F0),
|
||||
% bin_sums(PVars, Sums, F00),
|
||||
% reverse(F00,F0),
|
||||
% easier to do recursion on lists
|
||||
Sums =.. [_|LSums],
|
||||
generate_avg(0, Size, 0, Max, LSums, OVs, Ev, F1, []),
|
||||
reverse(F0, RF0),
|
||||
get_evidence(V, Tree, Ev, F1, F2, Lvs, Outs),
|
||||
append(RF0, F2, Formula).
|
||||
|
||||
%
|
||||
% use binary approach, like what is standard
|
||||
%
|
||||
bin_sums(Vs, Sums, F) :-
|
||||
vs_to_sums(Vs, Sums0),
|
||||
bin_sums(Sums0, Sums, F, []).
|
||||
|
||||
vs_to_sums([], []).
|
||||
vs_to_sums([V|Vs], [Sum|Sums0]) :-
|
||||
Sum =.. [sum|V],
|
||||
vs_to_sums(Vs, Sums0).
|
||||
|
||||
bin_sums([Sum], Sum) --> !.
|
||||
bin_sums(LSums, Sum) -->
|
||||
{ halve(LSums, Sums1, Sums2) },
|
||||
bin_sums(Sums1, Sum1),
|
||||
bin_sums(Sums2, Sum2),
|
||||
sum(Sum1, Sum2, Sum).
|
||||
|
||||
halve(LSums, Sums1, Sums2) :-
|
||||
length(LSums, L),
|
||||
Take is L div 2,
|
||||
head(Take, LSums, Sums1, Sums2).
|
||||
|
||||
head(0, L, [], L) :- !.
|
||||
head(Take, [H|L], [H|Sums1], Sum2) :-
|
||||
Take1 is Take-1,
|
||||
head(Take1, L, Sums1, Sum2).
|
||||
|
||||
sum(Sum1, Sum2, Sum) -->
|
||||
{ functor(Sum1, _, M1),
|
||||
functor(Sum2, _, M2),
|
||||
Max is M1+M2-2,
|
||||
Max1 is Max+1,
|
||||
Max0 is M2-1,
|
||||
functor(Sum, sum, Max1),
|
||||
Sum1 =.. [_|PVals] },
|
||||
expand_sums(PVals, 0, Max0, Max1, M2, Sum2, Sum).
|
||||
|
||||
%
|
||||
% bottom up step by step
|
||||
%
|
||||
%
|
||||
generate_sums([PVals], Size, Max, _, _Protected, _Domains, _, Sum, []) :- !,
|
||||
Max is Size-1,
|
||||
Sum =.. [sum|PVals].
|
||||
generate_sums([PVals|Parents], Size, Max, Reach, Protected, Domains, ASize, NewSums, F) :-
|
||||
NewReach is Reach+(Size-1),
|
||||
generate_sums(Parents, Size, Max0, NewReach, Protected, Domains, ASize, Sums, F0),
|
||||
Max is Max0+(Size-1),
|
||||
Max1 is Max+1,
|
||||
functor(NewSums, sum, Max1),
|
||||
protect_avg(0, Max0, Protected, Domains, ASize, Reach),
|
||||
expand_sums(PVals, 0, Max0, Max1, Size, Sums, Protected, NewSums, F, F0).
|
||||
|
||||
protect_avg(Max0,Max0,_Protected, _Domains, _ASize, _Reach) :- !.
|
||||
protect_avg(I0, Max0, Protected, Domains, ASize, Reach) :-
|
||||
I is I0+1,
|
||||
Top is I+Reach,
|
||||
( Top > ASize ;
|
||||
arg(I, Domains, CD),
|
||||
arg(Top, Domains, CD)
|
||||
), !,
|
||||
arg(I, Protected, yes),
|
||||
protect_avg(I, Max0, Protected, Domains, ASize, Reach).
|
||||
protect_avg(I0, Max0, Protected, Domains, ASize, Reach) :-
|
||||
I is I0+1,
|
||||
protect_avg(I, Max0, Protected, Domains, ASize, Reach).
|
||||
|
||||
|
||||
%
|
||||
% outer loop: generate array of sums at level j= Sum[j0...jMax]
|
||||
%
|
||||
expand_sums(_Parents, Max, _, Max, _Size, _Sums, _P, _NewSums, F0, F0) :- !.
|
||||
expand_sums(Parents, I0, Max0, Max, Size, Sums, Prot, NewSums, [O=SUM|F], F0) :-
|
||||
I is I0+1,
|
||||
arg(I, Prot, P),
|
||||
var(P), !,
|
||||
arg(I, NewSums, O),
|
||||
sum_all(Parents, 0, I0, Max0, Sums, List),
|
||||
to_disj(List, SUM),
|
||||
expand_sums(Parents, I, Max0, Max, Size, Sums, Prot, NewSums, F, F0).
|
||||
expand_sums(Parents, I0, Max0, Max, Size, Sums, Prot, NewSums, F, F0) :-
|
||||
I is I0+1,
|
||||
arg(I, Sums, O),
|
||||
arg(I, NewSums, O),
|
||||
expand_sums(Parents, I, Max0, Max, Size, Sums, Prot, NewSums, F, F0).
|
||||
|
||||
%
|
||||
%inner loop: find all parents that contribute to A_ji,
|
||||
% that is generate Pk*Sum_(j-1)l and k+l st k+l = i
|
||||
%
|
||||
sum_all([], _, _, _, _, []).
|
||||
sum_all([V|Vs], Pos, I, Max0, Sums, [O|List]) :-
|
||||
J is I-Pos,
|
||||
J >= 0,
|
||||
J =< Max0, !,
|
||||
J1 is J+1,
|
||||
arg(J1, Sums, S0),
|
||||
( J < I -> O = V*S0 ; O = S0*V ),
|
||||
Pos1 is Pos+1,
|
||||
sum_all(Vs, Pos1, I, Max0, Sums, List).
|
||||
sum_all([_V|Vs], Pos, I, Max0, Sums, List) :-
|
||||
Pos1 is Pos+1,
|
||||
sum_all(Vs, Pos1, I, Max0, Sums, List).
|
||||
|
||||
gen_arg(J, Sums, Max, S0) :-
|
||||
gen_arg(0, Max, J, Sums, S0).
|
||||
|
||||
gen_arg(Max, Max, J, Sums, S0) :- !,
|
||||
I is Max+1,
|
||||
arg(I, Sums, A),
|
||||
( Max = J -> S0 = A ; S0 = not(A)).
|
||||
gen_arg(I0, Max, J, Sums, S) :-
|
||||
I is I0+1,
|
||||
arg(I, Sums, A),
|
||||
( I0 = J -> S = A*S0 ; S = not(A)*S0),
|
||||
gen_arg(I, Max, J, Sums, S0).
|
||||
|
||||
|
||||
avg_borders(Size, Size, _Max, []) :- !.
|
||||
avg_borders(I0, Size, Max, [J|Vals]) :-
|
||||
I is I0+1,
|
||||
Border is (I*Max)/Size,
|
||||
J is integer(round(Border)),
|
||||
avg_borders(I, Size, Max, Vals).
|
||||
|
||||
avg_domains(Size, Size, _J, _Max, []).
|
||||
avg_domains(I0, Size, J0, Max, Vals) :-
|
||||
I is I0+1,
|
||||
Border is (I*Max)/Size,
|
||||
fetch_domain_for_avg(J0, Border, J, I0, Vals, ValsI),
|
||||
avg_domains(I, Size, J, Max, ValsI).
|
||||
|
||||
fetch_domain_for_avg(J, Border, J, _, Vals, Vals) :-
|
||||
J > Border, !.
|
||||
fetch_domain_for_avg(J0, Border, J, I0, [I0|LVals], RLVals) :-
|
||||
J1 is J0+1,
|
||||
fetch_domain_for_avg(J1, Border, J, I0, LVals, RLVals).
|
||||
|
||||
generate_avg(Size, Size, _J, _Max, [], [], [], F, F).
|
||||
generate_avg(I0, Size, J0, Max, LSums, [O|OVs], [Ev|Evs], [O=Ev*Disj|F], F0) :-
|
||||
I is I0+1,
|
||||
Border is (I*Max)/Size,
|
||||
fetch_for_avg(J0, Border, J, LSums, MySums, RSums),
|
||||
to_disj(MySums, Disj),
|
||||
generate_avg(I, Size, J, Max, RSums, OVs, Evs, F, F0).
|
||||
|
||||
fetch_for_avg(J, Border, J, RSums, [], RSums) :-
|
||||
J > Border, !.
|
||||
fetch_for_avg(J0, Border, J, [S|LSums], [S|MySums], RSums) :-
|
||||
J1 is J0+1,
|
||||
fetch_for_avg(J1, Border, J, LSums, MySums, RSums).
|
||||
|
||||
|
||||
to_disj([], 0).
|
||||
to_disj([V], V).
|
||||
to_disj([V,V1|Vs], Out) :-
|
||||
to_disj2([V1|Vs], V, Out).
|
||||
|
||||
to_disj2([V], V0, V0+V).
|
||||
to_disj2([V,V1|Vs], V0, Out) :-
|
||||
to_disj2([V1|Vs], V0+V, Out).
|
||||
|
||||
|
||||
%
|
||||
% look for parameters in the rb-tree, or add a new.
|
||||
% distid is the key
|
||||
%
|
||||
check_p(DistId, Map, Parms, ParmVars, Ps, Ps) :-
|
||||
rb_lookup(DistId-Map, theta(Parms, ParmVars), Ps), !.
|
||||
check_p(DistId, Map, Parms, ParmVars, Ps, PsF) :-
|
||||
get_dist_params(DistId, Parms0),
|
||||
get_dist_all_sizes(DistId, Sizes),
|
||||
swap_parms(Parms0, Sizes, [0|Map], Parms1),
|
||||
length(Parms1, L0),
|
||||
get_dist_domain_size(DistId, Size),
|
||||
L1 is L0 div Size,
|
||||
L is L0-L1,
|
||||
initial_maxes(L1, Multipliers),
|
||||
copy(L, Multipliers, NextMults, NextMults, Parms1, Parms, ParmVars),
|
||||
%writeln(t:Size:Parms0:Parms:ParmVars),
|
||||
rb_insert(Ps, DistId-Map, theta(Parms, ParmVars), PsF).
|
||||
|
||||
swap_parms(Parms0, Sizes, Map, Parms1) :-
|
||||
matrix_new(floats, Sizes, Parms0, T0),
|
||||
matrix_shuffle(T0,Map,TF),
|
||||
matrix_to_list(TF, Parms1).
|
||||
|
||||
%
|
||||
% we are using switches by two
|
||||
%
|
||||
initial_maxes(0, []) :- !.
|
||||
initial_maxes(Size, [1.0|Multipliers]) :- !,
|
||||
Size1 is Size-1,
|
||||
initial_maxes(Size1, Multipliers).
|
||||
|
||||
copy(0, [], [], _, _Parms0, [], []) :- !.
|
||||
copy(N, [], [], Ms, Parms0, Parms, ParmVars) :-!,
|
||||
copy(N, Ms, NewMs, NewMs, Parms0, Parms, ParmVars).
|
||||
copy(N, D.Ds, ND.NDs, New, El.Parms0, NEl.Parms, V.ParmVars) :-
|
||||
N1 is N-1,
|
||||
(El == 0.0 ->
|
||||
NEl = 0,
|
||||
V = NEl,
|
||||
ND = D
|
||||
;El == 1.0 ->
|
||||
NEl = 1,
|
||||
V = NEl,
|
||||
ND = 0.0
|
||||
;El == 0 ->
|
||||
NEl = 0,
|
||||
V = NEl,
|
||||
ND = D
|
||||
;El =:= 1 ->
|
||||
NEl = 1,
|
||||
V = NEl,
|
||||
ND = 0.0,
|
||||
V = NEl
|
||||
;
|
||||
NEl is El/D,
|
||||
ND is D-El,
|
||||
V = NEl
|
||||
),
|
||||
copy(N1, Ds, NDs, New, Parms0, Parms, ParmVars).
|
||||
|
||||
unbound_parms([], []).
|
||||
unbound_parms(_.Parms, _.ParmVars) :-
|
||||
unbound_parms(Parms, ParmVars).
|
||||
|
||||
check_v(V, _, INFO, Vs, Vs) :-
|
||||
rb_lookup(V, INFO, Vs), !.
|
||||
check_v(V, DistId, INFO, Vs0, Vs) :-
|
||||
get_dist_domain_size(DistId, Size),
|
||||
length(Values, Size),
|
||||
length(Ev, Size),
|
||||
INFO = info(V, _Tree, Ev, Values, _Formula, _, _),
|
||||
rb_insert(Vs0, V, INFO, Vs).
|
||||
|
||||
get_parents([], [], Vs, Vs).
|
||||
get_parents(V.Parents, Values.PVars, Vs0, Vs) :-
|
||||
clpbn:get_atts(V, [dist(DistId, _)]),
|
||||
check_v(V, DistId, INFO, Vs0, Vs1),
|
||||
INFO = info(V, _Parent, _Ev, Values, _, _, _),
|
||||
get_parents(Parents, PVars, Vs1, Vs).
|
||||
|
||||
%
|
||||
% construct the formula, this is the key...
|
||||
%
|
||||
cross_product(Values, Ev, PVars, ParmVars, Formulas) :-
|
||||
arrangements(PVars, Arranges),
|
||||
apply_parents_first(Values, Ev, ParmCombos, ParmCombos, Arranges, Formulas, ParmVars).
|
||||
|
||||
%
|
||||
% if we have the parent variables with two values, we get
|
||||
% [[XP,YP],[XP,YN],[XN,YP],[XN,YN]]
|
||||
%
|
||||
arrangements([], [[]]).
|
||||
arrangements([L1|Ls],O) :-
|
||||
arrangements(Ls, LN),
|
||||
expand(L1, LN, O, []).
|
||||
|
||||
expand([], _LN) --> [].
|
||||
expand([H|L1], LN) -->
|
||||
concatenate_all(H, LN),
|
||||
expand(L1, LN).
|
||||
|
||||
concatenate_all(_H, []) --> [].
|
||||
concatenate_all(H, L.LN) -->
|
||||
[[H|L]],
|
||||
concatenate_all(H, LN).
|
||||
|
||||
%
|
||||
% core of algorithm
|
||||
%
|
||||
% Values -> Output Vars for BDD
|
||||
% Es -> Evidence variables
|
||||
% Previous -> top of difference list with parameters used so far
|
||||
% P0 -> end of difference list with parameters used so far
|
||||
% Pvars -> Parents
|
||||
% Eqs -> Output Equations
|
||||
% Pars -> Output Theta Parameters
|
||||
%
|
||||
apply_parents_first([Value], [E], Previous, [], PVars, [Value=Disj*E], Parameters) :- !,
|
||||
apply_last_parent(PVars, Previous, Disj),
|
||||
flatten(Previous, Parameters).
|
||||
apply_parents_first([Value|Values], [E|Ev], Previous, P0, PVars, (Value=Disj*E).Formulas, Parameters) :-
|
||||
P0 = [TheseParents|End],
|
||||
apply_first_parent(PVars, Disj, TheseParents),
|
||||
apply_parents_second(Values, Ev, Previous, End, PVars, Formulas, Parameters).
|
||||
|
||||
apply_parents_second([Value], [E], Previous, [], PVars, [Value=Disj*E], Parameters) :- !,
|
||||
apply_last_parent(PVars, Previous, Disj),
|
||||
flatten(Previous, Parameters).
|
||||
apply_parents_second([Value|Values], [E|Ev], Previous, P0, PVars, (Value=Disj*E).Formulas, Parameters) :-
|
||||
apply_middle_parent(PVars, Previous, Disj, TheseParents),
|
||||
% this must be done after applying middle parents because of the var
|
||||
% test.
|
||||
P0 = [TheseParents|End],
|
||||
apply_parents_second(Values, Ev, Previous, End, PVars, Formulas, Parameters).
|
||||
|
||||
apply_first_parent([Parents], Conj, [Theta]) :- !,
|
||||
parents_to_conj(Parents,Theta,Conj).
|
||||
apply_first_parent(Parents.PVars, Conj+Disj, Theta.TheseParents) :-
|
||||
parents_to_conj(Parents,Theta,Conj),
|
||||
apply_first_parent(PVars, Disj, TheseParents).
|
||||
|
||||
apply_middle_parent([Parents], Other, Conj, [ThetaPar]) :- !,
|
||||
skim_for_theta(Other, Theta, _, ThetaPar),
|
||||
parents_to_conj(Parents,Theta,Conj).
|
||||
apply_middle_parent(Parents.PVars, Other, Conj+Disj, ThetaPar.TheseParents) :-
|
||||
skim_for_theta(Other, Theta, Remaining, ThetaPar),
|
||||
parents_to_conj(Parents,(Theta),Conj),
|
||||
apply_middle_parent(PVars, Remaining, Disj, TheseParents).
|
||||
|
||||
apply_last_parent([Parents], Other, Conj) :- !,
|
||||
parents_to_conj(Parents,(Theta),Conj),
|
||||
skim_for_theta(Other, Theta, _, _).
|
||||
apply_last_parent(Parents.PVars, Other, Conj+Disj) :-
|
||||
parents_to_conj(Parents,(Theta),Conj),
|
||||
skim_for_theta(Other, Theta, Remaining, _),
|
||||
apply_last_parent(PVars, Remaining, Disj).
|
||||
|
||||
%
|
||||
%
|
||||
% simplify stuff, removing process that is cancelled by 0s
|
||||
%
|
||||
parents_to_conj([], Theta, Theta) :- !.
|
||||
parents_to_conj(Ps, Theta, Theta*Conj) :-
|
||||
parents_to_conj2(Ps, Conj).
|
||||
|
||||
parents_to_conj2([P],P) :- !.
|
||||
parents_to_conj2(P.Ps,P*Conj) :-
|
||||
parents_to_conj2(Ps,Conj).
|
||||
|
||||
%
|
||||
% first case we haven't reached the end of the list so we need
|
||||
% to create a new parameter variable
|
||||
%
|
||||
skim_for_theta([[P|Other]|V], not(P)*New, [Other|_], New) :- var(V), !.
|
||||
%
|
||||
% last theta, it is just negation of the other ones
|
||||
%
|
||||
skim_for_theta([[P|Other]], not(P), [Other], _) :- !.
|
||||
%
|
||||
% recursive case, build-up
|
||||
%
|
||||
skim_for_theta([[P|Other]|More], not(P)*Ps, [Other|Left], New ) :-
|
||||
skim_for_theta(More, Ps, Left, New ).
|
||||
|
||||
get_evidence(V, Tree, Ev, F0, F, Leaves, Finals) :-
|
||||
clpbn:get_atts(V, [evidence(Pos)]), !,
|
||||
zero_pos(0, Pos, Ev),
|
||||
insert_output(Leaves, V, Finals, Tree, Outs, SendOut),
|
||||
get_outs(F0, F, SendOut, Outs).
|
||||
% hidden deterministic node, can be removed.
|
||||
get_evidence(V, _Tree, Ev, F0, [], _Leaves, _Finals) :-
|
||||
clpbn:get_atts(V, [key(K)]),
|
||||
functor(K, Name, 2),
|
||||
( Name = 'AVG' ; Name = 'MAX' ; Name = 'MIN' ),
|
||||
!,
|
||||
one_list(Ev),
|
||||
eval_outs(F0).
|
||||
%% no evidence !!!
|
||||
get_evidence(V, Tree, _Values, F0, F1, Leaves, Finals) :-
|
||||
insert_output(Leaves, V, Finals, Tree, Outs, SendOut),
|
||||
get_outs(F0, F1, SendOut, Outs).
|
||||
|
||||
zero_pos(_, _Pos, []).
|
||||
zero_pos(Pos, Pos, 1.Values) :- !,
|
||||
I is Pos+1,
|
||||
zero_pos(I, Pos, Values).
|
||||
zero_pos(I0, Pos, 0.Values) :-
|
||||
I is I0+1,
|
||||
zero_pos(I, Pos, Values).
|
||||
|
||||
one_list([]).
|
||||
one_list(1.Ev) :-
|
||||
one_list(Ev).
|
||||
|
||||
%
|
||||
% insert a node with the disj of all alternatives, this is only done if node ends up to be in the output
|
||||
%
|
||||
insert_output([], _V, [], _Out, _Outs, []).
|
||||
insert_output(V._Leaves, V0, [Top|_], Top, Outs, [Top = Outs]) :- V == V0, !.
|
||||
insert_output(_.Leaves, V, _.Finals, Top, Outs, SendOut) :-
|
||||
insert_output(Leaves, V, Finals, Top, Outs, SendOut).
|
||||
|
||||
|
||||
get_outs([V=F], [V=NF|End], End, V) :- !,
|
||||
% writeln(f0:F),
|
||||
simplify_exp(F,NF).
|
||||
get_outs((V=F).Outs, (V=NF).NOuts, End, (F0 + V)) :-
|
||||
% writeln(f0:F),
|
||||
simplify_exp(F,NF),
|
||||
get_outs(Outs, NOuts, End, F0).
|
||||
|
||||
eval_outs([]).
|
||||
eval_outs((V=F).Outs) :-
|
||||
simplify_exp(F,NF),
|
||||
V = NF,
|
||||
eval_outs(Outs).
|
||||
|
||||
run_bdd_solver([[V]], LPs, bdd(Term, _Leaves, Nodes)) :-
|
||||
build_out_node(Nodes, Node),
|
||||
findall(Prob, get_prob(Term, Node, V, Prob),TermProbs),
|
||||
sumlist(TermProbs, Sum),
|
||||
writeln(TermProbs:Sum),
|
||||
normalise(TermProbs, Sum, LPs).
|
||||
|
||||
build_out_node([_Top], []).
|
||||
build_out_node([T,T1|Tops], [Top = T*Top]) :-
|
||||
build_out_node2(T1.Tops, Top).
|
||||
|
||||
build_out_node2([Top], Top).
|
||||
build_out_node2([T,T1|Tops], T*Top) :-
|
||||
build_out_node2(T1.Tops, Top).
|
||||
|
||||
|
||||
get_prob(Term, Node, V, SP) :-
|
||||
bind_all(Term, Node, Bindings, V, AllParms, AllParmValues),
|
||||
% reverse(AllParms, RAllParms),
|
||||
term_variables(AllParms, NVs),
|
||||
build_bdd(Bindings, NVs, AllParms, AllParmValues, Bdd),
|
||||
bdd_to_probability_sum_product(Bdd, SP),
|
||||
bdd_close(Bdd).
|
||||
|
||||
build_bdd(Bindings, NVs, VTheta, Theta, Bdd) :-
|
||||
bdd_from_list(Bindings, NVs, Bdd),
|
||||
bdd_size(Bdd, Len),
|
||||
number_codes(Len,Codes),
|
||||
atom_codes(Name,Codes),
|
||||
bdd_print(Bdd, Name),
|
||||
writeln(length=Len),
|
||||
VTheta = Theta.
|
||||
|
||||
bind_all([], End, End, _V, [], []).
|
||||
bind_all(info(V, _Tree, Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
|
||||
V0 == V, !,
|
||||
set_to_one_zeros(Ev),
|
||||
bind_formula(Formula, BindsF, BindsI),
|
||||
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
|
||||
bind_all(info(_V, _Tree, Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
|
||||
set_to_ones(Ev),!,
|
||||
bind_formula(Formula, BindsF, BindsI),
|
||||
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
|
||||
% evidence: no need to add any stuff.
|
||||
bind_all(info(_V, _Tree, _Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
|
||||
bind_formula(Formula, BindsF, BindsI),
|
||||
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
|
||||
|
||||
bind_formula([], L, L).
|
||||
bind_formula(B.Formula, B.BsF, Bs0) :-
|
||||
bind_formula(Formula, BsF, Bs0).
|
||||
|
||||
set_to_one_zeros([1|Values]) :-
|
||||
set_to_zeros(Values).
|
||||
set_to_one_zeros([0|Values]) :-
|
||||
set_to_one_zeros(Values).
|
||||
|
||||
set_to_zeros([]).
|
||||
set_to_zeros(0.Values) :-
|
||||
set_to_zeros(Values).
|
||||
|
||||
set_to_ones([]).
|
||||
set_to_ones(1.Values) :-
|
||||
set_to_ones(Values).
|
||||
|
||||
normalise([], _Sum, []).
|
||||
normalise(P.TermProbs, Sum, NP.LPs) :-
|
||||
NP is P/Sum,
|
||||
normalise(TermProbs, Sum, LPs).
|
||||
|
||||
finalize_bdd_solver(_).
|
||||
|
@ -1,16 +1,16 @@
|
||||
|
||||
/************************************************
|
||||
/*******************************************************
|
||||
|
||||
Belief Propagation in CLP(BN)
|
||||
Belief Propagation and Variable Elimination Interface
|
||||
|
||||
**************************************************/
|
||||
********************************************************/
|
||||
|
||||
:- module(clpbn_bp,
|
||||
[bp/3,
|
||||
check_if_bp_done/1,
|
||||
set_horus_flag/2,
|
||||
init_bp_solver/4,
|
||||
run_bp_solver/3,
|
||||
call_bp_ground/6,
|
||||
finalize_bp_solver/1
|
||||
]).
|
||||
|
||||
@ -24,154 +24,143 @@
|
||||
|
||||
|
||||
:- use_module(library('clpbn/display'),
|
||||
[clpbn_bind_vals/3]).
|
||||
[clpbn_bind_vals/3]).
|
||||
|
||||
|
||||
:- use_module(library('clpbn/aggregates'),
|
||||
[check_for_agg_vars/2]).
|
||||
[check_for_agg_vars/2]).
|
||||
|
||||
|
||||
:- use_module(library(charsio),
|
||||
[term_to_atom/2]).
|
||||
|
||||
|
||||
:- use_module(library(pfl),
|
||||
[skolem/2,
|
||||
get_pfl_parameters/2
|
||||
]).
|
||||
|
||||
|
||||
:- use_module(library(lists)).
|
||||
|
||||
:- use_module(library(atts)).
|
||||
:- use_module(library(lists)).
|
||||
:- use_module(library(charsio)).
|
||||
|
||||
:- load_foreign_files(['horus'], [], init_predicates).
|
||||
|
||||
:- attribute id/1.
|
||||
:- use_module(library(bhash)).
|
||||
|
||||
|
||||
%:- set_horus_flag(inf_alg, ve).
|
||||
:- set_horus_flag(inf_alg, bn_bp).
|
||||
%:- set_horus_flag(inf_alg, fg_bp).
|
||||
%: -set_horus_flag(inf_alg, cbp).
|
||||
:- use_module(horus,
|
||||
[create_ground_network/4,
|
||||
set_factors_params/2,
|
||||
run_ground_solver/3,
|
||||
set_vars_information/2,
|
||||
free_ground_network/1
|
||||
]).
|
||||
|
||||
:- set_horus_flag(schedule, seq_fixed).
|
||||
%:- set_horus_flag(schedule, seq_random).
|
||||
%:- set_horus_flag(schedule, parallel).
|
||||
%:- set_horus_flag(schedule, max_residual).
|
||||
|
||||
:- set_horus_flag(accuracy, 0.0001).
|
||||
call_bp_ground(QueryVars, QueryKeys, AllKeys, Factors, Evidence, Output) :-
|
||||
writeln(here:Factors),
|
||||
b_hash_new(Hash0),
|
||||
keys_to_ids(AllKeys, 0, Hash0, Hash),
|
||||
get_factors_type(Factors, Type),
|
||||
evidence_to_ids(Evidence, Hash, EvidenceIds),
|
||||
factors_to_ids(Factors, Hash, FactorIds),
|
||||
writeln(type:Type), writeln(''),
|
||||
writeln(allKeys:AllKeys), writeln(''),
|
||||
writeln(factors:Factors), writeln(''),
|
||||
writeln(factorIds:FactorIds), writeln(''),
|
||||
writeln(evidence:Evidence), writeln(''),
|
||||
writeln(evidenceIds:EvidenceIds), writeln(''),
|
||||
create_ground_network(Type, FactorIds, EvidenceIds, Network),
|
||||
%get_vars_information(AllKeys, StatesNames),
|
||||
%set_vars_information(AllKeys, StatesNames),
|
||||
run_solver(ground(Network,Hash), QueryKeys, Solutions),
|
||||
writeln(answer:Solutions),
|
||||
clpbn_bind_vals([QueryVars], Solutions, Output),
|
||||
free_ground_network(Network).
|
||||
|
||||
:- set_horus_flag(max_iter, 1000).
|
||||
|
||||
:- set_horus_flag(use_logarithms, false).
|
||||
%:- set_horus_flag(use_logarithms, true).
|
||||
run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
||||
%get_dists_parameters(DistIds, DistsParams),
|
||||
%set_factors_params(Network, DistsParams),
|
||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||
writeln(queryKeys:QueryKeys), writeln(''),
|
||||
writeln(queryIds:QueryIds), writeln(''),
|
||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||
run_ground_solver(Network, [QueryIds], Solutions).
|
||||
|
||||
:- set_horus_flag(order_factor_variables, false).
|
||||
%:- set_horus_flag(order_factor_variables, true).
|
||||
|
||||
keys_to_ids([], _, Hash, Hash).
|
||||
keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
|
||||
b_hash_insert(Hash0, Key, I0, HashI),
|
||||
I is I0+1,
|
||||
keys_to_ids(AllKeys, I, HashI, Hash).
|
||||
|
||||
|
||||
get_factors_type([f(bayes, _, _, _)|_], bayes) :- ! .
|
||||
get_factors_type([f(markov, _, _, _)|_], markov) :- ! .
|
||||
|
||||
|
||||
list_of_keys_to_ids([], _, []).
|
||||
list_of_keys_to_ids([Key|QueryKeys], Hash, [Id|QueryIds]) :-
|
||||
b_hash_lookup(Key, Id, Hash),
|
||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds).
|
||||
|
||||
|
||||
factors_to_ids([], _, []).
|
||||
factors_to_ids([f(_, DistId, Keys, CPT)|Fs], Hash, [f(Ids, Ranges, CPT, DistId)|NFs]) :-
|
||||
list_of_keys_to_ids(Keys, Hash, Ids),
|
||||
get_ranges(Keys, Ranges),
|
||||
factors_to_ids(Fs, Hash, NFs).
|
||||
|
||||
|
||||
get_ranges([],[]).
|
||||
get_ranges(K.Ks, Range.Rs) :- !,
|
||||
skolem(K,Domain),
|
||||
length(Domain,Range),
|
||||
get_ranges(Ks, Rs).
|
||||
|
||||
|
||||
evidence_to_ids([], _, []).
|
||||
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
|
||||
b_hash_lookup(Key, Id, Hash),
|
||||
evidence_to_ids(QueryKeys, Hash, QueryIds).
|
||||
|
||||
|
||||
get_vars_information([], []).
|
||||
get_vars_information(Key.QueryKeys, Domain.StatesNames) :-
|
||||
pfl:skolem(Key, Domain),
|
||||
get_vars_information(QueryKeys, StatesNames).
|
||||
|
||||
|
||||
finalize_bp_solver(bp(Network, _)) :-
|
||||
free_ground_network(Network).
|
||||
|
||||
|
||||
bp([[]],_,_) :- !.
|
||||
bp([QueryVars], AllVars, Output) :-
|
||||
init_bp_solver(_, AllVars, _, Network),
|
||||
run_bp_solver([QueryVars], LPs, Network),
|
||||
finalize_bp_solver(Network),
|
||||
clpbn_bind_vals([QueryVars], LPs, Output).
|
||||
init_bp_solver(_, AllVars, _, Network),
|
||||
run_bp_solver([QueryVars], LPs, Network),
|
||||
finalize_bp_solver(Network),
|
||||
clpbn_bind_vals([QueryVars], LPs, Output).
|
||||
|
||||
|
||||
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
||||
check_for_agg_vars(AllVars0, AllVars),
|
||||
writeln('clpbn_vars:'),
|
||||
print_clpbn_vars(AllVars),
|
||||
assign_ids(AllVars, 0),
|
||||
get_vars_info(AllVars, VarsInfo, DistIds0),
|
||||
sort(DistIds0, DistIds),
|
||||
create_ground_network(VarsInfo, BayesNet).
|
||||
%get_extra_vars_info(AllVars, ExtraVarsInfo),
|
||||
%set_extra_vars_info(BayesNet, ExtraVarsInfo).
|
||||
%check_for_agg_vars(AllVars0, AllVars),
|
||||
get_vars_info(AllVars0, VarsInfo, DistIds0),
|
||||
sort(DistIds0, DistIds),
|
||||
create_ground_network(VarsInfo, BayesNet),
|
||||
true.
|
||||
|
||||
|
||||
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
||||
get_dists_parameters(DistIds, DistsParams),
|
||||
set_bayes_net_params(Network, DistsParams),
|
||||
flatten_1_element_sublists(QueryVars, QueryVars1),
|
||||
vars_to_ids(QueryVars1, QueryVarsIds),
|
||||
run_other_solvers(Network, QueryVarsIds, Solutions).
|
||||
|
||||
|
||||
finalize_bp_solver(bp(Network, _)) :-
|
||||
free_bayesian_network(Network).
|
||||
|
||||
|
||||
assign_ids([], _).
|
||||
assign_ids([V|Vs], Count) :-
|
||||
put_atts(V, [id(Count)]),
|
||||
Count1 is Count + 1,
|
||||
assign_ids(Vs, Count1).
|
||||
|
||||
|
||||
get_vars_info([], [], []).
|
||||
get_vars_info(V.Vs,
|
||||
var(VarId,DS,Ev,PIds,DistId).VarsInfo,
|
||||
DistId.DistIds) :-
|
||||
clpbn:get_atts(V, [dist(DistId, Parents)]), !,
|
||||
get_atts(V, [id(VarId)]),
|
||||
get_dist_domain_size(DistId, DS),
|
||||
get_evidence(V, Ev),
|
||||
vars_to_ids(Parents, PIds),
|
||||
get_vars_info(Vs, VarsInfo, DistIds).
|
||||
|
||||
|
||||
get_evidence(V, Ev) :-
|
||||
clpbn:get_atts(V, [evidence(Ev)]), !.
|
||||
get_evidence(_V, -1). % no evidence !!!
|
||||
|
||||
|
||||
vars_to_ids([], []).
|
||||
vars_to_ids([L|Vars], [LIds|Ids]) :-
|
||||
is_list(L), !,
|
||||
vars_to_ids(L, LIds),
|
||||
vars_to_ids(Vars, Ids).
|
||||
vars_to_ids([V|Vars], [VarId|Ids]) :-
|
||||
get_atts(V, [id(VarId)]),
|
||||
vars_to_ids(Vars, Ids).
|
||||
|
||||
|
||||
get_extra_vars_info([], []).
|
||||
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
|
||||
get_atts(V, [id(VarId)]), !,
|
||||
clpbn:get_atts(V, [key(Key),dist(DistId, _)]),
|
||||
term_to_atom(Key, Label),
|
||||
get_dist_domain(DistId, Domain0),
|
||||
numbers_to_atoms(Domain0, Domain),
|
||||
get_extra_vars_info(Vs, VarsInfo).
|
||||
get_extra_vars_info([_|Vs], VarsInfo) :-
|
||||
get_extra_vars_info(Vs, VarsInfo).
|
||||
get_dists_parameters(DistIds, DistsParams),
|
||||
set_factors_params(Network, DistsParams),
|
||||
vars_to_ids(QueryVars, QueryVarsIds),
|
||||
run_ground_solver(Network, QueryVarsIds, Solutions).
|
||||
|
||||
|
||||
get_dists_parameters([],[]).
|
||||
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
||||
get_dist_params(Id, Params),
|
||||
get_dists_parameters(Ids, DistsInfo).
|
||||
|
||||
|
||||
numbers_to_atoms([], []).
|
||||
numbers_to_atoms([Atom|L0], [Atom|L]) :-
|
||||
atom(Atom), !,
|
||||
numbers_to_atoms(L0, L).
|
||||
numbers_to_atoms([Number|L0], [Atom|L]) :-
|
||||
number_atom(Number, Atom),
|
||||
numbers_to_atoms(L0, L).
|
||||
|
||||
|
||||
flatten_1_element_sublists([],[]).
|
||||
flatten_1_element_sublists([[H|[]]|T],[H|R]) :- !,
|
||||
flatten_1_element_sublists(T,R).
|
||||
flatten_1_element_sublists([H|T],[H|R]) :-
|
||||
flatten_1_element_sublists(T,R).
|
||||
|
||||
|
||||
print_clpbn_vars(Var.AllVars) :-
|
||||
clpbn:get_atts(Var, [key(Key),dist(DistId,Parents)]),
|
||||
parents_to_keys(Parents, ParentKeys),
|
||||
writeln(Var:Key:ParentKeys:DistId),
|
||||
print_clpbn_vars(AllVars).
|
||||
print_clpbn_vars([]).
|
||||
|
||||
|
||||
parents_to_keys([], []).
|
||||
parents_to_keys(Var.Parents, Key.Keys) :-
|
||||
clpbn:get_atts(Var, [key(Key)]),
|
||||
parents_to_keys(Parents, Keys).
|
||||
get_dist_params(Id, Params),
|
||||
get_dists_parameters(Ids, DistsInfo).
|
||||
|
||||
|
77
packages/CLPBN/clpbn/bp/BayesBall.cpp
Normal file
77
packages/CLPBN/clpbn/bp/BayesBall.cpp
Normal file
@ -0,0 +1,77 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "BayesBall.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||
{
|
||||
assert (fg_.isFromBayesNetwork());
|
||||
|
||||
Scheduling scheduling;
|
||||
for (unsigned i = 0; i < queryIds.size(); i++) {
|
||||
assert (dag_.getNode (queryIds[i]));
|
||||
DAGraphNode* n = dag_.getNode (queryIds[i]);
|
||||
scheduling.push (ScheduleInfo (n, false, true));
|
||||
}
|
||||
|
||||
while (!scheduling.empty()) {
|
||||
ScheduleInfo& sch = scheduling.front();
|
||||
DAGraphNode* n = sch.node;
|
||||
n->setAsVisited();
|
||||
if (n->hasEvidence() == false && sch.visitedFromChild) {
|
||||
if (n->isMarkedOnTop() == false) {
|
||||
n->markOnTop();
|
||||
scheduleParents (n, scheduling);
|
||||
}
|
||||
if (n->isMarkedOnBottom() == false) {
|
||||
n->markOnBottom();
|
||||
scheduleChilds (n, scheduling);
|
||||
}
|
||||
}
|
||||
if (sch.visitedFromParent) {
|
||||
if (n->hasEvidence() && n->isMarkedOnTop() == false) {
|
||||
n->markOnTop();
|
||||
scheduleParents (n, scheduling);
|
||||
}
|
||||
if (n->hasEvidence() == false && n->isMarkedOnBottom() == false) {
|
||||
n->markOnBottom();
|
||||
scheduleChilds (n, scheduling);
|
||||
}
|
||||
}
|
||||
scheduling.pop();
|
||||
}
|
||||
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
constructGraph (fg);
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesBall::constructGraph (FactorGraph* fg) const
|
||||
{
|
||||
const FacNodes& facNodes = fg_.facNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const DAGraphNode* n = dag_.getNode (
|
||||
facNodes[i]->factor().argument (0));
|
||||
if (n->isMarkedOnTop()) {
|
||||
fg->addFactor (Factor (facNodes[i]->factor()));
|
||||
} else if (n->hasEvidence() && n->isVisited()) {
|
||||
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
||||
Ranges ranges = { facNodes[i]->factor().range (0) };
|
||||
Params params (ranges[0], LogAware::noEvidence());
|
||||
params[n->getEvidence()] = LogAware::withEvidence();
|
||||
fg->addFactor (Factor (varIds, ranges, params));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
85
packages/CLPBN/clpbn/bp/BayesBall.h
Normal file
85
packages/CLPBN/clpbn/bp/BayesBall.h
Normal file
@ -0,0 +1,85 @@
|
||||
#ifndef HORUS_BAYESBALL_H
|
||||
#define HORUS_BAYESBALL_H
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <list>
|
||||
#include <map>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "BayesNet.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
struct ScheduleInfo
|
||||
{
|
||||
ScheduleInfo (DAGraphNode* n, bool vfp, bool vfc) :
|
||||
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||
|
||||
DAGraphNode* node;
|
||||
bool visitedFromParent;
|
||||
bool visitedFromChild;
|
||||
};
|
||||
|
||||
|
||||
typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
|
||||
|
||||
|
||||
class BayesBall
|
||||
{
|
||||
public:
|
||||
BayesBall (FactorGraph& fg)
|
||||
: fg_(fg) , dag_(fg.getStructure())
|
||||
{
|
||||
dag_.clear();
|
||||
}
|
||||
|
||||
FactorGraph* getMinimalFactorGraph (const VarIds&);
|
||||
|
||||
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
|
||||
{
|
||||
BayesBall bb (fg);
|
||||
return bb.getMinimalFactorGraph (vids);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
void constructGraph (FactorGraph* fg) const;
|
||||
|
||||
void scheduleParents (const DAGraphNode* n, Scheduling& sch) const;
|
||||
|
||||
void scheduleChilds (const DAGraphNode* n, Scheduling& sch) const;
|
||||
|
||||
FactorGraph& fg_;
|
||||
|
||||
DAGraph& dag_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
|
||||
{
|
||||
const vector<DAGraphNode*>& ps = n->parents();
|
||||
for (vector<DAGraphNode*>::const_iterator it = ps.begin();
|
||||
it != ps.end(); it++) {
|
||||
sch.push (ScheduleInfo (*it, false, true));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const
|
||||
{
|
||||
const vector<DAGraphNode*>& cs = n->childs();
|
||||
for (vector<DAGraphNode*>::const_iterator it = cs.begin();
|
||||
it != cs.end(); it++) {
|
||||
sch.push (ScheduleInfo (*it, true, false));
|
||||
}
|
||||
}
|
||||
|
||||
#endif // HORUS_BAYESBALL_H
|
||||
|
@ -5,381 +5,57 @@
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "xmlParser/xmlParser.h"
|
||||
|
||||
#include "BayesNet.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
|
||||
BayesNet::~BayesNet (void)
|
||||
void
|
||||
DAGraph::addNode (DAGraphNode* n)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
delete nodes_[i];
|
||||
}
|
||||
assert (Util::contains (varMap_, n->varId()) == false);
|
||||
nodes_.push_back (n);
|
||||
varMap_[n->varId()] = n;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::readFromBifFormat (const char* fileName)
|
||||
DAGraph::addEdge (VarId vid1, VarId vid2)
|
||||
{
|
||||
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
|
||||
// only the first network is parsed, others are ignored
|
||||
XMLNode xNode = xMainNode.getChildNode ("NETWORK");
|
||||
unsigned nVars = xNode.nChildNode ("VARIABLE");
|
||||
for (unsigned i = 0; i < nVars; i++) {
|
||||
XMLNode var = xNode.getChildNode ("VARIABLE", i);
|
||||
if (string (var.getAttribute ("TYPE")) != "nature") {
|
||||
cerr << "error: only \"nature\" variables are supported" << endl;
|
||||
abort();
|
||||
}
|
||||
States states;
|
||||
string label = var.getChildNode("NAME").getText();
|
||||
unsigned nrStates = var.nChildNode ("OUTCOME");
|
||||
for (unsigned j = 0; j < nrStates; j++) {
|
||||
if (var.getChildNode("OUTCOME", j).getText() == 0) {
|
||||
stringstream ss;
|
||||
ss << j + 1;
|
||||
states.push_back (ss.str());
|
||||
} else {
|
||||
states.push_back (var.getChildNode("OUTCOME", j).getText());
|
||||
}
|
||||
}
|
||||
addNode (label, states);
|
||||
}
|
||||
|
||||
unsigned nDefs = xNode.nChildNode ("DEFINITION");
|
||||
if (nVars != nDefs) {
|
||||
cerr << "error: different number of variables and definitions" << endl;
|
||||
abort();
|
||||
}
|
||||
for (unsigned i = 0; i < nDefs; i++) {
|
||||
XMLNode def = xNode.getChildNode ("DEFINITION", i);
|
||||
string label = def.getChildNode("FOR").getText();
|
||||
BayesNode* node = getBayesNode (label);
|
||||
if (!node) {
|
||||
cerr << "error: unknow variable `" << label << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
BnNodeSet parents;
|
||||
unsigned nParams = node->nrStates();
|
||||
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
|
||||
string parentLabel = def.getChildNode("GIVEN", j).getText();
|
||||
BayesNode* parentNode = getBayesNode (parentLabel);
|
||||
if (!parentNode) {
|
||||
cerr << "error: unknow variable `" << parentLabel << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
nParams *= parentNode->nrStates();
|
||||
parents.push_back (parentNode);
|
||||
}
|
||||
node->setParents (parents);
|
||||
unsigned count = 0;
|
||||
Params params (nParams);
|
||||
stringstream s (def.getChildNode("TABLE").getText());
|
||||
while (!s.eof() && count < nParams) {
|
||||
s >> params[count];
|
||||
count ++;
|
||||
}
|
||||
if (count != nParams) {
|
||||
cerr << "error: invalid number of parameters " ;
|
||||
cerr << "for variable `" << label << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
params = reorderParameters (params, node->nrStates());
|
||||
Distribution* dist = new Distribution (params);
|
||||
node->setDistribution (dist);
|
||||
addDistribution (dist);
|
||||
}
|
||||
|
||||
setIndexes();
|
||||
if (Globals::logDomain) {
|
||||
distributionsToLogs();
|
||||
}
|
||||
unordered_map<VarId, DAGraphNode*>::iterator it1;
|
||||
unordered_map<VarId, DAGraphNode*>::iterator it2;
|
||||
it1 = varMap_.find (vid1);
|
||||
it2 = varMap_.find (vid2);
|
||||
assert (it1 != varMap_.end());
|
||||
assert (it2 != varMap_.end());
|
||||
it1->second->addChild (it2->second);
|
||||
it2->second->addParent (it1->second);
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::addNode (string label, const States& states)
|
||||
const DAGraphNode*
|
||||
DAGraph::getNode (VarId vid) const
|
||||
{
|
||||
VarId vid = nodes_.size();
|
||||
varMap_.insert (make_pair (vid, nodes_.size()));
|
||||
GraphicalModel::addVariableInformation (vid, label, states);
|
||||
BayesNode* node = new BayesNode (VarNode (vid, states.size()));
|
||||
nodes_.push_back (node);
|
||||
return node;
|
||||
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::addNode (VarId vid, unsigned dsize, int evidence, Distribution* dist)
|
||||
DAGraphNode*
|
||||
DAGraph::getNode (VarId vid)
|
||||
{
|
||||
varMap_.insert (make_pair (vid, nodes_.size()));
|
||||
nodes_.push_back (new BayesNode (vid, dsize, evidence, dist));
|
||||
return nodes_.back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::getBayesNode (VarId vid) const
|
||||
{
|
||||
IndexMap::const_iterator it = varMap_.find (vid);
|
||||
if (it == varMap_.end()) {
|
||||
return 0;
|
||||
} else {
|
||||
return nodes_[it->second];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::getBayesNode (string label) const
|
||||
{
|
||||
BayesNode* node = 0;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (nodes_[i]->label() == label) {
|
||||
node = nodes_[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
VarNode*
|
||||
BayesNet::getVariableNode (VarId vid) const
|
||||
{
|
||||
BayesNode* node = getBayesNode (vid);
|
||||
assert (node);
|
||||
return node;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
VarNodes
|
||||
BayesNet::getVariableNodes (void) const
|
||||
{
|
||||
VarNodes vars;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
vars.push_back (nodes_[i]);
|
||||
}
|
||||
return vars;
|
||||
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::addDistribution (Distribution* dist)
|
||||
{
|
||||
dists_.push_back (dist);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Distribution*
|
||||
BayesNet::getDistribution (unsigned distId) const
|
||||
{
|
||||
Distribution* dist = 0;
|
||||
for (unsigned i = 0; i < dists_.size(); i++) {
|
||||
if (dists_[i]->id == (int) distId) {
|
||||
dist = dists_[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
return dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
const BnNodeSet&
|
||||
BayesNet::getBayesNodes (void) const
|
||||
{
|
||||
return nodes_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
BayesNet::nrNodes (void) const
|
||||
{
|
||||
return nodes_.size();
|
||||
}
|
||||
|
||||
|
||||
|
||||
BnNodeSet
|
||||
BayesNet::getRootNodes (void) const
|
||||
{
|
||||
BnNodeSet roots;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (nodes_[i]->isRoot()) {
|
||||
roots.push_back (nodes_[i]);
|
||||
}
|
||||
}
|
||||
return roots;
|
||||
}
|
||||
|
||||
|
||||
|
||||
BnNodeSet
|
||||
BayesNet::getLeafNodes (void) const
|
||||
{
|
||||
BnNodeSet leafs;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (nodes_[i]->isLeaf()) {
|
||||
leafs.push_back (nodes_[i]);
|
||||
}
|
||||
}
|
||||
return leafs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNet*
|
||||
BayesNet::getMinimalRequesiteNetwork (VarId vid) const
|
||||
{
|
||||
return getMinimalRequesiteNetwork (VarIds() = {vid});
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNet*
|
||||
BayesNet::getMinimalRequesiteNetwork (const VarIds& queryVarIds) const
|
||||
{
|
||||
BnNodeSet queryVars;
|
||||
Scheduling scheduling;
|
||||
for (unsigned i = 0; i < queryVarIds.size(); i++) {
|
||||
BayesNode* n = getBayesNode (queryVarIds[i]);
|
||||
assert (n);
|
||||
queryVars.push_back (n);
|
||||
scheduling.push (ScheduleInfo (n, false, true));
|
||||
}
|
||||
|
||||
vector<StateInfo*> states (nodes_.size(), 0);
|
||||
|
||||
while (!scheduling.empty()) {
|
||||
ScheduleInfo& sch = scheduling.front();
|
||||
StateInfo* state = states[sch.node->getIndex()];
|
||||
if (!state) {
|
||||
state = new StateInfo();
|
||||
states[sch.node->getIndex()] = state;
|
||||
} else {
|
||||
state->visited = true;
|
||||
}
|
||||
if (!sch.node->hasEvidence() && sch.visitedFromChild) {
|
||||
if (!state->markedOnTop) {
|
||||
state->markedOnTop = true;
|
||||
scheduleParents (sch.node, scheduling);
|
||||
}
|
||||
if (!state->markedOnBottom) {
|
||||
state->markedOnBottom = true;
|
||||
scheduleChilds (sch.node, scheduling);
|
||||
}
|
||||
}
|
||||
if (sch.visitedFromParent) {
|
||||
if (sch.node->hasEvidence() && !state->markedOnTop) {
|
||||
state->markedOnTop = true;
|
||||
scheduleParents (sch.node, scheduling);
|
||||
}
|
||||
if (!sch.node->hasEvidence() && !state->markedOnBottom) {
|
||||
state->markedOnBottom = true;
|
||||
scheduleChilds (sch.node, scheduling);
|
||||
}
|
||||
}
|
||||
scheduling.pop();
|
||||
}
|
||||
/*
|
||||
cout << "\t\ttop\tbottom" << endl;
|
||||
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
|
||||
cout << "----------------------------------------------------------" ;
|
||||
cout << endl;
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
cout << nodes_[i]->label() << ":\t\t" ;
|
||||
if (states[i]) {
|
||||
states[i]->markedOnTop ? cout << "yes\t" : cout << "no\t" ;
|
||||
states[i]->markedOnBottom ? cout << "yes\t" : cout << "no\t" ;
|
||||
states[i]->visited ? cout << "yes\t" : cout << "no\t" ;
|
||||
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
|
||||
cout << endl;
|
||||
} else {
|
||||
cout << "no\tno\tno\t" ;
|
||||
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
cout << endl;
|
||||
*/
|
||||
BayesNet* bn = new BayesNet();
|
||||
constructGraph (bn, states);
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
delete states[i];
|
||||
}
|
||||
return bn;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::constructGraph (BayesNet* bn,
|
||||
const vector<StateInfo*>& states) const
|
||||
{
|
||||
BnNodeSet mrnNodes;
|
||||
vector<VarIds> parents;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
bool isRequired = false;
|
||||
if (states[i]) {
|
||||
isRequired = (nodes_[i]->hasEvidence() && states[i]->visited)
|
||||
||
|
||||
states[i]->markedOnTop;
|
||||
}
|
||||
if (isRequired) {
|
||||
parents.push_back (VarIds());
|
||||
if (states[i]->markedOnTop) {
|
||||
const BnNodeSet& ps = nodes_[i]->getParents();
|
||||
for (unsigned j = 0; j < ps.size(); j++) {
|
||||
parents.back().push_back (ps[j]->varId());
|
||||
}
|
||||
}
|
||||
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
|
||||
BayesNode* mrnNode = bn->addNode (nodes_[i]->varId(),
|
||||
nodes_[i]->nrStates(),
|
||||
nodes_[i]->getEvidence(),
|
||||
nodes_[i]->getDistribution());
|
||||
mrnNodes.push_back (mrnNode);
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < mrnNodes.size(); i++) {
|
||||
BnNodeSet ps;
|
||||
for (unsigned j = 0; j < parents[i].size(); j++) {
|
||||
assert (bn->getBayesNode (parents[i][j]) != 0);
|
||||
ps.push_back (bn->getBayesNode (parents[i][j]));
|
||||
}
|
||||
mrnNodes[i]->setParents (ps);
|
||||
}
|
||||
bn->setIndexes();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BayesNet::isPolyTree (void) const
|
||||
{
|
||||
return !containsUndirectedCycle();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::setIndexes (void)
|
||||
DAGraph::setIndexes (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->setIndex (i);
|
||||
@ -389,233 +65,43 @@ BayesNet::setIndexes (void)
|
||||
|
||||
|
||||
void
|
||||
BayesNet::distributionsToLogs (void)
|
||||
{
|
||||
for (unsigned i = 0; i < dists_.size(); i++) {
|
||||
Util::toLog (dists_[i]->params);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::freeDistributions (void)
|
||||
{
|
||||
for (unsigned i = 0; i < dists_.size(); i++) {
|
||||
delete dists_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::printGraphicalModel (void) const
|
||||
DAGraph::clear (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
cout << *nodes_[i];
|
||||
nodes_[i]->clear();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::exportToGraphViz (const char* fileName,
|
||||
bool showNeighborless,
|
||||
const VarIds& highlightVarIds) const
|
||||
DAGraph::exportToGraphViz (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "BayesNet::exportToDotFile()" << endl;
|
||||
cerr << "DAGraph::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "digraph {" << endl;
|
||||
out << "ranksep=1" << endl;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (showNeighborless || nodes_[i]->hasNeighbors()) {
|
||||
out << nodes_[i]->varId() ;
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
out << " [" ;
|
||||
out << "label=\"" << nodes_[i]->label() << "\"," ;
|
||||
out << "style=filled, fillcolor=yellow" ;
|
||||
out << "]" ;
|
||||
} else {
|
||||
out << " [" ;
|
||||
out << "label=\"" << nodes_[i]->label() << "\"" ;
|
||||
out << "]" ;
|
||||
}
|
||||
out << endl;
|
||||
out << nodes_[i]->varId() ;
|
||||
out << " [" ;
|
||||
out << "label=\"" << nodes_[i]->label() << "\"" ;
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
out << ",style=filled, fillcolor=yellow" ;
|
||||
}
|
||||
out << "]" << endl;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
||||
BayesNode* node = getBayesNode (highlightVarIds[i]);
|
||||
if (node) {
|
||||
out << node->varId() ;
|
||||
out << " [shape=box3d]" << endl;
|
||||
} else {
|
||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
const BnNodeSet& childs = nodes_[i]->getChilds();
|
||||
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
out << nodes_[i]->varId() << " -> " << childs[j]->varId() << " [style=bold]" << endl ;
|
||||
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
||||
out << " [style=bold]" << endl ;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::exportToBifFormat (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if(!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "BayesNet::exportToBifFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
out << "<?xml version=\"1.0\" encoding=\"US-ASCII\"?>" << endl;
|
||||
out << "<BIF VERSION=\"0.3\">" << endl;
|
||||
out << "<NETWORK>" << endl;
|
||||
out << "<NAME>" << fileName << "</NAME>" << endl << endl;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
out << "<VARIABLE TYPE=\"nature\">" << endl;
|
||||
out << "\t<NAME>" << nodes_[i]->label() << "</NAME>" << endl;
|
||||
const States& states = nodes_[i]->states();
|
||||
for (unsigned j = 0; j < states.size(); j++) {
|
||||
out << "\t<OUTCOME>" << states[j] << "</OUTCOME>" << endl;
|
||||
}
|
||||
out << "</VARIABLE>" << endl << endl;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
out << "<DEFINITION>" << endl;
|
||||
out << "\t<FOR>" << nodes_[i]->label() << "</FOR>" << endl;
|
||||
const BnNodeSet& parents = nodes_[i]->getParents();
|
||||
for (unsigned j = 0; j < parents.size(); j++) {
|
||||
out << "\t<GIVEN>" << parents[j]->label();
|
||||
out << "</GIVEN>" << endl;
|
||||
}
|
||||
Params params = revertParameterReorder (nodes_[i]->getParameters(),
|
||||
nodes_[i]->nrStates());
|
||||
out << "\t<TABLE>" ;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
out << " " << params[j];
|
||||
}
|
||||
out << " </TABLE>" << endl;
|
||||
out << "</DEFINITION>" << endl << endl;
|
||||
}
|
||||
out << "</NETWORK>" << endl;
|
||||
out << "</BIF>" << endl << endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BayesNet::containsUndirectedCycle (void) const
|
||||
{
|
||||
vector<bool> visited (nodes_.size(), false);
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
int v = nodes_[i]->getIndex();
|
||||
if (!visited[v]) {
|
||||
if (containsUndirectedCycle (v, -1, visited)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BayesNet::containsUndirectedCycle (int v, int p, vector<bool>& visited) const
|
||||
{
|
||||
visited[v] = true;
|
||||
vector<int> adjacencies = getAdjacentNodes (v);
|
||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||
int w = adjacencies[i];
|
||||
if (!visited[w]) {
|
||||
if (containsUndirectedCycle (w, v, visited)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else if (visited[w] && w != p) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false; // no cycle detected in this component
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<int>
|
||||
BayesNet::getAdjacentNodes (int v) const
|
||||
{
|
||||
vector<int> adjacencies;
|
||||
const BnNodeSet& parents = nodes_[v]->getParents();
|
||||
const BnNodeSet& childs = nodes_[v]->getChilds();
|
||||
for (unsigned i = 0; i < parents.size(); i++) {
|
||||
adjacencies.push_back (parents[i]->getIndex());
|
||||
}
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
adjacencies.push_back (childs[i]->getIndex());
|
||||
}
|
||||
return adjacencies;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BayesNet::reorderParameters (const Params& params, unsigned dsize) const
|
||||
{
|
||||
// the interchange format for bayesian networks keeps the probabilities
|
||||
// in the following order:
|
||||
// p(a1|b1,c1) p(a2|b1,c1) p(a1|b1,c2) p(a2|b1,c2) p(a1|b2,c1) p(a2|b2,c1)
|
||||
// p(a1|b2,c2) p(a2|b2,c2).
|
||||
//
|
||||
// however, in clpbn we keep the probabilities in this order:
|
||||
// p(a1|b1,c1) p(a1|b1,c2) p(a1|b2,c1) p(a1|b2,c2) p(a2|b1,c1) p(a2|b1,c2)
|
||||
// p(a2|b2,c1) p(a2|b2,c2).
|
||||
unsigned count = 0;
|
||||
unsigned rowSize = params.size() / dsize;
|
||||
Params reordered;
|
||||
while (reordered.size() < params.size()) {
|
||||
unsigned idx = count;
|
||||
for (unsigned i = 0; i < rowSize; i++) {
|
||||
reordered.push_back (params[idx]);
|
||||
idx += dsize ;
|
||||
}
|
||||
count++;
|
||||
}
|
||||
return reordered;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BayesNet::revertParameterReorder (const Params& params, unsigned dsize) const
|
||||
{
|
||||
unsigned count = 0;
|
||||
unsigned rowSize = params.size() / dsize;
|
||||
Params reordered;
|
||||
while (reordered.size() < params.size()) {
|
||||
unsigned idx = count;
|
||||
for (unsigned i = 0; i < dsize; i++) {
|
||||
reordered.push_back (params[idx]);
|
||||
idx += rowSize;
|
||||
}
|
||||
count ++;
|
||||
}
|
||||
return reordered;
|
||||
}
|
||||
|
||||
|
@ -6,118 +6,83 @@
|
||||
#include <list>
|
||||
#include <map>
|
||||
|
||||
#include "GraphicalModel.h"
|
||||
#include "BayesNode.h"
|
||||
#include "Var.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
class Distribution;
|
||||
|
||||
struct ScheduleInfo
|
||||
{
|
||||
ScheduleInfo (BayesNode* n, bool vfp, bool vfc)
|
||||
{
|
||||
node = n;
|
||||
visitedFromParent = vfp;
|
||||
visitedFromChild = vfc;
|
||||
}
|
||||
BayesNode* node;
|
||||
bool visitedFromParent;
|
||||
bool visitedFromChild;
|
||||
};
|
||||
class Var;
|
||||
|
||||
|
||||
struct StateInfo
|
||||
{
|
||||
StateInfo (void)
|
||||
{
|
||||
visited = true;
|
||||
markedOnTop = false;
|
||||
markedOnBottom = false;
|
||||
}
|
||||
bool visited;
|
||||
bool markedOnTop;
|
||||
bool markedOnBottom;
|
||||
};
|
||||
|
||||
typedef vector<Distribution*> DistSet;
|
||||
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
|
||||
|
||||
|
||||
class BayesNet : public GraphicalModel
|
||||
class DAGraphNode : public Var
|
||||
{
|
||||
public:
|
||||
BayesNet (void) {};
|
||||
~BayesNet (void);
|
||||
DAGraphNode (Var* v) : Var (v) , visited_(false),
|
||||
markedOnTop_(false), markedOnBottom_(false) { }
|
||||
|
||||
void readFromBifFormat (const char*);
|
||||
BayesNode* addNode (string, const States&);
|
||||
// BayesNode* addNode (VarId, unsigned, int, BnNodeSet&, Distribution*);
|
||||
BayesNode* addNode (VarId, unsigned, int, Distribution*);
|
||||
BayesNode* getBayesNode (VarId) const;
|
||||
BayesNode* getBayesNode (string) const;
|
||||
VarNode* getVariableNode (VarId) const;
|
||||
VarNodes getVariableNodes (void) const;
|
||||
void addDistribution (Distribution*);
|
||||
Distribution* getDistribution (unsigned) const;
|
||||
const BnNodeSet& getBayesNodes (void) const;
|
||||
unsigned nrNodes (void) const;
|
||||
BnNodeSet getRootNodes (void) const;
|
||||
BnNodeSet getLeafNodes (void) const;
|
||||
BayesNet* getMinimalRequesiteNetwork (VarId) const;
|
||||
BayesNet* getMinimalRequesiteNetwork (const VarIds&) const;
|
||||
void constructGraph (
|
||||
BayesNet*, const vector<StateInfo*>&) const;
|
||||
bool isPolyTree (void) const;
|
||||
void setIndexes (void);
|
||||
void distributionsToLogs (void);
|
||||
void freeDistributions (void);
|
||||
void printGraphicalModel (void) const;
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIds& = VarIds()) const;
|
||||
void exportToBifFormat (const char*) const;
|
||||
const vector<DAGraphNode*>& childs (void) const { return childs_; }
|
||||
|
||||
vector<DAGraphNode*>& childs (void) { return childs_; }
|
||||
|
||||
const vector<DAGraphNode*>& parents (void) const { return parents_; }
|
||||
|
||||
vector<DAGraphNode*>& parents (void) { return parents_; }
|
||||
|
||||
void addParent (DAGraphNode* p) { parents_.push_back (p); }
|
||||
|
||||
void addChild (DAGraphNode* c) { childs_.push_back (c); }
|
||||
|
||||
bool isVisited (void) const { return visited_; }
|
||||
|
||||
void setAsVisited (void) { visited_ = true; }
|
||||
|
||||
bool isMarkedOnTop (void) const { return markedOnTop_; }
|
||||
|
||||
void markOnTop (void) { markedOnTop_ = true; }
|
||||
|
||||
bool isMarkedOnBottom (void) const { return markedOnBottom_; }
|
||||
|
||||
void markOnBottom (void) { markedOnBottom_ = true; }
|
||||
|
||||
void clear (void) { visited_ = markedOnTop_ = markedOnBottom_ = false; }
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BayesNet);
|
||||
bool visited_;
|
||||
bool markedOnTop_;
|
||||
bool markedOnBottom_;
|
||||
|
||||
bool containsUndirectedCycle (void) const;
|
||||
bool containsUndirectedCycle (int, int, vector<bool>&)const;
|
||||
vector<int> getAdjacentNodes (int) const;
|
||||
Params reorderParameters (const Params&, unsigned) const;
|
||||
Params revertParameterReorder (const Params&, unsigned) const;
|
||||
void scheduleParents (const BayesNode*, Scheduling&) const;
|
||||
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
||||
|
||||
BnNodeSet nodes_;
|
||||
DistSet dists_;
|
||||
|
||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
||||
IndexMap varMap_;
|
||||
vector<DAGraphNode*> childs_;
|
||||
vector<DAGraphNode*> parents_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
|
||||
class DAGraph
|
||||
{
|
||||
const BnNodeSet& ps = n->getParents();
|
||||
for (BnNodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) {
|
||||
sch.push (ScheduleInfo (*it, false, true));
|
||||
}
|
||||
}
|
||||
public:
|
||||
DAGraph (void) { }
|
||||
|
||||
void addNode (DAGraphNode* n);
|
||||
|
||||
void addEdge (VarId vid1, VarId vid2);
|
||||
|
||||
inline void
|
||||
BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
|
||||
{
|
||||
const BnNodeSet& cs = n->getChilds();
|
||||
for (BnNodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) {
|
||||
sch.push (ScheduleInfo (*it, true, false));
|
||||
}
|
||||
}
|
||||
const DAGraphNode* getNode (VarId vid) const;
|
||||
|
||||
DAGraphNode* getNode (VarId vid);
|
||||
|
||||
bool empty (void) const { return nodes_.empty(); }
|
||||
|
||||
void setIndexes (void);
|
||||
|
||||
void clear (void);
|
||||
|
||||
void exportToGraphViz (const char*);
|
||||
|
||||
private:
|
||||
vector<DAGraphNode*> nodes_;
|
||||
|
||||
unordered_map<VarId, DAGraphNode*> varMap_;
|
||||
};
|
||||
|
||||
#endif // HORUS_BAYESNET_H
|
||||
|
||||
|
@ -1,291 +0,0 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "BayesNode.h"
|
||||
|
||||
|
||||
BayesNode::BayesNode (VarId vid,
|
||||
unsigned dsize,
|
||||
int evidence,
|
||||
Distribution* dist)
|
||||
: VarNode (vid, dsize, evidence)
|
||||
{
|
||||
dist_ = dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode::BayesNode (VarId vid,
|
||||
unsigned dsize,
|
||||
int evidence,
|
||||
const BnNodeSet& parents,
|
||||
Distribution* dist)
|
||||
: VarNode (vid, dsize, evidence)
|
||||
{
|
||||
parents_ = parents;
|
||||
dist_ = dist;
|
||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||
parents[i]->addChild (this);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNode::setParents (const BnNodeSet& parents)
|
||||
{
|
||||
parents_ = parents;
|
||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||
parents[i]->addChild (this);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNode::addChild (BayesNode* node)
|
||||
{
|
||||
childs_.push_back (node);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNode::setDistribution (Distribution* dist)
|
||||
{
|
||||
assert (dist);
|
||||
dist_ = dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Distribution*
|
||||
BayesNode::getDistribution (void)
|
||||
{
|
||||
return dist_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
const Params&
|
||||
BayesNode::getParameters (void)
|
||||
{
|
||||
return dist_->params;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BayesNode::getRow (int rowIndex) const
|
||||
{
|
||||
int rowSize = getRowSize();
|
||||
int offset = rowSize * rowIndex;
|
||||
Params row (rowSize);
|
||||
for (int i = 0; i < rowSize; i++) {
|
||||
row[i] = dist_->params[offset + i] ;
|
||||
}
|
||||
return row;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BayesNode::isRoot (void)
|
||||
{
|
||||
return getParents().empty();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BayesNode::isLeaf (void)
|
||||
{
|
||||
return getChilds().empty();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BayesNode::hasNeighbors (void) const
|
||||
{
|
||||
return childs_.size() != 0 || parents_.size() != 0;
|
||||
}
|
||||
|
||||
|
||||
int
|
||||
BayesNode::getCptSize (void)
|
||||
{
|
||||
return dist_->params.size();
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
BayesNode::getIndexOfParent (const BayesNode* parent) const
|
||||
{
|
||||
for (unsigned int i = 0; i < parents_.size(); i++) {
|
||||
if (parents_[i] == parent) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
BayesNode::cptEntryToString (
|
||||
int row,
|
||||
const vector<unsigned>& stateConf) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "p(" ;
|
||||
ss << states()[row];
|
||||
if (parents_.size() > 0) {
|
||||
ss << "|" ;
|
||||
for (unsigned int i = 0; i < stateConf.size(); i++) {
|
||||
if (i != 0) {
|
||||
ss << ",";
|
||||
}
|
||||
ss << parents_[i]->states()[stateConf[i]];
|
||||
}
|
||||
}
|
||||
ss << ")" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<string>
|
||||
BayesNode::getDomainHeaders (void) const
|
||||
{
|
||||
unsigned nParents = parents_.size();
|
||||
unsigned rowSize = getRowSize();
|
||||
unsigned nReps = 1;
|
||||
vector<string> headers (rowSize);
|
||||
for (int i = nParents - 1; i >= 0; i--) {
|
||||
States states = parents_[i]->states();
|
||||
unsigned index = 0;
|
||||
while (index < rowSize) {
|
||||
for (unsigned j = 0; j < parents_[i]->nrStates(); j++) {
|
||||
for (unsigned r = 0; r < nReps; r++) {
|
||||
if (headers[index] != "") {
|
||||
headers[index] = states[j] + "," + headers[index];
|
||||
} else {
|
||||
headers[index] = states[j];
|
||||
}
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
nReps *= parents_[i]->nrStates();
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream&
|
||||
operator << (ostream& o, const BayesNode& node)
|
||||
{
|
||||
o << "variable " << node.getIndex() << endl;
|
||||
o << "Var Id: " << node.varId() << endl;
|
||||
o << "Label: " << node.label() << endl;
|
||||
|
||||
o << "Evidence: " ;
|
||||
if (node.hasEvidence()) {
|
||||
o << node.getEvidence();
|
||||
}
|
||||
else {
|
||||
o << "no" ;
|
||||
}
|
||||
o << endl;
|
||||
|
||||
o << "Parents: " ;
|
||||
const BnNodeSet& parents = node.getParents();
|
||||
if (parents.size() != 0) {
|
||||
for (unsigned int i = 0; i < parents.size() - 1; i++) {
|
||||
o << parents[i]->label() << ", " ;
|
||||
}
|
||||
o << parents[parents.size() - 1]->label();
|
||||
}
|
||||
o << endl;
|
||||
|
||||
o << "Childs: " ;
|
||||
const BnNodeSet& childs = node.getChilds();
|
||||
if (childs.size() != 0) {
|
||||
for (unsigned int i = 0; i < childs.size() - 1; i++) {
|
||||
o << childs[i]->label() << ", " ;
|
||||
}
|
||||
o << childs[childs.size() - 1]->label();
|
||||
}
|
||||
o << endl;
|
||||
|
||||
o << "Domain: " ;
|
||||
States states = node.states();
|
||||
for (unsigned int i = 0; i < states.size() - 1; i++) {
|
||||
o << states[i] << ", " ;
|
||||
}
|
||||
if (states.size() != 0) {
|
||||
o << states[states.size() - 1];
|
||||
}
|
||||
o << endl;
|
||||
|
||||
// min width of first column
|
||||
const unsigned int MIN_DOMAIN_WIDTH = 4;
|
||||
// min width of following columns
|
||||
const unsigned int MIN_COMBO_WIDTH = 12;
|
||||
|
||||
unsigned int domainWidth = states[0].length();
|
||||
for (unsigned int i = 1; i < states.size(); i++) {
|
||||
if (states[i].length() > domainWidth) {
|
||||
domainWidth = states[i].length();
|
||||
}
|
||||
}
|
||||
domainWidth = (domainWidth < MIN_DOMAIN_WIDTH)
|
||||
? MIN_DOMAIN_WIDTH
|
||||
: domainWidth;
|
||||
|
||||
o << left << setw (domainWidth) << "cpt" << right;
|
||||
|
||||
vector<int> widths;
|
||||
int lineWidth = domainWidth;
|
||||
vector<string> headers = node.getDomainHeaders();
|
||||
|
||||
if (!headers.empty()) {
|
||||
for (unsigned int i = 0; i < headers.size(); i++) {
|
||||
unsigned int len = headers[i].length();
|
||||
int w = (len < MIN_COMBO_WIDTH) ? MIN_COMBO_WIDTH : len;
|
||||
widths.push_back (w);
|
||||
o << setw (w) << headers[i];
|
||||
lineWidth += w;
|
||||
}
|
||||
o << endl;
|
||||
} else {
|
||||
cout << endl;
|
||||
widths.push_back (domainWidth);
|
||||
lineWidth += MIN_COMBO_WIDTH;
|
||||
}
|
||||
|
||||
for (int i = 0; i < lineWidth; i++) {
|
||||
o << "-" ;
|
||||
}
|
||||
o << endl;
|
||||
|
||||
for (unsigned int i = 0; i < states.size(); i++) {
|
||||
Params row = node.getRow (i);
|
||||
o << left << setw (domainWidth) << states[i] << right;
|
||||
for (unsigned j = 0; j < node.getRowSize(); j++) {
|
||||
o << setw (widths[j]) << row[j];
|
||||
}
|
||||
o << endl;
|
||||
}
|
||||
o << endl;
|
||||
|
||||
return o;
|
||||
}
|
||||
|
@ -1,61 +0,0 @@
|
||||
#ifndef HORUS_BAYESNODE_H
|
||||
#define HORUS_BAYESNODE_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "VarNode.h"
|
||||
#include "Distribution.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class BayesNode : public VarNode
|
||||
{
|
||||
public:
|
||||
BayesNode (const VarNode& v) : VarNode (v) {}
|
||||
BayesNode (VarId, unsigned, int, Distribution*);
|
||||
BayesNode (VarId, unsigned, int, const BnNodeSet&, Distribution*);
|
||||
|
||||
void setParents (const BnNodeSet&);
|
||||
void addChild (BayesNode*);
|
||||
void setDistribution (Distribution*);
|
||||
Distribution* getDistribution (void);
|
||||
const Params& getParameters (void);
|
||||
Params getRow (int) const;
|
||||
bool isRoot (void);
|
||||
bool isLeaf (void);
|
||||
bool hasNeighbors (void) const;
|
||||
int getCptSize (void);
|
||||
int getIndexOfParent (const BayesNode*) const;
|
||||
string cptEntryToString (int, const vector<unsigned>&) const;
|
||||
|
||||
const BnNodeSet& getParents (void) const { return parents_; }
|
||||
const BnNodeSet& getChilds (void) const { return childs_; }
|
||||
|
||||
unsigned getRowSize (void) const
|
||||
{
|
||||
return dist_->params.size() / nrStates();
|
||||
}
|
||||
|
||||
double getProbability (int row, unsigned col)
|
||||
{
|
||||
int idx = (row * getRowSize()) + col;
|
||||
return dist_->params[idx];
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BayesNode);
|
||||
|
||||
States getDomainHeaders (void) const;
|
||||
friend ostream& operator << (ostream&, const BayesNode&);
|
||||
|
||||
BnNodeSet parents_;
|
||||
BnNodeSet childs_;
|
||||
Distribution* dist_;
|
||||
};
|
||||
|
||||
ostream& operator << (ostream&, const BayesNode&);
|
||||
|
||||
#endif // HORUS_BAYESNODE_H
|
||||
|
@ -1,803 +0,0 @@
|
||||
#include <cstdlib>
|
||||
#include <limits>
|
||||
#include <time.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "BnBpSolver.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
BnBpSolver::BnBpSolver (const BayesNet& bn) : Solver (&bn)
|
||||
{
|
||||
bayesNet_ = &bn;
|
||||
}
|
||||
|
||||
|
||||
|
||||
BnBpSolver::~BnBpSolver (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
||||
delete nodesI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::runSolver (void)
|
||||
{
|
||||
clock_t start;
|
||||
if (COLLECT_STATISTICS) {
|
||||
start = clock();
|
||||
}
|
||||
initializeSolver();
|
||||
runLoopySolver();
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Belief propagation converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned size = bayesNet_->nrNodes();
|
||||
if (COLLECT_STATISTICS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = bayesNet_->isPolyTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
}
|
||||
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
|
||||
stringstream ss;
|
||||
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
|
||||
bayesNet_->exportToGraphViz (ss.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BnBpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
BayesNode* node = bayesNet_->getBayesNode (vid);
|
||||
assert (node);
|
||||
return nodesI_[node->getIndex()]->getBeliefs();
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BnBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
{
|
||||
if (DL >= 2) {
|
||||
cout << "calculating joint distribution on: " ;
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
VarNode* var = bayesNet_->getBayesNode (jointVarIds[i]);
|
||||
cout << var->label() << " " ;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
return getJointByConditioning (jointVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::initializeSolver (void)
|
||||
{
|
||||
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
||||
delete nodesI_[i];
|
||||
}
|
||||
nodesI_.clear();
|
||||
nodesI_.reserve (nodes.size());
|
||||
links_.clear();
|
||||
sortedOrder_.clear();
|
||||
linkMap_.clear();
|
||||
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
nodesI_.push_back (new BpNodeInfo (nodes[i]));
|
||||
}
|
||||
|
||||
BnNodeSet roots = bayesNet_->getRootNodes();
|
||||
for (unsigned i = 0; i < roots.size(); i++) {
|
||||
const Params& params = roots[i]->getParameters();
|
||||
Params& piVals = ninf(roots[i])->getPiValues();
|
||||
for (unsigned ri = 0; ri < roots[i]->nrStates(); ri++) {
|
||||
piVals[ri] = params[ri];
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
const BnNodeSet& parents = nodes[i]->getParents();
|
||||
for (unsigned j = 0; j < parents.size(); j++) {
|
||||
BpLink* newLink = new BpLink (
|
||||
parents[j], nodes[i], LinkOrientation::DOWN);
|
||||
links_.push_back (newLink);
|
||||
ninf(nodes[i])->addIncomingParentLink (newLink);
|
||||
ninf(parents[j])->addOutcomingChildLink (newLink);
|
||||
}
|
||||
const BnNodeSet& childs = nodes[i]->getChilds();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
BpLink* newLink = new BpLink (
|
||||
childs[j], nodes[i], LinkOrientation::UP);
|
||||
links_.push_back (newLink);
|
||||
ninf(nodes[i])->addIncomingChildLink (newLink);
|
||||
ninf(childs[j])->addOutcomingParentLink (newLink);
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
if (nodes[i]->hasEvidence()) {
|
||||
Params& piVals = ninf(nodes[i])->getPiValues();
|
||||
Params& ldVals = ninf(nodes[i])->getLambdaValues();
|
||||
for (unsigned xi = 0; xi < nodes[i]->nrStates(); xi++) {
|
||||
piVals[xi] = Util::noEvidence();
|
||||
ldVals[xi] = Util::noEvidence();
|
||||
}
|
||||
piVals[nodes[i]->getEvidence()] = Util::withEvidence();
|
||||
ldVals[nodes[i]->getEvidence()] = Util::withEvidence();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::runLoopySolver()
|
||||
{
|
||||
nIters_ = 0;
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
|
||||
nIters_++;
|
||||
if (DL >= 2) {
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
cout << " Iteration " << nIters_ << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
switch (BpOptions::schedule) {
|
||||
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
updateValues (links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
updateMessage (links_[i]);
|
||||
updateValues (links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
|
||||
}
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BnBpSolver::converged (void) const
|
||||
{
|
||||
// this can happen if the graph is fully disconnected
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIters_ == 0 || nIters_ == 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual < BpOptions::accuracy) {
|
||||
converged = true;
|
||||
} else {
|
||||
converged = false;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (DL >= 2) {
|
||||
cout << links_[i]->toString() + " residual change = " ;
|
||||
cout << residual << endl;
|
||||
}
|
||||
if (residual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
cout << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
BpLink* link = *it;
|
||||
if (link->getResidual() < BpOptions::accuracy) {
|
||||
sortedOrder_.erase (it);
|
||||
it = sortedOrder_.begin();
|
||||
return;
|
||||
}
|
||||
updateMessage (link);
|
||||
updateValues (link);
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
const BpLinkSet& outParentLinks =
|
||||
ninf(link->getDestination())->getOutcomingParentLinks();
|
||||
for (unsigned i = 0; i < outParentLinks.size(); i++) {
|
||||
if (outParentLinks[i]->getDestination() != link->getSource()
|
||||
&& outParentLinks[i]->getDestination()->hasEvidence() == false) {
|
||||
calculateMessage (outParentLinks[i]);
|
||||
BpLinkMap::iterator iter = linkMap_.find (outParentLinks[i]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (outParentLinks[i]);
|
||||
}
|
||||
}
|
||||
const BpLinkSet& outChildLinks =
|
||||
ninf(link->getDestination())->getOutcomingChildLinks();
|
||||
for (unsigned i = 0; i < outChildLinks.size(); i++) {
|
||||
if (outChildLinks[i]->getDestination() != link->getSource()) {
|
||||
calculateMessage (outChildLinks[i]);
|
||||
BpLinkMap::iterator iter = linkMap_.find (outChildLinks[i]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (outChildLinks[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (DL >= 2) {
|
||||
cout << "----------------------------------------" ;
|
||||
cout << "----------------------------------------" << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::updatePiValues (BayesNode* x)
|
||||
{
|
||||
// π(Xi)
|
||||
if (DL >= 3) {
|
||||
cout << "updating " << PI_SYMBOL << " values for " << x->label() << endl;
|
||||
}
|
||||
Params& piValues = ninf(x)->getPiValues();
|
||||
const BpLinkSet& parentLinks = ninf(x)->getIncomingParentLinks();
|
||||
const BnNodeSet& ps = x->getParents();
|
||||
Ranges ranges;
|
||||
for (unsigned i = 0; i < ps.size(); i++) {
|
||||
ranges.push_back (ps[i]->nrStates());
|
||||
}
|
||||
StatesIndexer indexer (ranges, false);
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
Params messageProducts (indexer.size());
|
||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double messageProduct = Util::multIdenty();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
||||
messageProduct += parentLinks[i]->getMessage()[indexer[i]];
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
||||
messageProduct *= parentLinks[i]->getMessage()[indexer[i]];
|
||||
if (DL >= 5) {
|
||||
if (i != 0) *calcs1 << " + " ;
|
||||
if (i != 0) *calcs2 << " + " ;
|
||||
*calcs1 << parentLinks[i]->toString (indexer[i]);
|
||||
*calcs2 << parentLinks[i]->getMessage()[indexer[i]];
|
||||
}
|
||||
}
|
||||
}
|
||||
messageProducts[k] = messageProduct;
|
||||
if (DL >= 5) {
|
||||
cout << " mp" << k;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (parentLinks.size() == 1) {
|
||||
cout << " = " << messageProduct << endl;
|
||||
} else {
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << messageProduct << endl;
|
||||
}
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
|
||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||
double sum = Util::addIdenty();
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
indexer.reset();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
||||
Util::logSum (sum,
|
||||
x->getProbability(xi, indexer.linearIndex()) + messageProducts[k]);
|
||||
++ indexer;
|
||||
}
|
||||
} else {
|
||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
||||
sum += x->getProbability (xi, indexer.linearIndex()) * messageProducts[k];
|
||||
if (DL >= 5) {
|
||||
if (k != 0) *calcs1 << " + " ;
|
||||
if (k != 0) *calcs2 << " + " ;
|
||||
*calcs1 << x->cptEntryToString (xi, indexer.indices());
|
||||
*calcs1 << ".mp" << k;
|
||||
*calcs2 << Util::fl (x->getProbability (xi, indexer.linearIndex()));
|
||||
*calcs2 << "*" << messageProducts[k];
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
|
||||
piValues[xi] = sum;
|
||||
if (DL >= 5) {
|
||||
cout << " " << PI_SYMBOL << "(" << x->label() << ")" ;
|
||||
cout << "[" << x->states()[xi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << piValues[xi] << endl;
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::updateLambdaValues (BayesNode* x)
|
||||
{
|
||||
// λ(Xi)
|
||||
if (DL >= 3) {
|
||||
cout << "updating " << LD_SYMBOL << " values for " << x->label() << endl;
|
||||
}
|
||||
Params& lambdaValues = ninf(x)->getLambdaValues();
|
||||
const BpLinkSet& childLinks = ninf(x)->getIncomingChildLinks();
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double product = Util::multIdenty();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
||||
product += childLinks[i]->getMessage()[xi];
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
||||
product *= childLinks[i]->getMessage()[xi];
|
||||
if (DL >= 5) {
|
||||
if (i != 0) *calcs1 << "." ;
|
||||
if (i != 0) *calcs2 << "*" ;
|
||||
*calcs1 << childLinks[i]->toString (xi);
|
||||
*calcs2 << childLinks[i]->getMessage()[xi];
|
||||
}
|
||||
}
|
||||
}
|
||||
lambdaValues[xi] = product;
|
||||
if (DL >= 5) {
|
||||
cout << " " << LD_SYMBOL << "(" << x->label() << ")" ;
|
||||
cout << "[" << x->states()[xi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (childLinks.size() == 1) {
|
||||
cout << " = " << product << endl;
|
||||
} else {
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << lambdaValues[xi] << endl;
|
||||
}
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::calculatePiMessage (BpLink* link)
|
||||
{
|
||||
// πX(Zi)
|
||||
BayesNode* z = link->getSource();
|
||||
BayesNode* x = link->getDestination();
|
||||
Params& zxPiNextMessage = link->getNextMessage();
|
||||
const BpLinkSet& zChildLinks = ninf(z)->getIncomingChildLinks();
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
const Params& zPiValues = ninf(z)->getPiValues();
|
||||
for (unsigned zi = 0; zi < z->nrStates(); zi++) {
|
||||
double product = zPiValues[zi];
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
*calcs1 << PI_SYMBOL << "(" << z->label() << ")";
|
||||
*calcs1 << "[" << z->states()[zi] << "]" ;
|
||||
*calcs2 << product;
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
||||
if (zChildLinks[i]->getSource() != x) {
|
||||
product += zChildLinks[i]->getMessage()[zi];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
||||
if (zChildLinks[i]->getSource() != x) {
|
||||
product *= zChildLinks[i]->getMessage()[zi];
|
||||
if (DL >= 5) {
|
||||
*calcs1 << "." << zChildLinks[i]->toString (zi);
|
||||
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
zxPiNextMessage[zi] = product;
|
||||
if (DL >= 5) {
|
||||
cout << " " << link->toString();
|
||||
cout << "[" << z->states()[zi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (zChildLinks.size() == 1) {
|
||||
cout << " = " << product << endl;
|
||||
} else {
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << product << endl;
|
||||
}
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
Util::normalize (zxPiNextMessage);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::calculateLambdaMessage (BpLink* link)
|
||||
{
|
||||
// λY(Xi)
|
||||
BayesNode* y = link->getSource();
|
||||
BayesNode* x = link->getDestination();
|
||||
if (x->hasEvidence()) {
|
||||
return;
|
||||
}
|
||||
Params& yxLambdaNextMessage = link->getNextMessage();
|
||||
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
|
||||
const Params& yLambdaValues = ninf(y)->getLambdaValues();
|
||||
int parentIndex = y->getIndexOfParent (x);
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
const BnNodeSet& ps = y->getParents();
|
||||
Ranges ranges;
|
||||
for (unsigned i = 0; i < ps.size(); i++) {
|
||||
ranges.push_back (ps[i]->nrStates());
|
||||
}
|
||||
StatesIndexer indexer (ranges, false);
|
||||
|
||||
|
||||
unsigned N = indexer.size() / x->nrStates();
|
||||
Params messageProducts (N);
|
||||
for (unsigned k = 0; k < N; k++) {
|
||||
while (indexer[parentIndex] != 0) {
|
||||
++ indexer;
|
||||
}
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double messageProduct = Util::multIdenty();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
||||
if (yParentLinks[i]->getSource() != x) {
|
||||
messageProduct += yParentLinks[i]->getMessage()[indexer[i]];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
||||
if (yParentLinks[i]->getSource() != x) {
|
||||
if (DL >= 5) {
|
||||
if (messageProduct != Util::multIdenty()) *calcs1 << "*" ;
|
||||
if (messageProduct != Util::multIdenty()) *calcs2 << "*" ;
|
||||
*calcs1 << yParentLinks[i]->toString (indexer[i]);
|
||||
*calcs2 << yParentLinks[i]->getMessage()[indexer[i]];
|
||||
}
|
||||
messageProduct *= yParentLinks[i]->getMessage()[indexer[i]];
|
||||
}
|
||||
}
|
||||
}
|
||||
messageProducts[k] = messageProduct;
|
||||
++ indexer;
|
||||
if (DL >= 5) {
|
||||
cout << " mp" << k;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (yParentLinks.size() == 1) {
|
||||
cout << 1 << endl;
|
||||
} else if (yParentLinks.size() == 2) {
|
||||
cout << " = " << messageProduct << endl;
|
||||
} else {
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << messageProduct << endl;
|
||||
}
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double outerSum = Util::addIdenty();
|
||||
for (unsigned yi = 0; yi < y->nrStates(); yi++) {
|
||||
if (DL >= 5) {
|
||||
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
|
||||
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
|
||||
}
|
||||
double innerSum = Util::addIdenty();
|
||||
indexer.reset();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned k = 0; k < N; k++) {
|
||||
while (indexer[parentIndex] != xi) {
|
||||
++ indexer;
|
||||
}
|
||||
Util::logSum (innerSum, y->getProbability (
|
||||
yi, indexer.linearIndex()) + messageProducts[k]);
|
||||
++ indexer;
|
||||
}
|
||||
Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
|
||||
} else {
|
||||
for (unsigned k = 0; k < N; k++) {
|
||||
while (indexer[parentIndex] != xi) {
|
||||
++ indexer;
|
||||
}
|
||||
if (DL >= 5) {
|
||||
if (k != 0) *calcs1 << " + " ;
|
||||
if (k != 0) *calcs2 << " + " ;
|
||||
*calcs1 << y->cptEntryToString (yi, indexer.indices());
|
||||
*calcs1 << ".mp" << k;
|
||||
*calcs2 << y->getProbability (yi, indexer.linearIndex());
|
||||
*calcs2 << "*" << messageProducts[k];
|
||||
}
|
||||
innerSum += y->getProbability (
|
||||
yi, indexer.linearIndex()) * messageProducts[k];
|
||||
++ indexer;
|
||||
}
|
||||
outerSum += innerSum * yLambdaValues[yi];
|
||||
}
|
||||
if (DL >= 5) {
|
||||
*calcs1 << "}." << LD_SYMBOL << "(" << y->label() << ")" ;
|
||||
*calcs1 << "[" << y->states()[yi] << "]";
|
||||
*calcs2 << "}*" << yLambdaValues[yi];
|
||||
}
|
||||
}
|
||||
yxLambdaNextMessage[xi] = outerSum;
|
||||
if (DL >= 5) {
|
||||
cout << " " << link->toString();
|
||||
cout << "[" << x->states()[xi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << yxLambdaNextMessage[xi] << endl;
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
Util::normalize (yxLambdaNextMessage);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BnBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
BnNodeSet jointVars;
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
assert (bayesNet_->getBayesNode (jointVarIds[i]));
|
||||
jointVars.push_back (bayesNet_->getBayesNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
BayesNet* mrn = bayesNet_->getMinimalRequesiteNetwork (jointVarIds[0]);
|
||||
BnBpSolver solver (*mrn);
|
||||
solver.runSolver();
|
||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||
delete mrn;
|
||||
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
|
||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
||||
assert (jointVars[i]->hasEvidence() == false);
|
||||
VarIds reqVars = {jointVarIds[i]};
|
||||
reqVars.insert (reqVars.end(), observedVids.begin(), observedVids.end());
|
||||
mrn = bayesNet_->getMinimalRequesiteNetwork (reqVars);
|
||||
Params newBeliefs;
|
||||
VarNodes observedVars;
|
||||
for (unsigned j = 0; j < observedVids.size(); j++) {
|
||||
observedVars.push_back (mrn->getBayesNode (observedVids[j]));
|
||||
}
|
||||
StatesIndexer idx (observedVars, false);
|
||||
while (idx.valid()) {
|
||||
for (unsigned j = 0; j < observedVars.size(); j++) {
|
||||
observedVars[j]->setEvidence (idx[j]);
|
||||
}
|
||||
BnBpSolver solver (*mrn);
|
||||
solver.runSolver();
|
||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
++ idx;
|
||||
}
|
||||
|
||||
int count = -1;
|
||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->nrStates() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
observedVids.push_back (jointVars[i]->varId());
|
||||
delete mrn;
|
||||
}
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::printPiLambdaValues (const BayesNode* var) const
|
||||
{
|
||||
cout << left;
|
||||
cout << setw (10) << "states" ;
|
||||
cout << setw (20) << PI_SYMBOL << "(" + var->label() + ")" ;
|
||||
cout << setw (20) << LD_SYMBOL << "(" + var->label() + ")" ;
|
||||
cout << setw (16) << "belief" ;
|
||||
cout << endl;
|
||||
cout << "--------------------------------" ;
|
||||
cout << "--------------------------------" ;
|
||||
cout << endl;
|
||||
const States& states = var->states();
|
||||
const Params& piVals = ninf(var)->getPiValues();
|
||||
const Params& ldVals = ninf(var)->getLambdaValues();
|
||||
const Params& beliefs = ninf(var)->getBeliefs();
|
||||
for (unsigned xi = 0; xi < var->nrStates(); xi++) {
|
||||
cout << setw (10) << states[xi];
|
||||
cout << setw (19) << piVals[xi];
|
||||
cout << setw (19) << ldVals[xi];
|
||||
cout.precision (PRECISION);
|
||||
cout << setw (16) << beliefs[xi];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::printAllMessageStatus (void) const
|
||||
{
|
||||
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
printPiLambdaValues (nodes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
BpNodeInfo::BpNodeInfo (BayesNode* node)
|
||||
{
|
||||
node_ = node;
|
||||
piVals_.resize (node->nrStates(), Util::one());
|
||||
ldVals_.resize (node->nrStates(), Util::one());
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BpNodeInfo::getBeliefs (void) const
|
||||
{
|
||||
double sum = 0.0;
|
||||
Params beliefs (node_->nrStates());
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
||||
beliefs[xi] = exp (piVals_[xi] + ldVals_[xi]);
|
||||
sum += beliefs[xi];
|
||||
}
|
||||
} else {
|
||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
||||
beliefs[xi] = piVals_[xi] * ldVals_[xi];
|
||||
sum += beliefs[xi];
|
||||
}
|
||||
}
|
||||
assert (sum);
|
||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
||||
beliefs[xi] /= sum;
|
||||
}
|
||||
return beliefs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BpNodeInfo::receivedBottomInfluence (void) const
|
||||
{
|
||||
// if all lambda values are equal, then neither
|
||||
// this node neither its descendents have evidence,
|
||||
// we can use this to don't send lambda messages his parents
|
||||
bool childInfluenced = false;
|
||||
for (unsigned xi = 1; xi < node_->nrStates(); xi++) {
|
||||
if (ldVals_[xi] != ldVals_[0]) {
|
||||
childInfluenced = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return childInfluenced;
|
||||
}
|
||||
|
@ -1,245 +0,0 @@
|
||||
#ifndef HORUS_BNBPSOLVER_H
|
||||
#define HORUS_BNBPSOLVER_H
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "BayesNet.h"
|
||||
#include "Horus.h"
|
||||
#include "Util.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class BpNodeInfo;
|
||||
|
||||
static const string PI_SYMBOL = "pi" ;
|
||||
static const string LD_SYMBOL = "ld" ;
|
||||
|
||||
enum LinkOrientation {UP, DOWN};
|
||||
|
||||
class BpLink
|
||||
{
|
||||
public:
|
||||
BpLink (BayesNode* s, BayesNode* d, LinkOrientation o)
|
||||
{
|
||||
source_ = s;
|
||||
destin_ = d;
|
||||
orientation_ = o;
|
||||
if (orientation_ == LinkOrientation::DOWN) {
|
||||
v1_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
|
||||
v2_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
|
||||
} else {
|
||||
v1_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
|
||||
v2_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
|
||||
}
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
residual_ = 0;
|
||||
msgSended_ = false;
|
||||
}
|
||||
|
||||
void updateMessage (void)
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
msgSended_ = true;
|
||||
}
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = Util::getMaxNorm (v1_, v2_);
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
if (orientation_ == LinkOrientation::DOWN) {
|
||||
ss << PI_SYMBOL;
|
||||
} else {
|
||||
ss << LD_SYMBOL;
|
||||
}
|
||||
ss << "(" << source_->label();
|
||||
ss << " --> " << destin_->label() << ")" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
string toString (unsigned stateIndex) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << toString() << "[" ;
|
||||
if (orientation_ == LinkOrientation::DOWN) {
|
||||
ss << source_->states()[stateIndex] << "]" ;
|
||||
} else {
|
||||
ss << destin_->states()[stateIndex] << "]" ;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
BayesNode* getSource (void) const { return source_; }
|
||||
BayesNode* getDestination (void) const { return destin_; }
|
||||
LinkOrientation getOrientation (void) const { return orientation_; }
|
||||
const Params& getMessage (void) const { return *currMsg_; }
|
||||
Params& getNextMessage (void) { return *nextMsg_; }
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
double getResidual (void) const { return residual_; }
|
||||
void clearResidual (void) { residual_ = 0;}
|
||||
|
||||
private:
|
||||
BayesNode* source_;
|
||||
BayesNode* destin_;
|
||||
LinkOrientation orientation_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
|
||||
|
||||
typedef vector<BpLink*> BpLinkSet;
|
||||
|
||||
|
||||
class BpNodeInfo
|
||||
{
|
||||
public:
|
||||
BpNodeInfo (BayesNode*);
|
||||
|
||||
Params getBeliefs (void) const;
|
||||
bool receivedBottomInfluence (void) const;
|
||||
|
||||
Params& getPiValues (void) { return piVals_; }
|
||||
Params& getLambdaValues (void) { return ldVals_; }
|
||||
|
||||
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
|
||||
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
|
||||
const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; }
|
||||
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
|
||||
|
||||
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
|
||||
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
|
||||
void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); }
|
||||
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BpNodeInfo);
|
||||
|
||||
const BayesNode* node_;
|
||||
Params piVals_; // pi values
|
||||
Params ldVals_; // lambda values
|
||||
BpLinkSet inParentLinks_;
|
||||
BpLinkSet inChildLinks_;
|
||||
BpLinkSet outParentLinks_;
|
||||
BpLinkSet outChildLinks_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class BnBpSolver : public Solver
|
||||
{
|
||||
public:
|
||||
BnBpSolver (const BayesNet&);
|
||||
~BnBpSolver (void);
|
||||
|
||||
void runSolver (void);
|
||||
Params getPosterioriOf (VarId);
|
||||
Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BnBpSolver);
|
||||
|
||||
void initializeSolver (void);
|
||||
void runLoopySolver (void);
|
||||
void maxResidualSchedule (void);
|
||||
bool converged (void) const;
|
||||
void updatePiValues (BayesNode*);
|
||||
void updateLambdaValues (BayesNode*);
|
||||
void calculateLambdaMessage (BpLink*);
|
||||
void calculatePiMessage (BpLink*);
|
||||
Params getJointByJunctionNode (const VarIds&);
|
||||
Params getJointByConditioning (const VarIds&) const;
|
||||
void printPiLambdaValues (const BayesNode*) const;
|
||||
void printAllMessageStatus (void) const;
|
||||
|
||||
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
||||
calculatePiMessage (link);
|
||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
||||
calculateLambdaMessage (link);
|
||||
}
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
|
||||
void calculateMessage (BpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
||||
calculatePiMessage (link);
|
||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
||||
calculateLambdaMessage (link);
|
||||
}
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
|
||||
void updateMessage (BpLink* link)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
|
||||
void updateValues (BpLink* link)
|
||||
{
|
||||
if (!link->getDestination()->hasEvidence()) {
|
||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
||||
updatePiValues (link->getDestination());
|
||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
||||
updateLambdaValues (link->getDestination());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BpNodeInfo* ninf (const BayesNode* node) const
|
||||
{
|
||||
assert (node);
|
||||
assert (node == bayesNet_->getBayesNode (node->varId()));
|
||||
assert (node->getIndex() < nodesI_.size());
|
||||
return nodesI_[node->getIndex()];
|
||||
}
|
||||
|
||||
const BayesNet* bayesNet_;
|
||||
vector<BpLink*> links_;
|
||||
vector<BpNodeInfo*> nodesI_;
|
||||
unsigned nIters_;
|
||||
|
||||
struct compare
|
||||
{
|
||||
inline bool operator() (const BpLink* e1, const BpLink* e2)
|
||||
{
|
||||
return e1->getResidual() > e2->getResidual();
|
||||
}
|
||||
};
|
||||
|
||||
typedef multiset<BpLink*, compare> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||
BpLinkMap linkMap_;
|
||||
|
||||
};
|
||||
|
||||
#endif // HORUS_BNBPSOLVER_H
|
||||
|
@ -5,21 +5,22 @@
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "FgBpSolver.h"
|
||||
#include "BpSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Indexer.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
FgBpSolver::FgBpSolver (const FactorGraph& fg) : Solver (&fg)
|
||||
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
||||
{
|
||||
factorGraph_ = &fg;
|
||||
fg_ = &fg;
|
||||
runned_ = false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FgBpSolver::~FgBpSolver (void)
|
||||
BpSolver::~BpSolver (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
@ -34,64 +35,45 @@ FgBpSolver::~FgBpSolver (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::runSolver (void)
|
||||
Params
|
||||
BpSolver::solveQuery (VarIds queryVids)
|
||||
{
|
||||
clock_t start;
|
||||
if (COLLECT_STATISTICS) {
|
||||
start = clock();
|
||||
}
|
||||
runLoopySolver();
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
unsigned size = factorGraph_->getVarNodes().size();
|
||||
if (COLLECT_STATISTICS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = factorGraph_->isTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
}
|
||||
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
|
||||
stringstream ss;
|
||||
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
|
||||
factorGraph_->exportToGraphViz (ss.str().c_str());
|
||||
assert (queryVids.empty() == false);
|
||||
if (queryVids.size() == 1) {
|
||||
return getPosterioriOf (queryVids[0]);
|
||||
} else {
|
||||
return getJointDistributionOf (queryVids);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
FgBpSolver::getPosterioriOf (VarId vid)
|
||||
BpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
assert (factorGraph_->getFgVarNode (vid));
|
||||
FgVarNode* var = factorGraph_->getFgVarNode (vid);
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
assert (fg_->getVarNode (vid));
|
||||
VarNode* var = fg_->getVarNode (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->nrStates(), Util::noEvidence());
|
||||
probs[var->getEvidence()] = Util::withEvidence();
|
||||
probs.resize (var->range(), LogAware::noEvidence());
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->nrStates(), Util::multIdenty());
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::add (probs, links[i]->getMessage());
|
||||
}
|
||||
Util::normalize (probs);
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::multiply (probs, links[i]->getMessage());
|
||||
}
|
||||
Util::normalize (probs);
|
||||
LogAware::normalize (probs);
|
||||
}
|
||||
}
|
||||
return probs;
|
||||
@ -100,13 +82,16 @@ FgBpSolver::getPosterioriOf (VarId vid)
|
||||
|
||||
|
||||
Params
|
||||
FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
{
|
||||
FgVarNode* vn = factorGraph_->getFgVarNode (jointVarIds[0]);
|
||||
const FgFacSet& factorNodes = vn->neighbors();
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
int idx = -1;
|
||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
||||
if (factorNodes[i]->factor()->contains (jointVarIds)) {
|
||||
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
|
||||
const FacNodes& facNodes = vn->neighbors();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
if (facNodes[i]->factor().contains (jointVarIds)) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
@ -114,18 +99,18 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
if (idx == -1) {
|
||||
return getJointByConditioning (jointVarIds);
|
||||
} else {
|
||||
Factor r (*factorNodes[idx]->factor());
|
||||
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
|
||||
Factor res (facNodes[idx]->factor());
|
||||
const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Factor msg (links[i]->getVariable()->varId(),
|
||||
links[i]->getVariable()->nrStates(),
|
||||
Factor msg ({links[i]->getVariable()->varId()},
|
||||
{links[i]->getVariable()->range()},
|
||||
getVar2FactorMsg (links[i]));
|
||||
r.multiply (msg);
|
||||
res.multiply (msg);
|
||||
}
|
||||
r.sumOutAllExcept (jointVarIds);
|
||||
r.reorderVariables (jointVarIds);
|
||||
r.normalize();
|
||||
Params jointDist = r.getParameters();
|
||||
res.sumOutAllExcept (jointVarIds);
|
||||
res.reorderArguments (jointVarIds);
|
||||
res.normalize();
|
||||
Params jointDist = res.params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (jointDist);
|
||||
}
|
||||
@ -136,35 +121,29 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::runLoopySolver (void)
|
||||
BpSolver::runSolver (void)
|
||||
{
|
||||
clock_t start;
|
||||
if (Constants::COLLECT_STATS) {
|
||||
start = clock();
|
||||
}
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
|
||||
nIters_ ++;
|
||||
if (DL >= 2) {
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
cout << " Iteration " << nIters_ << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||
// cout << endl;
|
||||
}
|
||||
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
@ -173,61 +152,43 @@ FgBpSolver::runLoopySolver (void)
|
||||
updateMessage(links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
unsigned size = fg_->varNodes().size();
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = fg_->isTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
}
|
||||
runned_ = true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::initializeSolver (void)
|
||||
BpSolver::createLinks (void)
|
||||
{
|
||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
varsI_.reserve (varNodes.size());
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
varsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
|
||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
||||
delete facsI_[i];
|
||||
}
|
||||
facsI_.reserve (facNodes.size());
|
||||
const FacNodes& facNodes = fg_->facNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
facsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
createLinks();
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
FgFacNode* src = links_[i]->getFactor();
|
||||
FgVarNode* dst = links_[i]->getVariable();
|
||||
ninf (dst)->addSpLink (links_[i]);
|
||||
ninf (src)->addSpLink (links_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::createLinks (void)
|
||||
{
|
||||
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const FgVarSet& neighbors = facNodes[i]->neighbors();
|
||||
const VarNodes& neighbors = facNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
|
||||
}
|
||||
@ -236,42 +197,8 @@ FgBpSolver::createLinks (void)
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FgBpSolver::converged (void)
|
||||
{
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIters_ == 0 || nIters_ == 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
} else {
|
||||
converged = true;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (DL >= 2) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
if (DL == 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::maxResidualSchedule (void)
|
||||
BpSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
@ -283,7 +210,7 @@ FgBpSolver::maxResidualSchedule (void)
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
@ -303,7 +230,7 @@ FgBpSolver::maxResidualSchedule (void)
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
|
||||
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
if (factorNeighbors[i] != link->getFactor()) {
|
||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||
@ -317,9 +244,8 @@ FgBpSolver::maxResidualSchedule (void)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (DL >= 2) {
|
||||
cout << "----------------------------------------" ;
|
||||
cout << "----------------------------------------" << endl;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printDashedLine();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -327,26 +253,26 @@ FgBpSolver::maxResidualSchedule (void)
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
||||
BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
||||
{
|
||||
const FgFacNode* src = link->getFactor();
|
||||
const FgVarNode* dst = link->getVariable();
|
||||
FacNode* src = link->getFactor();
|
||||
const VarNode* dst = link->getVariable();
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dst'
|
||||
unsigned msgSize = 1;
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
msgSize *= links[i]->getVariable()->nrStates();
|
||||
msgSize *= links[i]->getVariable()->range();
|
||||
}
|
||||
unsigned repetitions = 1;
|
||||
Params msgProduct (msgSize, Util::multIdenty());
|
||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||
if (Globals::logDomain) {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->nrStates();
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
} else {
|
||||
unsigned ds = links[i]->getVariable()->nrStates();
|
||||
unsigned ds = links[i]->getVariable()->range();
|
||||
Util::add (msgProduct, Params (ds, 1.0), repetitions);
|
||||
repetitions *= ds;
|
||||
}
|
||||
@ -354,70 +280,64 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
||||
} else {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message from " << links[i]->getVariable()->label();
|
||||
cout << ": " << endl;
|
||||
}
|
||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->nrStates();
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
} else {
|
||||
unsigned ds = links[i]->getVariable()->nrStates();
|
||||
unsigned ds = links[i]->getVariable()->range();
|
||||
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
|
||||
repetitions *= ds;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Factor result (src->factor()->getVarIds(),
|
||||
src->factor()->getRanges(),
|
||||
msgProduct);
|
||||
result.multiply (*(src->factor()));
|
||||
if (DL >= 5) {
|
||||
cout << " message product: " ;
|
||||
cout << Util::parametersToString (msgProduct) << endl;
|
||||
cout << " original factor: " ;
|
||||
cout << Util::parametersToString (src->getParameters()) << endl;
|
||||
cout << " factor product: " ;
|
||||
cout << Util::parametersToString (result.getParameters()) << endl;
|
||||
Factor result (src->factor().arguments(),
|
||||
src->factor().ranges(), msgProduct);
|
||||
result.multiply (src->factor());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->factor().params() << endl;
|
||||
cout << " factor product: " << result.params() << endl;
|
||||
}
|
||||
result.sumOutAllExcept (dst->varId());
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " marginalized: " ;
|
||||
cout << Util::parametersToString (result.getParameters()) << endl;
|
||||
cout << result.params() << endl;
|
||||
}
|
||||
const Params& resultParams = result.getParameters();
|
||||
const Params& resultParams = result.params();
|
||||
Params& message = link->getNextMessage();
|
||||
for (unsigned i = 0; i < resultParams.size(); i++) {
|
||||
message[i] = resultParams[i];
|
||||
}
|
||||
Util::normalize (message);
|
||||
if (DL >= 5) {
|
||||
cout << " curr msg: " ;
|
||||
cout << Util::parametersToString (link->getMessage()) << endl;
|
||||
cout << " next msg: " ;
|
||||
cout << Util::parametersToString (message) << endl;
|
||||
LogAware::normalize (message);
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " curr msg: " << link->getMessage() << endl;
|
||||
cout << " next msg: " << message << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
BpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
{
|
||||
const FgVarNode* src = link->getVariable();
|
||||
const FgFacNode* dst = link->getFactor();
|
||||
const VarNode* src = link->getVariable();
|
||||
const FacNode* dst = link->getFactor();
|
||||
Params msg;
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->nrStates(), Util::noEvidence());
|
||||
msg[src->getEvidence()] = Util::withEvidence();
|
||||
if (DL >= 5) {
|
||||
cout << Util::parametersToString (msg);
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
msg[src->getEvidence()] = LogAware::withEvidence();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << msg;
|
||||
}
|
||||
} else {
|
||||
msg.resize (src->nrStates(), Util::one());
|
||||
msg.resize (src->range(), LogAware::one());
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << Util::parametersToString (msg);
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << msg;
|
||||
}
|
||||
const SpLinkSet& links = ninf (src)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
@ -430,14 +350,14 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
Util::multiply (msg, links[i]->getMessage());
|
||||
if (DL >= 5) {
|
||||
cout << " x " << Util::parametersToString (links[i]->getMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " x " << links[i]->getMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " = " << Util::parametersToString (msg);
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
@ -445,16 +365,16 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
|
||||
|
||||
Params
|
||||
FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
FgVarSet jointVars;
|
||||
VarNodes jointVars;
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
assert (factorGraph_->getFgVarNode (jointVarIds[i]));
|
||||
jointVars.push_back (factorGraph_->getFgVarNode (jointVarIds[i]));
|
||||
assert (fg_->getVarNode (jointVarIds[i]));
|
||||
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
FactorGraph* fg = new FactorGraph (*factorGraph_);
|
||||
FgBpSolver solver (*fg);
|
||||
FactorGraph* fg = new FactorGraph (*fg_);
|
||||
BpSolver solver (*fg);
|
||||
solver.runSolver();
|
||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||
|
||||
@ -463,9 +383,9 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
||||
assert (jointVars[i]->hasEvidence() == false);
|
||||
Params newBeliefs;
|
||||
VarNodes observedVars;
|
||||
Vars observedVars;
|
||||
for (unsigned j = 0; j < observedVids.size(); j++) {
|
||||
observedVars.push_back (fg->getFgVarNode (observedVids[j]));
|
||||
observedVars.push_back (fg->getVarNode (observedVids[j]));
|
||||
}
|
||||
StatesIndexer idx (observedVars, false);
|
||||
while (idx.valid()) {
|
||||
@ -473,7 +393,7 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
observedVars[j]->setEvidence (idx[j]);
|
||||
}
|
||||
++ idx;
|
||||
FgBpSolver solver (*fg);
|
||||
BpSolver solver (*fg);
|
||||
solver.runSolver();
|
||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
||||
@ -483,7 +403,7 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
|
||||
int count = -1;
|
||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->nrStates() == 0) {
|
||||
if (j % jointVars[i]->range() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
@ -497,15 +417,76 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::printLinkInformation (void) const
|
||||
BpSolver::initializeSolver (void)
|
||||
{
|
||||
const VarNodes& varNodes = fg_->varNodes();
|
||||
varsI_.reserve (varNodes.size());
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
varsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
const FacNodes& facNodes = fg_->facNodes();
|
||||
facsI_.reserve (facNodes.size());
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
facsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
createLinks();
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
FacNode* src = links_[i]->getFactor();
|
||||
VarNode* dst = links_[i]->getVariable();
|
||||
ninf (dst)->addSpLink (links_[i]);
|
||||
ninf (src)->addSpLink (links_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BpSolver::converged (void)
|
||||
{
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIters_ <= 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
} else {
|
||||
converged = true;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
if (Constants::DEBUG == 0) break;
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::printLinkInformation (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
SpLink* l = links_[i];
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " ;
|
||||
cout << Util::parametersToString (l->getMessage()) << endl;
|
||||
cout << l->getMessage() << endl;
|
||||
cout << " next msg = " ;
|
||||
cout << Util::parametersToString (l->getNextMessage()) << endl;
|
||||
cout << l->getNextMessage() << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
}
|
||||
}
|
188
packages/CLPBN/clpbn/bp/BpSolver.h
Normal file
188
packages/CLPBN/clpbn/bp/BpSolver.h
Normal file
@ -0,0 +1,188 @@
|
||||
#ifndef HORUS_BPSOLVER_H
|
||||
#define HORUS_BPSOLVER_H
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "Factor.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Util.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class SpLink
|
||||
{
|
||||
public:
|
||||
SpLink (FacNode* fn, VarNode* vn)
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
msgSended_ = false;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
virtual ~SpLink (void) { };
|
||||
|
||||
FacNode* getFactor (void) const { return fac_; }
|
||||
|
||||
VarNode* getVariable (void) const { return var_; }
|
||||
|
||||
const Params& getMessage (void) const { return *currMsg_; }
|
||||
|
||||
Params& getNextMessage (void) { return *nextMsg_; }
|
||||
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
|
||||
double getResidual (void) const { return residual_; }
|
||||
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = LogAware::getMaxNorm (v1_,v2_);
|
||||
}
|
||||
|
||||
virtual void updateMessage (void)
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
msgSended_ = true;
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << fac_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->label();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
FacNode* fac_;
|
||||
VarNode* var_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
|
||||
typedef vector<SpLink*> SpLinkSet;
|
||||
|
||||
|
||||
class SPNodeInfo
|
||||
{
|
||||
public:
|
||||
void addSpLink (SpLink* link) { links_.push_back (link); }
|
||||
const SpLinkSet& getLinks (void) { return links_; }
|
||||
private:
|
||||
SpLinkSet links_;
|
||||
};
|
||||
|
||||
|
||||
class BpSolver : public Solver
|
||||
{
|
||||
public:
|
||||
BpSolver (const FactorGraph&);
|
||||
|
||||
virtual ~BpSolver (void);
|
||||
|
||||
Params solveQuery (VarIds);
|
||||
|
||||
virtual Params getPosterioriOf (VarId);
|
||||
|
||||
virtual Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
protected:
|
||||
void runSolver (void);
|
||||
|
||||
virtual void createLinks (void);
|
||||
|
||||
virtual void maxResidualSchedule (void);
|
||||
|
||||
virtual void calculateFactor2VariableMsg (SpLink*);
|
||||
|
||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
||||
|
||||
virtual Params getJointByConditioning (const VarIds&) const;
|
||||
|
||||
SPNodeInfo* ninf (const VarNode* var) const
|
||||
{
|
||||
return varsI_[var->getIndex()];
|
||||
}
|
||||
|
||||
SPNodeInfo* ninf (const FacNode* fac) const
|
||||
{
|
||||
return facsI_[fac->getIndex()];
|
||||
}
|
||||
|
||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
|
||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
|
||||
void updateMessage (SpLink* link)
|
||||
{
|
||||
link->updateMessage();
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
struct CompareResidual
|
||||
{
|
||||
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
||||
{
|
||||
return link1->getResidual() > link2->getResidual();
|
||||
}
|
||||
};
|
||||
|
||||
SpLinkSet links_;
|
||||
unsigned nIters_;
|
||||
vector<SPNodeInfo*> varsI_;
|
||||
vector<SPNodeInfo*> facsI_;
|
||||
bool runned_;
|
||||
const FactorGraph* fg_;
|
||||
|
||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
|
||||
SpLinkMap linkMap_;
|
||||
|
||||
private:
|
||||
void initializeSolver (void);
|
||||
|
||||
bool converged (void);
|
||||
|
||||
void printLinkInformation (void) const;
|
||||
};
|
||||
|
||||
#endif // HORUS_BPSOLVER_H
|
||||
|
@ -1,7 +1,6 @@
|
||||
|
||||
#include "CFactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Distribution.h"
|
||||
|
||||
|
||||
bool CFactorGraph::checkForIdenticalFactors = true;
|
||||
@ -11,22 +10,22 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
||||
groundFg_ = &fg;
|
||||
freeColor_ = 0;
|
||||
|
||||
const FgVarSet& varNodes = fg.getVarNodes();
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
varSignatures_.reserve (varNodes.size());
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
|
||||
varSignatures_.push_back (Signature (c));
|
||||
}
|
||||
|
||||
const FgFacSet& facNodes = fg.getFactorNodes();
|
||||
factorSignatures_.reserve (facNodes.size());
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
facSignatures_.reserve (facNodes.size());
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
unsigned c = facNodes[i]->neighbors().size() + 1;
|
||||
factorSignatures_.push_back (Signature (c));
|
||||
facSignatures_.push_back (Signature (c));
|
||||
}
|
||||
|
||||
varColors_.resize (varNodes.size());
|
||||
factorColors_.resize (facNodes.size());
|
||||
facColors_.resize (facNodes.size());
|
||||
setInitialColors();
|
||||
createGroups();
|
||||
}
|
||||
@ -50,9 +49,9 @@ CFactorGraph::setInitialColors (void)
|
||||
{
|
||||
// create the initial variable colors
|
||||
VarColorMap colorMap;
|
||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
||||
const VarNodes& varNodes = groundFg_->varNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
unsigned dsize = varNodes[i]->nrStates();
|
||||
unsigned dsize = varNodes[i]->range();
|
||||
VarColorMap::iterator it = colorMap.find (dsize);
|
||||
if (it == colorMap.end()) {
|
||||
it = colorMap.insert (make_pair (
|
||||
@ -71,29 +70,40 @@ CFactorGraph::setInitialColors (void)
|
||||
setColor (varNodes[i], stateColors[idx]);
|
||||
}
|
||||
|
||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
||||
if (checkForIdenticalFactors) {
|
||||
for (unsigned i = 0, s = facNodes.size(); i < s; i++) {
|
||||
Distribution* dist1 = facNodes[i]->getDistribution();
|
||||
for (unsigned j = 0; j < i; j++) {
|
||||
Distribution* dist2 = facNodes[j]->getDistribution();
|
||||
if (dist1 != dist2 && dist1->params == dist2->params) {
|
||||
if (facNodes[i]->factor()->getRanges() ==
|
||||
facNodes[j]->factor()->getRanges()) {
|
||||
facNodes[i]->factor()->setDistribution (dist2);
|
||||
}
|
||||
const FacNodes& facNodes = groundFg_->facNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
facNodes[i]->factor().setDistId (Util::maxUnsigned());
|
||||
}
|
||||
// FIXME FIXME FIXME : pfl should give correct dist ids.
|
||||
if (checkForIdenticalFactors || true) {
|
||||
unsigned groupCount = 1;
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
Factor& f1 = facNodes[i]->factor();
|
||||
if (f1.distId() != Util::maxUnsigned()) {
|
||||
continue;
|
||||
}
|
||||
f1.setDistId (groupCount);
|
||||
for (unsigned j = i + 1; j < facNodes.size(); j++) {
|
||||
Factor& f2 = facNodes[j]->factor();
|
||||
if (f2.distId() != Util::maxUnsigned()) {
|
||||
continue;
|
||||
}
|
||||
if (f1.size() == f2.size() &&
|
||||
f1.ranges() == f2.ranges() &&
|
||||
f1.params() == f2.params()) {
|
||||
f2.setDistId (groupCount);
|
||||
}
|
||||
}
|
||||
groupCount ++;
|
||||
}
|
||||
}
|
||||
|
||||
// create the initial factor colors
|
||||
DistColorMap distColors;
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const Distribution* dist = facNodes[i]->getDistribution();
|
||||
DistColorMap::iterator it = distColors.find (dist);
|
||||
unsigned distId = facNodes[i]->factor().distId();
|
||||
DistColorMap::iterator it = distColors.find (distId);
|
||||
if (it == distColors.end()) {
|
||||
it = distColors.insert (make_pair (dist, getFreeColor())).first;
|
||||
it = distColors.insert (make_pair (distId, getFreeColor())).first;
|
||||
}
|
||||
setColor (facNodes[i], it->second);
|
||||
}
|
||||
@ -104,31 +114,31 @@ CFactorGraph::setInitialColors (void)
|
||||
void
|
||||
CFactorGraph::createGroups (void)
|
||||
{
|
||||
VarSignMap varGroups;
|
||||
FacSignMap factorGroups;
|
||||
VarSignMap varGroups;
|
||||
FacSignMap facGroups;
|
||||
unsigned nIters = 0;
|
||||
bool groupsHaveChanged = true;
|
||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
||||
const VarNodes& varNodes = groundFg_->varNodes();
|
||||
const FacNodes& facNodes = groundFg_->facNodes();
|
||||
|
||||
while (groupsHaveChanged || nIters == 1) {
|
||||
nIters ++;
|
||||
|
||||
unsigned prevFactorGroupsSize = factorGroups.size();
|
||||
factorGroups.clear();
|
||||
unsigned prevFactorGroupsSize = facGroups.size();
|
||||
facGroups.clear();
|
||||
// set a new color to the factors with the same signature
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const Signature& signature = getSignature (facNodes[i]);
|
||||
FacSignMap::iterator it = factorGroups.find (signature);
|
||||
if (it == factorGroups.end()) {
|
||||
it = factorGroups.insert (make_pair (signature, FgFacSet())).first;
|
||||
FacSignMap::iterator it = facGroups.find (signature);
|
||||
if (it == facGroups.end()) {
|
||||
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||
}
|
||||
it->second.push_back (facNodes[i]);
|
||||
}
|
||||
for (FacSignMap::iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
for (FacSignMap::iterator it = facGroups.begin();
|
||||
it != facGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FgFacSet& groupMembers = it->second;
|
||||
FacNodes& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
@ -141,36 +151,37 @@ CFactorGraph::createGroups (void)
|
||||
const Signature& signature = getSignature (varNodes[i]);
|
||||
VarSignMap::iterator it = varGroups.find (signature);
|
||||
if (it == varGroups.end()) {
|
||||
it = varGroups.insert (make_pair (signature, FgVarSet())).first;
|
||||
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||
}
|
||||
it->second.push_back (varNodes[i]);
|
||||
}
|
||||
for (VarSignMap::iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FgVarSet& groupMembers = it->second;
|
||||
VarNodes& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
}
|
||||
|
||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||
|| prevFactorGroupsSize != factorGroups.size();
|
||||
|| prevFactorGroupsSize != facGroups.size();
|
||||
}
|
||||
//printGroups (varGroups, factorGroups);
|
||||
createClusters (varGroups, factorGroups);
|
||||
printGroups (varGroups, facGroups);
|
||||
createClusters (varGroups, facGroups);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CFactorGraph::createClusters (const VarSignMap& varGroups,
|
||||
const FacSignMap& factorGroups)
|
||||
CFactorGraph::createClusters (
|
||||
const VarSignMap& varGroups,
|
||||
const FacSignMap& facGroups)
|
||||
{
|
||||
varClusters_.reserve (varGroups.size());
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
const FgVarSet& groupVars = it->second;
|
||||
const VarNodes& groupVars = it->second;
|
||||
VarCluster* vc = new VarCluster (groupVars);
|
||||
for (unsigned i = 0; i < groupVars.size(); i++) {
|
||||
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
||||
@ -178,12 +189,12 @@ CFactorGraph::createClusters (const VarSignMap& varGroups,
|
||||
varClusters_.push_back (vc);
|
||||
}
|
||||
|
||||
facClusters_.reserve (factorGroups.size());
|
||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
FgFacNode* groupFactor = it->second[0];
|
||||
const FgVarSet& neighs = groupFactor->neighbors();
|
||||
VarClusterSet varClusters;
|
||||
facClusters_.reserve (facGroups.size());
|
||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||
it != facGroups.end(); it++) {
|
||||
FacNode* groupFactor = it->second[0];
|
||||
const VarNodes& neighs = groupFactor->neighbors();
|
||||
VarClusters varClusters;
|
||||
varClusters.reserve (neighs.size());
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
VarId vid = neighs[i]->varId();
|
||||
@ -196,15 +207,15 @@ CFactorGraph::createClusters (const VarSignMap& varGroups,
|
||||
|
||||
|
||||
const Signature&
|
||||
CFactorGraph::getSignature (const FgVarNode* varNode)
|
||||
CFactorGraph::getSignature (const VarNode* varNode)
|
||||
{
|
||||
Signature& sign = varSignatures_[varNode->getIndex()];
|
||||
vector<Color>::iterator it = sign.colors.begin();
|
||||
const FgFacSet& neighs = varNode->neighbors();
|
||||
const FacNodes& neighs = varNode->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
*it = getColor (neighs[i]);
|
||||
it ++;
|
||||
*it = neighs[i]->factor()->indexOf (varNode->varId());
|
||||
*it = neighs[i]->factor().indexOf (varNode->varId());
|
||||
it ++;
|
||||
}
|
||||
*it = getColor (varNode);
|
||||
@ -214,11 +225,11 @@ CFactorGraph::getSignature (const FgVarNode* varNode)
|
||||
|
||||
|
||||
const Signature&
|
||||
CFactorGraph::getSignature (const FgFacNode* facNode)
|
||||
CFactorGraph::getSignature (const FacNode* facNode)
|
||||
{
|
||||
Signature& sign = factorSignatures_[facNode->getIndex()];
|
||||
Signature& sign = facSignatures_[facNode->getIndex()];
|
||||
vector<Color>::iterator it = sign.colors.begin();
|
||||
const FgVarSet& neighs = facNode->neighbors();
|
||||
const VarNodes& neighs = facNode->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
*it = getColor (neighs[i]);
|
||||
it ++;
|
||||
@ -230,55 +241,53 @@ CFactorGraph::getSignature (const FgFacNode* facNode)
|
||||
|
||||
|
||||
FactorGraph*
|
||||
CFactorGraph::getCompressedFactorGraph (void)
|
||||
CFactorGraph::getGroundFactorGraph (void) const
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0];
|
||||
FgVarNode* newVar = new FgVarNode (var);
|
||||
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
|
||||
VarNode* newVar = new VarNode (var);
|
||||
varClusters_[i]->setRepresentativeVariable (newVar);
|
||||
fg->addVariable (newVar);
|
||||
fg->addVarNode (newVar);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
||||
const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters();
|
||||
VarNodes myGroundVars;
|
||||
const VarClusters& myVarClusters = facClusters_[i]->getVarClusters();
|
||||
Vars myGroundVars;
|
||||
myGroundVars.reserve (myVarClusters.size());
|
||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||
FgVarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
||||
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
||||
myGroundVars.push_back (v);
|
||||
}
|
||||
Factor* newFactor = new Factor (myGroundVars,
|
||||
facClusters_[i]->getGroundFactors()[0]->getDistribution());
|
||||
FgFacNode* fn = new FgFacNode (newFactor);
|
||||
FacNode* fn = new FacNode (Factor (myGroundVars,
|
||||
facClusters_[i]->getGroundFactors()[0]->factor().params()));
|
||||
facClusters_[i]->setRepresentativeFactor (fn);
|
||||
fg->addFactor (fn);
|
||||
fg->addFacNode (fn);
|
||||
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
||||
fg->addEdge (fn, static_cast<FgVarNode*> (myGroundVars[j]));
|
||||
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
|
||||
}
|
||||
}
|
||||
fg->setIndexes();
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
CFactorGraph::getGroundEdgeCount (
|
||||
CFactorGraph::getEdgeCount (
|
||||
const FacCluster* fc,
|
||||
const VarCluster* vc) const
|
||||
{
|
||||
const FgFacSet& clusterGroundFactors = fc->getGroundFactors();
|
||||
FgVarNode* varNode = vc->getGroundFgVarNodes()[0];
|
||||
unsigned count = 0;
|
||||
VarId vid = vc->getGroundVarNodes().front()->varId();
|
||||
const FacNodes& clusterGroundFactors = fc->getGroundFactors();
|
||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||
if (clusterGroundFactors[i]->factor()->indexOf (varNode->varId()) != -1) {
|
||||
if (clusterGroundFactors[i]->factor().contains (vid)) {
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
// CFgVarSet vars = vc->getGroundFgVarNodes();
|
||||
// CVarNodes vars = vc->getGroundVarNodes();
|
||||
// for (unsigned i = 1; i < vars.size(); i++) {
|
||||
// FgVarNode* var = vc->getGroundFgVarNodes()[i];
|
||||
// VarNode* var = vc->getGroundVarNodes()[i];
|
||||
// unsigned count2 = 0;
|
||||
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||
// if (clusterGroundFactors[i]->getPosition (var) != -1) {
|
||||
@ -293,14 +302,15 @@ CFactorGraph::getGroundEdgeCount (
|
||||
|
||||
|
||||
void
|
||||
CFactorGraph::printGroups (const VarSignMap& varGroups,
|
||||
const FacSignMap& factorGroups) const
|
||||
CFactorGraph::printGroups (
|
||||
const VarSignMap& varGroups,
|
||||
const FacSignMap& facGroups) const
|
||||
{
|
||||
unsigned count = 1;
|
||||
cout << "variable groups:" << endl;
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
const FgVarSet& groupMembers = it->second;
|
||||
const VarNodes& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << count << ": " ;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
@ -313,9 +323,9 @@ CFactorGraph::printGroups (const VarSignMap& varGroups,
|
||||
|
||||
count = 1;
|
||||
cout << endl << "factor groups:" << endl;
|
||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
const FgFacSet& groupMembers = it->second;
|
||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||
it != facGroups.end(); it++) {
|
||||
const FacNodes& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << ++count << ": " ;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
|
@ -15,23 +15,25 @@ class Signature;
|
||||
class SignatureHash;
|
||||
|
||||
|
||||
typedef long Color;
|
||||
typedef unordered_map<unsigned, vector<Color> > VarColorMap;
|
||||
typedef unordered_map<const Distribution*, Color> DistColorMap;
|
||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||
typedef vector<VarCluster*> VarClusterSet;
|
||||
typedef vector<FacCluster*> FacClusterSet;
|
||||
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
|
||||
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
|
||||
typedef long Color;
|
||||
|
||||
typedef unordered_map<unsigned, vector<Color>> VarColorMap;
|
||||
|
||||
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||
|
||||
typedef vector<VarCluster*> VarClusters;
|
||||
typedef vector<FacCluster*> FacClusters;
|
||||
|
||||
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
|
||||
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
|
||||
|
||||
|
||||
|
||||
struct Signature
|
||||
{
|
||||
Signature (unsigned size)
|
||||
{
|
||||
colors.resize (size);
|
||||
}
|
||||
Signature (unsigned size) : colors(size) { }
|
||||
|
||||
bool operator< (const Signature& sig) const
|
||||
{
|
||||
if (colors.size() < sig.colors.size()) {
|
||||
@ -49,6 +51,7 @@ struct Signature
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool operator== (const Signature& sig) const
|
||||
{
|
||||
if (colors.size() != sig.colors.size()) {
|
||||
@ -61,12 +64,14 @@ struct Signature
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
vector<Color> colors;
|
||||
};
|
||||
|
||||
|
||||
|
||||
struct SignatureHash {
|
||||
struct SignatureHash
|
||||
{
|
||||
size_t operator() (const Signature &sig) const
|
||||
{
|
||||
size_t val = hash<size_t>()(sig.colors.size());
|
||||
@ -82,7 +87,7 @@ struct SignatureHash {
|
||||
class VarCluster
|
||||
{
|
||||
public:
|
||||
VarCluster (const FgVarSet& vs)
|
||||
VarCluster (const VarNodes& vs)
|
||||
{
|
||||
for (unsigned i = 0; i < vs.size(); i++) {
|
||||
groundVars_.push_back (vs[i]);
|
||||
@ -94,26 +99,28 @@ class VarCluster
|
||||
facClusters_.push_back (fc);
|
||||
}
|
||||
|
||||
const FacClusterSet& getFacClusters (void) const
|
||||
const FacClusters& getFacClusters (void) const
|
||||
{
|
||||
return facClusters_;
|
||||
}
|
||||
|
||||
FgVarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
|
||||
const FgVarSet& getGroundFgVarNodes (void) const { return groundVars_; }
|
||||
VarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||
|
||||
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
|
||||
|
||||
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
|
||||
|
||||
private:
|
||||
FgVarSet groundVars_;
|
||||
FacClusterSet facClusters_;
|
||||
FgVarNode* representVar_;
|
||||
VarNodes groundVars_;
|
||||
FacClusters facClusters_;
|
||||
VarNode* representVar_;
|
||||
};
|
||||
|
||||
|
||||
class FacCluster
|
||||
{
|
||||
public:
|
||||
FacCluster (const FgFacSet& groundFactors, const VarClusterSet& vcs)
|
||||
FacCluster (const FacNodes& groundFactors, const VarClusters& vcs)
|
||||
{
|
||||
groundFactors_ = groundFactors;
|
||||
varClusters_ = vcs;
|
||||
@ -122,12 +129,12 @@ class FacCluster
|
||||
}
|
||||
}
|
||||
|
||||
const VarClusterSet& getVarClusters (void) const
|
||||
const VarClusters& getVarClusters (void) const
|
||||
{
|
||||
return varClusters_;
|
||||
}
|
||||
|
||||
bool containsGround (const FgFacNode* fn)
|
||||
bool containsGround (const FacNode* fn)
|
||||
{
|
||||
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
||||
if (groundFactors_[i] == fn) {
|
||||
@ -137,24 +144,26 @@ class FacCluster
|
||||
return false;
|
||||
}
|
||||
|
||||
FgFacNode* getRepresentativeFactor (void) const
|
||||
FacNode* getRepresentativeFactor (void) const
|
||||
{
|
||||
return representFactor_;
|
||||
}
|
||||
void setRepresentativeFactor (FgFacNode* fn)
|
||||
|
||||
void setRepresentativeFactor (FacNode* fn)
|
||||
{
|
||||
representFactor_ = fn;
|
||||
}
|
||||
const FgFacSet& getGroundFactors (void) const
|
||||
|
||||
const FacNodes& getGroundFactors (void) const
|
||||
{
|
||||
return groundFactors_;
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
FgFacSet groundFactors_;
|
||||
VarClusterSet varClusters_;
|
||||
FgFacNode* representFactor_;
|
||||
FacNodes groundFactors_;
|
||||
VarClusters varClusters_;
|
||||
FacNode* representFactor_;
|
||||
};
|
||||
|
||||
|
||||
@ -162,51 +171,48 @@ class CFactorGraph
|
||||
{
|
||||
public:
|
||||
CFactorGraph (const FactorGraph&);
|
||||
|
||||
~CFactorGraph (void);
|
||||
|
||||
FactorGraph* getCompressedFactorGraph (void);
|
||||
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
|
||||
const VarClusters& getVarClusters (void) { return varClusters_; }
|
||||
|
||||
FgVarNode* getEquivalentVariable (VarId vid)
|
||||
const FacClusters& getFacClusters (void) { return facClusters_; }
|
||||
|
||||
VarNode* getEquivalentVariable (VarId vid)
|
||||
{
|
||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||
return vc->getRepresentativeVariable();
|
||||
}
|
||||
|
||||
const VarClusterSet& getVarClusters (void) { return varClusters_; }
|
||||
const FacClusterSet& getFacClusters (void) { return facClusters_; }
|
||||
|
||||
FactorGraph* getGroundFactorGraph (void) const;
|
||||
|
||||
unsigned getEdgeCount (const FacCluster*, const VarCluster*) const;
|
||||
|
||||
static bool checkForIdenticalFactors;
|
||||
|
||||
private:
|
||||
void setInitialColors (void);
|
||||
void createGroups (void);
|
||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||
const Signature& getSignature (const FgVarNode*);
|
||||
const Signature& getSignature (const FgFacNode*);
|
||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||
|
||||
Color getFreeColor (void) {
|
||||
Color getFreeColor (void)
|
||||
{
|
||||
++ freeColor_;
|
||||
return freeColor_ - 1;
|
||||
}
|
||||
|
||||
Color getColor (const FgVarNode* vn) const
|
||||
Color getColor (const VarNode* vn) const
|
||||
{
|
||||
return varColors_[vn->getIndex()];
|
||||
}
|
||||
Color getColor (const FgFacNode* fn) const {
|
||||
return factorColors_[fn->getIndex()];
|
||||
Color getColor (const FacNode* fn) const {
|
||||
return facColors_[fn->getIndex()];
|
||||
}
|
||||
|
||||
void setColor (const FgVarNode* vn, Color c)
|
||||
void setColor (const VarNode* vn, Color c)
|
||||
{
|
||||
varColors_[vn->getIndex()] = c;
|
||||
}
|
||||
|
||||
void setColor (const FgFacNode* fn, Color c)
|
||||
void setColor (const FacNode* fn, Color c)
|
||||
{
|
||||
factorColors_[fn->getIndex()] = c;
|
||||
facColors_[fn->getIndex()] = c;
|
||||
}
|
||||
|
||||
VarCluster* getVariableCluster (VarId vid) const
|
||||
@ -214,14 +220,26 @@ class CFactorGraph
|
||||
return vid2VarCluster_.find (vid)->second;
|
||||
}
|
||||
|
||||
void setInitialColors (void);
|
||||
|
||||
void createGroups (void);
|
||||
|
||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||
|
||||
const Signature& getSignature (const VarNode*);
|
||||
|
||||
const Signature& getSignature (const FacNode*);
|
||||
|
||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||
|
||||
Color freeColor_;
|
||||
vector<Color> varColors_;
|
||||
vector<Color> factorColors_;
|
||||
vector<Color> facColors_;
|
||||
vector<Signature> varSignatures_;
|
||||
vector<Signature> factorSignatures_;
|
||||
VarClusterSet varClusters_;
|
||||
FacClusterSet facClusters_;
|
||||
VarId2VarCluster vid2VarCluster_;
|
||||
vector<Signature> facSignatures_;
|
||||
VarClusters varClusters_;
|
||||
FacClusters facClusters_;
|
||||
VarId2VarCluster vid2VarCluster_;
|
||||
const FactorGraph* groundFg_;
|
||||
};
|
||||
|
||||
|
@ -1,10 +1,41 @@
|
||||
#include "CbpSolver.h"
|
||||
|
||||
|
||||
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
||||
{
|
||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
||||
if (Constants::COLLECT_STATS) {
|
||||
nGroundVars = fg_->varNodes().size();
|
||||
nGroundFacs = fg_->facNodes().size();
|
||||
const VarNodes& vars = fg_->varNodes();
|
||||
nWithoutNeighs = 0;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
const FacNodes& factors = vars[i]->neighbors();
|
||||
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
||||
nWithoutNeighs ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
cfg_ = new CFactorGraph (fg);
|
||||
fg_ = cfg_->getGroundFactorGraph();
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nClusterVars = fg_->varNodes().size();
|
||||
unsigned nClusterFacs = fg_->facNodes().size();
|
||||
Statistics::updateCompressingStatistics (nGroundVars,
|
||||
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
|
||||
}
|
||||
Util::printHeader ("Uncompressed Factor Graph");
|
||||
fg.print();
|
||||
Util::printHeader ("Compressed Factor Graph");
|
||||
fg_->print();
|
||||
}
|
||||
|
||||
|
||||
|
||||
CbpSolver::~CbpSolver (void)
|
||||
{
|
||||
delete lfg_;
|
||||
delete factorGraph_;
|
||||
delete cfg_;
|
||||
delete fg_;
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
@ -16,28 +47,31 @@ CbpSolver::~CbpSolver (void)
|
||||
Params
|
||||
CbpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
assert (lfg_->getEquivalentVariable (vid));
|
||||
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
assert (cfg_->getEquivalentVariable (vid));
|
||||
VarNode* var = cfg_->getEquivalentVariable (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->nrStates(), Util::noEvidence());
|
||||
probs[var->getEvidence()] = Util::withEvidence();
|
||||
probs.resize (var->range(), LogAware::noEvidence());
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->nrStates(), Util::multIdenty());
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (probs, l->getPoweredMessage());
|
||||
}
|
||||
Util::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (probs, l->poweredMessage());
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::multiply (probs, l->getPoweredMessage());
|
||||
Util::multiply (probs, l->poweredMessage());
|
||||
}
|
||||
Util::normalize (probs);
|
||||
LogAware::normalize (probs);
|
||||
}
|
||||
}
|
||||
return probs;
|
||||
@ -46,55 +80,14 @@ CbpSolver::getPosterioriOf (VarId vid)
|
||||
|
||||
|
||||
Params
|
||||
CbpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
||||
{
|
||||
VarIds eqVarIds;
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId());
|
||||
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
|
||||
eqVarIds.push_back (vn->varId());
|
||||
}
|
||||
return FgBpSolver::getJointDistributionOf (eqVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::initializeSolver (void)
|
||||
{
|
||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
||||
if (COLLECT_STATISTICS) {
|
||||
nGroundVars = factorGraph_->getVarNodes().size();
|
||||
nGroundFacs = factorGraph_->getFactorNodes().size();
|
||||
const FgVarSet& vars = factorGraph_->getVarNodes();
|
||||
nWithoutNeighs = 0;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
const FgFacSet& factors = vars[i]->neighbors();
|
||||
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
||||
nWithoutNeighs ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lfg_ = new CFactorGraph (*factorGraph_);
|
||||
|
||||
// cout << "Uncompressed Factor Graph" << endl;
|
||||
// factorGraph_->printGraphicalModel();
|
||||
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
|
||||
factorGraph_ = lfg_->getCompressedFactorGraph();
|
||||
|
||||
if (COLLECT_STATISTICS) {
|
||||
unsigned nClusterVars = factorGraph_->getVarNodes().size();
|
||||
unsigned nClusterFacs = factorGraph_->getFactorNodes().size();
|
||||
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
|
||||
nClusterVars, nClusterFacs,
|
||||
nWithoutNeighs);
|
||||
}
|
||||
|
||||
// cout << "Compressed Factor Graph" << endl;
|
||||
// factorGraph_->printGraphicalModel();
|
||||
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
|
||||
// abort();
|
||||
FgBpSolver::initializeSolver();
|
||||
return BpSolver::getJointDistributionOf (eqVarIds);
|
||||
}
|
||||
|
||||
|
||||
@ -102,12 +95,13 @@ CbpSolver::initializeSolver (void)
|
||||
void
|
||||
CbpSolver::createLinks (void)
|
||||
{
|
||||
const FacClusterSet fcs = lfg_->getFacClusters();
|
||||
const FacClusters& fcs = cfg_->getFacClusters();
|
||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||
const VarClusterSet vcs = fcs[i]->getVarClusters();
|
||||
const VarClusters& vcs = fcs[i]->getVarClusters();
|
||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
|
||||
links_.push_back (new CbpSolverLink (fcs[i]->getRepresentativeFactor(),
|
||||
unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]);
|
||||
links_.push_back (new CbpSolverLink (
|
||||
fcs[i]->getRepresentativeFactor(),
|
||||
vcs[j]->getRepresentativeVariable(), c));
|
||||
}
|
||||
}
|
||||
@ -123,7 +117,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
calculateMessage (links_[i]);
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
if (DL >= 2 && DL < 5) {
|
||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||
cout << "calculating " << links_[i]->toString() << endl;
|
||||
}
|
||||
}
|
||||
@ -131,7 +125,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
@ -142,7 +136,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
SpLink* link = *it;
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
}
|
||||
if (link->getResidual() < BpOptions::accuracy) {
|
||||
@ -154,12 +148,12 @@ CbpSolver::maxResidualSchedule (void)
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
|
||||
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) {
|
||||
if (DL >= 2 && DL < 5) {
|
||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||
cout << " calculating " << links[j]->toString() << endl;
|
||||
}
|
||||
calculateMessage (links[j]);
|
||||
@ -174,7 +168,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getVariable() != link->getVariable()) {
|
||||
if (DL >= 2 && DL < 5) {
|
||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||
cout << " calculating " << links[i]->toString() << endl;
|
||||
}
|
||||
calculateMessage (links[i]);
|
||||
@ -192,43 +186,43 @@ Params
|
||||
CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
{
|
||||
Params msg;
|
||||
const FgVarNode* src = link->getVariable();
|
||||
const FgFacNode* dst = link->getFactor();
|
||||
const VarNode* src = link->getVariable();
|
||||
const FacNode* dst = link->getFactor();
|
||||
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->nrStates(), Util::noEvidence());
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
double value = link->getMessage()[src->getEvidence()];
|
||||
msg[src->getEvidence()] = Util::pow (value, l->getNumberOfEdges() - 1);
|
||||
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
|
||||
} else {
|
||||
msg = link->getMessage();
|
||||
Util::pow (msg, l->getNumberOfEdges() - 1);
|
||||
LogAware::pow (msg, l->nrEdges() - 1);
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " " << "init: " << Util::parametersToString (msg) << endl;
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " " << "init: " << msg << endl;
|
||||
}
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (msg, l->getPoweredMessage());
|
||||
Util::add (msg, l->poweredMessage());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::multiply (msg, l->getPoweredMessage());
|
||||
if (DL >= 5) {
|
||||
Util::multiply (msg, l->poweredMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
||||
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
|
||||
cout << l->poweredMessage() << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (DL >= 5) {
|
||||
cout << " result = " << Util::parametersToString (msg) << endl;
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " result = " << msg << endl;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
@ -241,12 +235,9 @@ CbpSolver::printLinkInformation (void) const
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " ;
|
||||
cout << Util::parametersToString (l->getMessage()) << endl;
|
||||
cout << " next msg = " ;
|
||||
cout << Util::parametersToString (l->getNextMessage()) << endl;
|
||||
cout << " powered = " ;
|
||||
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
|
||||
cout << " curr msg = " << l->getMessage() << endl;
|
||||
cout << " next msg = " << l->getNextMessage() << endl;
|
||||
cout << " powered = " << l->poweredMessage() << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
#ifndef HORUS_CBP_H
|
||||
#define HORUS_CBP_H
|
||||
|
||||
#include "FgBpSolver.h"
|
||||
#include "BpSolver.h"
|
||||
#include "CFactorGraph.h"
|
||||
|
||||
class Factor;
|
||||
@ -9,49 +9,51 @@ class Factor;
|
||||
class CbpSolverLink : public SpLink
|
||||
{
|
||||
public:
|
||||
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn)
|
||||
{
|
||||
edgeCount_ = c;
|
||||
poweredMsg_.resize (vn->nrStates(), Util::one());
|
||||
}
|
||||
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
|
||||
: SpLink (fn, vn), nrEdges_(c),
|
||||
pwdMsg_(vn->range(), LogAware::one()) { }
|
||||
|
||||
unsigned nrEdges (void) const { return nrEdges_; }
|
||||
|
||||
const Params& poweredMessage (void) const { return pwdMsg_; }
|
||||
|
||||
void updateMessage (void)
|
||||
{
|
||||
poweredMsg_ = *nextMsg_;
|
||||
pwdMsg_ = *nextMsg_;
|
||||
swap (currMsg_, nextMsg_);
|
||||
msgSended_ = true;
|
||||
Util::pow (poweredMsg_, edgeCount_);
|
||||
LogAware::pow (pwdMsg_, nrEdges_);
|
||||
}
|
||||
|
||||
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
||||
const Params& getPoweredMessage (void) const { return poweredMsg_; }
|
||||
|
||||
private:
|
||||
Params poweredMsg_;
|
||||
unsigned edgeCount_;
|
||||
unsigned nrEdges_;
|
||||
Params pwdMsg_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CbpSolver : public FgBpSolver
|
||||
class CbpSolver : public BpSolver
|
||||
{
|
||||
public:
|
||||
CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { }
|
||||
~CbpSolver (void);
|
||||
CbpSolver (const FactorGraph& fg);
|
||||
|
||||
Params getPosterioriOf (VarId);
|
||||
Params getJointDistributionOf (const VarIds&);
|
||||
~CbpSolver (void);
|
||||
|
||||
Params getPosterioriOf (VarId);
|
||||
|
||||
Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
private:
|
||||
void initializeSolver (void);
|
||||
void createLinks (void);
|
||||
|
||||
void maxResidualSchedule (void);
|
||||
Params getVar2FactorMsg (const SpLink*) const;
|
||||
void printLinkInformation (void) const;
|
||||
void createLinks (void);
|
||||
|
||||
void maxResidualSchedule (void);
|
||||
|
||||
CFactorGraph* lfg_;
|
||||
Params getVar2FactorMsg (const SpLink*) const;
|
||||
|
||||
void printLinkInformation (void) const;
|
||||
|
||||
CFactorGraph* cfg_;
|
||||
};
|
||||
|
||||
#endif // HORUS_CBP_H
|
||||
|
@ -1,10 +1,11 @@
|
||||
#include <queue>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "ConstraintTree.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
|
||||
void
|
||||
CTNode::addChild (CTNode* child, bool updateLevels)
|
||||
{
|
||||
@ -42,6 +43,34 @@ CTNode::removeChild (CTNode* child)
|
||||
|
||||
|
||||
|
||||
void
|
||||
CTNode::removeChilds (void)
|
||||
{
|
||||
childs_.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CTNode::removeAndDeleteChild (CTNode* child)
|
||||
{
|
||||
removeChild (child);
|
||||
CTNode::deleteSubtree (child);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CTNode::removeAndDeleteAllChilds (void)
|
||||
{
|
||||
for (unsigned i = 0; i < childs_.size(); i++) {
|
||||
deleteSubtree (childs_[i]);
|
||||
}
|
||||
childs_.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
SymbolSet
|
||||
CTNode::childSymbols (void) const
|
||||
{
|
||||
@ -66,6 +95,32 @@ CTNode::updateChildLevels (CTNode* n, unsigned level)
|
||||
|
||||
|
||||
|
||||
CTNode*
|
||||
CTNode::copySubtree (const CTNode* n)
|
||||
{
|
||||
CTNode* newNode = new CTNode (*n);
|
||||
const CTNodes& childs = n->childs();
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
newNode->addChild (copySubtree (childs[i]));
|
||||
}
|
||||
return newNode;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CTNode::deleteSubtree (CTNode* n)
|
||||
{
|
||||
assert (n);
|
||||
const CTNodes& childs = n->childs();
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
deleteSubtree (childs[i]);
|
||||
}
|
||||
delete n;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &out, const CTNode& n)
|
||||
{
|
||||
// out << "(" << n.level() << ") " ;
|
||||
@ -75,6 +130,17 @@ ostream& operator<< (ostream &out, const CTNode& n)
|
||||
|
||||
|
||||
|
||||
ConstraintTree::ConstraintTree (unsigned nrLvs)
|
||||
{
|
||||
for (unsigned i = 0; i < nrLvs; i++) {
|
||||
logVars_.push_back (LogVar (i));
|
||||
}
|
||||
root_ = new CTNode (0, 0);
|
||||
logVarSet_ = LogVarSet (logVars_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
ConstraintTree::ConstraintTree (const LogVars& logVars)
|
||||
{
|
||||
root_ = new CTNode (0, 0);
|
||||
@ -99,7 +165,7 @@ ConstraintTree::ConstraintTree (const LogVars& logVars,
|
||||
|
||||
ConstraintTree::ConstraintTree (const ConstraintTree& ct)
|
||||
{
|
||||
root_ = copySubtree (ct.root_);
|
||||
root_ = CTNode::copySubtree (ct.root_);
|
||||
logVars_ = ct.logVars_;
|
||||
logVarSet_ = ct.logVarSet_;
|
||||
}
|
||||
@ -108,7 +174,7 @@ ConstraintTree::ConstraintTree (const ConstraintTree& ct)
|
||||
|
||||
ConstraintTree::~ConstraintTree (void)
|
||||
{
|
||||
deleteSubtree (root_);
|
||||
CTNode::deleteSubtree (root_);
|
||||
}
|
||||
|
||||
|
||||
@ -200,21 +266,28 @@ ConstraintTree::moveToBottom (const LogVars& lvs)
|
||||
|
||||
void
|
||||
ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
|
||||
{
|
||||
{
|
||||
if (logVarSet_.empty()) {
|
||||
delete root_;
|
||||
root_ = CTNode::copySubtree (ct->root());
|
||||
logVars_ = ct->logVars();
|
||||
logVarSet_ = ct->logVarSet();
|
||||
return;
|
||||
}
|
||||
|
||||
LogVarSet intersect = logVarSet_ & ct->logVarSet_;
|
||||
if (intersect.empty()) {
|
||||
const CTNodes& childs = ct->root()->childs();
|
||||
CTNodes leafs = getNodesAtLevel (getLevel (logVars_.back()));
|
||||
for (unsigned i = 0; i < leafs.size(); i++) {
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
leafs[i]->addChild (copySubtree (childs[j]));
|
||||
leafs[i]->addChild (CTNode::copySubtree (childs[j]));
|
||||
}
|
||||
}
|
||||
logVars_.insert (logVars_.end(), ct->logVars_.begin(), ct->logVars_.end());
|
||||
Util::addToVector (logVars_, ct->logVars_);
|
||||
logVarSet_ |= ct->logVarSet_;
|
||||
|
||||
} else {
|
||||
|
||||
moveToBottom (intersect.elements());
|
||||
ct->moveToTop (intersect.elements());
|
||||
|
||||
@ -222,25 +295,27 @@ ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
|
||||
CTNodes nodes = getNodesAtLevel (level);
|
||||
|
||||
Tuples tuples;
|
||||
CTNodes continuationNodes;
|
||||
CTNodes continNodes;
|
||||
getTuples (ct->root(),
|
||||
Tuples(),
|
||||
intersect.size(),
|
||||
tuples,
|
||||
continuationNodes);
|
||||
continNodes);
|
||||
|
||||
for (unsigned i = 0; i < tuples.size(); i++) {
|
||||
bool tupleFounded = false;
|
||||
for (unsigned j = 0; j < nodes.size(); j++) {
|
||||
tupleFounded |= join (nodes[j], tuples[i], 0, continuationNodes[i]);
|
||||
tupleFounded |= join (nodes[j], tuples[i], 0, continNodes[i]);
|
||||
}
|
||||
if (assertWhenNotFound) {
|
||||
assert (tupleFounded);
|
||||
}
|
||||
}
|
||||
LogVarSet newLvs = ct->logVarSet_ - intersect;
|
||||
logVars_.insert (logVars_.end(), newLvs.begin(), newLvs.end());
|
||||
logVarSet_ |= newLvs;
|
||||
|
||||
LogVars newLvs (ct->logVars().begin() + intersect.size(),
|
||||
ct->logVars().end());
|
||||
Util::addToVector (logVars_, newLvs);
|
||||
logVarSet_ |= LogVarSet (newLvs);
|
||||
}
|
||||
}
|
||||
|
||||
@ -280,6 +355,10 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new)
|
||||
void
|
||||
ConstraintTree::applySubstitution (const Substitution& theta)
|
||||
{
|
||||
LogVars discardedLvs = theta.getDiscardedLogVars();
|
||||
for (unsigned i = 0; i < discardedLvs.size(); i++) {
|
||||
remove(discardedLvs[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < logVars_.size(); i++) {
|
||||
logVars_[i] = theta.newNameFor (logVars_[i]);
|
||||
}
|
||||
@ -308,11 +387,7 @@ ConstraintTree::remove (const LogVarSet& X)
|
||||
unsigned level = getLevel (X.front()) - 1;
|
||||
CTNodes nodes = getNodesAtLevel (level);
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
CTNodes childs = nodes[i]->childs();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
nodes[i]->removeChild (childs[j]);
|
||||
deleteSubtree (childs[j]);
|
||||
}
|
||||
nodes[i]->removeAndDeleteAllChilds();
|
||||
}
|
||||
logVars_.resize (logVars_.size() - X.size());
|
||||
logVarSet_ -= X;
|
||||
@ -545,16 +620,16 @@ ConstraintTree::split (
|
||||
for (unsigned i = 0; i < commNodes.size(); i++) {
|
||||
commCt->root()->addChild (commNodes[i]);
|
||||
}
|
||||
//cout << commCt->tupleSet() << " + " ;
|
||||
//cout << exclCt->tupleSet() << " = " ;
|
||||
//cout << tupleSet() << endl << endl;
|
||||
// cout << commCt->tupleSet() << " + " ;
|
||||
// cout << exclCt->tupleSet() << " = " ;
|
||||
// cout << tupleSet() << endl << endl;
|
||||
// if (((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet()) == false) {
|
||||
// exportToGraphViz ("_fail.dot", true);
|
||||
// commCt->exportToGraphViz ("_fail_comm.dot", true);
|
||||
// exclCt->exportToGraphViz ("_fail_excl.dot", true);
|
||||
// }
|
||||
assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
|
||||
assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
|
||||
// assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
|
||||
// assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
|
||||
return {commCt, exclCt};
|
||||
}
|
||||
|
||||
@ -601,36 +676,32 @@ ConstraintTree::jointCountNormalize (
|
||||
LogVar X_new1,
|
||||
LogVar X_new2)
|
||||
{
|
||||
exportToGraphViz ("C.dot", true);
|
||||
commCt->exportToGraphViz ("C_comm.dot", true);
|
||||
exclCt->exportToGraphViz ("C_exlc.dot", true);
|
||||
unsigned N = getConditionalCount (X);
|
||||
cout << "My tuples: " << tupleSet() << endl;
|
||||
cout << "CommCt tuples: " << commCt->tupleSet() << endl;
|
||||
cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
|
||||
cout << "Counted Lv: " << X << endl;
|
||||
cout << "Original N: " << N << endl;
|
||||
cout << endl;
|
||||
// cout << "My tuples: " << tupleSet() << endl;
|
||||
// cout << "CommCt tuples: " << commCt->tupleSet() << endl;
|
||||
// cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
|
||||
// cout << "Counted Lv: " << X << endl;
|
||||
// cout << "X_new1: " << X_new1 << endl;
|
||||
// cout << "X_new2: " << X_new2 << endl;
|
||||
// cout << "Original N: " << N << endl;
|
||||
// cout << endl;
|
||||
|
||||
ConstraintTrees normCts1 = commCt->countNormalize (X);
|
||||
vector<unsigned> counts1 (normCts1.size());
|
||||
for (unsigned i = 0; i < normCts1.size(); i++) {
|
||||
counts1[i] = normCts1[i]->getConditionalCount (X);
|
||||
cout << "normCts1[" << i << "] #" << counts1[i] ;
|
||||
cout << " " << normCts1[i]->tupleSet() << endl;
|
||||
// cout << "normCts1[" << i << "] #" << counts1[i] ;
|
||||
// cout << " " << normCts1[i]->tupleSet() << endl;
|
||||
}
|
||||
|
||||
ConstraintTrees normCts2 = exclCt->countNormalize (X);
|
||||
vector<unsigned> counts2 (normCts2.size());
|
||||
for (unsigned i = 0; i < normCts2.size(); i++) {
|
||||
counts2[i] = normCts2[i]->getConditionalCount (X);
|
||||
cout << "normCts2[" << i << "] #" << counts2[i] ;
|
||||
cout << " " << normCts2[i]->tupleSet() << endl;
|
||||
// cout << "normCts2[" << i << "] #" << counts2[i] ;
|
||||
// cout << " " << normCts2[i]->tupleSet() << endl;
|
||||
}
|
||||
cout << endl;
|
||||
|
||||
cout << "1###### " << normCts1.size() << endl;
|
||||
cout << "2###### " << normCts2.size() << endl;
|
||||
// cout << endl;
|
||||
|
||||
ConstraintTree* excl1 = 0;
|
||||
for (unsigned i = 0; i < normCts1.size(); i++) {
|
||||
@ -638,7 +709,7 @@ ConstraintTree::jointCountNormalize (
|
||||
excl1 = normCts1[i];
|
||||
normCts1.erase (normCts1.begin() + i);
|
||||
counts1.erase (counts1.begin() + i);
|
||||
cout << ">joint-count(" << N << ",0)" << endl;
|
||||
// cout << "joint-count(" << N << ",0)" << endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -649,22 +720,21 @@ ConstraintTree::jointCountNormalize (
|
||||
excl2 = normCts2[i];
|
||||
normCts2.erase (normCts2.begin() + i);
|
||||
counts2.erase (counts2.begin() + i);
|
||||
cout << ">>joint-count(0," << N << ")" << endl;
|
||||
// cout << "joint-count(0," << N << ")" << endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
cout << "3###### " << normCts1.size() << endl;
|
||||
cout << "4###### " << normCts2.size() << endl;
|
||||
|
||||
for (unsigned i = 0; i < normCts1.size(); i++) {
|
||||
unsigned j;
|
||||
for (j = 0; counts1[i] + counts2[j] != N; j++) ;
|
||||
cout << "joint-count(" << counts1[i] << "," << counts2[j] << ")" << endl;
|
||||
// cout << "joint-count(" << counts1[i] ;
|
||||
// cout << "," << counts2[j] << ")" << endl;
|
||||
const CTNodes& childs = normCts2[j]->root_->childs();
|
||||
for (unsigned k = 0; k < childs.size(); k++) {
|
||||
normCts1[i]->root_->addChild (childs[k]);
|
||||
normCts1[i]->root_->addChild (CTNode::copySubtree (childs[k]));
|
||||
}
|
||||
delete normCts2[j];
|
||||
}
|
||||
|
||||
ConstraintTrees cts = normCts1;
|
||||
@ -683,11 +753,6 @@ ConstraintTree::jointCountNormalize (
|
||||
cts.push_back (excl2);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < cts.size(); i++) {
|
||||
stringstream ss;
|
||||
ss << "aaacts_" << i + 1 << ".dot" ;
|
||||
cts[i]->exportToGraphViz (ss.str().c_str(), true);
|
||||
}
|
||||
return cts;
|
||||
}
|
||||
|
||||
@ -735,11 +800,11 @@ ConstraintTree::expand (LogVar X)
|
||||
unsigned nrSymbols = getConditionalCount (X);
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
Symbols symbols;
|
||||
CTNodes childs = nodes[i]->childs();
|
||||
const CTNodes& childs = nodes[i]->childs();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
symbols.push_back (childs[j]->symbol());
|
||||
nodes[i]->removeChild (childs[j]);
|
||||
}
|
||||
nodes[i]->removeAndDeleteAllChilds();
|
||||
CTNode* prev = nodes[i];
|
||||
assert (symbols.size() == nrSymbols);
|
||||
for (unsigned j = 0; j < nrSymbols; j++) {
|
||||
@ -768,7 +833,7 @@ ConstraintTree::ground (LogVar X)
|
||||
ConstraintTrees cts;
|
||||
const CTNodes& nodes = root_->childs();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
CTNode* copy = copySubtree (nodes[i]);
|
||||
CTNode* copy = CTNode::copySubtree (nodes[i]);
|
||||
copy->setSymbol (nodes[i]->symbol());
|
||||
ConstraintTree* newCt = new ConstraintTree (logVars_);
|
||||
newCt->root()->addChild (copy);
|
||||
@ -840,19 +905,19 @@ ConstraintTree::getNodesAtLevel (unsigned level) const
|
||||
void
|
||||
ConstraintTree::swapLogVar (LogVar X)
|
||||
{
|
||||
TupleSet before = tupleSet();
|
||||
LogVars::iterator it =
|
||||
std::find (logVars_.begin(),logVars_.end(), X);
|
||||
assert (it != logVars_.end());
|
||||
unsigned pos = std::distance (logVars_.begin(), it);
|
||||
|
||||
const CTNodes& nodes = getNodesAtLevel (pos);
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
const CTNodes childs = nodes[i]->childs();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
nodes[i]->removeChild (childs[j]);
|
||||
const CTNodes grandsons = childs[j]->childs();
|
||||
CTNodes childsCopy = nodes[i]->childs();
|
||||
nodes[i]->removeChilds();
|
||||
for (unsigned j = 0; j < childsCopy.size(); j++) {
|
||||
const CTNodes grandsons = childsCopy[j]->childs();
|
||||
for (unsigned k = 0; k < grandsons.size(); k++) {
|
||||
CTNode* childCopy = new CTNode (*childs[j]);
|
||||
CTNode* childCopy = new CTNode (*childsCopy[j]);
|
||||
const CTNodes greatGrandsons = grandsons[k]->childs();
|
||||
for (unsigned t = 0; t < greatGrandsons.size(); t++) {
|
||||
grandsons[k]->removeChild (greatGrandsons[t]);
|
||||
@ -863,10 +928,9 @@ ConstraintTree::swapLogVar (LogVar X)
|
||||
grandsons[k]->setLevel (grandsons[k]->level() - 1);
|
||||
nodes[i]->addChild (grandsons[k], false);
|
||||
}
|
||||
delete childs[j];
|
||||
delete childsCopy[j];
|
||||
}
|
||||
}
|
||||
|
||||
std::swap (logVars_[pos], logVars_[pos + 1]);
|
||||
}
|
||||
|
||||
@ -884,7 +948,7 @@ ConstraintTree::join (
|
||||
if (currIdx == tuple.size() - 1) {
|
||||
const CTNodes& childs = appendNode->childs();
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
n->addChild (copySubtree (childs[i]));
|
||||
n->addChild (CTNode::copySubtree (childs[i]));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -985,7 +1049,7 @@ ConstraintTree::countNormalize (
|
||||
{
|
||||
if (n->level() == stopLevel) {
|
||||
return vector<pair<CTNode*, unsigned>>() = {
|
||||
make_pair (copySubtree (n), countTuples (n))
|
||||
make_pair (CTNode::copySubtree (n), countTuples (n))
|
||||
};
|
||||
}
|
||||
|
||||
@ -1004,65 +1068,6 @@ ConstraintTree::countNormalize (
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
void
|
||||
ConstraintTree::split (
|
||||
CTNode* n1,
|
||||
CTNode* n2,
|
||||
CTNodes& nodes,
|
||||
unsigned stopLevel)
|
||||
{
|
||||
CTNodes& childs1 = n1->childs();
|
||||
CTNodes& childs2 = n2->childs();
|
||||
// cout << string (n1->level() * 8, '-') << "Level = " << n1->level() + 1;
|
||||
// cout << ", #I = " << childs1.size();
|
||||
// cout << ", #J = " << childs2.size() << endl;
|
||||
for (unsigned i = 0; i < childs1.size(); i++) {
|
||||
for (unsigned j = 0; j < childs2.size(); j++) {
|
||||
if (childs1[i]->symbol() != childs2[j]->symbol()) {
|
||||
continue;
|
||||
}
|
||||
if (childs1[i]->level() == stopLevel) {
|
||||
CTNode* newNode = copySubtree (childs1[i]);
|
||||
newNode->setSymbol (childs1[i]->symbol());
|
||||
nodes.push_back (newNode);
|
||||
childs1[i]->setSymbol (Symbol::invalid());
|
||||
break;
|
||||
} else {
|
||||
CTNodes lowerNodes;
|
||||
split (childs1[i], childs2[j], lowerNodes, stopLevel);
|
||||
if (lowerNodes.empty() == false) {
|
||||
CTNode* me = new CTNode (childs1[i]->symbol(), childs1[i]->level());
|
||||
for (unsigned k = 0; k < lowerNodes.size(); k++) {
|
||||
me->addChild (lowerNodes[k]);
|
||||
}
|
||||
nodes.push_back (me);
|
||||
}
|
||||
if (childs1[i]->isLeaf()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int)childs1.size(); i++) {
|
||||
// cout << string (n1->level() * 8, '-') << childs1[i];
|
||||
if (childs1[i]->symbol() == Symbol::invalid()) {
|
||||
// cout << " empty, removing..." ;
|
||||
n1->removeChild (childs1[i]);
|
||||
i --;
|
||||
} else if (childs1[i]->isLeaf() &&
|
||||
childs1[i]->level() != stopLevel) {
|
||||
// cout << " leaf, removing..." ;
|
||||
n1->removeChild (childs1[i]);
|
||||
i --;
|
||||
}
|
||||
// cout << endl;
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
|
||||
void
|
||||
ConstraintTree::split (
|
||||
@ -1085,7 +1090,7 @@ ConstraintTree::split (
|
||||
continue;
|
||||
}
|
||||
if (childs1[i]->level() == stopLevel) {
|
||||
CTNode* newNode = copySubtree (childs1[i]);
|
||||
CTNode* newNode = CTNode::copySubtree (childs1[i]);
|
||||
nodes.push_back (newNode);
|
||||
childs1[i]->setSymbol (Symbol::invalid());
|
||||
} else {
|
||||
@ -1103,11 +1108,11 @@ ConstraintTree::split (
|
||||
|
||||
for (int i = 0; i < (int)childs1.size(); i++) {
|
||||
if (childs1[i]->symbol() == Symbol::invalid()) {
|
||||
n1->removeChild (childs1[i]);
|
||||
n1->removeAndDeleteChild (childs1[i]);
|
||||
i --;
|
||||
} else if (childs1[i]->isLeaf() &&
|
||||
childs1[i]->level() != stopLevel) {
|
||||
n1->removeChild (childs1[i]);
|
||||
n1->removeAndDeleteChild (childs1[i]);
|
||||
i --;
|
||||
}
|
||||
}
|
||||
@ -1141,29 +1146,3 @@ ConstraintTree::overlap (
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
CTNode*
|
||||
ConstraintTree::copySubtree (const CTNode* n)
|
||||
{
|
||||
CTNode* newNode = new CTNode (*n);
|
||||
const CTNodes& childs = n->childs();
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
newNode->addChild (copySubtree (childs[i]));
|
||||
}
|
||||
return newNode;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ConstraintTree::deleteSubtree (CTNode* n)
|
||||
{
|
||||
assert (n);
|
||||
const CTNodes& childs = n->childs();
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
deleteSubtree (childs[i]);
|
||||
}
|
||||
delete n;
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,6 @@ typedef vector<ConstraintTree*> ConstraintTrees;
|
||||
|
||||
|
||||
|
||||
|
||||
class CTNode
|
||||
{
|
||||
public:
|
||||
@ -47,29 +46,44 @@ class CTNode
|
||||
|
||||
bool isLeaf (void) const { return childs_.empty(); }
|
||||
|
||||
void addChild (CTNode*, bool = true);
|
||||
void removeChild (CTNode*);
|
||||
SymbolSet childSymbols (void) const;
|
||||
void addChild (CTNode*, bool = true);
|
||||
|
||||
void removeChild (CTNode*);
|
||||
|
||||
void removeChilds (void);
|
||||
|
||||
void removeAndDeleteChild (CTNode*);
|
||||
|
||||
void removeAndDeleteAllChilds (void);
|
||||
|
||||
SymbolSet childSymbols (void) const;
|
||||
|
||||
static CTNode* copySubtree (const CTNode*);
|
||||
|
||||
static void deleteSubtree (CTNode*);
|
||||
|
||||
private:
|
||||
void updateChildLevels (CTNode*, unsigned);
|
||||
void updateChildLevels (CTNode*, unsigned);
|
||||
|
||||
Symbol symbol_;
|
||||
CTNodes childs_;
|
||||
unsigned level_;
|
||||
Symbol symbol_;
|
||||
CTNodes childs_;
|
||||
unsigned level_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &out, const CTNode&);
|
||||
|
||||
|
||||
class ConstraintTree
|
||||
{
|
||||
public:
|
||||
ConstraintTree (unsigned);
|
||||
|
||||
ConstraintTree (const LogVars&);
|
||||
|
||||
ConstraintTree (const LogVars&, const Tuples&);
|
||||
|
||||
ConstraintTree (const ConstraintTree&);
|
||||
|
||||
~ConstraintTree (void);
|
||||
|
||||
CTNode* root (void) const { return root_; }
|
||||
@ -94,94 +108,95 @@ class ConstraintTree
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
}
|
||||
|
||||
void addTuple (const Tuple&);
|
||||
bool containsTuple (const Tuple&);
|
||||
void moveToTop (const LogVars&);
|
||||
void moveToBottom (const LogVars&);
|
||||
void join (ConstraintTree*, bool = false);
|
||||
unsigned getLevel (LogVar) const;
|
||||
void rename (LogVar, LogVar);
|
||||
void applySubstitution (const Substitution&);
|
||||
void project (const LogVarSet&);
|
||||
void remove (const LogVarSet&);
|
||||
bool isSingleton (LogVar);
|
||||
LogVarSet singletons (void);
|
||||
TupleSet tupleSet (unsigned = 0) const;
|
||||
TupleSet tupleSet (const LogVars&);
|
||||
unsigned size (void) const;
|
||||
unsigned nrSymbols (LogVar);
|
||||
void exportToGraphViz (const char*, bool = false) const;
|
||||
bool isCountNormalized (const LogVarSet&);
|
||||
unsigned getConditionalCount (const LogVarSet&);
|
||||
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
|
||||
bool isCarteesianProduct (const LogVarSet&) const;
|
||||
void addTuple (const Tuple&);
|
||||
|
||||
bool containsTuple (const Tuple&);
|
||||
|
||||
void moveToTop (const LogVars&);
|
||||
|
||||
void moveToBottom (const LogVars&);
|
||||
|
||||
void join (ConstraintTree*, bool = false);
|
||||
|
||||
unsigned getLevel (LogVar) const;
|
||||
|
||||
void rename (LogVar, LogVar);
|
||||
|
||||
void applySubstitution (const Substitution&);
|
||||
|
||||
void project (const LogVarSet&);
|
||||
|
||||
void remove (const LogVarSet&);
|
||||
|
||||
bool isSingleton (LogVar);
|
||||
|
||||
LogVarSet singletons (void);
|
||||
|
||||
TupleSet tupleSet (unsigned = 0) const;
|
||||
|
||||
TupleSet tupleSet (const LogVars&);
|
||||
|
||||
unsigned size (void) const;
|
||||
|
||||
unsigned nrSymbols (LogVar);
|
||||
|
||||
void exportToGraphViz (const char*, bool = false) const;
|
||||
|
||||
bool isCountNormalized (const LogVarSet&);
|
||||
|
||||
unsigned getConditionalCount (const LogVarSet&);
|
||||
|
||||
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
|
||||
|
||||
bool isCarteesianProduct (const LogVarSet&) const;
|
||||
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split (
|
||||
const Tuple&,
|
||||
unsigned);
|
||||
const Tuple&, unsigned);
|
||||
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split (
|
||||
const ConstraintTree*,
|
||||
unsigned) const;
|
||||
const ConstraintTree*, unsigned) const;
|
||||
|
||||
ConstraintTrees countNormalize (const LogVarSet&);
|
||||
ConstraintTrees countNormalize (const LogVarSet&);
|
||||
|
||||
ConstraintTrees jointCountNormalize (
|
||||
ConstraintTree*,
|
||||
ConstraintTree*,
|
||||
LogVar,
|
||||
LogVar,
|
||||
LogVar);
|
||||
ConstraintTree*, ConstraintTree*, LogVar, LogVar, LogVar);
|
||||
|
||||
static bool identical (
|
||||
const ConstraintTree*,
|
||||
const ConstraintTree*,
|
||||
unsigned);
|
||||
static bool identical (
|
||||
const ConstraintTree*, const ConstraintTree*, unsigned);
|
||||
|
||||
static bool overlap (
|
||||
const ConstraintTree*,
|
||||
const ConstraintTree*,
|
||||
unsigned);
|
||||
static bool overlap (
|
||||
const ConstraintTree*, const ConstraintTree*, unsigned);
|
||||
|
||||
LogVars expand (LogVar);
|
||||
ConstraintTrees ground (LogVar);
|
||||
LogVars expand (LogVar);
|
||||
ConstraintTrees ground (LogVar);
|
||||
|
||||
private:
|
||||
unsigned countTuples (const CTNode*) const;
|
||||
CTNodes getNodesBelow (CTNode*) const;
|
||||
CTNodes getNodesAtLevel (unsigned) const;
|
||||
void swapLogVar (LogVar);
|
||||
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
|
||||
unsigned countTuples (const CTNode*) const;
|
||||
|
||||
bool indenticalSubtrees (
|
||||
const CTNode*,
|
||||
const CTNode*,
|
||||
bool) const;
|
||||
CTNodes getNodesBelow (CTNode*) const;
|
||||
|
||||
void getTuples (
|
||||
CTNode*,
|
||||
Tuples,
|
||||
unsigned,
|
||||
Tuples&,
|
||||
CTNodes&) const;
|
||||
CTNodes getNodesAtLevel (unsigned) const;
|
||||
|
||||
void swapLogVar (LogVar);
|
||||
|
||||
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
|
||||
|
||||
bool indenticalSubtrees (
|
||||
const CTNode*, const CTNode*, bool) const;
|
||||
|
||||
void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const;
|
||||
|
||||
vector<std::pair<CTNode*, unsigned>> countNormalize (
|
||||
const CTNode*,
|
||||
unsigned);
|
||||
const CTNode*, unsigned);
|
||||
|
||||
static void split (
|
||||
CTNode*,
|
||||
CTNode*,
|
||||
CTNodes&,
|
||||
unsigned);
|
||||
static void split (
|
||||
CTNode*, CTNode*, CTNodes&, unsigned);
|
||||
|
||||
static bool overlap (const CTNode*, const CTNode*, unsigned);
|
||||
static CTNode* copySubtree (const CTNode*);
|
||||
static void deleteSubtree (CTNode*);
|
||||
static bool overlap (const CTNode*, const CTNode*, unsigned);
|
||||
|
||||
CTNode* root_;
|
||||
LogVars logVars_;
|
||||
LogVarSet logVarSet_;
|
||||
CTNode* root_;
|
||||
LogVars logVars_;
|
||||
LogVarSet logVarSet_;
|
||||
};
|
||||
|
||||
|
||||
|
@ -1,45 +0,0 @@
|
||||
#ifndef HORUS_DISTRIBUTION_H
|
||||
#define HORUS_DISTRIBUTION_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Horus.h"
|
||||
|
||||
//TODO die die die die die
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
struct Distribution
|
||||
{
|
||||
public:
|
||||
Distribution (int id)
|
||||
{
|
||||
this->id = id;
|
||||
}
|
||||
|
||||
Distribution (const Params& params, int id = -1)
|
||||
{
|
||||
this->id = id;
|
||||
this->params = params;
|
||||
}
|
||||
|
||||
void updateParameters (const Params& params)
|
||||
{
|
||||
this->params = params;
|
||||
}
|
||||
|
||||
bool shared (void)
|
||||
{
|
||||
return id != -1;
|
||||
}
|
||||
|
||||
int id;
|
||||
Params params;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (Distribution);
|
||||
};
|
||||
|
||||
#endif // HORUS_DISTRIBUTION_H
|
||||
|
@ -1,53 +1,39 @@
|
||||
#include <limits>
|
||||
|
||||
#include "ElimGraph.h"
|
||||
#include "BayesNet.h"
|
||||
#include <fstream>
|
||||
|
||||
#include "ElimGraph.h"
|
||||
|
||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
||||
|
||||
|
||||
ElimGraph::ElimGraph (const BayesNet& bayesNet)
|
||||
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||
{
|
||||
const BnNodeSet& bnNodes = bayesNet.getBayesNodes();
|
||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
||||
if (bnNodes[i]->hasEvidence() == false) {
|
||||
addNode (new EgNode (bnNodes[i]));
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
||||
if (bnNodes[i]->hasEvidence() == false) {
|
||||
EgNode* n = getEgNode (bnNodes[i]->varId());
|
||||
const BnNodeSet& childs = bnNodes[i]->getChilds();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
if (childs[j]->hasEvidence() == false) {
|
||||
addEdge (n, getEgNode (childs[j]->varId()));
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
const VarIds& vids = factors[i]->arguments();
|
||||
for (unsigned j = 0; j < vids.size() - 1; j++) {
|
||||
EgNode* n1 = getEgNode (vids[j]);
|
||||
if (n1 == 0) {
|
||||
n1 = new EgNode (vids[j], factors[i]->range (j));
|
||||
addNode (n1);
|
||||
}
|
||||
for (unsigned k = j + 1; k < vids.size(); k++) {
|
||||
EgNode* n2 = getEgNode (vids[k]);
|
||||
if (n2 == 0) {
|
||||
n2 = new EgNode (vids[k], factors[i]->range (k));
|
||||
addNode (n2);
|
||||
}
|
||||
if (neighbors (n1, n2) == false) {
|
||||
addEdge (n1, n2);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (vids.size() == 1) {
|
||||
if (getEgNode (vids[0]) == 0) {
|
||||
addNode (new EgNode (vids[0], factors[i]->range (0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
||||
vector<EgNode*> neighs;
|
||||
const vector<BayesNode*>& parents = bnNodes[i]->getParents();
|
||||
for (unsigned i = 0; i < parents.size(); i++) {
|
||||
if (parents[i]->hasEvidence() == false) {
|
||||
neighs.push_back (getEgNode (parents[i]->varId()));
|
||||
}
|
||||
}
|
||||
if (neighs.size() > 0) {
|
||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||
if (!neighbors (neighs[i], neighs[j])) {
|
||||
addEdge (neighs[i], neighs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
@ -61,40 +47,16 @@ ElimGraph::~ElimGraph (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::addNode (EgNode* n)
|
||||
{
|
||||
nodes_.push_back (n);
|
||||
varMap_.insert (make_pair (n->varId(), n));
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getEgNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId,EgNode*>::const_iterator it =varMap_.find (vid);
|
||||
if (it ==varMap_.end()) {
|
||||
return 0;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarIds
|
||||
ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
||||
{
|
||||
VarIds elimOrder;
|
||||
marked_.resize (nodes_.size(), false);
|
||||
|
||||
for (unsigned i = 0; i < exclude.size(); i++) {
|
||||
assert (getEgNode (exclude[i]));
|
||||
EgNode* node = getEgNode (exclude[i]);
|
||||
assert (node);
|
||||
marked_[*node] = true;
|
||||
}
|
||||
|
||||
unsigned nVarsToEliminate = nodes_.size() - exclude.size();
|
||||
for (unsigned i = 0; i < nVarsToEliminate; i++) {
|
||||
EgNode* node = getLowestCostNode();
|
||||
@ -107,6 +69,100 @@ ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::print (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
cout << " " << neighs[j]->label();
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::exportToGraphViz (
|
||||
const char* fileName,
|
||||
bool showNeighborless,
|
||||
const VarIds& highlightVarIds) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Markov::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "strict graph {" << endl;
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
||||
out << '"' << nodes_[i]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
||||
EgNode* node =getEgNode (highlightVarIds[i]);
|
||||
if (node) {
|
||||
out << '"' << node->label() << '"' ;
|
||||
out << " [shape=box3d]" << endl;
|
||||
} else {
|
||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
||||
out << '"' << neighs[j]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarIds
|
||||
ElimGraph::getEliminationOrder (
|
||||
const vector<Factor*> factors,
|
||||
VarIds excludedVids)
|
||||
{
|
||||
ElimGraph graph (factors);
|
||||
// graph.print();
|
||||
// graph.exportToGraphViz ("_egg.dot");
|
||||
return graph.getEliminatingOrder (excludedVids);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::addNode (EgNode* n)
|
||||
{
|
||||
nodes_.push_back (n);
|
||||
n->setIndex (nodes_.size() - 1);
|
||||
varMap_.insert (make_pair (n->varId(), n));
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getEgNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId, EgNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return (it != varMap_.end()) ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getLowestCostNode (void) const
|
||||
{
|
||||
@ -164,7 +220,7 @@ ElimGraph::getWeightCost (const EgNode* n) const
|
||||
const vector<EgNode*>& neighs = n->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
if (marked_[*neighs[i]] == false) {
|
||||
cost *= neighs[i]->nrStates();
|
||||
cost *= neighs[i]->range();
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
@ -204,7 +260,7 @@ ElimGraph::getWeightedFillCost (const EgNode* n) const
|
||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||
if (marked_[*neighs[j]] == true) continue;
|
||||
if (!neighbors (neighs[i], neighs[j])) {
|
||||
cost += neighs[i]->nrStates() * neighs[j]->nrStates();
|
||||
cost += neighs[i]->range() * neighs[j]->range();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -245,78 +301,3 @@ ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::setIndexes (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->setIndex (i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::printGraphicalModel (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
cout << " " << neighs[j]->label();
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::exportToGraphViz (const char* fileName,
|
||||
bool showNeighborless,
|
||||
const VarIds& highlightVarIds) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Markov::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "strict graph {" << endl;
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
||||
out << '"' << nodes_[i]->label() << '"' ;
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||
} else {
|
||||
out << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
||||
EgNode* node =getEgNode (highlightVarIds[i]);
|
||||
if (node) {
|
||||
out << '"' << node->label() << '"' ;
|
||||
out << " [shape=box3d]" << endl;
|
||||
} else {
|
||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
||||
out << '"' << neighs[j]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
@ -17,15 +17,15 @@ enum ElimHeuristic
|
||||
};
|
||||
|
||||
|
||||
class EgNode : public VarNode {
|
||||
class EgNode : public Var
|
||||
{
|
||||
public:
|
||||
EgNode (VarNode* var) : VarNode (var) { }
|
||||
void addNeighbor (EgNode* n)
|
||||
{
|
||||
neighs_.push_back (n);
|
||||
}
|
||||
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
|
||||
|
||||
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
|
||||
|
||||
const vector<EgNode*>& neighbors (void) const { return neighs_; }
|
||||
|
||||
private:
|
||||
vector<EgNode*> neighs_;
|
||||
};
|
||||
@ -34,22 +34,18 @@ class EgNode : public VarNode {
|
||||
class ElimGraph
|
||||
{
|
||||
public:
|
||||
ElimGraph (const BayesNet&);
|
||||
ElimGraph (const vector<Factor*>&); // TODO
|
||||
|
||||
~ElimGraph (void);
|
||||
|
||||
void addEdge (EgNode* n1, EgNode* n2)
|
||||
{
|
||||
assert (n1 != n2);
|
||||
n1->addNeighbor (n2);
|
||||
n2->addNeighbor (n1);
|
||||
}
|
||||
void addNode (EgNode*);
|
||||
EgNode* getEgNode (VarId) const;
|
||||
VarIds getEliminatingOrder (const VarIds&);
|
||||
void printGraphicalModel (void) const;
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIds& = VarIds()) const;
|
||||
void setIndexes();
|
||||
VarIds getEliminatingOrder (const VarIds&);
|
||||
|
||||
void print (void) const;
|
||||
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIds& = VarIds()) const;
|
||||
|
||||
static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
|
||||
|
||||
static void setEliminationHeuristic (ElimHeuristic h)
|
||||
{
|
||||
@ -57,18 +53,34 @@ class ElimGraph
|
||||
}
|
||||
|
||||
private:
|
||||
EgNode* getLowestCostNode (void) const;
|
||||
unsigned getNeighborsCost (const EgNode*) const;
|
||||
unsigned getWeightCost (const EgNode*) const;
|
||||
unsigned getFillCost (const EgNode*) const;
|
||||
unsigned getWeightedFillCost (const EgNode*) const;
|
||||
void connectAllNeighbors (const EgNode*);
|
||||
bool neighbors (const EgNode*, const EgNode*) const;
|
||||
|
||||
void addEdge (EgNode* n1, EgNode* n2)
|
||||
{
|
||||
assert (n1 != n2);
|
||||
n1->addNeighbor (n2);
|
||||
n2->addNeighbor (n1);
|
||||
}
|
||||
|
||||
void addNode (EgNode*);
|
||||
|
||||
EgNode* getEgNode (VarId) const;
|
||||
EgNode* getLowestCostNode (void) const;
|
||||
|
||||
unsigned getNeighborsCost (const EgNode*) const;
|
||||
|
||||
unsigned getWeightCost (const EgNode*) const;
|
||||
|
||||
unsigned getFillCost (const EgNode*) const;
|
||||
|
||||
unsigned getWeightedFillCost (const EgNode*) const;
|
||||
|
||||
void connectAllNeighbors (const EgNode*);
|
||||
|
||||
bool neighbors (const EgNode*, const EgNode*) const;
|
||||
|
||||
vector<EgNode*> nodes_;
|
||||
vector<bool> marked_;
|
||||
unordered_map<VarId,EgNode*> varMap_;
|
||||
unordered_map<VarId, EgNode*> varMap_;
|
||||
static ElimHeuristic elimHeuristic_;
|
||||
};
|
||||
|
||||
|
@ -8,7 +8,7 @@
|
||||
|
||||
#include "Factor.h"
|
||||
#include "Indexer.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const Factor& g)
|
||||
@ -18,206 +18,33 @@ Factor::Factor (const Factor& g)
|
||||
|
||||
|
||||
|
||||
Factor::Factor (VarId vid, unsigned nStates)
|
||||
Factor::Factor (
|
||||
const VarIds& vids,
|
||||
const Ranges& ranges,
|
||||
const Params& params,
|
||||
unsigned distId)
|
||||
{
|
||||
varids_.push_back (vid);
|
||||
ranges_.push_back (nStates);
|
||||
dist_ = new Distribution (Params (nStates, 1.0));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const VarNodes& vars)
|
||||
{
|
||||
int nParams = 1;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varids_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
nParams *= vars[i]->nrStates();
|
||||
}
|
||||
// create a uniform distribution
|
||||
double val = 1.0 / nParams;
|
||||
dist_ = new Distribution (Params (nParams, val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (VarId vid, unsigned nStates, const Params& params)
|
||||
{
|
||||
varids_.push_back (vid);
|
||||
ranges_.push_back (nStates);
|
||||
dist_ = new Distribution (params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (VarNodes& vars, Distribution* dist)
|
||||
{
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varids_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
}
|
||||
dist_ = dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const VarNodes& vars, const Params& params)
|
||||
{
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varids_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
}
|
||||
dist_ = new Distribution (params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const VarIds& vids,
|
||||
const Ranges& ranges,
|
||||
const Params& params)
|
||||
{
|
||||
varids_ = vids;
|
||||
args_ = vids;
|
||||
ranges_ = ranges;
|
||||
dist_ = new Distribution (params);
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::~Factor (void)
|
||||
Factor::Factor (
|
||||
const Vars& vars,
|
||||
const Params& params,
|
||||
unsigned distId)
|
||||
{
|
||||
if (dist_->shared() == false) {
|
||||
delete dist_;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::setParameters (const Params& params)
|
||||
{
|
||||
assert (dist_->params.size() == params.size());
|
||||
dist_->params = params;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::copyFromFactor (const Factor& g)
|
||||
{
|
||||
varids_ = g.getVarIds();
|
||||
ranges_ = g.getRanges();
|
||||
dist_ = new Distribution (g.getParameters());
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::multiply (const Factor& g)
|
||||
{
|
||||
if (varids_.size() == 0) {
|
||||
copyFromFactor (g);
|
||||
return;
|
||||
}
|
||||
|
||||
const VarIds& g_varids = g.getVarIds();
|
||||
const Ranges& g_ranges = g.getRanges();
|
||||
const Params& g_params = g.getParameters();
|
||||
|
||||
if (varids_ == g_varids) {
|
||||
// optimization: if the factors contain the same set of variables,
|
||||
// we can do a 1 to 1 operation on the parameters
|
||||
if (Globals::logDomain) {
|
||||
Util::add (dist_->params, g_params);
|
||||
} else {
|
||||
Util::multiply (dist_->params, g_params);
|
||||
}
|
||||
} else {
|
||||
bool sharedVars = false;
|
||||
vector<unsigned> gvarpos;
|
||||
for (unsigned i = 0; i < g_varids.size(); i++) {
|
||||
int idx = indexOf (g_varids[i]);
|
||||
if (idx == -1) {
|
||||
insertVariable (g_varids[i], g_ranges[i]);
|
||||
gvarpos.push_back (varids_.size() - 1);
|
||||
} else {
|
||||
sharedVars = true;
|
||||
gvarpos.push_back (idx);
|
||||
}
|
||||
}
|
||||
if (sharedVars == false) {
|
||||
// optimization: if the original factors doesn't have common variables,
|
||||
// we don't need to marry the states of the common variables
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
if (Globals::logDomain) {
|
||||
dist_->params[i] += g_params[count];
|
||||
} else {
|
||||
dist_->params[i] *= g_params[count];
|
||||
}
|
||||
count ++;
|
||||
if (count >= g_params.size()) {
|
||||
count = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
StatesIndexer indexer (ranges_, false);
|
||||
while (indexer.valid()) {
|
||||
unsigned g_li = 0;
|
||||
unsigned prod = 1;
|
||||
for (int j = gvarpos.size() - 1; j >= 0; j--) {
|
||||
g_li += indexer[gvarpos[j]] * prod;
|
||||
prod *= g_ranges[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
dist_->params[indexer] += g_params[g_li];
|
||||
} else {
|
||||
dist_->params[indexer] *= g_params[g_li];
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::insertVariable (VarId varId, unsigned nrStates)
|
||||
{
|
||||
assert (indexOf (varId) == -1);
|
||||
Params oldParams = dist_->params;
|
||||
dist_->params.clear();
|
||||
dist_->params.reserve (oldParams.size() * nrStates);
|
||||
for (unsigned i = 0; i < oldParams.size(); i++) {
|
||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
||||
dist_->params.push_back (oldParams[i]);
|
||||
}
|
||||
}
|
||||
varids_.push_back (varId);
|
||||
ranges_.push_back (nrStates);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::insertVariables (const VarIds& varIds, const Ranges& ranges)
|
||||
{
|
||||
Params oldParams = dist_->params;
|
||||
unsigned nrStates = 1;
|
||||
for (unsigned i = 0; i < varIds.size(); i++) {
|
||||
assert (indexOf (varIds[i]) == -1);
|
||||
varids_.push_back (varIds[i]);
|
||||
ranges_.push_back (ranges[i]);
|
||||
nrStates *= ranges[i];
|
||||
}
|
||||
dist_->params.clear();
|
||||
dist_->params.reserve (oldParams.size() * nrStates);
|
||||
for (unsigned i = 0; i < oldParams.size(); i++) {
|
||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
||||
dist_->params.push_back (oldParams[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
args_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->range());
|
||||
}
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
@ -226,10 +53,10 @@ void
|
||||
Factor::sumOutAllExcept (VarId vid)
|
||||
{
|
||||
assert (indexOf (vid) != -1);
|
||||
while (varids_.back() != vid) {
|
||||
while (args_.back() != vid) {
|
||||
sumOutLastVariable();
|
||||
}
|
||||
while (varids_.front() != vid) {
|
||||
while (args_.front() != vid) {
|
||||
sumOutFirstVariable();
|
||||
}
|
||||
}
|
||||
@ -239,9 +66,10 @@ Factor::sumOutAllExcept (VarId vid)
|
||||
void
|
||||
Factor::sumOutAllExcept (const VarIds& vids)
|
||||
{
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
if (std::find (vids.begin(), vids.end(), varids_[i]) == vids.end()) {
|
||||
sumOut (varids_[i]);
|
||||
for (int i = 0; i < (int)args_.size(); i++) {
|
||||
if (Util::contains (vids, args_[i]) == false) {
|
||||
sumOut (args_[i]);
|
||||
i --;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -254,11 +82,11 @@ Factor::sumOut (VarId vid)
|
||||
int idx = indexOf (vid);
|
||||
assert (idx != -1);
|
||||
|
||||
if (vid == varids_.back()) {
|
||||
if (vid == args_.back()) {
|
||||
sumOutLastVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
if (vid == varids_.front()) {
|
||||
if (vid == args_.front()) {
|
||||
sumOutFirstVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
@ -271,7 +99,7 @@ Factor::sumOut (VarId vid)
|
||||
// on the left of `var', with the states of the remaining vars fixed
|
||||
unsigned leftVarOffset = 1;
|
||||
|
||||
for (int i = varids_.size() - 1; i > idx; i--) {
|
||||
for (int i = args_.size() - 1; i > idx; i--) {
|
||||
varOffset *= ranges_[i];
|
||||
leftVarOffset *= ranges_[i];
|
||||
}
|
||||
@ -280,25 +108,24 @@ Factor::sumOut (VarId vid)
|
||||
unsigned offset = 0;
|
||||
unsigned count1 = 0;
|
||||
unsigned count2 = 0;
|
||||
unsigned newpsSize = dist_->params.size() / ranges_[idx];
|
||||
unsigned newpsSize = params_.size() / ranges_[idx];
|
||||
|
||||
Params newps;
|
||||
newps.reserve (newpsSize);
|
||||
Params& params = dist_->params;
|
||||
|
||||
while (newps.size() < newpsSize) {
|
||||
double sum = Util::addIdenty();
|
||||
double sum = LogAware::addIdenty();
|
||||
for (unsigned i = 0; i < ranges_[idx]; i++) {
|
||||
if (Globals::logDomain) {
|
||||
Util::logSum (sum, params[offset]);
|
||||
sum = Util::logSum (sum, params_[offset]);
|
||||
} else {
|
||||
sum += params[offset];
|
||||
sum += params_[offset];
|
||||
}
|
||||
offset += varOffset;
|
||||
}
|
||||
newps.push_back (sum);
|
||||
count1 ++;
|
||||
if (idx == (int)varids_.size() - 1) {
|
||||
if (idx == (int)args_.size() - 1) {
|
||||
offset = count1 * ranges_[idx];
|
||||
} else {
|
||||
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
||||
@ -308,9 +135,9 @@ Factor::sumOut (VarId vid)
|
||||
offset = (leftVarOffset * count2) + count1;
|
||||
}
|
||||
}
|
||||
varids_.erase (varids_.begin() + idx);
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
dist_->params = newps;
|
||||
params_ = newps;
|
||||
}
|
||||
|
||||
|
||||
@ -318,20 +145,19 @@ Factor::sumOut (VarId vid)
|
||||
void
|
||||
Factor::sumOutFirstVariable (void)
|
||||
{
|
||||
Params& params = dist_->params;
|
||||
unsigned nStates = ranges_.front();
|
||||
unsigned sep = params.size() / nStates;
|
||||
unsigned range = ranges_.front();
|
||||
unsigned sep = params_.size() / range;
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = sep; i < params.size(); i++) {
|
||||
Util::logSum (params[i % sep], params[i]);
|
||||
for (unsigned i = sep; i < params_.size(); i++) {
|
||||
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = sep; i < params.size(); i++) {
|
||||
params[i % sep] += params[i];
|
||||
for (unsigned i = sep; i < params_.size(); i++) {
|
||||
params_[i % sep] += params_[i];
|
||||
}
|
||||
}
|
||||
params.resize (sep);
|
||||
varids_.erase (varids_.begin());
|
||||
params_.resize (sep);
|
||||
args_.erase (args_.begin());
|
||||
ranges_.erase (ranges_.begin());
|
||||
}
|
||||
|
||||
@ -340,143 +166,55 @@ Factor::sumOutFirstVariable (void)
|
||||
void
|
||||
Factor::sumOutLastVariable (void)
|
||||
{
|
||||
Params& params = dist_->params;
|
||||
unsigned nStates = ranges_.back();
|
||||
unsigned range = ranges_.back();
|
||||
unsigned idx1 = 0;
|
||||
unsigned idx2 = 0;
|
||||
if (Globals::logDomain) {
|
||||
while (idx1 < params.size()) {
|
||||
params[idx2] = params[idx1];
|
||||
while (idx1 < params_.size()) {
|
||||
params_[idx2] = params_[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < nStates; j++) {
|
||||
Util::logSum (params[idx2], params[idx1]);
|
||||
for (unsigned j = 1; j < range; j++) {
|
||||
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
|
||||
idx1 ++;
|
||||
}
|
||||
idx2 ++;
|
||||
}
|
||||
} else {
|
||||
while (idx1 < params.size()) {
|
||||
params[idx2] = params[idx1];
|
||||
while (idx1 < params_.size()) {
|
||||
params_[idx2] = params_[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < nStates; j++) {
|
||||
params[idx2] += params[idx1];
|
||||
for (unsigned j = 1; j < range; j++) {
|
||||
params_[idx2] += params_[idx1];
|
||||
idx1 ++;
|
||||
}
|
||||
idx2 ++;
|
||||
}
|
||||
}
|
||||
params.resize (idx2);
|
||||
varids_.pop_back();
|
||||
params_.resize (idx2);
|
||||
args_.pop_back();
|
||||
ranges_.pop_back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::orderVariables (void)
|
||||
Factor::multiply (Factor& g)
|
||||
{
|
||||
VarIds sortedVarIds = varids_;
|
||||
sort (sortedVarIds.begin(), sortedVarIds.end());
|
||||
reorderVariables (sortedVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::reorderVariables (const VarIds& newVarIds)
|
||||
{
|
||||
assert (newVarIds.size() == varids_.size());
|
||||
if (newVarIds == varids_) {
|
||||
if (args_.size() == 0) {
|
||||
copyFromFactor (g);
|
||||
return;
|
||||
}
|
||||
|
||||
Ranges newRanges;
|
||||
vector<unsigned> positions;
|
||||
for (unsigned i = 0; i < newVarIds.size(); i++) {
|
||||
unsigned idx = indexOf (newVarIds[i]);
|
||||
newRanges.push_back (ranges_[idx]);
|
||||
positions.push_back (idx);
|
||||
}
|
||||
|
||||
unsigned N = ranges_.size();
|
||||
Params newParams (dist_->params.size());
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
unsigned li = i;
|
||||
// calculate vector index corresponding to linear index
|
||||
vector<unsigned> vi (N);
|
||||
for (int k = N-1; k >= 0; k--) {
|
||||
vi[k] = li % ranges_[k];
|
||||
li /= ranges_[k];
|
||||
}
|
||||
// convert permuted vector index to corresponding linear index
|
||||
unsigned prod = 1;
|
||||
unsigned new_li = 0;
|
||||
for (int k = N-1; k >= 0; k--) {
|
||||
new_li += vi[positions[k]] * prod;
|
||||
prod *= ranges_[positions[k]];
|
||||
}
|
||||
newParams[new_li] = dist_->params[i];
|
||||
}
|
||||
varids_ = newVarIds;
|
||||
ranges_ = newRanges;
|
||||
dist_->params = newParams;
|
||||
TFactor<VarId>::multiply (g);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::absorveEvidence (VarId vid, unsigned evidence)
|
||||
Factor::reorderAccordingVarIds (void)
|
||||
{
|
||||
int idx = indexOf (vid);
|
||||
assert (idx != -1);
|
||||
|
||||
Params oldParams = dist_->params;
|
||||
dist_->params.clear();
|
||||
dist_->params.reserve (oldParams.size() / ranges_[idx]);
|
||||
StatesIndexer indexer (ranges_);
|
||||
for (unsigned i = 0; i < evidence; i++) {
|
||||
indexer.increment (idx);
|
||||
}
|
||||
while (indexer.valid()) {
|
||||
dist_->params.push_back (oldParams[indexer]);
|
||||
indexer.incrementExcluding (idx);
|
||||
}
|
||||
varids_.erase (varids_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::normalize (void)
|
||||
{
|
||||
Util::normalize (dist_->params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Factor::contains (const VarIds& vars) const
|
||||
{
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
if (indexOf (vars[i]) == -1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
Factor::indexOf (VarId vid) const
|
||||
{
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
if (varids_[i] == vid) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
VarIds sortedVarIds = args_;
|
||||
sort (sortedVarIds.begin(), sortedVarIds.end());
|
||||
reorderArguments (sortedVarIds);
|
||||
}
|
||||
|
||||
|
||||
@ -486,9 +224,9 @@ Factor::getLabel (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "f(" ;
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
ss << VarNode (varids_[i], ranges_[i]).label();
|
||||
ss << Var (args_[i], ranges_[i]).label();
|
||||
}
|
||||
ss << ")" ;
|
||||
return ss.str();
|
||||
@ -499,14 +237,14 @@ Factor::getLabel (void) const
|
||||
void
|
||||
Factor::print (void) const
|
||||
{
|
||||
VarNodes vars;
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
vars.push_back (new VarNode (varids_[i], ranges_[i]));
|
||||
Vars vars;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
vars.push_back (new Var (args_[i], ranges_[i]));
|
||||
}
|
||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
cout << "f(" << jointStrings[i] << ")" ;
|
||||
cout << " = " << dist_->params[i] << endl;
|
||||
vector<string> jointStrings = Util::getStateLines (vars);
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
cout << "[" << distId_ << "] f(" << jointStrings[i] << ")" ;
|
||||
cout << " = " << params_[i] << endl;
|
||||
}
|
||||
cout << endl;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
@ -515,3 +253,13 @@ Factor::print (void) const
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::copyFromFactor (const Factor& g)
|
||||
{
|
||||
args_ = g.arguments();
|
||||
ranges_ = g.ranges();
|
||||
params_ = g.params();
|
||||
distId_ = g.distId();
|
||||
}
|
||||
|
||||
|
@ -3,64 +3,285 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Distribution.h"
|
||||
#include "VarNode.h"
|
||||
#include "Var.h"
|
||||
#include "Indexer.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
class Distribution;
|
||||
|
||||
template <typename T>
|
||||
class TFactor
|
||||
{
|
||||
public:
|
||||
const vector<T>& arguments (void) const { return args_; }
|
||||
|
||||
vector<T>& arguments (void) { return args_; }
|
||||
|
||||
const Ranges& ranges (void) const { return ranges_; }
|
||||
|
||||
const Params& params (void) const { return params_; }
|
||||
|
||||
Params& params (void) { return params_; }
|
||||
|
||||
unsigned nrArguments (void) const { return args_.size(); }
|
||||
|
||||
unsigned size (void) const { return params_.size(); }
|
||||
|
||||
unsigned distId (void) const { return distId_; }
|
||||
|
||||
void setDistId (unsigned id) { distId_ = id; }
|
||||
|
||||
void normalize (void) { LogAware::normalize (params_); }
|
||||
|
||||
void setParams (const Params& newParams)
|
||||
{
|
||||
params_ = newParams;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
int indexOf (const T& t) const
|
||||
{
|
||||
int idx = -1;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i] == t) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
const T& argument (unsigned idx) const
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
T& argument (unsigned idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
unsigned range (unsigned idx) const
|
||||
{
|
||||
assert (idx < ranges_.size());
|
||||
return ranges_[idx];
|
||||
}
|
||||
|
||||
void multiply (TFactor<T>& g)
|
||||
{
|
||||
const vector<T>& g_args = g.arguments();
|
||||
const Ranges& g_ranges = g.ranges();
|
||||
const Params& g_params = g.params();
|
||||
if (args_ == g_args) {
|
||||
// optimization: if the factors contain the same set of args,
|
||||
// we can do a 1 to 1 operation on the parameters
|
||||
if (Globals::logDomain) {
|
||||
Util::add (params_, g_params);
|
||||
} else {
|
||||
Util::multiply (params_, g_params);
|
||||
}
|
||||
} else {
|
||||
bool sharedArgs = false;
|
||||
vector<unsigned> gvarpos;
|
||||
for (unsigned i = 0; i < g_args.size(); i++) {
|
||||
int idx = indexOf (g_args[i]);
|
||||
if (idx == -1) {
|
||||
insertArgument (g_args[i], g_ranges[i]);
|
||||
gvarpos.push_back (args_.size() - 1);
|
||||
} else {
|
||||
sharedArgs = true;
|
||||
gvarpos.push_back (idx);
|
||||
}
|
||||
}
|
||||
if (sharedArgs == false) {
|
||||
// optimization: if the original factors doesn't have common args,
|
||||
// we don't need to marry the states of the common args
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
if (Globals::logDomain) {
|
||||
params_[i] += g_params[count];
|
||||
} else {
|
||||
params_[i] *= g_params[count];
|
||||
}
|
||||
count ++;
|
||||
if (count >= g_params.size()) {
|
||||
count = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
StatesIndexer indexer (ranges_, false);
|
||||
while (indexer.valid()) {
|
||||
unsigned g_li = 0;
|
||||
unsigned prod = 1;
|
||||
for (int j = gvarpos.size() - 1; j >= 0; j--) {
|
||||
g_li += indexer[gvarpos[j]] * prod;
|
||||
prod *= g_ranges[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
params_[indexer] += g_params[g_li];
|
||||
} else {
|
||||
params_[indexer] *= g_params[g_li];
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void absorveEvidence (const T& arg, unsigned evidence)
|
||||
{
|
||||
int idx = indexOf (arg);
|
||||
assert (idx != -1);
|
||||
assert (evidence < ranges_[idx]);
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (copy.size() / ranges_[idx]);
|
||||
StatesIndexer indexer (ranges_);
|
||||
for (unsigned i = 0; i < evidence; i++) {
|
||||
indexer.increment (idx);
|
||||
}
|
||||
while (indexer.valid()) {
|
||||
params_.push_back (copy[indexer]);
|
||||
indexer.incrementExcluding (idx);
|
||||
}
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
void reorderArguments (const vector<T> newArgs)
|
||||
{
|
||||
assert (newArgs.size() == args_.size());
|
||||
if (newArgs == args_) {
|
||||
return; // already in the wanted order
|
||||
}
|
||||
Ranges newRanges;
|
||||
vector<unsigned> positions;
|
||||
for (unsigned i = 0; i < newArgs.size(); i++) {
|
||||
unsigned idx = indexOf (newArgs[i]);
|
||||
newRanges.push_back (ranges_[idx]);
|
||||
positions.push_back (idx);
|
||||
}
|
||||
unsigned N = ranges_.size();
|
||||
Params newParams (params_.size());
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
unsigned li = i;
|
||||
// calculate vector index corresponding to linear index
|
||||
vector<unsigned> vi (N);
|
||||
for (int k = N-1; k >= 0; k--) {
|
||||
vi[k] = li % ranges_[k];
|
||||
li /= ranges_[k];
|
||||
}
|
||||
// convert permuted vector index to corresponding linear index
|
||||
unsigned prod = 1;
|
||||
unsigned new_li = 0;
|
||||
for (int k = N - 1; k >= 0; k--) {
|
||||
new_li += vi[positions[k]] * prod;
|
||||
prod *= ranges_[positions[k]];
|
||||
}
|
||||
newParams[new_li] = params_[i];
|
||||
}
|
||||
args_ = newArgs;
|
||||
ranges_ = newRanges;
|
||||
params_ = newParams;
|
||||
}
|
||||
|
||||
bool contains (const T& arg) const
|
||||
{
|
||||
return Util::contains (args_, arg);
|
||||
}
|
||||
|
||||
bool contains (const vector<T>& args) const
|
||||
{
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (contains (args[i]) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
vector<T> args_;
|
||||
Ranges ranges_;
|
||||
Params params_;
|
||||
unsigned distId_;
|
||||
|
||||
private:
|
||||
void insertArgument (const T& arg, unsigned range)
|
||||
{
|
||||
assert (indexOf (arg) == -1);
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (copy.size() * range);
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
for (unsigned reps = 0; reps < range; reps++) {
|
||||
params_.push_back (copy[i]);
|
||||
}
|
||||
}
|
||||
args_.push_back (arg);
|
||||
ranges_.push_back (range);
|
||||
}
|
||||
|
||||
void insertArguments (const vector<T>& args, const Ranges& ranges)
|
||||
{
|
||||
Params copy = params_;
|
||||
unsigned nrStates = 1;
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
assert (indexOf (args[i]) == -1);
|
||||
args_.push_back (args[i]);
|
||||
ranges_.push_back (ranges[i]);
|
||||
nrStates *= ranges[i];
|
||||
}
|
||||
params_.clear();
|
||||
params_.reserve (copy.size() * nrStates);
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
||||
params_.push_back (copy[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class Factor
|
||||
|
||||
class Factor : public TFactor<VarId>
|
||||
{
|
||||
public:
|
||||
Factor (void) { }
|
||||
|
||||
Factor (const Factor&);
|
||||
Factor (VarId, unsigned);
|
||||
Factor (const VarNodes&);
|
||||
Factor (VarId, unsigned, const Params&);
|
||||
Factor (VarNodes&, Distribution*);
|
||||
Factor (const VarNodes&, const Params&);
|
||||
Factor (const VarIds&, const Ranges&, const Params&);
|
||||
~Factor (void);
|
||||
|
||||
void setParameters (const Params&);
|
||||
void copyFromFactor (const Factor& f);
|
||||
void multiply (const Factor&);
|
||||
void insertVariable (VarId, unsigned);
|
||||
void insertVariables (const VarIds&, const Ranges&);
|
||||
void sumOutAllExcept (VarId);
|
||||
void sumOutAllExcept (const VarIds&);
|
||||
void sumOut (VarId);
|
||||
void sumOutFirstVariable (void);
|
||||
void sumOutLastVariable (void);
|
||||
void orderVariables (void);
|
||||
void reorderVariables (const VarIds&);
|
||||
void absorveEvidence (VarId, unsigned);
|
||||
void normalize (void);
|
||||
bool contains (const VarIds&) const;
|
||||
int indexOf (VarId) const;
|
||||
string getLabel (void) const;
|
||||
void print (void) const;
|
||||
Factor (const VarIds&, const Ranges&, const Params&,
|
||||
unsigned = Util::maxUnsigned());
|
||||
|
||||
const VarIds& getVarIds (void) const { return varids_; }
|
||||
const Ranges& getRanges (void) const { return ranges_; }
|
||||
const Params& getParameters (void) const { return dist_->params; }
|
||||
Distribution* getDistribution (void) const { return dist_; }
|
||||
unsigned nrVariables (void) const { return varids_.size(); }
|
||||
unsigned nrParameters() const { return dist_->params.size(); }
|
||||
Factor (const Vars&, const Params&,
|
||||
unsigned = Util::maxUnsigned());
|
||||
|
||||
void setDistribution (Distribution* dist)
|
||||
{
|
||||
dist_ = dist;
|
||||
}
|
||||
void sumOutAllExcept (VarId);
|
||||
|
||||
void sumOutAllExcept (const VarIds&);
|
||||
|
||||
void sumOut (VarId);
|
||||
|
||||
void sumOutFirstVariable (void);
|
||||
|
||||
void sumOutLastVariable (void);
|
||||
|
||||
void multiply (Factor&);
|
||||
|
||||
void reorderAccordingVarIds (void);
|
||||
|
||||
string getLabel (void) const;
|
||||
|
||||
void print (void) const;
|
||||
|
||||
private:
|
||||
void copyFromFactor (const Factor& f);
|
||||
|
||||
VarIds varids_;
|
||||
Ranges ranges_;
|
||||
Distribution* dist_;
|
||||
};
|
||||
|
||||
#endif // HORUS_FACTOR_H
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "BayesNet.h"
|
||||
#include "BayesBall.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
@ -17,140 +18,92 @@ bool FactorGraph::orderFactorVariables = false;
|
||||
|
||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||
{
|
||||
const FgVarSet& vars = fg.getVarNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
FgVarNode* varNode = new FgVarNode (vars[i]);
|
||||
addVariable (varNode);
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
addVarNode (new VarNode (varNodes[i]));
|
||||
}
|
||||
|
||||
const FgFacSet& facs = fg.getFactorNodes();
|
||||
for (unsigned i = 0; i < facs.size(); i++) {
|
||||
FgFacNode* facNode = new FgFacNode (facs[i]);
|
||||
addFactor (facNode);
|
||||
const FgVarSet& neighs = facs[i]->neighbors();
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
FacNode* facNode = new FacNode (facNodes[i]->factor());
|
||||
addFacNode (facNode);
|
||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
|
||||
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph::FactorGraph (const BayesNet& bn)
|
||||
{
|
||||
const BnNodeSet& nodes = bn.getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
FgVarNode* varNode = new FgVarNode (nodes[i]);
|
||||
addVariable (varNode);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
const BnNodeSet& parents = nodes[i]->getParents();
|
||||
if (!(nodes[i]->hasEvidence() && parents.size() == 0)) {
|
||||
VarNodes neighs;
|
||||
neighs.push_back (varNodes_[nodes[i]->getIndex()]);
|
||||
for (unsigned j = 0; j < parents.size(); j++) {
|
||||
neighs.push_back (varNodes_[parents[j]->getIndex()]);
|
||||
}
|
||||
FgFacNode* fn = new FgFacNode (
|
||||
new Factor (neighs, nodes[i]->getDistribution()));
|
||||
if (orderFactorVariables) {
|
||||
sort (neighs.begin(), neighs.end(), CompVarId());
|
||||
fn->factor()->orderVariables();
|
||||
}
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
{
|
||||
ifstream is (fileName);
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
||||
cerr << "error: cannot read from file " << fileName << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
ignoreLines (is);
|
||||
string line;
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
getline (is, line);
|
||||
if (line != "MARKOV") {
|
||||
cerr << "error: the network must be a MARKOV network " << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
unsigned nVars;
|
||||
is >> nVars;
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
vector<int> domainSizes (nVars);
|
||||
for (unsigned i = 0; i < nVars; i++) {
|
||||
unsigned ds;
|
||||
is >> ds;
|
||||
domainSizes[i] = ds;
|
||||
// read the number of vars
|
||||
ignoreLines (is);
|
||||
unsigned nrVars;
|
||||
is >> nrVars;
|
||||
// read the range of each var
|
||||
ignoreLines (is);
|
||||
Ranges ranges (nrVars);
|
||||
for (unsigned i = 0; i < nrVars; i++) {
|
||||
is >> ranges[i];
|
||||
}
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
for (unsigned i = 0; i < nVars; i++) {
|
||||
addVariable (new FgVarNode (i, domainSizes[i]));
|
||||
}
|
||||
|
||||
unsigned nFactors;
|
||||
is >> nFactors;
|
||||
for (unsigned i = 0; i < nFactors; i++) {
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
unsigned nFactorVars;
|
||||
is >> nFactorVars;
|
||||
VarNodes neighs;
|
||||
for (unsigned j = 0; j < nFactorVars; j++) {
|
||||
unsigned vid;
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
unsigned vid;
|
||||
is >> nrFactors;
|
||||
vector<VarIds> factorVarIds;
|
||||
vector<Ranges> factorRanges;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrArgs;
|
||||
factorVarIds.push_back ({ });
|
||||
factorRanges.push_back ({ });
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
is >> vid;
|
||||
FgVarNode* neigh = getFgVarNode (vid);
|
||||
if (!neigh) {
|
||||
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
|
||||
if (vid >= ranges.size()) {
|
||||
cerr << "error: invalid variable identifier `" << vid << "'" << endl;
|
||||
cerr << "identifiers must be between 0 and " << ranges.size() - 1 ;
|
||||
cerr << endl;
|
||||
abort();
|
||||
}
|
||||
neighs.push_back (neigh);
|
||||
}
|
||||
FgFacNode* fn = new FgFacNode (new Factor (neighs));
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
||||
factorVarIds.back().push_back (vid);
|
||||
factorRanges.back().push_back (ranges[vid]);
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nFactors; i++) {
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
unsigned nParams;
|
||||
is >> nParams;
|
||||
if (facNodes_[i]->getParameters().size() != nParams) {
|
||||
cerr << "error: invalid number of parameters for factor " ;
|
||||
cerr << facNodes_[i]->getLabel() ;
|
||||
cerr << ", expected: " << facNodes_[i]->getParameters().size();
|
||||
cerr << ", given: " << nParams << endl;
|
||||
// read the parameters
|
||||
unsigned nrParams;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrParams;
|
||||
if (nrParams != Util::expectedSize (factorRanges[i])) {
|
||||
cerr << "error: invalid number of parameters for factor nº " << i ;
|
||||
cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
|
||||
cerr << ", given: " << nrParams << endl;
|
||||
abort();
|
||||
}
|
||||
Params params (nParams);
|
||||
for (unsigned j = 0; j < nParams; j++) {
|
||||
double param;
|
||||
is >> param;
|
||||
params[j] = param;
|
||||
Params params (nrParams);
|
||||
for (unsigned j = 0; j < nrParams; j++) {
|
||||
is >> params[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
facNodes_[i]->factor()->setParameters (params);
|
||||
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
|
||||
}
|
||||
is.close();
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
@ -158,87 +111,58 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
void
|
||||
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
{
|
||||
ifstream is (fileName);
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
||||
cerr << "error: cannot read from file " << fileName << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
string line;
|
||||
unsigned nFactors;
|
||||
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> nFactors;
|
||||
|
||||
if (is.fail()) {
|
||||
cerr << "error: cannot read the number of factors" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
getline (is, line);
|
||||
if (is.fail() || line.size() > 0) {
|
||||
cerr << "error: cannot read the number of factors" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nFactors; i++) {
|
||||
unsigned nVars;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
|
||||
is >> nVars;
|
||||
ignoreLines (is);
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
VarId vid;
|
||||
is >> nrFactors;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
// read the factor arguments
|
||||
is >> nrArgs;
|
||||
VarIds vids;
|
||||
for (unsigned j = 0; j < nVars; j++) {
|
||||
VarId vid;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> vid;
|
||||
vids.push_back (vid);
|
||||
}
|
||||
|
||||
VarNodes neighs;
|
||||
unsigned nParams = 1;
|
||||
for (unsigned j = 0; j < nVars; j++) {
|
||||
unsigned dsize;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> dsize;
|
||||
FgVarNode* var = getFgVarNode (vids[j]);
|
||||
if (var == 0) {
|
||||
var = new FgVarNode (vids[j], dsize);
|
||||
addVariable (var);
|
||||
} else {
|
||||
if (var->nrStates() != dsize) {
|
||||
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
||||
cerr << "more factors with different domain sizes" << endl;
|
||||
}
|
||||
// read ranges
|
||||
Ranges ranges (nrArgs);
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> ranges[j];
|
||||
VarNode* var = getVarNode (vids[j]);
|
||||
if (var != 0 && ranges[j] != var->range()) {
|
||||
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
||||
cerr << "more factors with a different range" << endl;
|
||||
}
|
||||
neighs.push_back (var);
|
||||
nParams *= var->nrStates();
|
||||
}
|
||||
Params params (nParams, 0);
|
||||
// read parameters
|
||||
ignoreLines (is);
|
||||
unsigned nNonzeros;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> nNonzeros;
|
||||
|
||||
Params params (Util::expectedSize (ranges), 0);
|
||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||
ignoreLines (is);
|
||||
unsigned index;
|
||||
double val;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> index;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
ignoreLines (is);
|
||||
double val;
|
||||
is >> val;
|
||||
params[index] = val;
|
||||
}
|
||||
reverse (neighs.begin(), neighs.end());
|
||||
reverse (vids.begin(), vids.end());
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
FgFacNode* fn = new FgFacNode (new Factor (neighs, params));
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
||||
}
|
||||
addFactor (Factor (vids, ranges, params));
|
||||
}
|
||||
is.close();
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
@ -256,17 +180,41 @@ FactorGraph::~FactorGraph (void)
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addVariable (FgVarNode* vn)
|
||||
FactorGraph::addFactor (const Factor& factor)
|
||||
{
|
||||
varNodes_.push_back (vn);
|
||||
vn->setIndex (varNodes_.size() - 1);
|
||||
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
|
||||
FacNode* fn = new FacNode (factor);
|
||||
addFacNode (fn);
|
||||
const VarIds& vids = factor.arguments();
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
bool found = false;
|
||||
for (unsigned j = 0; j < varNodes_.size(); j++) {
|
||||
if (varNodes_[j]->varId() == vids[i]) {
|
||||
addEdge (varNodes_[j], fn);
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
VarNode* vn = new VarNode (vids[i], factor.range (i));
|
||||
addVarNode (vn);
|
||||
addEdge (vn, fn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (FgFacNode* fn)
|
||||
FactorGraph::addVarNode (VarNode* vn)
|
||||
{
|
||||
varNodes_.push_back (vn);
|
||||
vn->setIndex (varNodes_.size() - 1);
|
||||
varMap_.insert (make_pair (vn->varId(), vn));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFacNode (FacNode* fn)
|
||||
{
|
||||
facNodes_.push_back (fn);
|
||||
fn->setIndex (facNodes_.size() - 1);
|
||||
@ -275,7 +223,7 @@ FactorGraph::addFactor (FgFacNode* fn)
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
||||
FactorGraph::addEdge (VarNode* vn, FacNode* fn)
|
||||
{
|
||||
vn->addNeighbor (fn);
|
||||
fn->addNeighbor (vn);
|
||||
@ -283,37 +231,6 @@ FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addEdge (FgFacNode* fn, FgVarNode* vn)
|
||||
{
|
||||
fn->addNeighbor (vn);
|
||||
vn->addNeighbor (fn);
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarNode*
|
||||
FactorGraph::getVariableNode (VarId vid) const
|
||||
{
|
||||
FgVarNode* vn = getFgVarNode (vid);
|
||||
assert (vn);
|
||||
return vn;
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarNodes
|
||||
FactorGraph::getVariableNodes (void) const
|
||||
{
|
||||
VarNodes vars;
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
vars.push_back (varNodes_[i]);
|
||||
}
|
||||
return vars;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::isTree (void) const
|
||||
{
|
||||
@ -322,51 +239,42 @@ FactorGraph::isTree (void) const
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::setIndexes (void)
|
||||
DAGraph&
|
||||
FactorGraph::getStructure (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
varNodes_[i]->setIndex (i);
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
facNodes_[i]->setIndex (i);
|
||||
assert (fromBayesNet_);
|
||||
if (structure_.empty()) {
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const VarIds& vids = facNodes_[i]->factor().arguments();
|
||||
for (unsigned j = 1; j < vids.size(); j++) {
|
||||
structure_.addEdge (vids[j], vids[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return structure_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::freeDistributions (void)
|
||||
{
|
||||
set<Distribution*> dists;
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
dists.insert (facNodes_[i]->factor()->getDistribution());
|
||||
}
|
||||
for (set<Distribution*>::iterator it = dists.begin();
|
||||
it != dists.end(); it++) {
|
||||
delete *it;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::printGraphicalModel (void) const
|
||||
FactorGraph::print (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
cout << "VarId = " << varNodes_[i]->varId() << endl;
|
||||
cout << "Label = " << varNodes_[i]->label() << endl;
|
||||
cout << "Nr States = " << varNodes_[i]->nrStates() << endl;
|
||||
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||
cout << "Factors = " ;
|
||||
cout << "var id = " << varNodes_[i]->varId() << endl;
|
||||
cout << "label = " << varNodes_[i]->label() << endl;
|
||||
cout << "range = " << varNodes_[i]->range() << endl;
|
||||
cout << "evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||
cout << "factors = " ;
|
||||
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
|
||||
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
|
||||
}
|
||||
cout << endl << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
facNodes_[i]->factor()->print();
|
||||
cout << endl;
|
||||
facNodes_[i]->factor().print();
|
||||
}
|
||||
}
|
||||
|
||||
@ -381,31 +289,26 @@ FactorGraph::exportToGraphViz (const char* fileName) const
|
||||
cerr << "FactorGraph::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "graph \"" << fileName << "\" {" << endl;
|
||||
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
if (varNodes_[i]->hasEvidence()) {
|
||||
out << '"' << varNodes_[i]->label() << '"' ;
|
||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||
out << " [label=\"" << facNodes_[i]->getLabel();
|
||||
out << "\"" << ", shape=box]" << endl;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const FgVarSet& myVars = facNodes_[i]->neighbors();
|
||||
const VarNodes& myVars = facNodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < myVars.size(); j++) {
|
||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||
out << " -- " ;
|
||||
out << '"' << myVars[j]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
@ -417,30 +320,26 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "FactorGraph::exportToUaiFormat()" << endl;
|
||||
cerr << "error: cannot open file " << fileName << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "MARKOV" << endl;
|
||||
out << varNodes_.size() << endl;
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
out << varNodes_[i]->nrStates() << " " ;
|
||||
out << varNodes_[i]->range() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
|
||||
out << facNodes_.size() << endl;
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const FgVarSet& factorVars = facNodes_[i]->neighbors();
|
||||
const VarNodes& factorVars = facNodes_[i]->neighbors();
|
||||
out << factorVars.size();
|
||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||
out << " " << factorVars[j]->getIndex();
|
||||
}
|
||||
out << endl;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
Params params = facNodes_[i]->getParameters();
|
||||
Params params = facNodes_[i]->factor().params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
@ -450,7 +349,6 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
}
|
||||
out << endl;
|
||||
}
|
||||
|
||||
out.close();
|
||||
}
|
||||
|
||||
@ -461,23 +359,22 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "FactorGraph::exportToLibDaiFormat()" << endl;
|
||||
cerr << "error: cannot open file " << fileName << endl;
|
||||
abort();
|
||||
}
|
||||
out << facNodes_.size() << endl << endl;
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const FgVarSet& factorVars = facNodes_[i]->neighbors();
|
||||
const VarNodes& factorVars = facNodes_[i]->neighbors();
|
||||
out << factorVars.size() << endl;
|
||||
for (int j = factorVars.size() - 1; j >= 0; j--) {
|
||||
out << factorVars[j]->varId() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||
out << factorVars[j]->nrStates() << " " ;
|
||||
out << factorVars[j]->range() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
Params params = facNodes_[i]->factor()->getParameters();
|
||||
Params params = facNodes_[i]->factor().params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
@ -492,6 +389,17 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::ignoreLines (std::ifstream& is) const
|
||||
{
|
||||
string ignoreStr;
|
||||
while (is.peek() == '#' || is.peek() == '\n') {
|
||||
getline (is, ignoreStr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (void) const
|
||||
{
|
||||
@ -511,13 +419,14 @@ FactorGraph::containsCycle (void) const
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (const FgVarNode* v,
|
||||
const FgFacNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
FactorGraph::containsCycle (
|
||||
const VarNode* v,
|
||||
const FacNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedVars[v->getIndex()] = true;
|
||||
const FgFacSet& adjacencies = v->neighbors();
|
||||
const FacNodes& adjacencies = v->neighbors();
|
||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||
int w = adjacencies[i]->getIndex();
|
||||
if (!visitedFactors[w]) {
|
||||
@ -535,13 +444,14 @@ FactorGraph::containsCycle (const FgVarNode* v,
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (const FgFacNode* v,
|
||||
const FgVarNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
FactorGraph::containsCycle (
|
||||
const FacNode* v,
|
||||
const VarNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedFactors[v->getIndex()] = true;
|
||||
const FgVarSet& adjacencies = v->neighbors();
|
||||
const VarNodes& adjacencies = v->neighbors();
|
||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||
int w = adjacencies[i]->getIndex();
|
||||
if (!visitedVars[w]) {
|
||||
|
@ -3,135 +3,139 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "GraphicalModel.h"
|
||||
#include "Distribution.h"
|
||||
#include "Factor.h"
|
||||
#include "BayesNet.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class BayesNet;
|
||||
class FgFacNode;
|
||||
|
||||
class FgVarNode : public VarNode
|
||||
class FacNode;
|
||||
|
||||
class VarNode : public Var
|
||||
{
|
||||
public:
|
||||
FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { }
|
||||
FgVarNode (const VarNode* v) : VarNode (v) { }
|
||||
VarNode (VarId varId, unsigned nrStates)
|
||||
: Var (varId, nrStates) { }
|
||||
|
||||
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); }
|
||||
const FgFacSet& neighbors (void) const { return neighs_; }
|
||||
VarNode (const Var* v) : Var (v) { }
|
||||
|
||||
void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
|
||||
|
||||
const FacNodes& neighbors (void) const { return neighs_; }
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
||||
// members
|
||||
FgFacSet neighs_;
|
||||
DISALLOW_COPY_AND_ASSIGN (VarNode);
|
||||
|
||||
FacNodes neighs_;
|
||||
};
|
||||
|
||||
|
||||
class FgFacNode
|
||||
class FacNode
|
||||
{
|
||||
public:
|
||||
FgFacNode (const FgFacNode* fn) {
|
||||
factor_ = new Factor (*fn->factor());
|
||||
index_ = -1;
|
||||
}
|
||||
FgFacNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
|
||||
Factor* factor() const { return factor_; }
|
||||
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
|
||||
const FgVarSet& neighbors (void) const { return neighs_; }
|
||||
FacNode (const Factor& f) : factor_(f), index_(-1) { }
|
||||
|
||||
const Factor& factor (void) const { return factor_; }
|
||||
|
||||
Factor& factor (void) { return factor_; }
|
||||
|
||||
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
|
||||
|
||||
const VarNodes& neighbors (void) const { return neighs_; }
|
||||
|
||||
int getIndex (void) const { return index_; }
|
||||
|
||||
void setIndex (int index) { index_ = index; }
|
||||
|
||||
string getLabel (void) { return factor_.getLabel(); }
|
||||
|
||||
int getIndex (void) const
|
||||
{
|
||||
assert (index_ != -1);
|
||||
return index_;
|
||||
}
|
||||
void setIndex (int index)
|
||||
{
|
||||
index_ = index;
|
||||
}
|
||||
Distribution* getDistribution (void)
|
||||
{
|
||||
return factor_->getDistribution();
|
||||
}
|
||||
const Params& getParameters (void) const
|
||||
{
|
||||
return factor_->getParameters();
|
||||
}
|
||||
string getLabel (void)
|
||||
{
|
||||
return factor_->getLabel();
|
||||
}
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FgFacNode);
|
||||
DISALLOW_COPY_AND_ASSIGN (FacNode);
|
||||
|
||||
Factor* factor_;
|
||||
int index_;
|
||||
FgVarSet neighs_;
|
||||
VarNodes neighs_;
|
||||
Factor factor_;
|
||||
int index_;
|
||||
};
|
||||
|
||||
|
||||
struct CompVarId
|
||||
{
|
||||
bool operator() (const VarNode* vn1, const VarNode* vn2) const
|
||||
bool operator() (const Var* v1, const Var* v2) const
|
||||
{
|
||||
return vn1->varId() < vn2->varId();
|
||||
return v1->varId() < v2->varId();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class FactorGraph : public GraphicalModel
|
||||
class FactorGraph
|
||||
{
|
||||
public:
|
||||
FactorGraph (void) {};
|
||||
FactorGraph (bool fbn = false) : fromBayesNet_(fbn) { }
|
||||
|
||||
FactorGraph (const FactorGraph&);
|
||||
FactorGraph (const BayesNet&);
|
||||
|
||||
~FactorGraph (void);
|
||||
|
||||
void readFromUaiFormat (const char*);
|
||||
void readFromLibDaiFormat (const char*);
|
||||
void addVariable (FgVarNode*);
|
||||
void addFactor (FgFacNode*);
|
||||
void addEdge (FgVarNode*, FgFacNode*);
|
||||
void addEdge (FgFacNode*, FgVarNode*);
|
||||
VarNode* getVariableNode (unsigned) const;
|
||||
VarNodes getVariableNodes (void) const;
|
||||
bool isTree (void) const;
|
||||
void setIndexes (void);
|
||||
void freeDistributions (void);
|
||||
void printGraphicalModel (void) const;
|
||||
void exportToGraphViz (const char*) const;
|
||||
void exportToUaiFormat (const char*) const;
|
||||
void exportToLibDaiFormat (const char*) const;
|
||||
|
||||
const FgVarSet& getVarNodes (void) const { return varNodes_; }
|
||||
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
|
||||
const VarNodes& varNodes (void) const { return varNodes_; }
|
||||
|
||||
FgVarNode* getFgVarNode (VarId vid) const
|
||||
const FacNodes& facNodes (void) const { return facNodes_; }
|
||||
|
||||
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
|
||||
|
||||
VarNode* getVarNode (VarId vid) const
|
||||
{
|
||||
IndexMap::const_iterator it = varMap_.find (vid);
|
||||
if (it == varMap_.end()) {
|
||||
return 0;
|
||||
} else {
|
||||
return varNodes_[it->second];
|
||||
}
|
||||
VarMap::const_iterator it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
void readFromUaiFormat (const char*);
|
||||
|
||||
void readFromLibDaiFormat (const char*);
|
||||
|
||||
void addFactor (const Factor& factor);
|
||||
|
||||
void addVarNode (VarNode*);
|
||||
|
||||
void addFacNode (FacNode*);
|
||||
|
||||
void addEdge (VarNode*, FacNode*);
|
||||
|
||||
bool isTree (void) const;
|
||||
|
||||
DAGraph& getStructure (void);
|
||||
|
||||
void print (void) const;
|
||||
|
||||
void exportToGraphViz (const char*) const;
|
||||
|
||||
void exportToUaiFormat (const char*) const;
|
||||
|
||||
void exportToLibDaiFormat (const char*) const;
|
||||
|
||||
static bool orderFactorVariables;
|
||||
|
||||
private:
|
||||
//DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||
bool containsCycle (void) const;
|
||||
bool containsCycle (const FgVarNode*, const FgFacNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
bool containsCycle (const FgFacNode*, const FgVarNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||
|
||||
FgVarSet varNodes_;
|
||||
FgFacSet facNodes_;
|
||||
void ignoreLines (std::ifstream&) const;
|
||||
|
||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
||||
IndexMap varMap_;
|
||||
bool containsCycle (void) const;
|
||||
|
||||
bool containsCycle (const VarNode*, const FacNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
|
||||
bool containsCycle (const FacNode*, const VarNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
|
||||
VarNodes varNodes_;
|
||||
FacNodes facNodes_;
|
||||
|
||||
DAGraph structure_;
|
||||
bool fromBayesNet_;
|
||||
|
||||
typedef unordered_map<unsigned, VarNode*> VarMap;
|
||||
VarMap varMap_;
|
||||
};
|
||||
|
||||
#endif // HORUS_FACTORGRAPH_H
|
||||
|
@ -1,175 +0,0 @@
|
||||
#ifndef HORUS_FGBPSOLVER_H
|
||||
#define HORUS_FGBPSOLVER_H
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "Factor.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Util.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
|
||||
class SpLink
|
||||
{
|
||||
public:
|
||||
SpLink (FgFacNode* fn, FgVarNode* vn)
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
|
||||
v2_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
msgSended_ = false;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
virtual ~SpLink (void) {};
|
||||
|
||||
virtual void updateMessage (void)
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
msgSended_ = true;
|
||||
}
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = Util::getMaxNorm (v1_, v2_);
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << fac_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->label();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
FgFacNode* getFactor (void) const { return fac_; }
|
||||
FgVarNode* getVariable (void) const { return var_; }
|
||||
const Params& getMessage (void) const { return *currMsg_; }
|
||||
Params& getNextMessage (void) { return *nextMsg_; }
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
double getResidual (void) const { return residual_; }
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
protected:
|
||||
FgFacNode* fac_;
|
||||
FgVarNode* var_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
|
||||
|
||||
typedef vector<SpLink*> SpLinkSet;
|
||||
|
||||
|
||||
class SPNodeInfo
|
||||
{
|
||||
public:
|
||||
void addSpLink (SpLink* link) { links_.push_back (link); }
|
||||
const SpLinkSet& getLinks (void) { return links_; }
|
||||
|
||||
private:
|
||||
SpLinkSet links_;
|
||||
};
|
||||
|
||||
|
||||
class FgBpSolver : public Solver
|
||||
{
|
||||
public:
|
||||
FgBpSolver (const FactorGraph&);
|
||||
virtual ~FgBpSolver (void);
|
||||
|
||||
void runSolver (void);
|
||||
virtual Params getPosterioriOf (VarId);
|
||||
virtual Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
protected:
|
||||
virtual void initializeSolver (void);
|
||||
virtual void createLinks (void);
|
||||
virtual void maxResidualSchedule (void);
|
||||
virtual void calculateFactor2VariableMsg (SpLink*) const;
|
||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
||||
virtual Params getJointByConditioning (const VarIds&) const;
|
||||
virtual void printLinkInformation (void) const;
|
||||
|
||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
|
||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
|
||||
void updateMessage (SpLink* link)
|
||||
{
|
||||
link->updateMessage();
|
||||
if (DL >= 3) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SPNodeInfo* ninf (const FgVarNode* var) const
|
||||
{
|
||||
return varsI_[var->getIndex()];
|
||||
}
|
||||
|
||||
SPNodeInfo* ninf (const FgFacNode* fac) const
|
||||
{
|
||||
return facsI_[fac->getIndex()];
|
||||
}
|
||||
|
||||
struct CompareResidual {
|
||||
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
||||
{
|
||||
return link1->getResidual() > link2->getResidual();
|
||||
}
|
||||
};
|
||||
|
||||
SpLinkSet links_;
|
||||
unsigned nIters_;
|
||||
vector<SPNodeInfo*> varsI_;
|
||||
vector<SPNodeInfo*> facsI_;
|
||||
const FactorGraph* factorGraph_;
|
||||
|
||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
|
||||
SpLinkMap linkMap_;
|
||||
|
||||
private:
|
||||
void runLoopySolver (void);
|
||||
bool converged (void);
|
||||
|
||||
|
||||
};
|
||||
|
||||
#endif // HORUS_FGBPSOLVER_H
|
||||
|
@ -8,7 +8,9 @@
|
||||
|
||||
|
||||
vector<LiftedOperator*>
|
||||
LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
|
||||
LiftedOperator::getValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<LiftedOperator*> validOps;
|
||||
vector<SumOutOperator*> sumOutOps;
|
||||
@ -28,12 +30,15 @@ LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
|
||||
|
||||
|
||||
void
|
||||
LiftedOperator::printValidOps (ParfactorList& pfList, const Grounds& query)
|
||||
LiftedOperator::printValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<LiftedOperator*> validOps;
|
||||
validOps = LiftedOperator::getValidOps (pfList, query);
|
||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||
cout << "-> " << validOps[i]->toString() << endl;
|
||||
delete validOps[i];
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,14 +61,14 @@ SumOutOperator::getCost (void)
|
||||
pfIter = pfList_.begin();
|
||||
while (pfIter != pfList_.end()) {
|
||||
if ((*pfIter)->containsGroup (groupSet[i])) {
|
||||
int idx = (*pfIter)->indexOfFormulaWithGroup (groupSet[i]);
|
||||
int idx = (*pfIter)->indexOfGroup (groupSet[i]);
|
||||
cost *= (*pfIter)->range (idx);
|
||||
break;
|
||||
}
|
||||
++ pfIter;
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
@ -77,14 +82,13 @@ SumOutOperator::apply (void)
|
||||
pfList_.remove (iters[0]);
|
||||
for (unsigned i = 1; i < iters.size(); i++) {
|
||||
product->multiply (**(iters[i]));
|
||||
delete *(iters[i]);
|
||||
pfList_.remove (iters[i]);
|
||||
pfList_.removeAndDelete (iters[i]);
|
||||
}
|
||||
if (product->nrFormulas() == 1) {
|
||||
if (product->nrArguments() == 1) {
|
||||
delete product;
|
||||
return;
|
||||
}
|
||||
int fIdx = product->indexOfFormulaWithGroup (group_);
|
||||
int fIdx = product->indexOfGroup (group_);
|
||||
LogVarSet excl = product->exclusiveLogVars (fIdx);
|
||||
if (product->constr()->isCountNormalized (excl)) {
|
||||
product->sumOut (fIdx);
|
||||
@ -96,21 +100,21 @@ SumOutOperator::apply (void)
|
||||
pfList_.add (pfs[i]);
|
||||
}
|
||||
delete product;
|
||||
pfList_.shatter();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<SumOutOperator*>
|
||||
SumOutOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
|
||||
SumOutOperator::getValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<SumOutOperator*> validOps;
|
||||
set<unsigned> allGroups;
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
assert (*it);
|
||||
const ProbFormulas& formulas = (*it)->formulas();
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||
allGroups.insert (formulas[i].group());
|
||||
}
|
||||
@ -134,8 +138,8 @@ SumOutOperator::toString (void)
|
||||
stringstream ss;
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = parfactorsWithGroup (pfList_, group_);
|
||||
int idx = (*pfIters[0])->indexOfFormulaWithGroup (group_);
|
||||
ProbFormula f = (*pfIters[0])->formula (idx);
|
||||
int idx = (*pfIters[0])->indexOfGroup (group_);
|
||||
ProbFormula f = (*pfIters[0])->argument (idx);
|
||||
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
||||
ss << "sum out " << f.functor() << "/" << f.arity();
|
||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||
@ -158,9 +162,9 @@ SumOutOperator::validOp (
|
||||
}
|
||||
unordered_map<unsigned, unsigned> groupToRange;
|
||||
for (unsigned i = 0; i < pfIters.size(); i++) {
|
||||
int fIdx = (*pfIters[i])->indexOfFormulaWithGroup (group);
|
||||
if ((*pfIters[i])->formulas()[fIdx].contains (
|
||||
(*pfIters[i])->elimLogVars()) == false) {
|
||||
int fIdx = (*pfIters[i])->indexOfGroup (group);
|
||||
if ((*pfIters[i])->argument (fIdx).contains (
|
||||
(*pfIters[i])->elimLogVars()) == false) {
|
||||
return false;
|
||||
}
|
||||
vector<unsigned> ranges = (*pfIters[i])->ranges();
|
||||
@ -206,8 +210,8 @@ SumOutOperator::isToEliminate (
|
||||
unsigned group,
|
||||
const Grounds& query)
|
||||
{
|
||||
int fIdx = g->indexOfFormulaWithGroup (group);
|
||||
const ProbFormula& formula = g->formula (fIdx);
|
||||
int fIdx = g->indexOfGroup (group);
|
||||
const ProbFormula& formula = g->argument (fIdx);
|
||||
bool toElim = true;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
if (formula.functor() == query[i].functor() &&
|
||||
@ -228,7 +232,7 @@ unsigned
|
||||
CountingOperator::getCost (void)
|
||||
{
|
||||
unsigned cost = 0;
|
||||
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_);
|
||||
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||
unsigned range = (*pfIter_)->range (fIdx);
|
||||
unsigned size = (*pfIter_)->size() / range;
|
||||
TinySet<unsigned> counts;
|
||||
@ -247,18 +251,19 @@ CountingOperator::apply (void)
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
|
||||
(*pfIter_)->countConvert (X_);
|
||||
} else {
|
||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
||||
Parfactor* pf = *pfIter_;
|
||||
pfList_.remove (pfIter_);
|
||||
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
||||
bool cartProduct = pfs[i]->constr()->isCarteesianProduct (
|
||||
(*pfIter_)->countedLogVars() | X_);
|
||||
pfs[i]->countedLogVars() | X_);
|
||||
if (condCount > 1 && cartProduct) {
|
||||
pfs[i]->countConvert (X_);
|
||||
}
|
||||
pfList_.add (pfs[i]);
|
||||
}
|
||||
pfList_.deleteAndRemove (pfIter_);
|
||||
pfList_.shatter();
|
||||
delete pf;
|
||||
}
|
||||
}
|
||||
|
||||
@ -289,14 +294,17 @@ CountingOperator::toString (void)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "count convert " << X_ << " in " ;
|
||||
ss << (*pfIter_)->getHeaderString();
|
||||
ss << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << getCost() << "]" << endl;
|
||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
ss << " º " << pfs[i]->getHeaderString() << endl;
|
||||
ss << " º " << pfs[i]->getLabel() << endl;
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
delete pfs[i];
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -308,8 +316,8 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
||||
if (g->nrFormulas (X) != 1) {
|
||||
return false;
|
||||
}
|
||||
int fIdx = g->indexOfFormulaWithLogVar (X);
|
||||
if (g->formulas()[fIdx].isCounting()) {
|
||||
int fIdx = g->indexOfLogVar (X);
|
||||
if (g->argument (fIdx).isCounting()) {
|
||||
return false;
|
||||
}
|
||||
bool countNormalized = g->constr()->isCountNormalized (X);
|
||||
@ -332,10 +340,10 @@ GroundOperator::getCost (void)
|
||||
unsigned cost = 0;
|
||||
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||
if (isCountingLv) {
|
||||
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_);
|
||||
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||
unsigned currSize = (*pfIter_)->size();
|
||||
unsigned nrHists = (*pfIter_)->range (fIdx);
|
||||
unsigned range = (*pfIter_)->formula(fIdx).range();
|
||||
unsigned range = (*pfIter_)->argument (fIdx).range();
|
||||
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
|
||||
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
|
||||
} else {
|
||||
@ -350,18 +358,17 @@ void
|
||||
GroundOperator::apply (void)
|
||||
{
|
||||
bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||
Parfactor* pf = *pfIter_;
|
||||
pfList_.remove (pfIter_);
|
||||
if (countedLv) {
|
||||
(*pfIter_)->fullExpand (X_);
|
||||
(*pfIter_)->setNewGroups();
|
||||
pfList_.shatter();
|
||||
pf->fullExpand (X_);
|
||||
pfList_.add (pf);
|
||||
} else {
|
||||
ConstraintTrees cts = (*pfIter_)->constr()->ground (X_);
|
||||
ConstraintTrees cts = pf->constr()->ground (X_);
|
||||
for (unsigned i = 0; i < cts.size(); i++) {
|
||||
Parfactor* newPf = new Parfactor (*pfIter_, cts[i]);
|
||||
pfList_.add (newPf);
|
||||
pfList_.add (new Parfactor (pf, cts[i]));
|
||||
}
|
||||
pfList_.deleteAndRemove (pfIter_);
|
||||
pfList_.shatter();
|
||||
delete pf;
|
||||
}
|
||||
}
|
||||
|
||||
@ -393,24 +400,13 @@ GroundOperator::toString (void)
|
||||
((*pfIter_)->countedLogVars().contains (X_))
|
||||
? ss << "full expanding "
|
||||
: ss << "grounding " ;
|
||||
ss << X_ << " in " << (*pfIter_)->getHeaderString();
|
||||
ss << X_ << " in " << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << getCost() << "]" << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
FoveSolver::FoveSolver (const ParfactorList* pfList)
|
||||
{
|
||||
for (ParfactorList::const_iterator it = pfList->begin();
|
||||
it != pfList->end();
|
||||
it ++) {
|
||||
pfList_.addShattered (new Parfactor (**it));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
FoveSolver::getPosterioriOf (const Ground& query)
|
||||
{
|
||||
@ -422,14 +418,12 @@ FoveSolver::getPosterioriOf (const Ground& query)
|
||||
Params
|
||||
FoveSolver::getJointDistributionOf (const Grounds& query)
|
||||
{
|
||||
shatterAgainstQuery (query);
|
||||
runSolver (query);
|
||||
(*pfList_.begin())->normalize();
|
||||
Params params = (*pfList_.begin())->params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
delete *pfList_.begin();
|
||||
return params;
|
||||
}
|
||||
|
||||
@ -438,32 +432,38 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
|
||||
void
|
||||
FoveSolver::absorveEvidence (
|
||||
ParfactorList& pfList,
|
||||
const ObservedFormulas& obsFormulas)
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
bool increment = true;
|
||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||
if (absorved (pfList, it, obsFormulas[i])) {
|
||||
it = pfList.deleteAndRemove (it);
|
||||
increment = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (increment) {
|
||||
++ it;
|
||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||
Parfactors newPfs;
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
Parfactor* pf = *it;
|
||||
it = pfList.remove (it);
|
||||
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
||||
if (absorvedPfs.empty() == false) {
|
||||
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
|
||||
// just remove pf;
|
||||
} else {
|
||||
Util::addToVector (newPfs, absorvedPfs);
|
||||
}
|
||||
delete pf;
|
||||
} else {
|
||||
it = pfList.insertShattered (it, pf);
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
pfList.add (newPfs);
|
||||
}
|
||||
pfList.shatter();
|
||||
if (obsFormulas.empty() == false) {
|
||||
cout << "*******************************************************" << endl;
|
||||
if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||
cout << " -> " << *obsFormulas[i] << endl;
|
||||
cout << " -> " << obsFormulas[i] << endl;
|
||||
}
|
||||
cout << "*******************************************************" << endl;
|
||||
Util::printAsteriskLine();
|
||||
pfList.print();
|
||||
}
|
||||
pfList.print();
|
||||
}
|
||||
|
||||
|
||||
@ -473,14 +473,14 @@ FoveSolver::countNormalize (
|
||||
Parfactor* g,
|
||||
const LogVarSet& set)
|
||||
{
|
||||
if (set.empty()) {
|
||||
assert (false); // TODO
|
||||
return {};
|
||||
}
|
||||
Parfactors normPfs;
|
||||
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
||||
for (unsigned i = 0; i < normCts.size(); i++) {
|
||||
normPfs.push_back (new Parfactor (g, normCts[i]));
|
||||
if (set.empty()) {
|
||||
normPfs.push_back (new Parfactor (*g));
|
||||
} else {
|
||||
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
||||
for (unsigned i = 0; i < normCts.size(); i++) {
|
||||
normPfs.push_back (new Parfactor (g, normCts[i]));
|
||||
}
|
||||
}
|
||||
return normPfs;
|
||||
}
|
||||
@ -490,17 +490,25 @@ FoveSolver::countNormalize (
|
||||
void
|
||||
FoveSolver::runSolver (const Grounds& query)
|
||||
{
|
||||
shatterAgainstQuery (query);
|
||||
runWeakBayesBall (query);
|
||||
while (true) {
|
||||
cout << "---------------------------------------------------" << endl;
|
||||
pfList_.print();
|
||||
LiftedOperator::printValidOps (pfList_, query);
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printDashedLine();
|
||||
pfList_.print();
|
||||
LiftedOperator::printValidOps (pfList_, query);
|
||||
}
|
||||
LiftedOperator* op = getBestOperation (query);
|
||||
if (op == 0) {
|
||||
break;
|
||||
}
|
||||
cout << "best operation: " << op->toString() << endl;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << "best operation: " << op->toString() << endl;
|
||||
}
|
||||
op->apply();
|
||||
delete op;
|
||||
}
|
||||
assert (pfList_.size() > 0);
|
||||
if (pfList_.size() > 1) {
|
||||
ParfactorList::iterator pfIter = pfList_.begin();
|
||||
pfIter ++;
|
||||
@ -514,26 +522,6 @@ FoveSolver::runSolver (const Grounds& query)
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FoveSolver::allEliminated (const Grounds&)
|
||||
{
|
||||
ParfactorList::iterator pfIter = pfList_.begin();
|
||||
while (pfIter != pfList_.end()) {
|
||||
const ProbFormulas formulas = (*pfIter)->formulas();
|
||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||
//bool toElim = false;
|
||||
//for (unsigned j = 0; j < queries.size(); j++) {
|
||||
// if ((*pfIter)->containsGround (queries[i]) == false) {
|
||||
// return
|
||||
// }
|
||||
}
|
||||
++ pfIter;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedOperator*
|
||||
FoveSolver::getBestOperation (const Grounds& query)
|
||||
{
|
||||
@ -548,156 +536,176 @@ FoveSolver::getBestOperation (const Grounds& query)
|
||||
bestCost = cost;
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||
if (validOps[i] != bestOp) {
|
||||
delete validOps[i];
|
||||
}
|
||||
}
|
||||
return bestOp;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::runWeakBayesBall (const Grounds& query)
|
||||
{
|
||||
queue<unsigned> todo; // groups to process
|
||||
set<unsigned> done; // processed or in queue
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
int group = (*it)->findGroup (query[i]);
|
||||
if (group != -1) {
|
||||
todo.push (group);
|
||||
done.insert (group);
|
||||
break;
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
set<Parfactor*> requiredPfs;
|
||||
while (todo.empty() == false) {
|
||||
unsigned group = todo.front();
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false &&
|
||||
(*it)->containsGroup (group)) {
|
||||
vector<unsigned> groups = (*it)->getAllGroups();
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
if (Util::contains (done, groups[i]) == false) {
|
||||
todo.push (groups[i]);
|
||||
done.insert (groups[i]);
|
||||
}
|
||||
}
|
||||
requiredPfs.insert (*it);
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
todo.pop();
|
||||
}
|
||||
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false) {
|
||||
it = pfList_.removeAndDelete (it);
|
||||
} else {
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader ("REQUIRED PARFACTORS");
|
||||
pfList_.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::shatterAgainstQuery (const Grounds& query)
|
||||
{
|
||||
// return;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
if (query[i].isAtom()) {
|
||||
continue;
|
||||
}
|
||||
ParfactorList pfListCopy = pfList_;
|
||||
pfList_.clear();
|
||||
for (ParfactorList::iterator it = pfListCopy.begin();
|
||||
it != pfListCopy.end(); ++ it) {
|
||||
Parfactor* pf = *it;
|
||||
if (pf->containsGround (query[i])) {
|
||||
bool found = false;
|
||||
Parfactors newPfs;
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if ((*it)->containsGround (query[i])) {
|
||||
found = true;
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split =
|
||||
pf->constr()->split (query[i].args(), query[i].arity());
|
||||
(*it)->constr()->split (query[i].args(), query[i].arity());
|
||||
ConstraintTree* commCt = split.first;
|
||||
ConstraintTree* exclCt = split.second;
|
||||
pfList_.add (new Parfactor (pf, commCt));
|
||||
newPfs.push_back (new Parfactor (*it, commCt));
|
||||
if (exclCt->empty() == false) {
|
||||
pfList_.add (new Parfactor (pf, exclCt));
|
||||
newPfs.push_back (new Parfactor (*it, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
delete pf;
|
||||
it = pfList_.removeAndDelete (it);
|
||||
} else {
|
||||
pfList_.add (pf);
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
pfList_.shatter();
|
||||
if (found == false) {
|
||||
cerr << "error: could not find a parfactor with ground " ;
|
||||
cerr << "`" << query[i] << "'" << endl;
|
||||
exit (0);
|
||||
}
|
||||
pfList_.add (newPfs);
|
||||
}
|
||||
cout << endl;
|
||||
cout << "*******************************************************" << endl;
|
||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
cout << " -> " << query[i] << endl;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
Util::printAsteriskLine();
|
||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
cout << " -> " << query[i] << endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList_.print();
|
||||
}
|
||||
cout << "*******************************************************" << endl;
|
||||
pfList_.print();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FoveSolver::absorved (
|
||||
ParfactorList& pfList,
|
||||
ParfactorList::iterator pfIter,
|
||||
const ObservedFormula* obsFormula)
|
||||
Parfactors
|
||||
FoveSolver::absorve (
|
||||
ObservedFormula& obsFormula,
|
||||
Parfactor* g)
|
||||
{
|
||||
Parfactors absorvedPfs;
|
||||
Parfactor* g = *pfIter;
|
||||
const ProbFormulas& formulas = g->formulas();
|
||||
const ProbFormulas& formulas = g->arguments();
|
||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||
if (obsFormula->functor() == formulas[i].functor() &&
|
||||
obsFormula->arity() == formulas[i].arity()) {
|
||||
if (obsFormula.functor() == formulas[i].functor() &&
|
||||
obsFormula.arity() == formulas[i].arity()) {
|
||||
|
||||
if (obsFormula->isAtom()) {
|
||||
if (obsFormula.isAtom()) {
|
||||
if (formulas.size() > 1) {
|
||||
g->absorveEvidence (i, obsFormula->evidence());
|
||||
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
||||
} else {
|
||||
return true;
|
||||
// hack to erase parfactor g
|
||||
absorvedPfs.push_back (0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
g->constr()->moveToTop (formulas[i].logVars());
|
||||
std::pair<ConstraintTree*, ConstraintTree*> res
|
||||
= g->constr()->split (obsFormula->constr(), formulas[i].arity());
|
||||
= g->constr()->split (&(obsFormula.constr()), formulas[i].arity());
|
||||
ConstraintTree* commCt = res.first;
|
||||
ConstraintTree* exclCt = res.second;
|
||||
|
||||
if (commCt->empty()) {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (exclCt->empty() == false) {
|
||||
pfList.add (new Parfactor (g, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
|
||||
if (formulas.size() > 1) {
|
||||
LogVarSet excl = g->exclusiveLogVars (i);
|
||||
Parfactors countNormPfs = countNormalize (g, excl);
|
||||
for (unsigned j = 0; j < countNormPfs.size(); j++) {
|
||||
countNormPfs[j]->absorveEvidence (i, obsFormula->evidence());
|
||||
absorvedPfs.push_back (countNormPfs[j]);
|
||||
if (commCt->empty() == false) {
|
||||
if (formulas.size() > 1) {
|
||||
LogVarSet excl = g->exclusiveLogVars (i);
|
||||
Parfactors countNormPfs = countNormalize (g, excl);
|
||||
for (unsigned j = 0; j < countNormPfs.size(); j++) {
|
||||
countNormPfs[j]->absorveEvidence (
|
||||
formulas[i], obsFormula.evidence());
|
||||
absorvedPfs.push_back (countNormPfs[j]);
|
||||
}
|
||||
} else {
|
||||
delete commCt;
|
||||
}
|
||||
if (exclCt->empty() == false) {
|
||||
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
if (absorvedPfs.empty()) {
|
||||
// hack to erase parfactor g
|
||||
absorvedPfs.push_back (0);
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
}
|
||||
return true;
|
||||
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FoveSolver::proper (
|
||||
const ProbFormula& f1,
|
||||
ConstraintTree* c1,
|
||||
const ProbFormula& f2,
|
||||
ConstraintTree* c2)
|
||||
{
|
||||
return disjoint (f1, c1, f2, c2)
|
||||
|| identical (f1, c1, f2, c2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FoveSolver::identical (
|
||||
const ProbFormula& f1,
|
||||
ConstraintTree* c1,
|
||||
const ProbFormula& f2,
|
||||
ConstraintTree* c2)
|
||||
{
|
||||
if (f1.sameSkeletonAs (f2) == false) {
|
||||
return false;
|
||||
}
|
||||
c1->moveToTop (f1.logVars());
|
||||
c2->moveToTop (f2.logVars());
|
||||
return ConstraintTree::identical (
|
||||
c1, c2, f1.logVars().size());
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FoveSolver::disjoint (
|
||||
const ProbFormula& f1,
|
||||
ConstraintTree* c1,
|
||||
const ProbFormula& f2,
|
||||
ConstraintTree* c2)
|
||||
{
|
||||
if (f1.sameSkeletonAs (f2) == false) {
|
||||
return true;
|
||||
}
|
||||
c1->moveToTop (f1.logVars());
|
||||
c2->moveToTop (f2.logVars());
|
||||
return ConstraintTree::overlap (
|
||||
c1, c2, f1.arity()) == false;
|
||||
return absorvedPfs;
|
||||
}
|
||||
|
||||
|
@ -9,10 +9,14 @@ class LiftedOperator
|
||||
{
|
||||
public:
|
||||
virtual unsigned getCost (void) = 0;
|
||||
|
||||
virtual void apply (void) = 0;
|
||||
|
||||
virtual string toString (void) = 0;
|
||||
|
||||
static vector<LiftedOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
static void printValidOps (ParfactorList&, const Grounds&);
|
||||
};
|
||||
|
||||
@ -23,18 +27,26 @@ class SumOutOperator : public LiftedOperator
|
||||
public:
|
||||
SumOutOperator (unsigned group, ParfactorList& pfList)
|
||||
: group_(group), pfList_(pfList) { }
|
||||
|
||||
unsigned getCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<SumOutOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (unsigned, ParfactorList&, const Grounds&);
|
||||
|
||||
static vector<ParfactorList::iterator> parfactorsWithGroup (
|
||||
ParfactorList& pfList, unsigned group);
|
||||
|
||||
static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
|
||||
unsigned group_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
unsigned group_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
@ -47,15 +59,21 @@ class CountingOperator : public LiftedOperator
|
||||
LogVar X,
|
||||
ParfactorList& pfList)
|
||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||
|
||||
unsigned getCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<CountingOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (Parfactor*, LogVar);
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
@ -68,14 +86,19 @@ class GroundOperator : public LiftedOperator
|
||||
LogVar X,
|
||||
ParfactorList& pfList)
|
||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||
|
||||
unsigned getCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<GroundOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
@ -83,49 +106,29 @@ class GroundOperator : public LiftedOperator
|
||||
class FoveSolver
|
||||
{
|
||||
public:
|
||||
FoveSolver (const ParfactorList*);
|
||||
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
|
||||
|
||||
Params getPosterioriOf (const Ground&);
|
||||
Params getJointDistributionOf (const Grounds&);
|
||||
Params getPosterioriOf (const Ground&);
|
||||
|
||||
static void absorveEvidence (
|
||||
ParfactorList& pfList,
|
||||
const ObservedFormulas& obsFormulas);
|
||||
Params getJointDistributionOf (const Grounds&);
|
||||
|
||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||
static void absorveEvidence (
|
||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||
|
||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||
|
||||
private:
|
||||
void runSolver (const Grounds&);
|
||||
bool allEliminated (const Grounds&);
|
||||
LiftedOperator* getBestOperation (const Grounds&);
|
||||
void shatterAgainstQuery (const Grounds&);
|
||||
void runSolver (const Grounds&);
|
||||
|
||||
static bool absorved (
|
||||
ParfactorList& pfList,
|
||||
ParfactorList::iterator pfIter,
|
||||
const ObservedFormula*);
|
||||
LiftedOperator* getBestOperation (const Grounds&);
|
||||
|
||||
public:
|
||||
void runWeakBayesBall (const Grounds&);
|
||||
|
||||
static bool proper (
|
||||
const ProbFormula&,
|
||||
ConstraintTree*,
|
||||
const ProbFormula&,
|
||||
ConstraintTree*);
|
||||
void shatterAgainstQuery (const Grounds&);
|
||||
|
||||
static bool identical (
|
||||
const ProbFormula&,
|
||||
ConstraintTree*,
|
||||
const ProbFormula&,
|
||||
ConstraintTree*);
|
||||
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||
|
||||
static bool disjoint (
|
||||
const ProbFormula&,
|
||||
ConstraintTree*,
|
||||
const ProbFormula&,
|
||||
ConstraintTree*);
|
||||
|
||||
ParfactorList pfList_;
|
||||
ParfactorList pfList_;
|
||||
};
|
||||
|
||||
#endif // HORUS_FOVESOLVER_H
|
||||
|
@ -1,67 +0,0 @@
|
||||
#ifndef HORUS_GRAPHICALMODEL_H
|
||||
#define HORUS_GRAPHICALMODEL_H
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "VarNode.h"
|
||||
#include "Distribution.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
struct VariableInfo
|
||||
{
|
||||
VariableInfo (string l, const States& sts)
|
||||
{
|
||||
label = l;
|
||||
states = sts;
|
||||
}
|
||||
string label;
|
||||
States states;
|
||||
};
|
||||
|
||||
|
||||
class GraphicalModel
|
||||
{
|
||||
public:
|
||||
virtual ~GraphicalModel (void) {};
|
||||
virtual VarNode* getVariableNode (VarId) const = 0;
|
||||
virtual VarNodes getVariableNodes (void) const = 0;
|
||||
virtual void printGraphicalModel (void) const = 0;
|
||||
|
||||
static void addVariableInformation (VarId vid, string label,
|
||||
const States& states)
|
||||
{
|
||||
assert (varsInfo_.find (vid) == varsInfo_.end());
|
||||
varsInfo_.insert (make_pair (vid, VariableInfo (label, states)));
|
||||
}
|
||||
static VariableInfo getVariableInformation (VarId vid)
|
||||
{
|
||||
assert (varsInfo_.find (vid) != varsInfo_.end());
|
||||
return varsInfo_.find (vid)->second;
|
||||
}
|
||||
static bool variablesHaveInformation (void)
|
||||
{
|
||||
return varsInfo_.size() != 0;
|
||||
}
|
||||
static void clearVariablesInformation (void)
|
||||
{
|
||||
varsInfo_.clear();
|
||||
}
|
||||
static void addDistribution (unsigned id, Distribution* dist)
|
||||
{
|
||||
distsInfo_[id] = dist;
|
||||
}
|
||||
static void updateDistribution (unsigned id, const Params& params)
|
||||
{
|
||||
distsInfo_[id]->updateParameters (params);
|
||||
}
|
||||
|
||||
private:
|
||||
static unordered_map<VarId,VariableInfo> varsInfo_;
|
||||
static unordered_map<unsigned,Distribution*> distsInfo_;
|
||||
};
|
||||
|
||||
#endif // HORUS_GRAPHICALMODEL_H
|
||||
|
@ -84,16 +84,34 @@ HistogramSet::nrHistograms (unsigned N, unsigned R)
|
||||
|
||||
unsigned
|
||||
HistogramSet::findIndex (
|
||||
const Histogram& hist,
|
||||
const vector<Histogram>& histograms)
|
||||
const Histogram& h,
|
||||
const vector<Histogram>& hists)
|
||||
{
|
||||
vector<Histogram>::const_iterator it = std::lower_bound (
|
||||
histograms.begin(),
|
||||
histograms.end(),
|
||||
hist,
|
||||
std::greater<Histogram>());
|
||||
assert (it != histograms.end() && *it == hist);
|
||||
return std::distance (histograms.begin(), it);
|
||||
hists.begin(), hists.end(), h, std::greater<Histogram>());
|
||||
assert (it != hists.end() && *it == h);
|
||||
return std::distance (hists.begin(), it);
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<double>
|
||||
HistogramSet::getNumAssigns (unsigned N, unsigned R)
|
||||
{
|
||||
HistogramSet hs (N, R);
|
||||
unsigned N_factorial = Util::factorial (N);
|
||||
unsigned H = hs.nrHistograms();
|
||||
vector<double> numAssigns;
|
||||
numAssigns.reserve (H);
|
||||
for (unsigned h = 0; h < H; h++) {
|
||||
unsigned prod = 1;
|
||||
for (unsigned r = 0; r < R; r++) {
|
||||
prod *= Util::factorial (hs[r]);
|
||||
}
|
||||
numAssigns.push_back (LogAware::tl (N_factorial / prod));
|
||||
hs.nextHistogram();
|
||||
}
|
||||
return numAssigns;
|
||||
}
|
||||
|
||||
|
||||
|
@ -26,8 +26,9 @@ class HistogramSet
|
||||
static unsigned nrHistograms (unsigned, unsigned);
|
||||
|
||||
static unsigned findIndex (
|
||||
const Histogram&,
|
||||
const vector<Histogram>&);
|
||||
const Histogram&, const vector<Histogram>&);
|
||||
|
||||
static vector<double> getNumAssigns (unsigned, unsigned);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
|
||||
|
||||
|
@ -1,17 +1,9 @@
|
||||
#ifndef HORUS_HORUS_H
|
||||
#define HORUS_HORUS_H
|
||||
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||
TypeName(const TypeName&); \
|
||||
@ -19,55 +11,51 @@
|
||||
|
||||
using namespace std;
|
||||
|
||||
class VarNode;
|
||||
class BayesNode;
|
||||
class FgVarNode;
|
||||
class FgFacNode;
|
||||
class Var;
|
||||
class Factor;
|
||||
class VarNode;
|
||||
class FacNode;
|
||||
|
||||
typedef vector<double> Params;
|
||||
typedef unsigned VarId;
|
||||
typedef vector<VarId> VarIds;
|
||||
typedef vector<VarNode*> VarNodes;
|
||||
typedef vector<BayesNode*> BnNodeSet;
|
||||
typedef vector<FgVarNode*> FgVarSet;
|
||||
typedef vector<FgFacNode*> FgFacSet;
|
||||
typedef vector<Factor*> FactorSet;
|
||||
typedef vector<string> States;
|
||||
typedef vector<unsigned> Ranges;
|
||||
typedef vector<double> Params;
|
||||
typedef unsigned VarId;
|
||||
typedef vector<VarId> VarIds;
|
||||
typedef vector<Var*> Vars;
|
||||
typedef vector<VarNode*> VarNodes;
|
||||
typedef vector<FacNode*> FacNodes;
|
||||
typedef vector<Factor*> Factors;
|
||||
typedef vector<string> States;
|
||||
typedef vector<unsigned> Ranges;
|
||||
|
||||
|
||||
namespace Globals {
|
||||
extern bool logDomain;
|
||||
enum InfAlgorithms
|
||||
{
|
||||
VE, // variable elimination
|
||||
BP, // belief propagation
|
||||
CBP // counting belief propagation
|
||||
};
|
||||
|
||||
|
||||
// level of debug information
|
||||
static const unsigned DL = 1;
|
||||
namespace Globals {
|
||||
|
||||
static const int NO_EVIDENCE = -1;
|
||||
extern bool logDomain;
|
||||
|
||||
extern InfAlgorithms infAlgorithm;
|
||||
|
||||
};
|
||||
|
||||
|
||||
namespace Constants {
|
||||
|
||||
// level of debug information
|
||||
const unsigned DEBUG = 0;
|
||||
|
||||
const int NO_EVIDENCE = -1;
|
||||
|
||||
// number of digits to show when printing a parameter
|
||||
static const unsigned PRECISION = 5;
|
||||
const unsigned PRECISION = 5;
|
||||
|
||||
static const bool COLLECT_STATISTICS = false;
|
||||
const bool COLLECT_STATS = false;
|
||||
|
||||
static const bool EXPORT_TO_GRAPHVIZ = false;
|
||||
static const unsigned EXPORT_MINIMAL_SIZE = 100;
|
||||
|
||||
static const double INF = -numeric_limits<double>::infinity();
|
||||
|
||||
|
||||
|
||||
namespace InfAlgorithms {
|
||||
enum InfAlgs
|
||||
{
|
||||
VE, // variable elimination
|
||||
BN_BP, // bayesian network belief propagation
|
||||
FG_BP, // factor graph belief propagation
|
||||
CBP // counting bp solver
|
||||
};
|
||||
extern InfAlgs infAlgorithm;
|
||||
};
|
||||
|
||||
|
||||
|
@ -3,197 +3,89 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "BayesNet.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "VarElimSolver.h"
|
||||
#include "BnBpSolver.h"
|
||||
#include "FgBpSolver.h"
|
||||
#include "BpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
|
||||
//#include "TinySet.h"
|
||||
#include "LiftedUtils.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
void processArguments (BayesNet&, int, const char* []);
|
||||
void processArguments (FactorGraph&, int, const char* []);
|
||||
void runSolver (Solver*, const VarNodes&);
|
||||
void runSolver (const FactorGraph&, const VarIds&);
|
||||
|
||||
const string USAGE = "usage: \
|
||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||
|
||||
|
||||
class Cenas
|
||||
{
|
||||
public:
|
||||
Cenas (int cc)
|
||||
{
|
||||
c = cc;
|
||||
}
|
||||
//operator int (void) const
|
||||
//{
|
||||
// cout << "return int" << endl;
|
||||
// return c;
|
||||
//}
|
||||
operator double (void) const
|
||||
{
|
||||
cout << "return double" << endl;
|
||||
return 0.0;
|
||||
}
|
||||
private:
|
||||
int c;
|
||||
};
|
||||
./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||
|
||||
|
||||
int
|
||||
main (int argc, const char* argv[])
|
||||
{
|
||||
LogVar X = 3;
|
||||
LogVarSet Xs = X;
|
||||
cout << "set: " << X << endl;
|
||||
Cenas c1 (1);
|
||||
Cenas c2 (3);
|
||||
cout << (c1 < c2) << endl;
|
||||
return 0;
|
||||
if (!argv[1]) {
|
||||
if (argc <= 1) {
|
||||
cerr << "error: no solver specified" << endl;
|
||||
cerr << "error: no graphical model specified" << endl;
|
||||
cerr << USAGE << endl;
|
||||
exit (0);
|
||||
}
|
||||
const string& fileName = argv[1];
|
||||
const string& extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||
if (extension == "xml") {
|
||||
BayesNet bn;
|
||||
bn.readFromBifFormat (argv[1]);
|
||||
processArguments (bn, argc, argv);
|
||||
} else if (extension == "uai") {
|
||||
FactorGraph fg;
|
||||
fg.readFromUaiFormat (argv[1]);
|
||||
processArguments (fg, argc, argv);
|
||||
} else if (extension == "fg") {
|
||||
FactorGraph fg;
|
||||
fg.readFromLibDaiFormat (argv[1]);
|
||||
processArguments (fg, argc, argv);
|
||||
} else {
|
||||
cerr << "error: the graphical model must be defined either " ;
|
||||
cerr << "in a xml, uai or libDAI file" << endl;
|
||||
if (argc <= 2) {
|
||||
cerr << "error: no graphical model specified" << endl;
|
||||
cerr << USAGE << endl;
|
||||
exit (0);
|
||||
}
|
||||
string solver (argv[1]);
|
||||
if (solver == "ve") {
|
||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||
} else if (solver == "bp") {
|
||||
Globals::infAlgorithm = InfAlgorithms::BP;
|
||||
} else if (solver == "cbp") {
|
||||
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||
} else {
|
||||
cerr << "error: unknow solver `" << solver << "'" << endl ;
|
||||
cerr << USAGE << endl;
|
||||
exit(0);
|
||||
}
|
||||
string fileName (argv[2]);
|
||||
string extension = fileName.substr (
|
||||
fileName.find_last_of ('.') + 1);
|
||||
FactorGraph fg;
|
||||
if (extension == "uai") {
|
||||
fg.readFromUaiFormat (fileName.c_str());
|
||||
} else if (extension == "fg") {
|
||||
fg.readFromLibDaiFormat (fileName.c_str());
|
||||
} else {
|
||||
cerr << "error: the graphical model must be defined either " ;
|
||||
cerr << "in a UAI or libDAI file" << endl;
|
||||
exit (0);
|
||||
}
|
||||
processArguments (fg, argc, argv);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
processArguments (BayesNet& bn, int argc, const char* argv[])
|
||||
{
|
||||
VarNodes queryVars;
|
||||
for (int i = 2; i < argc; i++) {
|
||||
const string& arg = argv[i];
|
||||
if (arg.find ('=') == std::string::npos) {
|
||||
BayesNode* queryVar = bn.getBayesNode (arg);
|
||||
if (queryVar) {
|
||||
queryVars.push_back (queryVar);
|
||||
} else {
|
||||
cerr << "error: there isn't a variable labeled of " ;
|
||||
cerr << "`" << arg << "'" ;
|
||||
cerr << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
size_t pos = arg.find ('=');
|
||||
const string& label = arg.substr (0, pos);
|
||||
const string& state = arg.substr (pos + 1);
|
||||
if (label.empty()) {
|
||||
cerr << "error: missing left argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
if (state.empty()) {
|
||||
cerr << "error: missing right argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
BayesNode* node = bn.getBayesNode (label);
|
||||
if (node) {
|
||||
if (node->isValidState (state)) {
|
||||
node->setEvidence (state);
|
||||
} else {
|
||||
cerr << "error: `" << state << "' " ;
|
||||
cerr << "is not a valid state for " ;
|
||||
cerr << "`" << node->label() << "'" ;
|
||||
cerr << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
cerr << "error: there isn't a variable labeled of " ;
|
||||
cerr << "`" << label << "'" ;
|
||||
cerr << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Solver* solver = 0;
|
||||
FactorGraph* fg = 0;
|
||||
switch (InfAlgorithms::infAlgorithm) {
|
||||
case InfAlgorithms::VE:
|
||||
fg = new FactorGraph (bn);
|
||||
solver = new VarElimSolver (*fg);
|
||||
break;
|
||||
case InfAlgorithms::BN_BP:
|
||||
solver = new BnBpSolver (bn);
|
||||
break;
|
||||
case InfAlgorithms::FG_BP:
|
||||
fg = new FactorGraph (bn);
|
||||
solver = new FgBpSolver (*fg);
|
||||
break;
|
||||
case InfAlgorithms::CBP:
|
||||
fg = new FactorGraph (bn);
|
||||
solver = new CbpSolver (*fg);
|
||||
break;
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
runSolver (solver, queryVars);
|
||||
delete fg;
|
||||
bn.freeDistributions();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
{
|
||||
VarNodes queryVars;
|
||||
for (int i = 2; i < argc; i++) {
|
||||
VarIds queryIds;
|
||||
for (int i = 3; i < argc; i++) {
|
||||
const string& arg = argv[i];
|
||||
if (arg.find ('=') == std::string::npos) {
|
||||
if (!Util::isInteger (arg)) {
|
||||
cerr << "error: `" << arg << "' " ;
|
||||
cerr << "is not a valid variable id" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
VarId vid;
|
||||
stringstream ss;
|
||||
ss << arg;
|
||||
ss >> vid;
|
||||
VarNode* queryVar = fg.getFgVarNode (vid);
|
||||
VarNode* queryVar = fg.getVarNode (vid);
|
||||
if (queryVar) {
|
||||
queryVars.push_back (queryVar);
|
||||
queryIds.push_back (vid);
|
||||
} else {
|
||||
cerr << "error: there isn't a variable with " ;
|
||||
cerr << "`" << vid << "' as id" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
@ -201,33 +93,29 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
if (arg.substr (0, pos).empty()) {
|
||||
cerr << "error: missing left argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
if (arg.substr (pos + 1).empty()) {
|
||||
cerr << "error: missing right argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
if (!Util::isInteger (arg.substr (0, pos))) {
|
||||
cerr << "error: `" << arg.substr (0, pos) << "' " ;
|
||||
cerr << "is not a variable id" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
VarId vid;
|
||||
stringstream ss;
|
||||
ss << arg.substr (0, pos);
|
||||
ss >> vid;
|
||||
VarNode* var = fg.getFgVarNode (vid);
|
||||
VarNode* var = fg.getVarNode (vid);
|
||||
if (var) {
|
||||
if (!Util::isInteger (arg.substr (pos + 1))) {
|
||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
||||
cerr << "is not a state index" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
int stateIndex;
|
||||
@ -241,29 +129,31 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
cerr << "is not a valid state index for variable " ;
|
||||
cerr << "`" << var->varId() << "'" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
cerr << "error: there isn't a variable with " ;
|
||||
cerr << "`" << vid << "' as id" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
}
|
||||
}
|
||||
runSolver (fg, queryIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||
{
|
||||
Solver* solver = 0;
|
||||
switch (InfAlgorithms::infAlgorithm) {
|
||||
switch (Globals::infAlgorithm) {
|
||||
case InfAlgorithms::VE:
|
||||
solver = new VarElimSolver (fg);
|
||||
break;
|
||||
case InfAlgorithms::BN_BP:
|
||||
case InfAlgorithms::FG_BP:
|
||||
//cout << "here!" << endl;
|
||||
//fg.printGraphicalModel();
|
||||
//fg.exportToLibDaiFormat ("net.fg");
|
||||
solver = new FgBpSolver (fg);
|
||||
case InfAlgorithms::BP:
|
||||
solver = new BpSolver (fg);
|
||||
break;
|
||||
case InfAlgorithms::CBP:
|
||||
solver = new CbpSolver (fg);
|
||||
@ -271,28 +161,10 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
runSolver (solver, queryVars);
|
||||
fg.freeDistributions();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
runSolver (Solver* solver, const VarNodes& queryVars)
|
||||
{
|
||||
VarIds vids;
|
||||
for (unsigned i = 0; i < queryVars.size(); i++) {
|
||||
vids.push_back (queryVars[i]->varId());
|
||||
}
|
||||
if (queryVars.size() == 0) {
|
||||
solver->runSolver();
|
||||
if (queryIds.size() == 0) {
|
||||
solver->printAllPosterioris();
|
||||
} else if (queryVars.size() == 1) {
|
||||
solver->runSolver();
|
||||
solver->printPosterioriOf (vids[0]);
|
||||
} else {
|
||||
solver->runSolver();
|
||||
solver->printJointDistributionOf (vids);
|
||||
solver->printAnswer (queryIds);
|
||||
}
|
||||
delete solver;
|
||||
}
|
||||
|
@ -7,22 +7,50 @@
|
||||
|
||||
#include <YapInterface.h>
|
||||
|
||||
#include "BayesNet.h"
|
||||
#include "ParfactorList.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "FoveSolver.h"
|
||||
#include "VarElimSolver.h"
|
||||
#include "BnBpSolver.h"
|
||||
#include "FgBpSolver.h"
|
||||
#include "BpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "FoveSolver.h"
|
||||
#include "ParfactorList.h"
|
||||
#include "BayesBall.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||
|
||||
|
||||
Params readParameters (YAP_Term);
|
||||
|
||||
vector<unsigned> readUnsignedList (YAP_Term);
|
||||
|
||||
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
||||
|
||||
Parfactor* readParfactor (YAP_Term);
|
||||
|
||||
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
|
||||
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
|
||||
|
||||
|
||||
|
||||
vector<unsigned>
|
||||
readUnsignedList (YAP_Term list)
|
||||
{
|
||||
vector<unsigned> vec;
|
||||
while (list != YAP_TermNil()) {
|
||||
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
|
||||
list = YAP_TailOfTerm (list);
|
||||
}
|
||||
return vec;
|
||||
}
|
||||
|
||||
Params readParams (YAP_Term);
|
||||
|
||||
|
||||
int createLiftedNetwork (void)
|
||||
@ -30,107 +58,121 @@ int createLiftedNetwork (void)
|
||||
Parfactors parfactors;
|
||||
YAP_Term parfactorList = YAP_ARG1;
|
||||
while (parfactorList != YAP_TermNil()) {
|
||||
YAP_Term parfactor = YAP_HeadOfTerm (parfactorList);
|
||||
|
||||
// read dist id
|
||||
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, parfactor));
|
||||
|
||||
// read the ranges
|
||||
Ranges ranges;
|
||||
YAP_Term rangeList = YAP_ArgOfTerm (3, parfactor);
|
||||
while (rangeList != YAP_TermNil()) {
|
||||
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
|
||||
ranges.push_back (range);
|
||||
rangeList = YAP_TailOfTerm (rangeList);
|
||||
}
|
||||
|
||||
// read parametric random vars
|
||||
ProbFormulas formulas;
|
||||
unsigned count = 0;
|
||||
unordered_map<YAP_Term, LogVar> lvMap;
|
||||
YAP_Term pvList = YAP_ArgOfTerm (2, parfactor);
|
||||
while (pvList != YAP_TermNil()) {
|
||||
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
||||
if (YAP_IsAtomTerm (formulaTerm)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
formulas.push_back (ProbFormula (functor, ranges[count]));
|
||||
} else {
|
||||
LogVars logVars;
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
|
||||
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
|
||||
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
||||
if (it != lvMap.end()) {
|
||||
logVars.push_back (it->second);
|
||||
} else {
|
||||
unsigned newLv = lvMap.size();
|
||||
lvMap[ti] = newLv;
|
||||
logVars.push_back (newLv);
|
||||
}
|
||||
}
|
||||
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
|
||||
}
|
||||
count ++;
|
||||
pvList = YAP_TailOfTerm (pvList);
|
||||
}
|
||||
|
||||
// read the parameters
|
||||
const Params& params = readParams (YAP_ArgOfTerm (4, parfactor));
|
||||
|
||||
// read the constraint
|
||||
Tuples tuples;
|
||||
if (lvMap.size() >= 1) {
|
||||
YAP_Term tupleList = YAP_ArgOfTerm (5, parfactor);
|
||||
while (tupleList != YAP_TermNil()) {
|
||||
YAP_Term term = YAP_HeadOfTerm (tupleList);
|
||||
assert (YAP_IsApplTerm (term));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
assert (lvMap.size() == arity);
|
||||
Tuple tuple (arity);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
||||
if (YAP_IsAtomTerm (ti) == false) {
|
||||
cerr << "error: bad formed constraint" << endl;
|
||||
abort();
|
||||
}
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
tuple[i - 1] = LiftedUtils::getSymbol (name);
|
||||
}
|
||||
tuples.push_back (tuple);
|
||||
tupleList = YAP_TailOfTerm (tupleList);
|
||||
}
|
||||
}
|
||||
parfactors.push_back (new Parfactor (formulas, params, tuples, distId));
|
||||
YAP_Term pfTerm = YAP_HeadOfTerm (parfactorList);
|
||||
parfactors.push_back (readParfactor (pfTerm));
|
||||
parfactorList = YAP_TailOfTerm (parfactorList);
|
||||
}
|
||||
|
||||
// LiftedUtils::printSymbolDictionary();
|
||||
cout << "*******************************************************" << endl;
|
||||
cout << "INITIAL PARFACTORS" << endl;
|
||||
cout << "*******************************************************" << endl;
|
||||
for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||
parfactors[i]->print();
|
||||
cout << endl;
|
||||
if (Constants::DEBUG > 2) {
|
||||
// Util::printHeader ("INITIAL PARFACTORS");
|
||||
// for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||
// parfactors[i]->print();
|
||||
// }
|
||||
}
|
||||
ParfactorList* pfList = new ParfactorList();
|
||||
for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||
pfList->add (parfactors[i]);
|
||||
}
|
||||
cout << endl;
|
||||
cout << "*******************************************************" << endl;
|
||||
cout << "SHATTERED PARFACTORS" << endl;
|
||||
cout << "*******************************************************" << endl;
|
||||
pfList->shatter();
|
||||
pfList->print();
|
||||
|
||||
// insert the evidence
|
||||
ObservedFormulas obsFormulas;
|
||||
YAP_Term observedList = YAP_ARG2;
|
||||
ParfactorList* pfList = new ParfactorList (parfactors);
|
||||
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader ("SHATTERED PARFACTORS");
|
||||
pfList->print();
|
||||
}
|
||||
|
||||
// read evidence
|
||||
ObservedFormulas* obsFormulas = new ObservedFormulas();
|
||||
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
|
||||
|
||||
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
|
||||
YAP_Int p = (YAP_Int) (net);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor* readParfactor (YAP_Term pfTerm)
|
||||
{
|
||||
// read dist id
|
||||
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
||||
|
||||
// read the ranges
|
||||
Ranges ranges;
|
||||
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
|
||||
while (rangeList != YAP_TermNil()) {
|
||||
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
|
||||
ranges.push_back (range);
|
||||
rangeList = YAP_TailOfTerm (rangeList);
|
||||
}
|
||||
|
||||
// read parametric random vars
|
||||
ProbFormulas formulas;
|
||||
unsigned count = 0;
|
||||
unordered_map<YAP_Term, LogVar> lvMap;
|
||||
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
|
||||
while (pvList != YAP_TermNil()) {
|
||||
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
||||
if (YAP_IsAtomTerm (formulaTerm)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
formulas.push_back (ProbFormula (functor, ranges[count]));
|
||||
} else {
|
||||
LogVars logVars;
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
|
||||
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
|
||||
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
||||
if (it != lvMap.end()) {
|
||||
logVars.push_back (it->second);
|
||||
} else {
|
||||
unsigned newLv = lvMap.size();
|
||||
lvMap[ti] = newLv;
|
||||
logVars.push_back (newLv);
|
||||
}
|
||||
}
|
||||
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
|
||||
}
|
||||
count ++;
|
||||
pvList = YAP_TailOfTerm (pvList);
|
||||
}
|
||||
|
||||
// read the parameters
|
||||
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
|
||||
|
||||
// read the constraint
|
||||
Tuples tuples;
|
||||
if (lvMap.size() >= 1) {
|
||||
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
|
||||
while (tupleList != YAP_TermNil()) {
|
||||
YAP_Term term = YAP_HeadOfTerm (tupleList);
|
||||
assert (YAP_IsApplTerm (term));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
assert (lvMap.size() == arity);
|
||||
Tuple tuple (arity);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
||||
if (YAP_IsAtomTerm (ti) == false) {
|
||||
cerr << "error: constraint has free variables" << endl;
|
||||
abort();
|
||||
}
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
tuple[i - 1] = LiftedUtils::getSymbol (name);
|
||||
}
|
||||
tuples.push_back (tuple);
|
||||
tupleList = YAP_TailOfTerm (tupleList);
|
||||
}
|
||||
}
|
||||
return new Parfactor (formulas, params, tuples, distId);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void readLiftedEvidence (
|
||||
YAP_Term observedList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
while (observedList != YAP_TermNil()) {
|
||||
YAP_Term pair = YAP_HeadOfTerm (observedList);
|
||||
YAP_Term ground = YAP_ArgOfTerm (1, pair);
|
||||
@ -155,22 +197,18 @@ int createLiftedNetwork (void)
|
||||
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
|
||||
bool found = false;
|
||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||
if (obsFormulas[i]->functor() == functor &&
|
||||
obsFormulas[i]->arity() == args.size() &&
|
||||
obsFormulas[i]->evidence() == evidence) {
|
||||
obsFormulas[i]->addTuple (args);
|
||||
if (obsFormulas[i].functor() == functor &&
|
||||
obsFormulas[i].arity() == args.size() &&
|
||||
obsFormulas[i].evidence() == evidence) {
|
||||
obsFormulas[i].addTuple (args);
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
obsFormulas.push_back (new ObservedFormula (functor, evidence, args));
|
||||
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
|
||||
}
|
||||
observedList = YAP_TailOfTerm (observedList);
|
||||
}
|
||||
FoveSolver::absorveEvidence (*pfList, obsFormulas);
|
||||
|
||||
YAP_Int p = (YAP_Int) (pfList);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -178,93 +216,46 @@ int createLiftedNetwork (void)
|
||||
int
|
||||
createGroundNetwork (void)
|
||||
{
|
||||
Statistics::incrementPrimaryNetworksCounting();
|
||||
// cout << "creating network number " ;
|
||||
// cout << Statistics::getPrimaryNetworksCounting() << endl;
|
||||
// if (Statistics::getPrimaryNetworksCounting() > 98) {
|
||||
// Statistics::writeStatisticsToFile ("../../compressing.stats");
|
||||
// }
|
||||
BayesNet* bn = new BayesNet();
|
||||
YAP_Term varList = YAP_ARG1;
|
||||
BnNodeSet nodes;
|
||||
vector<VarIds> parents;
|
||||
while (varList != YAP_TermNil()) {
|
||||
YAP_Term var = YAP_HeadOfTerm (varList);
|
||||
VarId vid = (VarId) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
|
||||
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var));
|
||||
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
|
||||
YAP_Term parentL = YAP_ArgOfTerm (4, var);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
|
||||
parents.push_back (VarIds());
|
||||
while (parentL != YAP_TermNil()) {
|
||||
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
|
||||
parents.back().push_back (parentId);
|
||||
parentL = YAP_TailOfTerm (parentL);
|
||||
}
|
||||
Distribution* dist = bn->getDistribution (distId);
|
||||
if (!dist) {
|
||||
dist = new Distribution (distId);
|
||||
bn->addDistribution (dist);
|
||||
}
|
||||
assert (bn->getBayesNode (vid) == 0);
|
||||
nodes.push_back (bn->addNode (vid, dsize, evidence, dist));
|
||||
varList = YAP_TailOfTerm (varList);
|
||||
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
bool fromBayesNet = factorsType == "bayes";
|
||||
FactorGraph* fg = new FactorGraph (fromBayesNet);
|
||||
YAP_Term factorList = YAP_ARG2;
|
||||
while (factorList != YAP_TermNil()) {
|
||||
YAP_Term factor = YAP_HeadOfTerm (factorList);
|
||||
// read the var ids
|
||||
VarIds varIds = readUnsignedList (YAP_ArgOfTerm (1, factor));
|
||||
// read the ranges
|
||||
Ranges ranges = readUnsignedList (YAP_ArgOfTerm (2, factor));
|
||||
// read the parameters
|
||||
Params params = readParameters (YAP_ArgOfTerm (3, factor));
|
||||
// read dist id
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (4, factor));
|
||||
fg->addFactor (Factor (varIds, ranges, params, distId));
|
||||
factorList = YAP_TailOfTerm (factorList);
|
||||
}
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
BnNodeSet ps;
|
||||
for (unsigned j = 0; j < parents[i].size(); j++) {
|
||||
assert (bn->getBayesNode (parents[i][j]) != 0);
|
||||
ps.push_back (bn->getBayesNode (parents[i][j]));
|
||||
}
|
||||
nodes[i]->setParents (ps);
|
||||
|
||||
YAP_Term evidenceList = YAP_ARG3;
|
||||
while (evidenceList != YAP_TermNil()) {
|
||||
YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
|
||||
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
|
||||
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
|
||||
assert (fg->getVarNode (vid));
|
||||
fg->getVarNode (vid)->setEvidence (ev);
|
||||
evidenceList = YAP_TailOfTerm (evidenceList);
|
||||
}
|
||||
bn->setIndexes();
|
||||
YAP_Int p = (YAP_Int) (bn);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setBayesNetParams (void)
|
||||
{
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
while (distList != YAP_TermNil()) {
|
||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
const Params params = readParams (YAP_ArgOfTerm (2, dist));
|
||||
bn->getDistribution(distId)->updateParameters (params);
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setParfactorGraphParams (void)
|
||||
{
|
||||
// FIXME
|
||||
// ParfactorGraph* pfg = (ParfactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
while (distList != YAP_TermNil()) {
|
||||
// YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
// unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
// const Params params = readParams (YAP_ArgOfTerm (2, dist));
|
||||
// pfg->getDistribution(distId)->setData (params);
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
}
|
||||
return TRUE;
|
||||
YAP_Int p = (YAP_Int) (fg);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
readParams (YAP_Term paramL)
|
||||
readParameters (YAP_Term paramL)
|
||||
{
|
||||
Params params;
|
||||
while (paramL!= YAP_TermNil()) {
|
||||
assert (YAP_IsPairTerm (paramL));
|
||||
while (paramL != YAP_TermNil()) {
|
||||
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
|
||||
paramL = YAP_TailOfTerm (paramL);
|
||||
}
|
||||
@ -279,15 +270,14 @@ readParams (YAP_Term paramL)
|
||||
int
|
||||
runLiftedSolver (void)
|
||||
{
|
||||
ParfactorList* pfList = (ParfactorList*) YAP_IntOfTerm (YAP_ARG1);
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
vector<Params> results;
|
||||
|
||||
ParfactorList pfListCopy (*network->first);
|
||||
FoveSolver::absorveEvidence (pfListCopy, *network->second);
|
||||
while (taskList != YAP_TermNil()) {
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
Grounds queryVars;
|
||||
assert (YAP_IsPairTerm (taskList));
|
||||
assert (YAP_IsPairTerm (jointList));
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
while (jointList != YAP_TermNil()) {
|
||||
YAP_Term ground = YAP_HeadOfTerm (jointList);
|
||||
if (YAP_IsAtomTerm (ground)) {
|
||||
@ -310,11 +300,11 @@ runLiftedSolver (void)
|
||||
}
|
||||
jointList = YAP_TailOfTerm (jointList);
|
||||
}
|
||||
FoveSolver solver (pfList);
|
||||
FoveSolver solver (pfListCopy);
|
||||
if (queryVars.size() == 1) {
|
||||
results.push_back (solver.getPosterioriOf (queryVars[0]));
|
||||
} else {
|
||||
assert (false); // TODO joint dist
|
||||
results.push_back (solver.getJointDistributionOf (queryVars));
|
||||
}
|
||||
taskList = YAP_TailOfTerm (taskList);
|
||||
}
|
||||
@ -339,77 +329,23 @@ runLiftedSolver (void)
|
||||
|
||||
|
||||
int
|
||||
runOtherSolvers (void)
|
||||
runGroundSolver (void)
|
||||
{
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
|
||||
vector<VarIds> tasks;
|
||||
std::set<VarId> vids;
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
while (taskList != YAP_TermNil()) {
|
||||
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) {
|
||||
tasks.push_back (VarIds());
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
while (jointList != YAP_TermNil()) {
|
||||
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
||||
assert (bn->getBayesNode (vid));
|
||||
tasks.back().push_back (vid);
|
||||
vids.insert (vid);
|
||||
jointList = YAP_TailOfTerm (jointList);
|
||||
}
|
||||
} else {
|
||||
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
|
||||
assert (bn->getBayesNode (vid));
|
||||
tasks.push_back (VarIds() = {vid});
|
||||
vids.insert (vid);
|
||||
}
|
||||
tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList)));
|
||||
taskList = YAP_TailOfTerm (taskList);
|
||||
}
|
||||
|
||||
Solver* bpSolver = 0;
|
||||
GraphicalModel* graphicalModel = 0;
|
||||
CFactorGraph::checkForIdenticalFactors = false;
|
||||
if (InfAlgorithms::infAlgorithm != InfAlgorithms::VE) {
|
||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (
|
||||
VarIds (vids.begin(), vids.end()));
|
||||
if (InfAlgorithms::infAlgorithm == InfAlgorithms::BN_BP) {
|
||||
graphicalModel = mrn;
|
||||
bpSolver = new BnBpSolver (*static_cast<BayesNet*> (graphicalModel));
|
||||
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::FG_BP) {
|
||||
graphicalModel = new FactorGraph (*mrn);
|
||||
bpSolver = new FgBpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
||||
delete mrn;
|
||||
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::CBP) {
|
||||
graphicalModel = new FactorGraph (*mrn);
|
||||
bpSolver = new CbpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
||||
delete mrn;
|
||||
}
|
||||
bpSolver->runSolver();
|
||||
}
|
||||
|
||||
vector<Params> results;
|
||||
results.reserve (tasks.size());
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
//if (i == 1) exit (0);
|
||||
if (InfAlgorithms::infAlgorithm == InfAlgorithms::VE) {
|
||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (tasks[i]);
|
||||
VarElimSolver* veSolver = new VarElimSolver (*mrn);
|
||||
if (tasks[i].size() == 1) {
|
||||
results.push_back (veSolver->getPosterioriOf (tasks[i][0]));
|
||||
} else {
|
||||
results.push_back (veSolver->getJointDistributionOf (tasks[i]));
|
||||
}
|
||||
delete mrn;
|
||||
delete veSolver;
|
||||
} else {
|
||||
if (tasks[i].size() == 1) {
|
||||
results.push_back (bpSolver->getPosterioriOf (tasks[i][0]));
|
||||
} else {
|
||||
results.push_back (bpSolver->getJointDistributionOf (tasks[i]));
|
||||
}
|
||||
}
|
||||
if (Globals::infAlgorithm == InfAlgorithms::VE) {
|
||||
runVeSolver (fg, tasks, results);
|
||||
} else {
|
||||
runBpSolver (fg, tasks, results);
|
||||
}
|
||||
delete bpSolver;
|
||||
delete graphicalModel;
|
||||
|
||||
YAP_Term list = YAP_TermNil();
|
||||
for (int i = results.size() - 1; i >= 0; i--) {
|
||||
@ -424,32 +360,142 @@ runOtherSolvers (void)
|
||||
}
|
||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||
}
|
||||
|
||||
return YAP_Unify (list, YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setExtraVarsInfo (void)
|
||||
void runVeSolver (
|
||||
FactorGraph* fg,
|
||||
const vector<VarIds>& tasks,
|
||||
vector<Params>& results)
|
||||
{
|
||||
// BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
GraphicalModel::clearVariablesInformation();
|
||||
YAP_Term varsInfoL = YAP_ARG2;
|
||||
while (varsInfoL != YAP_TermNil()) {
|
||||
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
||||
VarId vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
|
||||
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
|
||||
YAP_Term statesL = YAP_ArgOfTerm (3, head);
|
||||
States states;
|
||||
while (statesL != YAP_TermNil()) {
|
||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (statesL));
|
||||
states.push_back ((char*) YAP_AtomName (atom));
|
||||
statesL = YAP_TailOfTerm (statesL);
|
||||
results.reserve (tasks.size());
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
||||
}
|
||||
GraphicalModel::addVariableInformation (vid,
|
||||
(char*) YAP_AtomName (label), states);
|
||||
varsInfoL = YAP_TailOfTerm (varsInfoL);
|
||||
VarElimSolver solver (*mfg);
|
||||
results.push_back (solver.solveQuery (tasks[i]));
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
delete mfg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void runBpSolver (
|
||||
FactorGraph* fg,
|
||||
const vector<VarIds>& tasks,
|
||||
vector<Params>& results)
|
||||
{
|
||||
std::set<VarId> vids;
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
Util::addToSet (vids, tasks[i]);
|
||||
}
|
||||
Solver* solver = 0;
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
mfg = BayesBall::getMinimalFactorGraph (
|
||||
*fg, VarIds (vids.begin(),vids.end()));
|
||||
}
|
||||
if (Globals::infAlgorithm == InfAlgorithms::BP) {
|
||||
solver = new BpSolver (*mfg);
|
||||
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
||||
CFactorGraph::checkForIdenticalFactors = false;
|
||||
solver = new CbpSolver (*mfg);
|
||||
} else {
|
||||
cerr << "error: unknow solver" << endl;
|
||||
abort();
|
||||
}
|
||||
results.reserve (tasks.size());
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
results.push_back (solver->solveQuery (tasks[i]));
|
||||
}
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
delete mfg;
|
||||
}
|
||||
delete solver;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setParfactorsParams (void)
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
ParfactorList* pfList = network->first;
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
unordered_map<unsigned, Params> paramsMap;
|
||||
while (distList != YAP_TermNil()) {
|
||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
assert (Util::contains (paramsMap, distId) == false);
|
||||
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
}
|
||||
ParfactorList::iterator it = pfList->begin();
|
||||
while (it != pfList->end()) {
|
||||
assert (Util::contains (paramsMap, (*it)->distId()));
|
||||
// (*it)->setParams (paramsMap[(*it)->distId()]);
|
||||
++ it;
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setFactorsParams (void)
|
||||
{
|
||||
return TRUE; // TODO
|
||||
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
unordered_map<unsigned, Params> paramsMap;
|
||||
while (distList != YAP_TermNil()) {
|
||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
assert (Util::contains (paramsMap, distId) == false);
|
||||
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
}
|
||||
const FacNodes& facNodes = fg->facNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
unsigned distId = facNodes[i]->factor().distId();
|
||||
assert (Util::contains (paramsMap, distId));
|
||||
facNodes[i]->factor().setParams (paramsMap[distId]);
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setVarsInformation (void)
|
||||
{
|
||||
Var::clearVarsInfo();
|
||||
YAP_Term labelsL = YAP_ARG1;
|
||||
vector<string> labels;
|
||||
while (labelsL != YAP_TermNil()) {
|
||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
|
||||
labels.push_back ((char*) YAP_AtomName (atom));
|
||||
labelsL = YAP_TailOfTerm (labelsL);
|
||||
}
|
||||
unsigned count = 0;
|
||||
YAP_Term stateNamesL = YAP_ARG2;
|
||||
while (stateNamesL != YAP_TermNil()) {
|
||||
States states;
|
||||
YAP_Term namesL = YAP_HeadOfTerm (stateNamesL);
|
||||
while (namesL != YAP_TermNil()) {
|
||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (namesL));
|
||||
states.push_back ((char*) YAP_AtomName (atom));
|
||||
namesL = YAP_TailOfTerm (namesL);
|
||||
}
|
||||
Var::addVarInfo (count, labels[count], states);
|
||||
count ++;
|
||||
stateNamesL = YAP_TailOfTerm (stateNamesL);
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
@ -463,13 +509,11 @@ setHorusFlag (void)
|
||||
if (key == "inf_alg") {
|
||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||
if ( value == "ve") {
|
||||
InfAlgorithms::infAlgorithm = InfAlgorithms::VE;
|
||||
} else if (value == "bn_bp") {
|
||||
InfAlgorithms::infAlgorithm = InfAlgorithms::BN_BP;
|
||||
} else if (value == "fg_bp") {
|
||||
InfAlgorithms::infAlgorithm = InfAlgorithms::FG_BP;
|
||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||
} else if (value == "bp") {
|
||||
Globals::infAlgorithm = InfAlgorithms::BP;
|
||||
} else if (value == "cbp") {
|
||||
InfAlgorithms::infAlgorithm = InfAlgorithms::CBP;
|
||||
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
@ -541,21 +585,21 @@ setHorusFlag (void)
|
||||
|
||||
|
||||
int
|
||||
freeBayesNetwork (void)
|
||||
freeGroundNetwork (void)
|
||||
{
|
||||
//Statistics::writeStatisticsToFile ("stats.txt");
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
bn->freeDistributions();
|
||||
delete bn;
|
||||
delete (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
freeParfactorGraph (void)
|
||||
freeParfactors (void)
|
||||
{
|
||||
delete (ParfactorList*) YAP_IntOfTerm (YAP_ARG1);
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
delete network->first;
|
||||
delete network->second;
|
||||
delete network;
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
@ -564,15 +608,15 @@ freeParfactorGraph (void)
|
||||
extern "C" void
|
||||
init_predicates (void)
|
||||
{
|
||||
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
||||
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
|
||||
YAP_UserCPredicate ("set_parfactor_graph_params", setParfactorGraphParams, 2);
|
||||
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
|
||||
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
||||
YAP_UserCPredicate ("run_other_solvers", runOtherSolvers, 3);
|
||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
||||
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
||||
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
|
||||
YAP_UserCPredicate ("free_parfactor_graph", freeParfactorGraph, 1);
|
||||
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
||||
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 4);
|
||||
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
||||
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
||||
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
||||
YAP_UserCPredicate ("set_factors_params", setFactorsParams, 2);
|
||||
YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
|
||||
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
||||
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
|
||||
YAP_UserCPredicate ("free_ground_network", freeGroundNetwork, 1);
|
||||
}
|
||||
|
||||
|
@ -8,11 +8,13 @@
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "VarNode.h"
|
||||
#include "Var.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
class StatesIndexer {
|
||||
|
||||
class StatesIndexer
|
||||
{
|
||||
public:
|
||||
|
||||
StatesIndexer (const Ranges& ranges, bool calcOffsets = true)
|
||||
@ -29,14 +31,14 @@ class StatesIndexer {
|
||||
}
|
||||
}
|
||||
|
||||
StatesIndexer (const VarNodes& vars, bool calcOffsets = true)
|
||||
StatesIndexer (const Vars& vars, bool calcOffsets = true)
|
||||
{
|
||||
size_ = 1;
|
||||
indices_.resize (vars.size(), 0);
|
||||
ranges_.reserve (vars.size());
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
size_ *= vars[i]->nrStates();
|
||||
ranges_.push_back (vars[i]->range());
|
||||
size_ *= vars[i]->range();
|
||||
}
|
||||
li_ = 0;
|
||||
if (calcOffsets) {
|
||||
@ -134,11 +136,11 @@ class StatesIndexer {
|
||||
return size_ ;
|
||||
}
|
||||
|
||||
friend ostream& operator<< (ostream &out, const StatesIndexer& idx)
|
||||
friend ostream& operator<< (ostream &os, const StatesIndexer& idx)
|
||||
{
|
||||
out << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
|
||||
out << idx.indices_;
|
||||
return out;
|
||||
os << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
|
||||
os << idx.indices_;
|
||||
return os;
|
||||
}
|
||||
|
||||
private:
|
||||
@ -274,21 +276,14 @@ class MapIndexer
|
||||
index_ = 0;
|
||||
}
|
||||
|
||||
friend ostream& operator<< (ostream &out, const MapIndexer& idx)
|
||||
friend ostream& operator<< (ostream &os, const MapIndexer& idx)
|
||||
{
|
||||
out << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
|
||||
out << idx.indices_;
|
||||
return out;
|
||||
os << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
|
||||
os << idx.indices_;
|
||||
return os;
|
||||
}
|
||||
|
||||
private:
|
||||
MapIndexer (const Ranges& ranges) :
|
||||
ranges_(ranges),
|
||||
indices_(ranges.size(), 0),
|
||||
offsets_(ranges.size())
|
||||
{
|
||||
index_ = 0;
|
||||
}
|
||||
unsigned index_;
|
||||
bool valid_;
|
||||
vector<unsigned> ranges_;
|
||||
|
@ -95,26 +95,37 @@ ostream& operator<< (ostream &os, const Ground& gr)
|
||||
|
||||
|
||||
|
||||
void
|
||||
ObservedFormula::addTuple (const Tuple& t)
|
||||
LogVars
|
||||
Substitution::getDiscardedLogVars (void) const
|
||||
{
|
||||
if (constr_ == 0) {
|
||||
LogVars lvs (arity_);
|
||||
for (unsigned i = 0; i < arity_; i++) {
|
||||
lvs[i] = i;
|
||||
LogVars discardedLvs;
|
||||
set<LogVar> doneLvs;
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
it = subs_.begin();
|
||||
while (it != subs_.end()) {
|
||||
if (Util::contains (doneLvs, it->second)) {
|
||||
discardedLvs.push_back (it->first);
|
||||
} else {
|
||||
doneLvs.insert (it->second);
|
||||
}
|
||||
constr_ = new ConstraintTree (lvs);
|
||||
it ++;
|
||||
}
|
||||
constr_->addTuple (t);
|
||||
return discardedLvs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const ObservedFormula of)
|
||||
ostream& operator<< (ostream &os, const Substitution& theta)
|
||||
{
|
||||
os << of.functor_ << "/" << of.arity_;
|
||||
os << "|" << of.constr_->tupleSet();
|
||||
os << " [evidence=" << of.evidence_ << "]";
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
os << "[" ;
|
||||
it = theta.subs_.begin();
|
||||
while (it != theta.subs_.end()) {
|
||||
if (it != theta.subs_.begin()) os << ", " ;
|
||||
os << it->first << "->" << it->second ;
|
||||
++ it;
|
||||
}
|
||||
os << "]" ;
|
||||
return os;
|
||||
}
|
||||
|
||||
|
@ -18,11 +18,17 @@ class Symbol
|
||||
{
|
||||
public:
|
||||
Symbol (void) : id_(numeric_limits<unsigned>::max()) { }
|
||||
|
||||
Symbol (unsigned id) : id_(id) { }
|
||||
|
||||
operator unsigned (void) const { return id_; }
|
||||
|
||||
bool valid (void) const { return id_ != numeric_limits<unsigned>::max(); }
|
||||
|
||||
static Symbol invalid (void) { return Symbol(); }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Symbol& s);
|
||||
|
||||
private:
|
||||
unsigned id_;
|
||||
};
|
||||
@ -32,7 +38,9 @@ class LogVar
|
||||
{
|
||||
public:
|
||||
LogVar (void) : id_(numeric_limits<unsigned>::max()) { }
|
||||
|
||||
LogVar (unsigned id) : id_(id) { }
|
||||
|
||||
operator unsigned (void) const { return id_; }
|
||||
|
||||
LogVar& operator++ (void)
|
||||
@ -48,6 +56,7 @@ class LogVar
|
||||
}
|
||||
|
||||
friend ostream& operator<< (ostream &os, const LogVar& X);
|
||||
|
||||
private:
|
||||
unsigned id_;
|
||||
};
|
||||
@ -79,8 +88,8 @@ ostream& operator<< (ostream &os, const Tuple& t);
|
||||
|
||||
|
||||
namespace LiftedUtils {
|
||||
Symbol getSymbol (const string&);
|
||||
void printSymbolDictionary (void);
|
||||
Symbol getSymbol (const string&);
|
||||
void printSymbolDictionary (void);
|
||||
}
|
||||
|
||||
|
||||
@ -89,71 +98,56 @@ class Ground
|
||||
{
|
||||
public:
|
||||
Ground (Symbol f) : functor_(f) { }
|
||||
|
||||
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
Symbols args (void) const { return args_; }
|
||||
unsigned arity (void) const { return args_.size(); }
|
||||
bool isAtom (void) const { return args_.size() == 0; }
|
||||
Symbol functor (void) const { return functor_; }
|
||||
|
||||
Symbols args (void) const { return args_; }
|
||||
|
||||
unsigned arity (void) const { return args_.size(); }
|
||||
|
||||
bool isAtom (void) const { return args_.size() == 0; }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Ground& gr);
|
||||
|
||||
private:
|
||||
Symbol functor_;
|
||||
Symbols args_;
|
||||
Symbol functor_;
|
||||
Symbols args_;
|
||||
};
|
||||
|
||||
typedef vector<Ground> Grounds;
|
||||
|
||||
|
||||
|
||||
class ConstraintTree;
|
||||
class ObservedFormula
|
||||
{
|
||||
public:
|
||||
ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
||||
: functor_(f), arity_(a), evidence_(ev), constr_(0) { }
|
||||
|
||||
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
||||
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(0)
|
||||
{
|
||||
addTuple (tuple);
|
||||
}
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
unsigned arity (void) const { return arity_; }
|
||||
unsigned evidence (void) const { return evidence_; }
|
||||
ConstraintTree* constr (void) const { return constr_; }
|
||||
bool isAtom (void) const { return arity_ == 0; }
|
||||
|
||||
void addTuple (const Tuple& t);
|
||||
friend ostream& operator<< (ostream &os, const ObservedFormula opv);
|
||||
private:
|
||||
Symbol functor_;
|
||||
unsigned arity_;
|
||||
unsigned evidence_;
|
||||
ConstraintTree* constr_;
|
||||
};
|
||||
typedef vector<ObservedFormula*> ObservedFormulas;
|
||||
|
||||
|
||||
|
||||
class Substitution
|
||||
{
|
||||
public:
|
||||
void add (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (Util::contains (subs_, X_old) == false);
|
||||
subs_.insert (make_pair (X_old, X_new));
|
||||
}
|
||||
|
||||
void rename (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (subs_.find (X_old) != subs_.end());
|
||||
assert (Util::contains (subs_, X_old));
|
||||
subs_.find (X_old)->second = X_new;
|
||||
}
|
||||
|
||||
LogVar newNameFor (LogVar X) const
|
||||
{
|
||||
assert (subs_.find (X) != subs_.end());
|
||||
assert (Util::contains (subs_, X));
|
||||
return subs_.find (X)->second;
|
||||
}
|
||||
|
||||
LogVars getDiscardedLogVars (void) const;
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
||||
|
||||
private:
|
||||
unordered_map<LogVar, LogVar> subs_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
@ -45,9 +45,8 @@ CWD=$(PWD)
|
||||
|
||||
|
||||
HEADERS = \
|
||||
$(srcdir)/GraphicalModel.h \
|
||||
$(srcdir)/BayesNet.h \
|
||||
$(srcdir)/BayesNode.h \
|
||||
$(srcdir)/BayesBall.h \
|
||||
$(srcdir)/ElimGraph.h \
|
||||
$(srcdir)/FactorGraph.h \
|
||||
$(srcdir)/Factor.h \
|
||||
@ -55,12 +54,10 @@ HEADERS = \
|
||||
$(srcdir)/ConstraintTree.h \
|
||||
$(srcdir)/Solver.h \
|
||||
$(srcdir)/VarElimSolver.h \
|
||||
$(srcdir)/BnBpSolver.h \
|
||||
$(srcdir)/FgBpSolver.h \
|
||||
$(srcdir)/BpSolver.h \
|
||||
$(srcdir)/CbpSolver.h \
|
||||
$(srcdir)/FoveSolver.h \
|
||||
$(srcdir)/VarNode.h \
|
||||
$(srcdir)/Distribution.h \
|
||||
$(srcdir)/Var.h \
|
||||
$(srcdir)/Indexer.h \
|
||||
$(srcdir)/Parfactor.h \
|
||||
$(srcdir)/ProbFormula.h \
|
||||
@ -69,22 +66,20 @@ HEADERS = \
|
||||
$(srcdir)/LiftedUtils.h \
|
||||
$(srcdir)/TinySet.h \
|
||||
$(srcdir)/Util.h \
|
||||
$(srcdir)/Horus.h \
|
||||
$(srcdir)/xmlParser/xmlParser.h
|
||||
$(srcdir)/Horus.h
|
||||
|
||||
CPP_SOURCES = \
|
||||
$(srcdir)/BayesNet.cpp \
|
||||
$(srcdir)/BayesNode.cpp \
|
||||
$(srcdir)/BayesBall.cpp \
|
||||
$(srcdir)/ElimGraph.cpp \
|
||||
$(srcdir)/FactorGraph.cpp \
|
||||
$(srcdir)/Factor.cpp \
|
||||
$(srcdir)/CFactorGraph.cpp \
|
||||
$(srcdir)/ConstraintTree.cpp \
|
||||
$(srcdir)/VarNode.cpp \
|
||||
$(srcdir)/Var.cpp \
|
||||
$(srcdir)/Solver.cpp \
|
||||
$(srcdir)/VarElimSolver.cpp \
|
||||
$(srcdir)/BnBpSolver.cpp \
|
||||
$(srcdir)/FgBpSolver.cpp \
|
||||
$(srcdir)/BpSolver.cpp \
|
||||
$(srcdir)/CbpSolver.cpp \
|
||||
$(srcdir)/FoveSolver.cpp \
|
||||
$(srcdir)/Parfactor.cpp \
|
||||
@ -94,22 +89,20 @@ CPP_SOURCES = \
|
||||
$(srcdir)/LiftedUtils.cpp \
|
||||
$(srcdir)/Util.cpp \
|
||||
$(srcdir)/HorusYap.cpp \
|
||||
$(srcdir)/HorusCli.cpp \
|
||||
$(srcdir)/xmlParser/xmlParser.cpp
|
||||
$(srcdir)/HorusCli.cpp
|
||||
|
||||
OBJS = \
|
||||
BayesNet.o \
|
||||
BayesNode.o \
|
||||
BayesBall.o \
|
||||
ElimGraph.o \
|
||||
FactorGraph.o \
|
||||
Factor.o \
|
||||
CFactorGraph.o \
|
||||
ConstraintTree.o \
|
||||
VarNode.o \
|
||||
Var.o \
|
||||
Solver.o \
|
||||
VarElimSolver.o \
|
||||
BnBpSolver.o \
|
||||
FgBpSolver.o \
|
||||
BpSolver.o \
|
||||
CbpSolver.o \
|
||||
FoveSolver.o \
|
||||
Parfactor.o \
|
||||
@ -122,17 +115,16 @@ OBJS = \
|
||||
|
||||
HCLI_OBJS = \
|
||||
BayesNet.o \
|
||||
BayesNode.o \
|
||||
BayesBall.o \
|
||||
ElimGraph.o \
|
||||
FactorGraph.o \
|
||||
Factor.o \
|
||||
CFactorGraph.o \
|
||||
ConstraintTree.o \
|
||||
VarNode.o \
|
||||
Var.o \
|
||||
Solver.o \
|
||||
VarElimSolver.o \
|
||||
BnBpSolver.o \
|
||||
FgBpSolver.o \
|
||||
BpSolver.o \
|
||||
CbpSolver.o \
|
||||
FoveSolver.o \
|
||||
Parfactor.o \
|
||||
@ -141,7 +133,6 @@ HCLI_OBJS = \
|
||||
ParfactorList.o \
|
||||
LiftedUtils.o \
|
||||
Util.o \
|
||||
xmlParser/xmlParser.o \
|
||||
HorusCli.o
|
||||
|
||||
SOBJS=horus.@SO@
|
||||
@ -154,10 +145,6 @@ all: $(SOBJS) hcli
|
||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
||||
|
||||
|
||||
xmlParser/xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
|
||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
||||
|
||||
|
||||
@DO_SECOND_LD@horus.@SO@: $(OBJS)
|
||||
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
||||
|
||||
@ -171,7 +158,7 @@ install: all
|
||||
|
||||
|
||||
clean:
|
||||
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK hcli xmlParser/*.o
|
||||
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK hcli
|
||||
|
||||
|
||||
erase_dots:
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include "Parfactor.h"
|
||||
#include "Histogram.h"
|
||||
#include "Indexer.h"
|
||||
#include "Util.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
@ -11,55 +12,58 @@ Parfactor::Parfactor (
|
||||
const Tuples& tuples,
|
||||
unsigned distId)
|
||||
{
|
||||
formulas_ = formulas;
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
args_ = formulas;
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
|
||||
LogVars logVars;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
ranges_.push_back (formulas_[i].range());
|
||||
const LogVars& lvs = formulas_[i].logVars();
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
ranges_.push_back (args_[i].range());
|
||||
const LogVars& lvs = args_[i].logVars();
|
||||
for (unsigned j = 0; j < lvs.size(); j++) {
|
||||
if (std::find (logVars.begin(), logVars.end(), lvs[j]) ==
|
||||
logVars.end()) {
|
||||
if (Util::contains (logVars, lvs[j]) == false) {
|
||||
logVars.push_back (lvs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
constr_ = new ConstraintTree (logVars, tuples);
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
|
||||
{
|
||||
formulas_ = g->formulas();
|
||||
params_ = g->params();
|
||||
ranges_ = g->ranges();
|
||||
distId_ = g->distId();
|
||||
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
||||
args_ = g->arguments();
|
||||
params_ = g->params();
|
||||
ranges_ = g->ranges();
|
||||
distId_ = g->distId();
|
||||
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
|
||||
{
|
||||
formulas_ = g->formulas();
|
||||
params_ = g->params();
|
||||
ranges_ = g->ranges();
|
||||
distId_ = g->distId();
|
||||
constr_ = constr;
|
||||
args_ = g->arguments();
|
||||
params_ = g->params();
|
||||
ranges_ = g->ranges();
|
||||
distId_ = g->distId();
|
||||
constr_ = constr;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor::Parfactor (const Parfactor& g)
|
||||
{
|
||||
formulas_ = g.formulas();
|
||||
params_ = g.params();
|
||||
ranges_ = g.ranges();
|
||||
distId_ = g.distId();
|
||||
constr_ = new ConstraintTree (*g.constr());
|
||||
args_ = g.arguments();
|
||||
params_ = g.params();
|
||||
ranges_ = g.ranges();
|
||||
distId_ = g.distId();
|
||||
constr_ = new ConstraintTree (*g.constr());
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
@ -75,9 +79,9 @@ LogVarSet
|
||||
Parfactor::countedLogVars (void) const
|
||||
{
|
||||
LogVarSet set;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (formulas_[i].isCounting()) {
|
||||
set.insert (formulas_[i].countedLogVar());
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].isCounting()) {
|
||||
set.insert (args_[i].countedLogVar());
|
||||
}
|
||||
}
|
||||
return set;
|
||||
@ -107,14 +111,14 @@ Parfactor::elimLogVars (void) const
|
||||
LogVarSet
|
||||
Parfactor::exclusiveLogVars (unsigned fIdx) const
|
||||
{
|
||||
assert (fIdx < formulas_.size());
|
||||
assert (fIdx < args_.size());
|
||||
LogVarSet remaining;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != fIdx) {
|
||||
remaining |= formulas_[i].logVarSet();
|
||||
remaining |= args_[i].logVarSet();
|
||||
}
|
||||
}
|
||||
return formulas_[fIdx].logVarSet() - remaining;
|
||||
return args_[fIdx].logVarSet() - remaining;
|
||||
}
|
||||
|
||||
|
||||
@ -131,44 +135,51 @@ Parfactor::setConstraintTree (ConstraintTree* newTree)
|
||||
void
|
||||
Parfactor::sumOut (unsigned fIdx)
|
||||
{
|
||||
assert (fIdx < formulas_.size());
|
||||
assert (formulas_[fIdx].contains (elimLogVars()));
|
||||
assert (fIdx < args_.size());
|
||||
assert (args_[fIdx].contains (elimLogVars()));
|
||||
|
||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||
unsigned condCount = constr_->getConditionalCount (excl);
|
||||
Util::pow (params_, condCount);
|
||||
if (args_[fIdx].isCounting()) {
|
||||
LogAware::pow (params_, constr_->getConditionalCount (
|
||||
excl - args_[fIdx].countedLogVar()));
|
||||
} else {
|
||||
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||
}
|
||||
|
||||
vector<unsigned> numAssigns (ranges_[fIdx], 1);
|
||||
if (formulas_[fIdx].isCounting()) {
|
||||
if (args_[fIdx].isCounting()) {
|
||||
unsigned N = constr_->getConditionalCount (
|
||||
formulas_[fIdx].countedLogVar());
|
||||
unsigned R = formulas_[fIdx].range();
|
||||
unsigned H = ranges_[fIdx];
|
||||
HistogramSet hs (N, R);
|
||||
unsigned N_factorial = Util::factorial (N);
|
||||
for (unsigned h = 0; h < H; h++) {
|
||||
unsigned prod = 1;
|
||||
for (unsigned r = 0; r < R; r++) {
|
||||
prod *= Util::factorial (hs[r]);
|
||||
args_[fIdx].countedLogVar());
|
||||
unsigned R = args_[fIdx].range();
|
||||
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||
StatesIndexer sindexer (ranges_, fIdx);
|
||||
while (sindexer.valid()) {
|
||||
unsigned h = sindexer[fIdx];
|
||||
if (Globals::logDomain) {
|
||||
params_[sindexer] += numAssigns[h];
|
||||
} else {
|
||||
params_[sindexer] *= numAssigns[h];
|
||||
}
|
||||
numAssigns[h] = N_factorial / prod;
|
||||
hs.nextHistogram();
|
||||
++ sindexer;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.resize (copy.size() / ranges_[fIdx], 0.0);
|
||||
|
||||
params_.resize (copy.size() / ranges_[fIdx], LogAware::addIdenty());
|
||||
MapIndexer indexer (ranges_, fIdx);
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
unsigned h = indexer[fIdx];
|
||||
// TODO NOT LOG DOMAIN AWARE :(
|
||||
params_[indexer] += numAssigns[h] * copy[i];
|
||||
++ indexer;
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
params_[indexer] = Util::logSum (params_[indexer], copy[i]);
|
||||
++ indexer;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
params_[indexer] += copy[i];
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
formulas_.erase (formulas_.begin() + fIdx);
|
||||
|
||||
args_.erase (args_.begin() + fIdx);
|
||||
ranges_.erase (ranges_.begin() + fIdx);
|
||||
constr_->remove (excl);
|
||||
}
|
||||
@ -179,55 +190,7 @@ void
|
||||
Parfactor::multiply (Parfactor& g)
|
||||
{
|
||||
alignAndExponentiate (this, &g);
|
||||
bool sharedVars = false;
|
||||
vector<unsigned> g_varpos;
|
||||
const ProbFormulas& g_formulas = g.formulas();
|
||||
const Params& g_params = g.params();
|
||||
const Ranges& g_ranges = g.ranges();
|
||||
|
||||
for (unsigned i = 0; i < g_formulas.size(); i++) {
|
||||
int group = g_formulas[i].group();
|
||||
if (indexOfFormulaWithGroup (group) == -1) {
|
||||
insertDimension (g.ranges()[i]);
|
||||
formulas_.push_back (g_formulas[i]);
|
||||
g_varpos.push_back (formulas_.size() - 1);
|
||||
} else {
|
||||
sharedVars = true;
|
||||
g_varpos.push_back (indexOfFormulaWithGroup (group));
|
||||
}
|
||||
}
|
||||
|
||||
if (sharedVars == false) {
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
if (Globals::logDomain) {
|
||||
params_[i] += g_params[count];
|
||||
} else {
|
||||
params_[i] *= g_params[count];
|
||||
}
|
||||
count ++;
|
||||
if (count >= g_params.size()) {
|
||||
count = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
StatesIndexer indexer (ranges_, false);
|
||||
while (indexer.valid()) {
|
||||
unsigned g_li = 0;
|
||||
unsigned prod = 1;
|
||||
for (int j = g_varpos.size() - 1; j >= 0; j--) {
|
||||
g_li += indexer[g_varpos[j]] * prod;
|
||||
prod *= g_ranges[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
params_[indexer] += g_params[g_li];
|
||||
} else {
|
||||
params_[indexer] *= g_params[g_li];
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
|
||||
TFactor<ProbFormula>::multiply (g);
|
||||
constr_->join (g.constr(), true);
|
||||
}
|
||||
|
||||
@ -236,7 +199,7 @@ Parfactor::multiply (Parfactor& g)
|
||||
void
|
||||
Parfactor::countConvert (LogVar X)
|
||||
{
|
||||
int fIdx = indexOfFormulaWithLogVar (X);
|
||||
int fIdx = indexOfLogVar (X);
|
||||
assert (fIdx != -1);
|
||||
assert (constr_->isCountNormalized (X));
|
||||
assert (constr_->getConditionalCount (X) > 1);
|
||||
@ -248,12 +211,12 @@ Parfactor::countConvert (LogVar X)
|
||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
|
||||
StatesIndexer indexer (ranges_);
|
||||
vector<Params> summout (params_.size() / R);
|
||||
vector<Params> sumout (params_.size() / R);
|
||||
unsigned count = 0;
|
||||
while (indexer.valid()) {
|
||||
summout[count].reserve (R);
|
||||
sumout[count].reserve (R);
|
||||
for (unsigned r = 0; r < R; r++) {
|
||||
summout[count].push_back (params_[indexer]);
|
||||
sumout[count].push_back (params_[indexer]);
|
||||
indexer.increment (fIdx);
|
||||
}
|
||||
count ++;
|
||||
@ -262,45 +225,42 @@ Parfactor::countConvert (LogVar X)
|
||||
}
|
||||
|
||||
params_.clear();
|
||||
params_.reserve (summout.size() * H);
|
||||
params_.reserve (sumout.size() * H);
|
||||
|
||||
vector<bool> mapDims (ranges_.size(), true);
|
||||
ranges_[fIdx] = H;
|
||||
mapDims[fIdx] = false;
|
||||
MapIndexer mapIndexer (ranges_, mapDims);
|
||||
MapIndexer mapIndexer (ranges_, fIdx);
|
||||
while (mapIndexer.valid()) {
|
||||
double prod = 1.0;
|
||||
double prod = LogAware::multIdenty();
|
||||
unsigned i = mapIndexer.mappedIndex();
|
||||
unsigned h = mapIndexer[fIdx];
|
||||
for (unsigned r = 0; r < R; r++) {
|
||||
// TODO not log domain aware
|
||||
prod *= Util::pow (summout[i][r], histograms[h][r]);
|
||||
if (Globals::logDomain) {
|
||||
prod += LogAware::pow (sumout[i][r], histograms[h][r]);
|
||||
} else {
|
||||
prod *= LogAware::pow (sumout[i][r], histograms[h][r]);
|
||||
}
|
||||
}
|
||||
params_.push_back (prod);
|
||||
++ mapIndexer;
|
||||
}
|
||||
formulas_[fIdx].setCountedLogVar (X);
|
||||
args_[fIdx].setCountedLogVar (X);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::expandPotential (
|
||||
LogVar X,
|
||||
LogVar X_new1,
|
||||
LogVar X_new2)
|
||||
Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
|
||||
{
|
||||
int fIdx = indexOfFormulaWithLogVar (X);
|
||||
int fIdx = indexOfLogVar (X);
|
||||
assert (fIdx != -1);
|
||||
assert (formulas_[fIdx].isCounting());
|
||||
assert (args_[fIdx].isCounting());
|
||||
|
||||
unsigned N1 = constr_->getConditionalCount (X_new1);
|
||||
unsigned N2 = constr_->getConditionalCount (X_new2);
|
||||
unsigned N = N1 + N2;
|
||||
unsigned R = formulas_[fIdx].range();
|
||||
unsigned R = args_[fIdx].range();
|
||||
unsigned H1 = HistogramSet::nrHistograms (N1, R);
|
||||
unsigned H2 = HistogramSet::nrHistograms (N2, R);
|
||||
unsigned H = ranges_[fIdx];
|
||||
|
||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
|
||||
@ -320,48 +280,11 @@ Parfactor::expandPotential (
|
||||
}
|
||||
}
|
||||
|
||||
unsigned size = (params_.size() / H) * H1 * H2;
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (size);
|
||||
expandPotential (fIdx, H1 * H2, sumIndexes);
|
||||
|
||||
unsigned prod = 1;
|
||||
vector<unsigned> offsets_ (ranges_.size());
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
}
|
||||
|
||||
unsigned index = 0;
|
||||
ranges_[fIdx] = H1 * H2;
|
||||
vector<unsigned> indices (ranges_.size(), 0);
|
||||
for (unsigned k = 0; k < size; k++) {
|
||||
params_.push_back (copy[index]);
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
indices[i] ++;
|
||||
if (i == fIdx) {
|
||||
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index += offsets_[i];
|
||||
}
|
||||
if (indices[i] != ranges_[i]) {
|
||||
break;
|
||||
} else {
|
||||
if (i == fIdx) {
|
||||
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
indices[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
formulas_.insert (formulas_.begin() + fIdx + 1, formulas_[fIdx]);
|
||||
formulas_[fIdx].rename (X, X_new1);
|
||||
formulas_[fIdx + 1].rename (X, X_new2);
|
||||
args_.insert (args_.begin() + fIdx + 1, args_[fIdx]);
|
||||
args_[fIdx].rename (X, X_new1);
|
||||
args_[fIdx + 1].rename (X, X_new2);
|
||||
ranges_.insert (ranges_.begin() + fIdx + 1, H2);
|
||||
ranges_[fIdx] = H1;
|
||||
}
|
||||
@ -371,13 +294,12 @@ Parfactor::expandPotential (
|
||||
void
|
||||
Parfactor::fullExpand (LogVar X)
|
||||
{
|
||||
int fIdx = indexOfFormulaWithLogVar (X);
|
||||
int fIdx = indexOfLogVar (X);
|
||||
assert (fIdx != -1);
|
||||
assert (formulas_[fIdx].isCounting());
|
||||
assert (args_[fIdx].isCounting());
|
||||
|
||||
unsigned N = constr_->getConditionalCount (X);
|
||||
unsigned R = formulas_[fIdx].range();
|
||||
unsigned H = ranges_[fIdx];
|
||||
unsigned R = args_[fIdx].range();
|
||||
|
||||
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
||||
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
||||
@ -400,54 +322,17 @@ Parfactor::fullExpand (LogVar X)
|
||||
++ indexer;
|
||||
}
|
||||
|
||||
unsigned size = (params_.size() / H) * std::pow (R, N);
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (size);
|
||||
expandPotential (fIdx, std::pow (R, N), sumIndexes);
|
||||
|
||||
unsigned prod = 1;
|
||||
vector<unsigned> offsets_ (ranges_.size());
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
}
|
||||
|
||||
unsigned index = 0;
|
||||
ranges_[fIdx] = std::pow (R, N);
|
||||
vector<unsigned> indices (ranges_.size(), 0);
|
||||
for (unsigned k = 0; k < size; k++) {
|
||||
params_.push_back (copy[index]);
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
indices[i] ++;
|
||||
if (i == fIdx) {
|
||||
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index += offsets_[i];
|
||||
}
|
||||
if (indices[i] != ranges_[i]) {
|
||||
break;
|
||||
} else {
|
||||
if (i == fIdx) {
|
||||
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
indices[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ProbFormula f = formulas_[fIdx];
|
||||
formulas_.erase (formulas_.begin() + fIdx);
|
||||
ProbFormula f = args_[fIdx];
|
||||
args_.erase (args_.begin() + fIdx);
|
||||
ranges_.erase (ranges_.begin() + fIdx);
|
||||
LogVars newLvs = constr_->expand (X);
|
||||
assert (newLvs.size() == N);
|
||||
for (unsigned i = 0 ; i < N; i++) {
|
||||
ProbFormula newFormula (f.functor(), f.logVars(), f.range());
|
||||
newFormula.rename (X, newLvs[i]);
|
||||
formulas_.insert (formulas_.begin() + fIdx + i, newFormula);
|
||||
args_.insert (args_.begin() + fIdx + i, newFormula);
|
||||
ranges_.insert (ranges_.begin() + fIdx + i, R);
|
||||
}
|
||||
}
|
||||
@ -459,117 +344,43 @@ Parfactor::reorderAccordingGrounds (const Grounds& grounds)
|
||||
{
|
||||
ProbFormulas newFormulas;
|
||||
for (unsigned i = 0; i < grounds.size(); i++) {
|
||||
for (unsigned j = 0; j < formulas_.size(); j++) {
|
||||
if (grounds[i].functor() == formulas_[j].functor() &&
|
||||
grounds[i].arity() == formulas_[j].arity()) {
|
||||
constr_->moveToTop (formulas_[j].logVars());
|
||||
for (unsigned j = 0; j < args_.size(); j++) {
|
||||
if (grounds[i].functor() == args_[j].functor() &&
|
||||
grounds[i].arity() == args_[j].arity()) {
|
||||
constr_->moveToTop (args_[j].logVars());
|
||||
if (constr_->containsTuple (grounds[i].args())) {
|
||||
newFormulas.push_back (formulas_[j]);
|
||||
newFormulas.push_back (args_[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
assert (newFormulas.size() == i + 1);
|
||||
}
|
||||
reorderFormulas (newFormulas);
|
||||
reorderArguments (newFormulas);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::reorderFormulas (const ProbFormulas& newFormulas)
|
||||
{
|
||||
assert (newFormulas.size() == formulas_.size());
|
||||
if (newFormulas == formulas_) {
|
||||
return;
|
||||
}
|
||||
|
||||
Ranges newRanges;
|
||||
vector<unsigned> positions;
|
||||
for (unsigned i = 0; i < newFormulas.size(); i++) {
|
||||
unsigned idx = indexOf (newFormulas[i]);
|
||||
newRanges.push_back (ranges_[idx]);
|
||||
positions.push_back (idx);
|
||||
}
|
||||
|
||||
unsigned N = ranges_.size();
|
||||
Params newParams (params_.size());
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
unsigned li = i;
|
||||
// calculate vector index corresponding to linear index
|
||||
vector<unsigned> vi (N);
|
||||
for (int k = N-1; k >= 0; k--) {
|
||||
vi[k] = li % ranges_[k];
|
||||
li /= ranges_[k];
|
||||
}
|
||||
// convert permuted vector index to corresponding linear index
|
||||
unsigned prod = 1;
|
||||
unsigned new_li = 0;
|
||||
for (int k = N - 1; k >= 0; k--) {
|
||||
new_li += vi[positions[k]] * prod;
|
||||
prod *= ranges_[positions[k]];
|
||||
}
|
||||
newParams[new_li] = params_[i];
|
||||
}
|
||||
formulas_ = newFormulas;
|
||||
ranges_ = newRanges;
|
||||
params_ = newParams;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::absorveEvidence (unsigned fIdx, unsigned evidence)
|
||||
Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
|
||||
{
|
||||
int fIdx = indexOf (formula);
|
||||
assert (fIdx != -1);
|
||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||
assert (fIdx < formulas_.size());
|
||||
assert (evidence < formulas_[fIdx].range());
|
||||
assert (formulas_[fIdx].isCounting() == false);
|
||||
assert (args_[fIdx].isCounting() == false);
|
||||
assert (constr_->isCountNormalized (excl));
|
||||
|
||||
Util::pow (params_, constr_->getConditionalCount (excl));
|
||||
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (copy.size() / formulas_[fIdx].range());
|
||||
|
||||
StatesIndexer indexer (ranges_);
|
||||
for (unsigned i = 0; i < evidence; i++) {
|
||||
indexer.increment (fIdx);
|
||||
}
|
||||
while (indexer.valid()) {
|
||||
params_.push_back (copy[indexer]);
|
||||
indexer.incrementExcluding (fIdx);
|
||||
}
|
||||
formulas_.erase (formulas_.begin() + fIdx);
|
||||
ranges_.erase (ranges_.begin() + fIdx);
|
||||
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
|
||||
constr_->remove (excl);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::normalize (void)
|
||||
{
|
||||
Util::normalize (params_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::setFormulaGroup (const ProbFormula& f, int group)
|
||||
{
|
||||
assert (indexOf (f) != -1);
|
||||
formulas_[indexOf (f)].setGroup (group);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::setNewGroups (void)
|
||||
{
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
formulas_[i].setGroup (ProbFormula::getNewGroup());
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
args_[i].setGroup (ProbFormula::getNewGroup());
|
||||
}
|
||||
}
|
||||
|
||||
@ -578,14 +389,14 @@ Parfactor::setNewGroups (void)
|
||||
void
|
||||
Parfactor::applySubstitution (const Substitution& theta)
|
||||
{
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
LogVars& lvs = formulas_[i].logVars();
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
LogVars& lvs = args_[i].logVars();
|
||||
for (unsigned j = 0; j < lvs.size(); j++) {
|
||||
lvs[j] = theta.newNameFor (lvs[j]);
|
||||
}
|
||||
if (formulas_[i].isCounting()) {
|
||||
LogVar clv = formulas_[i].countedLogVar();
|
||||
formulas_[i].setCountedLogVar (theta.newNameFor (clv));
|
||||
if (args_[i].isCounting()) {
|
||||
LogVar clv = args_[i].countedLogVar();
|
||||
args_[i].setCountedLogVar (theta.newNameFor (clv));
|
||||
}
|
||||
}
|
||||
constr_->applySubstitution (theta);
|
||||
@ -593,19 +404,29 @@ Parfactor::applySubstitution (const Substitution& theta)
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGround (const Ground& ground) const
|
||||
int
|
||||
Parfactor::findGroup (const Ground& ground) const
|
||||
{
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (formulas_[i].functor() == ground.functor() &&
|
||||
formulas_[i].arity() == ground.arity()) {
|
||||
constr_->moveToTop (formulas_[i].logVars());
|
||||
int group = -1;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].functor() == ground.functor() &&
|
||||
args_[i].arity() == ground.arity()) {
|
||||
constr_->moveToTop (args_[i].logVars());
|
||||
if (constr_->containsTuple (ground.args())) {
|
||||
return true;
|
||||
group = args_[i].group();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return group;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGround (const Ground& ground) const
|
||||
{
|
||||
return findGroup (ground) != -1;
|
||||
}
|
||||
|
||||
|
||||
@ -613,8 +434,8 @@ Parfactor::containsGround (const Ground& ground) const
|
||||
bool
|
||||
Parfactor::containsGroup (unsigned group) const
|
||||
{
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (formulas_[i].group() == group) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].group() == group) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -623,30 +444,12 @@ Parfactor::containsGroup (unsigned group) const
|
||||
|
||||
|
||||
|
||||
const ProbFormula&
|
||||
Parfactor::formula (unsigned fIdx) const
|
||||
{
|
||||
assert (fIdx < formulas_.size());
|
||||
return formulas_[fIdx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
Parfactor::range (unsigned fIdx) const
|
||||
{
|
||||
assert (fIdx < ranges_.size());
|
||||
return ranges_[fIdx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
Parfactor::nrFormulas (LogVar X) const
|
||||
{
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (formulas_[i].contains (X)) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].contains (X)) {
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
@ -656,27 +459,12 @@ Parfactor::nrFormulas (LogVar X) const
|
||||
|
||||
|
||||
int
|
||||
Parfactor::indexOf (const ProbFormula& f) const
|
||||
{
|
||||
int idx = -1;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (f == formulas_[i]) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
Parfactor::indexOfFormulaWithLogVar (LogVar X) const
|
||||
Parfactor::indexOfLogVar (LogVar X) const
|
||||
{
|
||||
int idx = -1;
|
||||
assert (nrFormulas (X) == 1);
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (formulas_[i].contains (X)) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].contains (X)) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
@ -687,11 +475,11 @@ Parfactor::indexOfFormulaWithLogVar (LogVar X) const
|
||||
|
||||
|
||||
int
|
||||
Parfactor::indexOfFormulaWithGroup (unsigned group) const
|
||||
Parfactor::indexOfGroup (unsigned group) const
|
||||
{
|
||||
int pos = -1;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (formulas_[i].group() == group) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].group() == group) {
|
||||
pos = i;
|
||||
break;
|
||||
}
|
||||
@ -704,9 +492,9 @@ Parfactor::indexOfFormulaWithGroup (unsigned group) const
|
||||
vector<unsigned>
|
||||
Parfactor::getAllGroups (void) const
|
||||
{
|
||||
vector<unsigned> groups (formulas_.size());
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
groups[i] = formulas_[i].group();
|
||||
vector<unsigned> groups (args_.size());
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
groups[i] = args_[i].group();
|
||||
}
|
||||
return groups;
|
||||
}
|
||||
@ -714,13 +502,13 @@ Parfactor::getAllGroups (void) const
|
||||
|
||||
|
||||
string
|
||||
Parfactor::getHeaderString (void) const
|
||||
Parfactor::getLabel (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "phi(" ;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
ss << formulas_[i];
|
||||
ss << args_[i];
|
||||
}
|
||||
ss << ")" ;
|
||||
ConstraintTree copy (*constr_);
|
||||
@ -735,32 +523,37 @@ void
|
||||
Parfactor::print (bool printParams) const
|
||||
{
|
||||
cout << "Formulas: " ;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) cout << ", " ;
|
||||
cout << formulas_[i];
|
||||
cout << args_[i];
|
||||
}
|
||||
cout << endl;
|
||||
vector<string> groups;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
groups.push_back (string ("g") + Util::toString (formulas_[i].group()));
|
||||
if (args_[0].group() != Util::maxUnsigned()) {
|
||||
vector<string> groups;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
groups.push_back (string ("g") + Util::toString (args_[i].group()));
|
||||
}
|
||||
cout << "Groups: " << groups << endl;
|
||||
}
|
||||
cout << "Groups: " << groups << endl;
|
||||
cout << "LogVars: " << constr_->logVars() << endl;
|
||||
cout << "LogVars: " << constr_->logVarSet() << endl;
|
||||
cout << "Ranges: " << ranges_ << endl;
|
||||
if (printParams == false) {
|
||||
cout << "Params: " << params_ << endl;
|
||||
}
|
||||
cout << "Tuples: " << constr_->tupleSet() << endl;
|
||||
ConstraintTree copy (*constr_);
|
||||
copy.moveToTop (copy.logVarSet().elements());
|
||||
cout << "Tuples: " << copy.tupleSet() << endl;
|
||||
if (printParams) {
|
||||
vector<string> jointStrings;
|
||||
StatesIndexer indexer (ranges_);
|
||||
while (indexer.valid()) {
|
||||
stringstream ss;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
if (formulas_[i].isCounting()) {
|
||||
unsigned N = constr_->getConditionalCount (formulas_[i].countedLogVar());
|
||||
HistogramSet hs (N, formulas_[i].range());
|
||||
if (args_[i].isCounting()) {
|
||||
unsigned N = constr_->getConditionalCount (
|
||||
args_[i].countedLogVar());
|
||||
HistogramSet hs (N, args_[i].range());
|
||||
unsigned c = 0;
|
||||
while (c < indexer[i]) {
|
||||
hs.nextHistogram();
|
||||
@ -779,22 +572,56 @@ Parfactor::print (bool printParams) const
|
||||
cout << " = " << params_[i] << endl;
|
||||
}
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::insertDimension (unsigned range)
|
||||
Parfactor::expandPotential (
|
||||
int fIdx,
|
||||
unsigned newRange,
|
||||
const vector<unsigned>& sumIndexes)
|
||||
{
|
||||
unsigned size = (params_.size() / ranges_[fIdx]) * newRange;
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (copy.size() * range);
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
for (unsigned reps = 0; reps < range; reps++) {
|
||||
params_.push_back (copy[i]);
|
||||
params_.reserve (size);
|
||||
|
||||
unsigned prod = 1;
|
||||
vector<unsigned> offsets_ (ranges_.size());
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
}
|
||||
|
||||
unsigned index = 0;
|
||||
ranges_[fIdx] = newRange;
|
||||
vector<unsigned> indices (ranges_.size(), 0);
|
||||
for (unsigned k = 0; k < size; k++) {
|
||||
params_.push_back (copy[index]);
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
indices[i] ++;
|
||||
if (i == fIdx) {
|
||||
assert (indices[i] - 1 < sumIndexes.size());
|
||||
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index += offsets_[i];
|
||||
}
|
||||
if (indices[i] != ranges_[i]) {
|
||||
break;
|
||||
} else {
|
||||
if (i == fIdx) {
|
||||
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
indices[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
ranges_.push_back (range);
|
||||
}
|
||||
|
||||
|
||||
@ -803,29 +630,27 @@ void
|
||||
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
||||
{
|
||||
LogVars X_1, X_2;
|
||||
const ProbFormulas& formulas1 = g1->formulas();
|
||||
const ProbFormulas& formulas2 = g2->formulas();
|
||||
const ProbFormulas& formulas1 = g1->arguments();
|
||||
const ProbFormulas& formulas2 = g2->arguments();
|
||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||
if (formulas1[i].group() == formulas2[j].group()) {
|
||||
X_1.insert (X_1.end(),
|
||||
formulas1[i].logVars().begin(),
|
||||
formulas1[i].logVars().end());
|
||||
X_2.insert (X_2.end(),
|
||||
formulas2[j].logVars().begin(),
|
||||
formulas2[j].logVars().end());
|
||||
Util::addToVector (X_1, formulas1[i].logVars());
|
||||
Util::addToVector (X_2, formulas2[j].logVars());
|
||||
}
|
||||
}
|
||||
}
|
||||
align (g1, X_1, g2, X_2);
|
||||
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
|
||||
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
|
||||
assert (g1->constr()->isCountNormalized (Y_1));
|
||||
assert (g2->constr()->isCountNormalized (Y_2));
|
||||
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
|
||||
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
||||
Util::pow (g1->params(), 1.0 / condCount2);
|
||||
Util::pow (g2->params(), 1.0 / condCount1);
|
||||
LogAware::pow (g1->params(), 1.0 / condCount2);
|
||||
LogAware::pow (g2->params(), 1.0 / condCount1);
|
||||
// this must be done in the end or else X_1 and X_2
|
||||
// will refer the old log var names in the code above
|
||||
align (g1, X_1, g2, X_2);
|
||||
}
|
||||
|
||||
|
||||
@ -838,7 +663,6 @@ Parfactor::align (
|
||||
LogVar freeLogVar = 0;
|
||||
Substitution theta1;
|
||||
Substitution theta2;
|
||||
|
||||
const LogVarSet& allLvs1 = g1->logVarSet();
|
||||
for (unsigned i = 0; i < allLvs1.size(); i++) {
|
||||
theta1.add (allLvs1[i], freeLogVar);
|
||||
@ -850,7 +674,7 @@ Parfactor::align (
|
||||
theta2.add (allLvs2[i], freeLogVar);
|
||||
++ freeLogVar;
|
||||
}
|
||||
|
||||
|
||||
assert (alignLvs1.size() == alignLvs2.size());
|
||||
for (unsigned i = 0; i < alignLvs1.size(); i++) {
|
||||
theta1.rename (alignLvs1[i], theta2.newNameFor (alignLvs2[i]));
|
||||
|
@ -9,8 +9,9 @@
|
||||
#include "LiftedUtils.h"
|
||||
#include "Horus.h"
|
||||
|
||||
#include "Factor.h"
|
||||
|
||||
class Parfactor
|
||||
class Parfactor : public TFactor<ProbFormula>
|
||||
{
|
||||
public:
|
||||
Parfactor (
|
||||
@ -18,27 +19,15 @@ class Parfactor
|
||||
const Params&,
|
||||
const Tuples&,
|
||||
unsigned);
|
||||
|
||||
Parfactor (const Parfactor*, const Tuple&);
|
||||
|
||||
Parfactor (const Parfactor*, ConstraintTree*);
|
||||
|
||||
Parfactor (const Parfactor&);
|
||||
|
||||
~Parfactor (void);
|
||||
|
||||
ProbFormulas& formulas (void) { return formulas_; }
|
||||
|
||||
const ProbFormulas& formulas (void) const { return formulas_; }
|
||||
|
||||
unsigned nrFormulas (void) const { return formulas_.size(); }
|
||||
|
||||
Params& params (void) { return params_; }
|
||||
|
||||
const Params& params (void) const { return params_; }
|
||||
|
||||
unsigned size (void) const { return params_.size(); }
|
||||
|
||||
const Ranges& ranges (void) const { return ranges_; }
|
||||
|
||||
unsigned distId (void) const { return distId_; }
|
||||
|
||||
ConstraintTree* constr (void) { return constr_; }
|
||||
|
||||
const ConstraintTree* constr (void) const { return constr_; }
|
||||
@ -57,64 +46,52 @@ class Parfactor
|
||||
|
||||
void setConstraintTree (ConstraintTree*);
|
||||
|
||||
void sumOut (unsigned);
|
||||
void sumOut (unsigned fIdx);
|
||||
|
||||
void multiply (Parfactor&);
|
||||
|
||||
void countConvert (LogVar);
|
||||
|
||||
void expandPotential (LogVar, LogVar, LogVar);
|
||||
void expand (LogVar, LogVar, LogVar);
|
||||
|
||||
void fullExpand (LogVar);
|
||||
|
||||
void reorderAccordingGrounds (const Grounds&);
|
||||
|
||||
void reorderFormulas (const ProbFormulas&);
|
||||
|
||||
void absorveEvidence (unsigned, unsigned);
|
||||
|
||||
void normalize (void);
|
||||
|
||||
void setFormulaGroup (const ProbFormula&, int);
|
||||
void absorveEvidence (const ProbFormula&, unsigned);
|
||||
|
||||
void setNewGroups (void);
|
||||
|
||||
void applySubstitution (const Substitution&);
|
||||
|
||||
int findGroup (const Ground&) const;
|
||||
|
||||
bool containsGround (const Ground&) const;
|
||||
|
||||
bool containsGroup (unsigned) const;
|
||||
|
||||
const ProbFormula& formula (unsigned) const;
|
||||
|
||||
unsigned range (unsigned) const;
|
||||
|
||||
|
||||
unsigned nrFormulas (LogVar) const;
|
||||
|
||||
int indexOf (const ProbFormula&) const;
|
||||
int indexOfLogVar (LogVar) const;
|
||||
|
||||
int indexOfFormulaWithLogVar (LogVar) const;
|
||||
|
||||
int indexOfFormulaWithGroup (unsigned) const;
|
||||
int indexOfGroup (unsigned) const;
|
||||
|
||||
vector<unsigned> getAllGroups (void) const;
|
||||
|
||||
void print (bool = false) const;
|
||||
|
||||
string getHeaderString (void) const;
|
||||
string getLabel (void) const;
|
||||
|
||||
private:
|
||||
void expandPotential (int fIdx, unsigned newRange,
|
||||
const vector<unsigned>& sumIndexes);
|
||||
|
||||
static void alignAndExponentiate (Parfactor*, Parfactor*);
|
||||
|
||||
static void align (
|
||||
Parfactor*, const LogVars&, Parfactor*, const LogVars&);
|
||||
|
||||
void insertDimension (unsigned);
|
||||
|
||||
ProbFormulas formulas_;
|
||||
Ranges ranges_;
|
||||
Params params_;
|
||||
unsigned distId_;
|
||||
ConstraintTree* constr_;
|
||||
ConstraintTree* constr_;
|
||||
};
|
||||
|
||||
|
||||
|
@ -3,9 +3,32 @@
|
||||
#include "ParfactorList.h"
|
||||
|
||||
|
||||
ParfactorList::ParfactorList (Parfactors& pfs)
|
||||
ParfactorList::ParfactorList (const ParfactorList& pfList)
|
||||
{
|
||||
pfList_.insert (pfList_.end(), pfs.begin(), pfs.end());
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
addShattered (new Parfactor (**it));
|
||||
++ it;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParfactorList::ParfactorList (const Parfactors& pfs)
|
||||
{
|
||||
add (pfs);
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParfactorList::~ParfactorList (void)
|
||||
{
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
delete *it;
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -14,17 +37,17 @@ void
|
||||
ParfactorList::add (Parfactor* pf)
|
||||
{
|
||||
pf->setNewGroups();
|
||||
pfList_.push_back (pf);
|
||||
addToShatteredList (pf);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::add (Parfactors& pfs)
|
||||
ParfactorList::add (const Parfactors& pfs)
|
||||
{
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
pfs[i]->setNewGroups();
|
||||
pfList_.push_back (pfs[i]);
|
||||
addToShatteredList (pfs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -33,7 +56,20 @@ ParfactorList::add (Parfactors& pfs)
|
||||
void
|
||||
ParfactorList::addShattered (Parfactor* pf)
|
||||
{
|
||||
assert (isAllShattered());
|
||||
pfList_.push_back (pf);
|
||||
assert (isAllShattered());
|
||||
}
|
||||
|
||||
|
||||
|
||||
list<Parfactor*>::iterator
|
||||
ParfactorList::insertShattered (
|
||||
list<Parfactor*>::iterator it,
|
||||
Parfactor* pf)
|
||||
{
|
||||
return pfList_.insert (it, pf);
|
||||
assert (isAllShattered());
|
||||
}
|
||||
|
||||
|
||||
@ -47,7 +83,7 @@ ParfactorList::remove (list<Parfactor*>::iterator it)
|
||||
|
||||
|
||||
list<Parfactor*>::iterator
|
||||
ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it)
|
||||
ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
|
||||
{
|
||||
delete *it;
|
||||
return pfList_.erase (it);
|
||||
@ -55,58 +91,21 @@ ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it)
|
||||
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::shatter (void)
|
||||
bool
|
||||
ParfactorList::isAllShattered (void) const
|
||||
{
|
||||
list<Parfactor*> tempList;
|
||||
Parfactors newPfs;
|
||||
newPfs.insert (newPfs.end(), pfList_.begin(), pfList_.end());
|
||||
while (newPfs.empty() == false) {
|
||||
tempList.insert (tempList.end(), newPfs.begin(), newPfs.end());
|
||||
newPfs.clear();
|
||||
list<Parfactor*>::iterator iter1 = tempList.begin();
|
||||
while (tempList.size() > 1 && iter1 != -- tempList.end()) {
|
||||
list<Parfactor*>::iterator iter2 = iter1;
|
||||
++ iter2;
|
||||
bool incIter1 = true;
|
||||
while (iter2 != tempList.end()) {
|
||||
assert (iter1 != iter2);
|
||||
std::pair<Parfactors, Parfactors> res = shatter (
|
||||
(*iter1)->formulas(), *iter1, (*iter2)->formulas(), *iter2);
|
||||
bool incIter2 = true;
|
||||
if (res.second.empty() == false) {
|
||||
// cout << "second unshattered" << endl;
|
||||
delete *iter2;
|
||||
iter2 = tempList.erase (iter2);
|
||||
incIter2 = false;
|
||||
newPfs.insert (
|
||||
newPfs.begin(), res.second.begin(), res.second.end());
|
||||
}
|
||||
if (res.first.empty() == false) {
|
||||
// cout << "first unshattered" << endl;
|
||||
delete *iter1;
|
||||
iter1 = tempList.erase (iter1);
|
||||
newPfs.insert (
|
||||
newPfs.begin(), res.first.begin(), res.first.end());
|
||||
incIter1 = false;
|
||||
break;
|
||||
}
|
||||
if (incIter2) {
|
||||
++ iter2;
|
||||
}
|
||||
}
|
||||
if (incIter1) {
|
||||
++ iter1;
|
||||
if (pfList_.size() <= 1) {
|
||||
return true;
|
||||
}
|
||||
vector<Parfactor*> pfs (pfList_.begin(), pfList_.end());
|
||||
for (unsigned i = 0; i < pfs.size() - 1; i++) {
|
||||
for (unsigned j = i + 1; j < pfs.size(); j++) {
|
||||
if (isShattered (pfs[i], pfs[j]) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
|
||||
// cout << "||||||||||||| SHATTERING ITERATION ||||||||||||||" << endl;
|
||||
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
|
||||
// printParfactors (newPfs);
|
||||
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
|
||||
}
|
||||
pfList_.clear();
|
||||
pfList_.insert (pfList_.end(), tempList.begin(), tempList.end());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -117,25 +116,88 @@ ParfactorList::print (void) const
|
||||
list<Parfactor*>::const_iterator it;
|
||||
for (it = pfList_.begin(); it != pfList_.end(); ++it) {
|
||||
(*it)->print();
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::pair<Parfactors, Parfactors>
|
||||
ParfactorList::shatter (
|
||||
ProbFormulas& formulas1,
|
||||
Parfactor* g1,
|
||||
ProbFormulas& formulas2,
|
||||
Parfactor* g2)
|
||||
bool
|
||||
ParfactorList::isShattered (
|
||||
const Parfactor* g1,
|
||||
const Parfactor* g2) const
|
||||
{
|
||||
assert (g1 != g2);
|
||||
const ProbFormulas& fms1 = g1->arguments();
|
||||
const ProbFormulas& fms2 = g2->arguments();
|
||||
for (unsigned i = 0; i < fms1.size(); i++) {
|
||||
for (unsigned j = 0; j < fms2.size(); j++) {
|
||||
if (fms1[i].group() == fms2[j].group()) {
|
||||
if (identical (
|
||||
fms1[i], *(g1->constr()),
|
||||
fms2[j], *(g2->constr())) == false) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (disjoint (
|
||||
fms1[i], *(g1->constr()),
|
||||
fms2[j], *(g2->constr())) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::addToShatteredList (Parfactor* g)
|
||||
{
|
||||
queue<Parfactor*> residuals;
|
||||
residuals.push (g);
|
||||
while (residuals.empty() == false) {
|
||||
Parfactor* pf = residuals.front();
|
||||
bool pfSplitted = false;
|
||||
list<Parfactor*>::iterator pfIter;
|
||||
pfIter = pfList_.begin();
|
||||
while (pfIter != pfList_.end()) {
|
||||
std::pair<Parfactors, Parfactors> shattRes;
|
||||
shattRes = shatter (*pfIter, pf);
|
||||
if (shattRes.first.empty() == false) {
|
||||
pfIter = removeAndDelete (pfIter);
|
||||
Util::addToQueue (residuals, shattRes.first);
|
||||
} else {
|
||||
++ pfIter;
|
||||
}
|
||||
if (shattRes.second.empty() == false) {
|
||||
delete pf;
|
||||
Util::addToQueue (residuals, shattRes.second);
|
||||
pfSplitted = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
residuals.pop();
|
||||
if (pfSplitted == false) {
|
||||
addShattered (pf);
|
||||
}
|
||||
}
|
||||
assert (isAllShattered());
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::pair<Parfactors, Parfactors>
|
||||
ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
|
||||
{
|
||||
ProbFormulas& formulas1 = g1->arguments();
|
||||
ProbFormulas& formulas2 = g2->arguments();
|
||||
assert (g1 != 0 && g2 != 0 && g1 != g2);
|
||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||
if (formulas1[i].sameSkeletonAs (formulas2[j])) {
|
||||
std::pair<Parfactors, Parfactors> res
|
||||
= shatter (formulas1[i], g1, formulas2[j], g2);
|
||||
std::pair<Parfactors, Parfactors> res;
|
||||
res = shatter (i, g1, j, g2);
|
||||
if (res.first.empty() == false ||
|
||||
res.second.empty() == false) {
|
||||
return res;
|
||||
@ -150,21 +212,22 @@ ParfactorList::shatter (
|
||||
|
||||
std::pair<Parfactors, Parfactors>
|
||||
ParfactorList::shatter (
|
||||
ProbFormula& f1,
|
||||
Parfactor* g1,
|
||||
ProbFormula& f2,
|
||||
Parfactor* g2)
|
||||
unsigned fIdx1, Parfactor* g1,
|
||||
unsigned fIdx2, Parfactor* g2)
|
||||
{
|
||||
ProbFormula& f1 = g1->argument (fIdx1);
|
||||
ProbFormula& f2 = g2->argument (fIdx2);
|
||||
// cout << endl;
|
||||
// cout << "-------------------------------------------------" << endl;
|
||||
// Util::printDashedLine();
|
||||
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
|
||||
// g1->print();
|
||||
// cout << "-> WITH" << endl;
|
||||
// g2->print();
|
||||
// cout << "-> ON: " << f1.toString (g1->constr()) << endl;
|
||||
// cout << "-> ON: " << f2.toString (g2->constr()) << endl;
|
||||
// cout << "-------------------------------------------------" << endl;
|
||||
|
||||
// cout << "-> ON: " << f1 << "|" ;
|
||||
// cout << g1->constr()->tupleSet (f1.logVars()) << endl;
|
||||
// cout << "-> ON: " << f2 << "|" ;
|
||||
// cout << g2->constr()->tupleSet (f2.logVars()) << endl;
|
||||
// Util::printDashedLine();
|
||||
if (f1.isAtom()) {
|
||||
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
|
||||
f1.setGroup (group);
|
||||
@ -174,7 +237,7 @@ ParfactorList::shatter (
|
||||
assert (g1->constr()->empty() == false);
|
||||
assert (g2->constr()->empty() == false);
|
||||
if (f1.group() == f2.group()) {
|
||||
// assert (identical (f1, g1->constr(), f2, g2->constr()));
|
||||
assert (identical (f1, *(g1->constr()), f2, *(g2->constr())));
|
||||
return { };
|
||||
}
|
||||
|
||||
@ -201,21 +264,24 @@ ParfactorList::shatter (
|
||||
assert (commCt1->tupleSet (f1.arity()) ==
|
||||
commCt2->tupleSet (f2.arity()));
|
||||
|
||||
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
|
||||
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
|
||||
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
|
||||
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
|
||||
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
|
||||
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
|
||||
// ct1->exportToGraphViz (ss1.str().c_str(), true);
|
||||
// ct2->exportToGraphViz (ss2.str().c_str(), true);
|
||||
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
|
||||
// exclCt1->exportToGraphViz (ss4.str().c_str(), true);
|
||||
// commCt2->exportToGraphViz (ss5.str().c_str(), true);
|
||||
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
|
||||
// unsigned static count = 0; count ++;
|
||||
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
|
||||
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
|
||||
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
|
||||
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
|
||||
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
|
||||
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
|
||||
// g1->constr()->exportToGraphViz (ss1.str().c_str(), true);
|
||||
// g2->constr()->exportToGraphViz (ss2.str().c_str(), true);
|
||||
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
|
||||
// exclCt1->exportToGraphViz (ss4.str().c_str(), true);
|
||||
// commCt2->exportToGraphViz (ss5.str().c_str(), true);
|
||||
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
|
||||
|
||||
if (exclCt1->empty() && exclCt2->empty()) {
|
||||
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
|
||||
unsigned group = (f1.group() < f2.group())
|
||||
? f1.group()
|
||||
: f2.group();
|
||||
// identical
|
||||
f1.setGroup (group);
|
||||
f2.setGroup (group);
|
||||
@ -235,8 +301,8 @@ ParfactorList::shatter (
|
||||
} else {
|
||||
group = ProbFormula::getNewGroup();
|
||||
}
|
||||
Parfactors res1 = shatter (g1, f1, commCt1, exclCt1, group);
|
||||
Parfactors res2 = shatter (g2, f2, commCt2, exclCt2, group);
|
||||
Parfactors res1 = shatter (g1, fIdx1, commCt1, exclCt1, group);
|
||||
Parfactors res2 = shatter (g2, fIdx2, commCt2, exclCt2, group);
|
||||
return make_pair (res1, res2);
|
||||
}
|
||||
|
||||
@ -245,11 +311,19 @@ ParfactorList::shatter (
|
||||
Parfactors
|
||||
ParfactorList::shatter (
|
||||
Parfactor* g,
|
||||
const ProbFormula& f,
|
||||
unsigned fIdx,
|
||||
ConstraintTree* commCt,
|
||||
ConstraintTree* exclCt,
|
||||
unsigned commGroup)
|
||||
{
|
||||
ProbFormula& f = g->argument (fIdx);
|
||||
if (exclCt->empty()) {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
f.setGroup (commGroup);
|
||||
return { };
|
||||
}
|
||||
|
||||
Parfactors result;
|
||||
if (f.isCounting()) {
|
||||
LogVar X_new1 = g->constr()->logVarSet().back() + 1;
|
||||
@ -259,7 +333,7 @@ ParfactorList::shatter (
|
||||
for (unsigned i = 0; i < cts.size(); i++) {
|
||||
Parfactor* newPf = new Parfactor (g, cts[i]);
|
||||
if (cts[i]->nrLogVars() == g->constr()->nrLogVars() + 1) {
|
||||
newPf->expandPotential (f.countedLogVar(), X_new1, X_new2);
|
||||
newPf->expand (f.countedLogVar(), X_new1, X_new2);
|
||||
assert (g->constr()->getConditionalCount (f.countedLogVar()) ==
|
||||
cts[i]->getConditionalCount (X_new1) +
|
||||
cts[i]->getConditionalCount (X_new2));
|
||||
@ -270,20 +344,16 @@ ParfactorList::shatter (
|
||||
newPf->setNewGroups();
|
||||
result.push_back (newPf);
|
||||
}
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
} else {
|
||||
if (exclCt->empty()) {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
g->setFormulaGroup (f, commGroup);
|
||||
} else {
|
||||
Parfactor* newPf = new Parfactor (g, commCt);
|
||||
newPf->setNewGroups();
|
||||
newPf->setFormulaGroup (f, commGroup);
|
||||
result.push_back (newPf);
|
||||
newPf = new Parfactor (g, exclCt);
|
||||
newPf->setNewGroups();
|
||||
result.push_back (newPf);
|
||||
}
|
||||
Parfactor* newPf = new Parfactor (g, commCt);
|
||||
newPf->setNewGroups();
|
||||
newPf->argument (fIdx).setGroup (commGroup);
|
||||
result.push_back (newPf);
|
||||
newPf = new Parfactor (g, exclCt);
|
||||
newPf->setNewGroups();
|
||||
result.push_back (newPf);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -296,7 +366,7 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
|
||||
unsigned newGroup = ProbFormula::getNewGroup();
|
||||
for (ParfactorList::iterator it = pfList_.begin();
|
||||
it != pfList_.end(); it++) {
|
||||
ProbFormulas& formulas = (*it)->formulas();
|
||||
ProbFormulas& formulas = (*it)->arguments();
|
||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||
if (formulas[i].group() == group1 ||
|
||||
formulas[i].group() == group2) {
|
||||
@ -306,3 +376,52 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::proper (
|
||||
const ProbFormula& f1, ConstraintTree c1,
|
||||
const ProbFormula& f2, ConstraintTree c2) const
|
||||
{
|
||||
return disjoint (f1, c1, f2, c2)
|
||||
|| identical (f1, c1, f2, c2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::identical (
|
||||
const ProbFormula& f1, ConstraintTree c1,
|
||||
const ProbFormula& f2, ConstraintTree c2) const
|
||||
{
|
||||
if (f1.sameSkeletonAs (f2) == false) {
|
||||
return false;
|
||||
}
|
||||
if (f1.isAtom()) {
|
||||
return true;
|
||||
}
|
||||
c1.moveToTop (f1.logVars());
|
||||
c2.moveToTop (f2.logVars());
|
||||
return ConstraintTree::identical (
|
||||
&c1, &c2, f1.logVars().size());
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::disjoint (
|
||||
const ProbFormula& f1, ConstraintTree c1,
|
||||
const ProbFormula& f2, ConstraintTree c2) const
|
||||
{
|
||||
if (f1.sameSkeletonAs (f2) == false) {
|
||||
return true;
|
||||
}
|
||||
if (f1.isAtom()) {
|
||||
return true;
|
||||
}
|
||||
c1.moveToTop (f1.logVars());
|
||||
c2.moveToTop (f2.logVars());
|
||||
return ConstraintTree::overlap (
|
||||
&c1, &c2, f1.arity()) == false;
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
#define HORUS_PARFACTORLIST_H
|
||||
|
||||
#include <list>
|
||||
#include <queue>
|
||||
|
||||
#include "Parfactor.h"
|
||||
#include "ProbFormula.h"
|
||||
@ -14,56 +15,82 @@ class ParfactorList
|
||||
{
|
||||
public:
|
||||
ParfactorList (void) { }
|
||||
ParfactorList (Parfactors&);
|
||||
list<Parfactor*>& getParfactors (void) { return pfList_; }
|
||||
const list<Parfactor*>& getParfactors (void) const { return pfList_; }
|
||||
|
||||
void add (Parfactor* pf);
|
||||
void add (Parfactors& pfs);
|
||||
void addShattered (Parfactor* pf);
|
||||
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
|
||||
list<Parfactor*>::iterator deleteAndRemove (list<Parfactor*>::iterator);
|
||||
ParfactorList (const ParfactorList&);
|
||||
|
||||
void clear (void) { pfList_.clear(); }
|
||||
unsigned size (void) const { return pfList_.size(); }
|
||||
|
||||
ParfactorList (const Parfactors&);
|
||||
|
||||
void shatter (void);
|
||||
~ParfactorList (void);
|
||||
|
||||
const list<Parfactor*>& parfactors (void) const { return pfList_; }
|
||||
|
||||
void clear (void) { pfList_.clear(); }
|
||||
|
||||
unsigned size (void) const { return pfList_.size(); }
|
||||
|
||||
typedef std::list<Parfactor*>::iterator iterator;
|
||||
|
||||
iterator begin (void) { return pfList_.begin(); }
|
||||
iterator end (void) { return pfList_.end(); }
|
||||
|
||||
iterator end (void) { return pfList_.end(); }
|
||||
|
||||
typedef std::list<Parfactor*>::const_iterator const_iterator;
|
||||
|
||||
const_iterator begin (void) const { return pfList_.begin(); }
|
||||
const_iterator end (void) const { return pfList_.end(); }
|
||||
|
||||
const_iterator end (void) const { return pfList_.end(); }
|
||||
|
||||
void add (Parfactor* pf);
|
||||
|
||||
void add (const Parfactors& pfs);
|
||||
|
||||
void addShattered (Parfactor* pf);
|
||||
|
||||
list<Parfactor*>::iterator insertShattered (
|
||||
list<Parfactor*>::iterator, Parfactor*);
|
||||
|
||||
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
|
||||
|
||||
list<Parfactor*>::iterator removeAndDelete (list<Parfactor*>::iterator);
|
||||
|
||||
bool isAllShattered (void) const;
|
||||
|
||||
void print (void) const;
|
||||
|
||||
private:
|
||||
|
||||
bool isShattered (const Parfactor*, const Parfactor*) const;
|
||||
|
||||
static std::pair<Parfactors, Parfactors> shatter (
|
||||
ProbFormulas&,
|
||||
Parfactor*,
|
||||
ProbFormulas&,
|
||||
Parfactor*);
|
||||
void addToShatteredList (Parfactor*);
|
||||
|
||||
std::pair<Parfactors, Parfactors> shatter (
|
||||
Parfactor*, Parfactor*);
|
||||
|
||||
static std::pair<Parfactors, Parfactors> shatter (
|
||||
ProbFormula&,
|
||||
Parfactor*,
|
||||
ProbFormula&,
|
||||
Parfactor*);
|
||||
std::pair<Parfactors, Parfactors> shatter (
|
||||
unsigned, Parfactor*, unsigned, Parfactor*);
|
||||
|
||||
static Parfactors shatter (
|
||||
Parfactor*,
|
||||
const ProbFormula&,
|
||||
ConstraintTree*,
|
||||
ConstraintTree*,
|
||||
unsigned);
|
||||
Parfactors shatter (
|
||||
Parfactor*,
|
||||
unsigned,
|
||||
ConstraintTree*,
|
||||
ConstraintTree*,
|
||||
unsigned);
|
||||
|
||||
void unifyGroups (unsigned group1, unsigned group2);
|
||||
void unifyGroups (unsigned group1, unsigned group2);
|
||||
|
||||
list<Parfactor*> pfList_;
|
||||
bool proper (
|
||||
const ProbFormula&, ConstraintTree,
|
||||
const ProbFormula&, ConstraintTree) const;
|
||||
|
||||
bool identical (
|
||||
const ProbFormula&, ConstraintTree,
|
||||
const ProbFormula&, ConstraintTree) const;
|
||||
|
||||
bool disjoint (
|
||||
const ProbFormula&, ConstraintTree,
|
||||
const ProbFormula&, ConstraintTree) const;
|
||||
|
||||
list<Parfactor*> pfList_;
|
||||
};
|
||||
|
||||
#endif // HORUS_PARFACTORLIST_H
|
||||
|
@ -16,8 +16,7 @@ ProbFormula::sameSkeletonAs (const ProbFormula& f) const
|
||||
bool
|
||||
ProbFormula::contains (LogVar lv) const
|
||||
{
|
||||
return std::find (logVars_.begin(), logVars_.end(), lv) !=
|
||||
logVars_.end();
|
||||
return Util::contains (logVars_, lv);
|
||||
}
|
||||
|
||||
|
||||
@ -77,16 +76,15 @@ ProbFormula::rename (LogVar oldName, LogVar newName)
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ProbFormula::operator== (const ProbFormula& f) const
|
||||
bool operator== (const ProbFormula& f1, const ProbFormula& f2)
|
||||
{
|
||||
return functor_ == f.functor_ && logVars_ == f.logVars_ ;
|
||||
return f1.group_ == f2.group_;
|
||||
//return functor_ == f.functor_ && logVars_ == f.logVars_ ;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const ProbFormula& f)
|
||||
std::ostream& operator<< (ostream &os, const ProbFormula& f)
|
||||
{
|
||||
os << f.functor_;
|
||||
if (f.isAtom() == false) {
|
||||
@ -113,3 +111,13 @@ ProbFormula::getNewGroup (void)
|
||||
return freeGroup_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const ObservedFormula& of)
|
||||
{
|
||||
os << of.functor_ << "/" << of.arity_;
|
||||
os << "|" << of.constr_.tupleSet();
|
||||
os << " [evidence=" << of.evidence_ << "]";
|
||||
return os;
|
||||
}
|
||||
|
||||
|
@ -8,14 +8,16 @@
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
|
||||
class ProbFormula
|
||||
{
|
||||
public:
|
||||
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
|
||||
: functor_(f), logVars_(lvs), range_(range),
|
||||
countedLogVar_() { }
|
||||
countedLogVar_(), group_(Util::maxUnsigned()) { }
|
||||
|
||||
ProbFormula (Symbol f, unsigned r) : functor_(f), range_(r) { }
|
||||
ProbFormula (Symbol f, unsigned r)
|
||||
: functor_(f), range_(r), group_(Util::maxUnsigned()) { }
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
|
||||
@ -29,9 +31,9 @@ class ProbFormula
|
||||
|
||||
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||
|
||||
unsigned group (void) const { return groupId_; }
|
||||
unsigned group (void) const { return group_; }
|
||||
|
||||
void setGroup (unsigned g) { groupId_ = g; }
|
||||
void setGroup (unsigned g) { group_ = g; }
|
||||
|
||||
bool sameSkeletonAs (const ProbFormula&) const;
|
||||
|
||||
@ -49,23 +51,58 @@ class ProbFormula
|
||||
|
||||
void rename (LogVar, LogVar);
|
||||
|
||||
bool operator== (const ProbFormula& f) const;
|
||||
|
||||
friend ostream& operator<< (ostream &out, const ProbFormula& f);
|
||||
|
||||
static unsigned getNewGroup (void);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
|
||||
|
||||
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
|
||||
|
||||
private:
|
||||
Symbol functor_;
|
||||
LogVars logVars_;
|
||||
unsigned range_;
|
||||
LogVar countedLogVar_;
|
||||
unsigned groupId_;
|
||||
static int freeGroup_;
|
||||
Symbol functor_;
|
||||
LogVars logVars_;
|
||||
unsigned range_;
|
||||
LogVar countedLogVar_;
|
||||
unsigned group_;
|
||||
static int freeGroup_;
|
||||
};
|
||||
|
||||
typedef vector<ProbFormula> ProbFormulas;
|
||||
|
||||
|
||||
class ObservedFormula
|
||||
{
|
||||
public:
|
||||
ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
||||
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
|
||||
|
||||
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
||||
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
|
||||
{
|
||||
constr_.addTuple (tuple);
|
||||
}
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
|
||||
unsigned arity (void) const { return arity_; }
|
||||
|
||||
unsigned evidence (void) const { return evidence_; }
|
||||
|
||||
ConstraintTree& constr (void) { return constr_; }
|
||||
|
||||
bool isAtom (void) const { return arity_ == 0; }
|
||||
|
||||
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
|
||||
|
||||
private:
|
||||
Symbol functor_;
|
||||
unsigned arity_;
|
||||
unsigned evidence_;
|
||||
ConstraintTree constr_;
|
||||
};
|
||||
|
||||
typedef vector<ObservedFormula> ObservedFormulas;
|
||||
|
||||
#endif // HORUS_PROBFORMULA_H
|
||||
|
||||
|
@ -3,51 +3,35 @@
|
||||
|
||||
|
||||
void
|
||||
Solver::printAllPosterioris (void)
|
||||
Solver::printAnswer (const VarIds& vids)
|
||||
{
|
||||
const VarNodes& vars = gm_->getVariableNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
printPosterioriOf (vars[i]->varId());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Solver::printPosterioriOf (VarId vid)
|
||||
{
|
||||
VarNode* var = gm_->getVariableNode (vid);
|
||||
const Params& posterioriDist = getPosterioriOf (vid);
|
||||
const States& states = var->states();
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
cout << "P(" << var->label() << "=" << states[i] << ") = " ;
|
||||
cout << setprecision (PRECISION) << posterioriDist[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Solver::printJointDistributionOf (const VarIds& vids)
|
||||
{
|
||||
VarNodes vars;
|
||||
VarIds vidsWithoutEvidence;
|
||||
Vars unobservedVars;
|
||||
VarIds unobservedVids;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
VarNode* var = gm_->getVariableNode (vids[i]);
|
||||
if (var->hasEvidence() == false) {
|
||||
vars.push_back (var);
|
||||
vidsWithoutEvidence.push_back (vids[i]);
|
||||
VarNode* vn = fg.getVarNode (vids[i]);
|
||||
if (vn->hasEvidence() == false) {
|
||||
unobservedVars.push_back (vn);
|
||||
unobservedVids.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
const Params& jointDist = getJointDistributionOf (vidsWithoutEvidence);
|
||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
||||
for (unsigned i = 0; i < jointDist.size(); i++) {
|
||||
cout << "P(" << jointStrings[i] << ") = " ;
|
||||
cout << setprecision (PRECISION) << jointDist[i];
|
||||
Params res = solveQuery (unobservedVids);
|
||||
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
||||
for (unsigned i = 0; i < res.size(); i++) {
|
||||
cout << "P(" << stateLines[i] << ") = " ;
|
||||
cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Solver::printAllPosterioris (void)
|
||||
{
|
||||
const VarNodes& vars = fg.varNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
printAnswer ({vars[i]->varId()});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,29 +3,27 @@
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
#include "GraphicalModel.h"
|
||||
#include "VarNode.h"
|
||||
#include "Var.h"
|
||||
#include "FactorGraph.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
class Solver
|
||||
{
|
||||
public:
|
||||
Solver (const GraphicalModel* gm)
|
||||
{
|
||||
gm_ = gm;
|
||||
}
|
||||
virtual ~Solver() {} // to ensure that subclass destructor is called
|
||||
virtual void runSolver (void) = 0;
|
||||
virtual Params getPosterioriOf (VarId) = 0;
|
||||
virtual Params getJointDistributionOf (const VarIds&) = 0;
|
||||
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
|
||||
|
||||
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||
|
||||
virtual Params solveQuery (VarIds queryVids) = 0;
|
||||
|
||||
void printAnswer (const VarIds& vids);
|
||||
|
||||
void printAllPosterioris (void);
|
||||
void printPosterioriOf (VarId vid);
|
||||
void printJointDistributionOf (const VarIds& vids);
|
||||
|
||||
private:
|
||||
const GraphicalModel* gm_;
|
||||
protected:
|
||||
const FactorGraph& fg;
|
||||
};
|
||||
|
||||
#endif // HORUS_SOLVER_H
|
||||
|
4
packages/CLPBN/clpbn/bp/TODO
Normal file
4
packages/CLPBN/clpbn/bp/TODO
Normal file
@ -0,0 +1,4 @@
|
||||
TODO
|
||||
- add a way to calculate combinations and factorials with large numbers
|
||||
- refactor sumOut in parfactor -> is really ugly code
|
||||
- Indexer: start receiving ranges as constant reference
|
@ -1,21 +1,22 @@
|
||||
#include <limits>
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
|
||||
#include "Util.h"
|
||||
#include "Indexer.h"
|
||||
#include "GraphicalModel.h"
|
||||
|
||||
|
||||
namespace Globals {
|
||||
bool logDomain = false;
|
||||
bool logDomain = false;
|
||||
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::VE;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
|
||||
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
|
||||
};
|
||||
|
||||
|
||||
namespace InfAlgorithms {
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::VE;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
|
||||
InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::CBP;
|
||||
}
|
||||
|
||||
|
||||
namespace BpOptions {
|
||||
@ -23,13 +24,11 @@ Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
|
||||
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
|
||||
//Schedule schedule = BpOptions::Schedule::PARALLEL;
|
||||
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
||||
double accuracy = 0.0001;
|
||||
unsigned maxIter = 1000;
|
||||
double accuracy = 0.0001;
|
||||
unsigned maxIter = 1000;
|
||||
}
|
||||
|
||||
|
||||
unordered_map<VarId,VariableInfo> GraphicalModel::varsInfo_;
|
||||
unordered_map<unsigned,Distribution*> GraphicalModel::distsInfo_;
|
||||
|
||||
vector<NetInfo> Statistics::netInfo_;
|
||||
vector<CompressInfo> Statistics::compressInfo_;
|
||||
@ -58,76 +57,6 @@ fromLog (Params& v)
|
||||
|
||||
|
||||
|
||||
void
|
||||
normalize (Params& v)
|
||||
{
|
||||
double sum;
|
||||
if (Globals::logDomain) {
|
||||
sum = addIdenty();
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
logSum (sum, v[i]);
|
||||
}
|
||||
assert (sum != -numeric_limits<double>::infinity());
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] -= sum;
|
||||
}
|
||||
} else {
|
||||
sum = 0.0;
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
sum += v[i];
|
||||
}
|
||||
assert (sum != 0.0);
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
pow (Params& v, double expoent)
|
||||
{
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] *= expoent;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = std::pow (v[i], expoent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
pow (Params& v, unsigned expoent)
|
||||
{
|
||||
if (expoent == 1) {
|
||||
return;
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] *= expoent;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = std::pow (v[i], expoent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
pow (double p, unsigned expoent)
|
||||
{
|
||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
factorial (double num)
|
||||
{
|
||||
@ -153,6 +82,151 @@ nrCombinations (unsigned n, unsigned r)
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
expectedSize (const Ranges& ranges)
|
||||
{
|
||||
unsigned prod = 1;
|
||||
for (unsigned i = 0; i < ranges.size(); i++) {
|
||||
prod *= ranges[i];
|
||||
}
|
||||
return prod;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
getNumberOfDigits (int number)
|
||||
{
|
||||
unsigned count = 1;
|
||||
while (number >= 10) {
|
||||
number /= 10;
|
||||
count ++;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
isInteger (const string& s)
|
||||
{
|
||||
stringstream ss1 (s);
|
||||
stringstream ss2;
|
||||
int integer;
|
||||
ss1 >> integer;
|
||||
ss2 << integer;
|
||||
return (ss1.str() == ss2.str());
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
parametersToString (const Params& v, unsigned precision)
|
||||
{
|
||||
stringstream ss;
|
||||
ss.precision (precision);
|
||||
ss << "[" ;
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
ss << v[i];
|
||||
}
|
||||
ss << "]" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<string>
|
||||
getStateLines (const Vars& vars)
|
||||
{
|
||||
StatesIndexer idx (vars);
|
||||
vector<string> jointStrings;
|
||||
while (idx.valid()) {
|
||||
stringstream ss;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])];
|
||||
}
|
||||
jointStrings.push_back (ss.str());
|
||||
++ idx;
|
||||
}
|
||||
return jointStrings;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
printHeader (string header, std::ostream& os)
|
||||
{
|
||||
printAsteriskLine (os);
|
||||
os << header << endl;
|
||||
printAsteriskLine (os);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
printSubHeader (string header, std::ostream& os)
|
||||
{
|
||||
printDashedLine (os);
|
||||
os << header << endl;
|
||||
printDashedLine (os);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
printAsteriskLine (std::ostream& os)
|
||||
{
|
||||
os << "********************************" ;
|
||||
os << "********************************" ;
|
||||
os << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
printDashedLine (std::ostream& os)
|
||||
{
|
||||
os << "--------------------------------" ;
|
||||
os << "--------------------------------" ;
|
||||
os << endl;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
namespace LogAware {
|
||||
|
||||
void
|
||||
normalize (Params& v)
|
||||
{
|
||||
double sum;
|
||||
if (Globals::logDomain) {
|
||||
sum = LogAware::addIdenty();
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
sum = Util::logSum (sum, v[i]);
|
||||
}
|
||||
assert (sum != -numeric_limits<double>::infinity());
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] -= sum;
|
||||
}
|
||||
} else {
|
||||
sum = 0.0;
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
sum += v[i];
|
||||
}
|
||||
assert (sum != 0.0);
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
getL1Distance (const Params& v1, const Params& v2)
|
||||
{
|
||||
@ -196,67 +270,57 @@ getMaxNorm (const Params& v1, const Params& v2)
|
||||
}
|
||||
|
||||
|
||||
double
|
||||
pow (double p, unsigned expoent)
|
||||
{
|
||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||
}
|
||||
|
||||
unsigned
|
||||
getNumberOfDigits (int number) {
|
||||
unsigned count = 1;
|
||||
while (number >= 10) {
|
||||
number /= 10;
|
||||
count ++;
|
||||
|
||||
|
||||
double
|
||||
pow (double p, double expoent)
|
||||
{
|
||||
// assumes that `expoent' is never in log domain
|
||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
pow (Params& v, unsigned expoent)
|
||||
{
|
||||
if (expoent == 1) {
|
||||
return;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
isInteger (const string& s)
|
||||
{
|
||||
stringstream ss1 (s);
|
||||
stringstream ss2;
|
||||
int integer;
|
||||
ss1 >> integer;
|
||||
ss2 << integer;
|
||||
return (ss1.str() == ss2.str());
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
parametersToString (const Params& v, unsigned precision)
|
||||
{
|
||||
stringstream ss;
|
||||
ss.precision (precision);
|
||||
ss << "[" ;
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
ss << v[i];
|
||||
}
|
||||
ss << "]" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<string>
|
||||
getJointStateStrings (const VarNodes& vars)
|
||||
{
|
||||
StatesIndexer idx (vars);
|
||||
vector<string> jointStrings;
|
||||
while (idx.valid()) {
|
||||
stringstream ss;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])];
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] *= expoent;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = std::pow (v[i], expoent);
|
||||
}
|
||||
jointStrings.push_back (ss.str());
|
||||
++ idx;
|
||||
}
|
||||
return jointStrings;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
pow (Params& v, double expoent)
|
||||
{
|
||||
// assumes that `expoent' is never in log domain
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] *= expoent;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = std::pow (v[i], expoent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -286,8 +350,11 @@ Statistics::getPrimaryNetworksCounting (void)
|
||||
|
||||
|
||||
void
|
||||
Statistics::updateStatistics (unsigned size, bool loopy,
|
||||
unsigned nIters, double time)
|
||||
Statistics::updateStatistics (
|
||||
unsigned size,
|
||||
bool loopy,
|
||||
unsigned nIters,
|
||||
double time)
|
||||
{
|
||||
netInfo_.push_back (NetInfo (size, loopy, nIters, time));
|
||||
}
|
||||
@ -303,12 +370,12 @@ Statistics::printStatistics (void)
|
||||
|
||||
|
||||
void
|
||||
Statistics::writeStatisticsToFile (const char* fileName)
|
||||
Statistics::writeStatistics (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Statistics::writeStatisticsToFile()" << endl;
|
||||
cerr << "Statistics::writeStats()" << endl;
|
||||
abort();
|
||||
}
|
||||
out << getStatisticString();
|
||||
@ -318,13 +385,14 @@ Statistics::writeStatisticsToFile (const char* fileName)
|
||||
|
||||
|
||||
void
|
||||
Statistics::updateCompressingStatistics (unsigned nGroundVars,
|
||||
unsigned nGroundFactors,
|
||||
unsigned nClusterVars,
|
||||
unsigned nClusterFactors,
|
||||
unsigned nWithoutNeighs) {
|
||||
compressInfo_.push_back (CompressInfo (nGroundVars, nGroundFactors,
|
||||
nClusterVars, nClusterFactors, nWithoutNeighs));
|
||||
Statistics::updateCompressingStatistics (
|
||||
unsigned nrGroundVars,
|
||||
unsigned nrGroundFactors,
|
||||
unsigned nrClusterVars,
|
||||
unsigned nrClusterFactors,
|
||||
unsigned nrNeighborless) {
|
||||
compressInfo_.push_back (CompressInfo (nrGroundVars, nrGroundFactors,
|
||||
nrClusterVars, nrClusterFactors, nrNeighborless));
|
||||
}
|
||||
|
||||
|
||||
@ -334,26 +402,30 @@ Statistics::getStatisticString (void)
|
||||
{
|
||||
stringstream ss2, ss3, ss4, ss1;
|
||||
ss1 << "running mode: " ;
|
||||
switch (InfAlgorithms::infAlgorithm) {
|
||||
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
||||
case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break;
|
||||
case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break;
|
||||
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
|
||||
switch (Globals::infAlgorithm) {
|
||||
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
||||
case InfAlgorithms::BP: ss1 << "bp" << endl; break;
|
||||
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
|
||||
}
|
||||
ss1 << "message schedule: " ;
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_FIXED: ss1 << "sequential fixed" << endl; break;
|
||||
case BpOptions::Schedule::SEQ_RANDOM: ss1 << "sequential random" << endl; break;
|
||||
case BpOptions::Schedule::PARALLEL: ss1 << "parallel" << endl; break;
|
||||
case BpOptions::Schedule::MAX_RESIDUAL: ss1 << "max residual" << endl; break;
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
ss1 << "sequential fixed" << endl;
|
||||
break;
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
ss1 << "sequential random" << endl;
|
||||
break;
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
ss1 << "parallel" << endl;
|
||||
break;
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
ss1 << "max residual" << endl;
|
||||
break;
|
||||
}
|
||||
ss1 << "max iterations: " << BpOptions::maxIter << endl;
|
||||
ss1 << "accuracy " << BpOptions::accuracy << endl;
|
||||
ss1 << endl << endl;
|
||||
|
||||
ss2 << "---------------------------------------------------" << endl;
|
||||
ss2 << " Network information" << endl;
|
||||
ss2 << "---------------------------------------------------" << endl;
|
||||
Util::printSubHeader ("Network information", ss2);
|
||||
ss2 << left;
|
||||
ss2 << setw (15) << "Network Size" ;
|
||||
ss2 << setw (9) << "Loopy" ;
|
||||
@ -387,24 +459,22 @@ Statistics::getStatisticString (void)
|
||||
|
||||
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
|
||||
if (compressInfo_.size() > 0) {
|
||||
ss3 << "---------------------------------------------------" << endl;
|
||||
ss3 << " Compression information" << endl;
|
||||
ss3 << "---------------------------------------------------" << endl;
|
||||
Util::printSubHeader ("Compress information", ss3);
|
||||
ss3 << left;
|
||||
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
|
||||
ss3 << "Vars Vars Factors Factors Vars" << endl;
|
||||
for (unsigned i = 0; i < compressInfo_.size(); i++) {
|
||||
ss3 << setw (9) << compressInfo_[i].nGroundVars;
|
||||
ss3 << setw (10) << compressInfo_[i].nClusterVars;
|
||||
ss3 << setw (10) << compressInfo_[i].nGroundFactors;
|
||||
ss3 << setw (10) << compressInfo_[i].nClusterFactors;
|
||||
ss3 << setw (10) << compressInfo_[i].nWithoutNeighs;
|
||||
ss3 << setw (9) << compressInfo_[i].nrGroundVars;
|
||||
ss3 << setw (10) << compressInfo_[i].nrClusterVars;
|
||||
ss3 << setw (10) << compressInfo_[i].nrGroundFactors;
|
||||
ss3 << setw (10) << compressInfo_[i].nrClusterFactors;
|
||||
ss3 << setw (10) << compressInfo_[i].nrNeighborless;
|
||||
ss3 << endl;
|
||||
c1 += compressInfo_[i].nGroundVars - compressInfo_[i].nWithoutNeighs;
|
||||
c2 += compressInfo_[i].nClusterVars;
|
||||
c3 += compressInfo_[i].nGroundFactors - compressInfo_[i].nWithoutNeighs;
|
||||
c4 += compressInfo_[i].nClusterFactors;
|
||||
if (compressInfo_[i].nWithoutNeighs != 0) {
|
||||
c1 += compressInfo_[i].nrGroundVars - compressInfo_[i].nrNeighborless;
|
||||
c2 += compressInfo_[i].nrClusterVars;
|
||||
c3 += compressInfo_[i].nrGroundFactors - compressInfo_[i].nrNeighborless;
|
||||
c4 += compressInfo_[i].nrClusterFactors;
|
||||
if (compressInfo_[i].nrNeighborless != 0) {
|
||||
c2 --;
|
||||
c4 --;
|
||||
}
|
||||
|
@ -1,53 +1,141 @@
|
||||
#ifndef HORUS_UTIL_H
|
||||
#define HORUS_UTIL_H
|
||||
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
namespace Util {
|
||||
|
||||
void toLog (Params&);
|
||||
void fromLog (Params&);
|
||||
void normalize (Params&);
|
||||
void logSum (double&, double);
|
||||
void multiply (Params&, const Params&);
|
||||
void multiply (Params&, const Params&, unsigned);
|
||||
void add (Params&, const Params&);
|
||||
void add (Params&, const Params&, unsigned);
|
||||
void pow (Params&, double);
|
||||
void pow (Params&, unsigned);
|
||||
double pow (double, unsigned);
|
||||
double factorial (double);
|
||||
unsigned nrCombinations (unsigned, unsigned);
|
||||
double getL1Distance (const Params&, const Params&);
|
||||
double getMaxNorm (const Params&, const Params&);
|
||||
unsigned getNumberOfDigits (int);
|
||||
bool isInteger (const string&);
|
||||
string parametersToString (const Params&, unsigned = PRECISION);
|
||||
vector<string> getJointStateStrings (const VarNodes&);
|
||||
double tl (double);
|
||||
double fl (double);
|
||||
double multIdenty();
|
||||
double addIdenty();
|
||||
double withEvidence();
|
||||
double noEvidence();
|
||||
double one();
|
||||
double zero();
|
||||
template <typename T> void addToVector (vector<T>&, const vector<T>&);
|
||||
|
||||
template <typename T> void addToSet (set<T>&, const vector<T>&);
|
||||
|
||||
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
|
||||
|
||||
template <typename T> bool contains (const vector<T>&, const T&);
|
||||
|
||||
template <typename T> bool contains (const set<T>&, const T&);
|
||||
|
||||
template <typename K, typename V> bool contains (
|
||||
const unordered_map<K, V>&, const K&);
|
||||
|
||||
template <typename T> std::string toString (const T&);
|
||||
|
||||
void toLog (Params&);
|
||||
|
||||
void fromLog (Params&);
|
||||
|
||||
double logSum (double, double);
|
||||
|
||||
void multiply (Params&, const Params&);
|
||||
|
||||
void multiply (Params&, const Params&, unsigned);
|
||||
|
||||
void add (Params&, const Params&);
|
||||
|
||||
void add (Params&, const Params&, unsigned);
|
||||
|
||||
double factorial (double);
|
||||
|
||||
unsigned nrCombinations (unsigned, unsigned);
|
||||
|
||||
unsigned expectedSize (const Ranges&);
|
||||
|
||||
unsigned getNumberOfDigits (int);
|
||||
|
||||
bool isInteger (const string&);
|
||||
|
||||
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
||||
|
||||
vector<string> getStateLines (const Vars&);
|
||||
|
||||
void printHeader (string, std::ostream& os = std::cout);
|
||||
|
||||
void printSubHeader (string, std::ostream& os = std::cout);
|
||||
|
||||
void printAsteriskLine (std::ostream& os = std::cout);
|
||||
|
||||
void printDashedLine (std::ostream& os = std::cout);
|
||||
|
||||
unsigned maxUnsigned (void);
|
||||
|
||||
};
|
||||
|
||||
|
||||
template <class T>
|
||||
std::string toString (const T& t)
|
||||
|
||||
template <typename T> void
|
||||
Util::addToVector (vector<T>& v, const vector<T>& elements)
|
||||
{
|
||||
v.insert (v.end(), elements.begin(), elements.end());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToSet (set<T>& s, const vector<T>& elements)
|
||||
{
|
||||
s.insert (elements.begin(), elements.end());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
||||
{
|
||||
for (unsigned i = 0; i < elements.size(); i++) {
|
||||
q.push (elements[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
Util::contains (const vector<T>& v, const T& e)
|
||||
{
|
||||
return std::find (v.begin(), v.end(), e) != v.end();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
Util::contains (const set<T>& s, const T& e)
|
||||
{
|
||||
return s.find (e) != s.end();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename K, typename V> bool
|
||||
Util::contains (const unordered_map<K, V>& m, const K& k)
|
||||
{
|
||||
return m.find (k) != m.end();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> std::string
|
||||
Util::toString (const T& t)
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << t;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
@ -62,28 +150,31 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
|
||||
}
|
||||
|
||||
|
||||
namespace {
|
||||
const double INF = -numeric_limits<double>::infinity();
|
||||
};
|
||||
|
||||
|
||||
inline void
|
||||
Util::logSum (double& x, double y)
|
||||
inline double
|
||||
Util::logSum (double x, double y)
|
||||
{
|
||||
x = log (exp (x) + exp (y)); return;
|
||||
return log (exp (x) + exp (y));
|
||||
assert (isfinite (x) && isfinite (y));
|
||||
// If one value is much smaller than the other, keep the larger value.
|
||||
if (x < (y - log (1e200))) {
|
||||
x = y;
|
||||
return;
|
||||
return y;
|
||||
}
|
||||
if (y < (x - log (1e200))) {
|
||||
return;
|
||||
return x;
|
||||
}
|
||||
double diff = x - y;
|
||||
assert (isfinite (diff) && isfinite (x) && isfinite (y));
|
||||
if (!isfinite (exp (diff))) { // difference is too large
|
||||
x = x > y ? x : y;
|
||||
} else { // otherwise return the sum.
|
||||
x = y + log (static_cast<double>(1.0) + exp (diff));
|
||||
if (!isfinite (exp (diff))) {
|
||||
// difference is too large
|
||||
return x > y ? x : y;
|
||||
}
|
||||
// otherwise return the sum.
|
||||
return y + log (static_cast<double>(1.0) + exp (diff));
|
||||
}
|
||||
|
||||
|
||||
@ -140,52 +231,87 @@ Util::add (Params& v1, const Params& v2, unsigned repetitions)
|
||||
|
||||
|
||||
|
||||
inline double
|
||||
Util::tl (double v)
|
||||
inline unsigned
|
||||
Util::maxUnsigned (void)
|
||||
{
|
||||
return Globals::logDomain ? log(v) : v;
|
||||
return numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::fl (double v)
|
||||
{
|
||||
return Globals::logDomain ? exp(v) : v;
|
||||
}
|
||||
|
||||
|
||||
namespace LogAware {
|
||||
|
||||
inline double
|
||||
Util::multIdenty() {
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::addIdenty()
|
||||
{
|
||||
return Globals::logDomain ? INF : 0.0;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::withEvidence()
|
||||
one()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::noEvidence() {
|
||||
return Globals::logDomain ? INF : 0.0;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::one()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::zero() {
|
||||
zero() {
|
||||
return Globals::logDomain ? INF : 0.0 ;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
addIdenty()
|
||||
{
|
||||
return Globals::logDomain ? INF : 0.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
multIdenty()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
withEvidence()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
noEvidence() {
|
||||
return Globals::logDomain ? INF : 0.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
tl (double v)
|
||||
{
|
||||
return Globals::logDomain ? log (v) : v;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
fl (double v)
|
||||
{
|
||||
return Globals::logDomain ? exp (v) : v;
|
||||
}
|
||||
|
||||
|
||||
void normalize (Params&);
|
||||
|
||||
double getL1Distance (const Params&, const Params&);
|
||||
|
||||
double getMaxNorm (const Params&, const Params&);
|
||||
|
||||
double pow (double, unsigned);
|
||||
|
||||
double pow (double, double);
|
||||
|
||||
void pow (Params&, unsigned);
|
||||
|
||||
void pow (Params&, double);
|
||||
|
||||
};
|
||||
|
||||
|
||||
struct NetInfo
|
||||
{
|
||||
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
|
||||
@ -206,17 +332,17 @@ struct CompressInfo
|
||||
{
|
||||
CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e)
|
||||
{
|
||||
nGroundVars = a;
|
||||
nGroundFactors = b;
|
||||
nClusterVars = c;
|
||||
nClusterFactors = d;
|
||||
nWithoutNeighs = e;
|
||||
nrGroundVars = a;
|
||||
nrGroundFactors = b;
|
||||
nrClusterVars = c;
|
||||
nrClusterFactors = d;
|
||||
nrNeighborless = e;
|
||||
}
|
||||
unsigned nGroundVars;
|
||||
unsigned nGroundFactors;
|
||||
unsigned nClusterVars;
|
||||
unsigned nClusterFactors;
|
||||
unsigned nWithoutNeighs;
|
||||
unsigned nrGroundVars;
|
||||
unsigned nrGroundFactors;
|
||||
unsigned nrClusterVars;
|
||||
unsigned nrClusterFactors;
|
||||
unsigned nrNeighborless;
|
||||
};
|
||||
|
||||
|
||||
@ -224,11 +350,17 @@ class Statistics
|
||||
{
|
||||
public:
|
||||
static unsigned getSolvedNetworksCounting (void);
|
||||
|
||||
static void incrementPrimaryNetworksCounting (void);
|
||||
|
||||
static unsigned getPrimaryNetworksCounting (void);
|
||||
|
||||
static void updateStatistics (unsigned, bool, unsigned, double);
|
||||
|
||||
static void printStatistics (void);
|
||||
static void writeStatisticsToFile (const char*);
|
||||
|
||||
static void writeStatistics (const char*);
|
||||
|
||||
static void updateCompressingStatistics (
|
||||
unsigned, unsigned, unsigned, unsigned, unsigned);
|
||||
|
||||
|
102
packages/CLPBN/clpbn/bp/Var.cpp
Normal file
102
packages/CLPBN/clpbn/bp/Var.cpp
Normal file
@ -0,0 +1,102 @@
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
|
||||
#include "Var.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
unordered_map<VarId, VarInfo> Var::varsInfo_;
|
||||
|
||||
|
||||
Var::Var (const Var* v)
|
||||
{
|
||||
varId_ = v->varId();
|
||||
range_ = v->range();
|
||||
evidence_ = v->getEvidence();
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
Var::Var (VarId varId, unsigned range, int evidence)
|
||||
{
|
||||
assert (range != 0);
|
||||
assert (evidence < (int) range);
|
||||
varId_ = varId;
|
||||
range_ = range;
|
||||
evidence_ = evidence;
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Var::isValidState (int stateIndex)
|
||||
{
|
||||
return stateIndex >= 0 && stateIndex < (int) range_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Var::isValidState (const string& stateName)
|
||||
{
|
||||
States states = Var::getVarInfo (varId_).states;
|
||||
return Util::contains (states, stateName);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Var::setEvidence (int ev)
|
||||
{
|
||||
assert (ev < (int) range_);
|
||||
evidence_ = ev;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Var::setEvidence (const string& ev)
|
||||
{
|
||||
States states = Var::getVarInfo (varId_).states;
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
if (states[i] == ev) {
|
||||
evidence_ = i;
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert (false);
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
Var::label (void) const
|
||||
{
|
||||
if (Var::varsHaveInfo()) {
|
||||
return Var::getVarInfo (varId_).label;
|
||||
}
|
||||
stringstream ss;
|
||||
ss << "x" << varId_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
States
|
||||
Var::states (void) const
|
||||
{
|
||||
if (Var::varsHaveInfo()) {
|
||||
return Var::getVarInfo (varId_).states;
|
||||
}
|
||||
States states;
|
||||
for (unsigned i = 0; i < range_; i++) {
|
||||
stringstream ss;
|
||||
ss << i ;
|
||||
states.push_back (ss.str());
|
||||
}
|
||||
return states;
|
||||
}
|
||||
|
108
packages/CLPBN/clpbn/bp/Var.h
Normal file
108
packages/CLPBN/clpbn/bp/Var.h
Normal file
@ -0,0 +1,108 @@
|
||||
#ifndef HORUS_Var_H
|
||||
#define HORUS_Var_H
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "Util.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
struct VarInfo
|
||||
{
|
||||
VarInfo (string l, const States& sts) : label(l), states(sts) { }
|
||||
string label;
|
||||
States states;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class Var
|
||||
{
|
||||
public:
|
||||
Var (const Var*);
|
||||
|
||||
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
||||
|
||||
virtual ~Var (void) { };
|
||||
|
||||
unsigned varId (void) const { return varId_; }
|
||||
|
||||
unsigned range (void) const { return range_; }
|
||||
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
|
||||
unsigned getIndex (void) const { return index_; }
|
||||
|
||||
void setIndex (unsigned idx) { index_ = idx; }
|
||||
|
||||
operator unsigned () const { return index_; }
|
||||
|
||||
bool hasEvidence (void) const
|
||||
{
|
||||
return evidence_ != Constants::NO_EVIDENCE;
|
||||
}
|
||||
|
||||
bool operator== (const Var& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ == var.varId();
|
||||
}
|
||||
|
||||
bool operator!= (const Var& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ != var.varId();
|
||||
}
|
||||
|
||||
bool isValidState (int);
|
||||
|
||||
bool isValidState (const string&);
|
||||
|
||||
void setEvidence (int);
|
||||
|
||||
void setEvidence (const string&);
|
||||
|
||||
string label (void) const;
|
||||
|
||||
States states (void) const;
|
||||
|
||||
static void addVarInfo (
|
||||
VarId vid, string label, const States& states)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid) == false);
|
||||
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
||||
}
|
||||
|
||||
static VarInfo getVarInfo (VarId vid)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid));
|
||||
return varsInfo_.find (vid)->second;
|
||||
}
|
||||
|
||||
static bool varsHaveInfo (void)
|
||||
{
|
||||
return varsInfo_.size() != 0;
|
||||
}
|
||||
|
||||
static void clearVarsInfo (void)
|
||||
{
|
||||
varsInfo_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
VarId varId_;
|
||||
unsigned range_;
|
||||
int evidence_;
|
||||
unsigned index_;
|
||||
|
||||
static unordered_map<VarId, VarInfo> varsInfo_;
|
||||
|
||||
};
|
||||
|
||||
#endif // BP_Var_H
|
||||
|
@ -6,61 +6,27 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
VarElimSolver::VarElimSolver (const BayesNet& bn) : Solver (&bn)
|
||||
{
|
||||
bayesNet_ = &bn;
|
||||
factorGraph_ = new FactorGraph (bn);
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (&fg)
|
||||
{
|
||||
bayesNet_ = 0;
|
||||
factorGraph_ = &fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarElimSolver::~VarElimSolver (void)
|
||||
{
|
||||
if (bayesNet_) {
|
||||
delete factorGraph_;
|
||||
}
|
||||
delete factorList_.back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
VarElimSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
assert (factorGraph_->getFgVarNode (vid));
|
||||
FgVarNode* vn = factorGraph_->getFgVarNode (vid);
|
||||
if (vn->hasEvidence()) {
|
||||
Params params (vn->nrStates(), 0.0);
|
||||
params[vn->getEvidence()] = 1.0;
|
||||
return params;
|
||||
}
|
||||
return getJointDistributionOf (VarIds() = {vid});
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
||||
VarElimSolver::solveQuery (VarIds queryVids)
|
||||
{
|
||||
factorList_.clear();
|
||||
varFactors_.clear();
|
||||
elimOrder_.clear();
|
||||
createFactorList();
|
||||
introduceEvidence();
|
||||
chooseEliminationOrder (vids);
|
||||
processFactorList (vids);
|
||||
Params params = factorList_.back()->getParameters();
|
||||
absorveEvidence();
|
||||
findEliminationOrder (queryVids);
|
||||
processFactorList (queryVids);
|
||||
Params params = factorList_.back()->params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
delete factorList_.back();
|
||||
return params;
|
||||
}
|
||||
|
||||
@ -69,11 +35,11 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
||||
void
|
||||
VarElimSolver::createFactorList (void)
|
||||
{
|
||||
const FgFacSet& factorNodes = factorGraph_->getFactorNodes();
|
||||
factorList_.reserve (factorNodes.size() * 2);
|
||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
||||
factorList_.push_back (new Factor (*factorNodes[i]->factor()));
|
||||
const FgVarSet& neighs = factorNodes[i]->neighbors();
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
factorList_.reserve (facNodes.size() * 2);
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
unordered_map<VarId,vector<unsigned> >::iterator it
|
||||
= varFactors_.find (neighs[j]->varId());
|
||||
@ -89,16 +55,16 @@ VarElimSolver::createFactorList (void)
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::introduceEvidence (void)
|
||||
VarElimSolver::absorveEvidence (void)
|
||||
{
|
||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
if (varNodes[i]->hasEvidence()) {
|
||||
const vector<unsigned>& idxs =
|
||||
varFactors_.find (varNodes[i]->varId())->second;
|
||||
for (unsigned j = 0; j < idxs.size(); j++) {
|
||||
Factor* factor = factorList_[idxs[j]];
|
||||
if (factor->nrVariables() == 1) {
|
||||
if (factor->nrArguments() == 1) {
|
||||
factorList_[idxs[j]] = 0;
|
||||
} else {
|
||||
factorList_[idxs[j]]->absorveEvidence (
|
||||
@ -112,21 +78,9 @@ VarElimSolver::introduceEvidence (void)
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::chooseEliminationOrder (const VarIds& vids)
|
||||
VarElimSolver::findEliminationOrder (const VarIds& vids)
|
||||
{
|
||||
if (bayesNet_) {
|
||||
ElimGraph graph (*bayesNet_);
|
||||
elimOrder_ = graph.getEliminatingOrder (vids);
|
||||
} else {
|
||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
VarId vid = varNodes[i]->varId();
|
||||
if (std::find (vids.begin(), vids.end(), vid) == vids.end()
|
||||
&& !varNodes[i]->hasEvidence()) {
|
||||
elimOrder_.push_back (vid);
|
||||
}
|
||||
}
|
||||
}
|
||||
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
|
||||
}
|
||||
|
||||
|
||||
@ -149,12 +103,12 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
||||
|
||||
VarIds unobservedVids;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
if (factorGraph_->getFgVarNode (vids[i])->hasEvidence() == false) {
|
||||
if (fg.getVarNode (vids[i])->hasEvidence() == false) {
|
||||
unobservedVids.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
|
||||
finalFactor->reorderVariables (unobservedVids);
|
||||
finalFactor->reorderArguments (unobservedVids);
|
||||
finalFactor->normalize();
|
||||
factorList_.push_back (finalFactor);
|
||||
}
|
||||
@ -165,13 +119,12 @@ void
|
||||
VarElimSolver::eliminate (VarId elimVar)
|
||||
{
|
||||
Factor* result = 0;
|
||||
FgVarNode* vn = factorGraph_->getFgVarNode (elimVar);
|
||||
vector<unsigned>& idxs = varFactors_.find (elimVar)->second;
|
||||
for (unsigned i = 0; i < idxs.size(); i++) {
|
||||
unsigned idx = idxs[i];
|
||||
if (factorList_[idx]) {
|
||||
if (result == 0) {
|
||||
result = new Factor(*factorList_[idx]);
|
||||
result = new Factor (*factorList_[idx]);
|
||||
} else {
|
||||
result->multiply (*factorList_[idx]);
|
||||
}
|
||||
@ -179,10 +132,10 @@ VarElimSolver::eliminate (VarId elimVar)
|
||||
factorList_[idx] = 0;
|
||||
}
|
||||
}
|
||||
if (result != 0 && result->nrVariables() != 1) {
|
||||
result->sumOut (vn->varId());
|
||||
if (result != 0 && result->nrArguments() != 1) {
|
||||
result->sumOut (elimVar);
|
||||
factorList_.push_back (result);
|
||||
const VarIds& resultVarIds = result->getVarIds();
|
||||
const VarIds& resultVarIds = result->arguments();
|
||||
for (unsigned i = 0; i < resultVarIds.size(); i++) {
|
||||
vector<unsigned>& idxs =
|
||||
varFactors_.find (resultVarIds[i])->second;
|
||||
@ -199,7 +152,6 @@ VarElimSolver::printActiveFactors (void)
|
||||
for (unsigned i = 0; i < factorList_.size(); i++) {
|
||||
if (factorList_[i] != 0) {
|
||||
factorList_[i]->print();
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
#include "Solver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "BayesNet.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
@ -15,23 +14,25 @@ using namespace std;
|
||||
class VarElimSolver : public Solver
|
||||
{
|
||||
public:
|
||||
VarElimSolver (const BayesNet&);
|
||||
VarElimSolver (const FactorGraph&);
|
||||
VarElimSolver (const FactorGraph& fg) : Solver (fg) { }
|
||||
|
||||
~VarElimSolver (void);
|
||||
void runSolver (void) { }
|
||||
Params getPosterioriOf (VarId);
|
||||
Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
Params solveQuery (VarIds);
|
||||
|
||||
private:
|
||||
void createFactorList (void);
|
||||
void introduceEvidence (void);
|
||||
void chooseEliminationOrder (const VarIds&);
|
||||
|
||||
void absorveEvidence (void);
|
||||
|
||||
void findEliminationOrder (const VarIds&);
|
||||
|
||||
void processFactorList (const VarIds&);
|
||||
|
||||
void eliminate (VarId);
|
||||
|
||||
void printActiveFactors (void);
|
||||
|
||||
const BayesNet* bayesNet_;
|
||||
const FactorGraph* factorGraph_;
|
||||
vector<Factor*> factorList_;
|
||||
VarIds elimOrder_;
|
||||
unordered_map<VarId, vector<unsigned>> varFactors_;
|
||||
|
@ -1,100 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
|
||||
#include "VarNode.h"
|
||||
#include "GraphicalModel.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
VarNode::VarNode (const VarNode* v)
|
||||
{
|
||||
varId_ = v->varId();
|
||||
nrStates_ = v->nrStates();
|
||||
evidence_ = v->getEvidence();
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarNode::VarNode (VarId varId, unsigned nrStates, int evidence)
|
||||
{
|
||||
assert (nrStates != 0);
|
||||
assert (evidence < (int) nrStates);
|
||||
varId_ = varId;
|
||||
nrStates_ = nrStates;
|
||||
evidence_ = evidence;
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
VarNode::isValidState (int stateIndex)
|
||||
{
|
||||
return stateIndex >= 0 && stateIndex < (int) nrStates_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
VarNode::isValidState (const string& stateName)
|
||||
{
|
||||
States states = GraphicalModel::getVariableInformation (varId_).states;
|
||||
return find (states.begin(), states.end(), stateName) != states.end();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarNode::setEvidence (int ev)
|
||||
{
|
||||
assert (ev < (int) nrStates_);
|
||||
evidence_ = ev;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarNode::setEvidence (const string& ev)
|
||||
{
|
||||
States states = GraphicalModel::getVariableInformation (varId_).states;
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
if (states[i] == ev) {
|
||||
evidence_ = i;
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert (false);
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
VarNode::label (void) const
|
||||
{
|
||||
if (GraphicalModel::variablesHaveInformation()) {
|
||||
return GraphicalModel::getVariableInformation (varId_).label;
|
||||
}
|
||||
stringstream ss;
|
||||
ss << "x" << varId_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
States
|
||||
VarNode::states (void) const
|
||||
{
|
||||
if (GraphicalModel::variablesHaveInformation()) {
|
||||
return GraphicalModel::getVariableInformation (varId_).states;
|
||||
}
|
||||
States states;
|
||||
for (unsigned i = 0; i < nrStates_; i++) {
|
||||
stringstream ss;
|
||||
ss << i ;
|
||||
states.push_back (ss.str());
|
||||
}
|
||||
return states;
|
||||
}
|
||||
|
@ -1,54 +0,0 @@
|
||||
#ifndef HORUS_VARNODE_H
|
||||
#define HORUS_VARNODE_H
|
||||
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class VarNode
|
||||
{
|
||||
public:
|
||||
VarNode (const VarNode*);
|
||||
VarNode (VarId, unsigned, int = NO_EVIDENCE);
|
||||
virtual ~VarNode (void) {};
|
||||
|
||||
bool isValidState (int);
|
||||
bool isValidState (const string&);
|
||||
void setEvidence (int);
|
||||
void setEvidence (const string&);
|
||||
string label (void) const;
|
||||
States states (void) const;
|
||||
|
||||
unsigned varId (void) const { return varId_; }
|
||||
unsigned nrStates (void) const { return nrStates_; }
|
||||
bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; }
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
unsigned getIndex (void) const { return index_; }
|
||||
void setIndex (unsigned idx) { index_ = idx; }
|
||||
|
||||
operator unsigned () const { return index_; }
|
||||
|
||||
bool operator== (const VarNode& var) const
|
||||
{
|
||||
cout << "equal operator called" << endl;
|
||||
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
|
||||
return varId_ == var.varId();
|
||||
}
|
||||
|
||||
bool operator!= (const VarNode& var) const
|
||||
{
|
||||
cout << "diff operator called" << endl;
|
||||
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
|
||||
return varId_ != var.varId();
|
||||
}
|
||||
|
||||
private:
|
||||
VarId varId_;
|
||||
unsigned nrStates_;
|
||||
int evidence_;
|
||||
unsigned index_;
|
||||
|
||||
};
|
||||
|
||||
#endif // BP_VARNODE_H
|
||||
|
35
packages/CLPBN/clpbn/bp/benchmarks/benchs.sh
Executable file
35
packages/CLPBN/clpbn/bp/benchmarks/benchs.sh
Executable file
@ -0,0 +1,35 @@
|
||||
|
||||
if [ $1 ] && [ $1 == "clear" ]; then
|
||||
rm *~
|
||||
rm -f school/*.log school/*~
|
||||
rm -f city/*.log city/*~
|
||||
rm -f workshop_attrs/*.log workshop_attrs/*~
|
||||
fi
|
||||
|
||||
function run_solver
|
||||
{
|
||||
constraint=$1
|
||||
solver_flag=true
|
||||
if [ -n "$2" ]; then
|
||||
if [ $SOLVER = hve ]; then
|
||||
extra_flag=clpbn_horus:set_horus_flag\(elim_heuristic,$2\)
|
||||
elif [ $SOLVER = bp ]; then
|
||||
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||
elif [ $SOLVER = cbp ]; then
|
||||
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||
else
|
||||
echo "unknow flag $2"
|
||||
fi
|
||||
fi
|
||||
/usr/bin/time -o $LOG_FILE -a -f "real:%E\tuser:%U\tsys:%S" \
|
||||
$YAP << EOF >> $LOG_FILE 2>> ignore.$LOG_FILE
|
||||
[$NETWORK].
|
||||
[$constraint].
|
||||
clpbn_horus:set_solver($SOLVER).
|
||||
clpbn_horus:set_horus_flag(use_logarithms, true).
|
||||
$solver_flag.
|
||||
$QUERY.
|
||||
open("$LOG_FILE", 'append', S), format(S, '$constraint: ~15+ ', []), close(S).
|
||||
EOF
|
||||
}
|
||||
|
@ -1,50 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
cp ~/bin/yap ~/bin/town_bnbp
|
||||
YAP=~/bin/town_bnbp
|
||||
|
||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
||||
OUT_FILE_NAME=bnbp.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
|
||||
function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
||||
$extra_flag1, $extra_flag2,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
format(S, '$3: ~15+ ',[]),
|
||||
close(S).
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
||||
}
|
||||
|
||||
run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
|
||||
|
17
packages/CLPBN/clpbn/bp/benchmarks/city/bp_tests.sh
Executable file
17
packages/CLPBN/clpbn/bp/benchmarks/city/bp_tests.sh
Executable file
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
source city.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="bp"
|
||||
|
||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||
|
||||
LOG_FILE=$SOLVER.log
|
||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||
|
||||
rm -f $LOG_FILE
|
||||
rm -f ignore.$LOG_FILE
|
||||
|
||||
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
||||
|
@ -1,54 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
cp ~/bin/yap ~/bin/town_cbp
|
||||
YAP=~/bin/town_cbp
|
||||
source city.sh
|
||||
source ../benchs.sh
|
||||
|
||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
||||
OUT_FILE_NAME=cbp.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
SOLVER="cbp"
|
||||
|
||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||
|
||||
function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
||||
$extra_flag1, $extra_flag2,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
format(S, '$3: ~15+ ',[]),
|
||||
close(S).
|
||||
EOF
|
||||
}
|
||||
LOG_FILE=$SOLVER.log
|
||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||
|
||||
rm -f $LOG_FILE
|
||||
rm -f ignore.$LOG_FILE
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
||||
run_solver town_2500000 $1 town_2500000 $3 $4 $5
|
||||
run_solver town_5000000 $1 town_5000000 $3 $4 $5
|
||||
run_solver town_7500000 $1 town_7500000 $3 $4 $5
|
||||
run_solver town_10000000 $1 town_10000000 $3 $4 $5
|
||||
}
|
||||
|
||||
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
||||
|
||||
|
25
packages/CLPBN/clpbn/bp/benchmarks/city/city.sh
Executable file
25
packages/CLPBN/clpbn/bp/benchmarks/city/city.sh
Executable file
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
NETWORK="'../../examples/city'"
|
||||
SHORTNAME="city"
|
||||
QUERY="is_joe_guilty(X)"
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
cp ~/bin/yap $YAP
|
||||
echo -n "**********************************" >> $LOG_FILE
|
||||
echo "**********************************" >> $LOG_FILE
|
||||
echo "results for solver $1" >> $LOG_FILE
|
||||
echo -n "**********************************" >> $LOG_FILE
|
||||
echo "**********************************" >> $LOG_FILE
|
||||
run_solver city_5 $2
|
||||
#run_solver city_1000 $2
|
||||
#run_solver city_5000 $2
|
||||
#run_solver city_10000 $2
|
||||
#run_solver city_50000 $2
|
||||
#run_solver city_100000 $2
|
||||
#run_solver city_500000 $2
|
||||
#run_solver city_1000000 $2
|
||||
}
|
||||
|
37
packages/CLPBN/clpbn/bp/benchmarks/city/city_generator.sh
Executable file
37
packages/CLPBN/clpbn/bp/benchmarks/city/city_generator.sh
Executable file
@ -0,0 +1,37 @@
|
||||
#!/home/tiago/bin/yap -L --
|
||||
|
||||
|
||||
:- initialization(main).
|
||||
|
||||
|
||||
main :-
|
||||
unix(argv([H])),
|
||||
generate_town(H).
|
||||
|
||||
|
||||
generate_town(N) :-
|
||||
atomic_concat(['city_', N, '.yap'], FileName),
|
||||
open(FileName, 'write', S),
|
||||
atom_number(N, N2),
|
||||
generate_people(S, N2, 4),
|
||||
write(S, '\n'),
|
||||
generate_query(S, N2, 4),
|
||||
write(S, '\n'),
|
||||
close(S).
|
||||
|
||||
|
||||
generate_people(S, N, Counting) :-
|
||||
Counting > N, !.
|
||||
generate_people(S, N, Counting) :-
|
||||
format(S, 'people(p~w, nyc).~n', [Counting]),
|
||||
Counting1 is Counting + 1,
|
||||
generate_people(S, N, Counting1).
|
||||
|
||||
|
||||
generate_query(S, N, Counting) :-
|
||||
Counting > N, !.
|
||||
generate_query(S, N, Counting) :- !,
|
||||
format(S, 'ev(descn(p~w, t)).~n', [Counting]),
|
||||
Counting1 is Counting + 1,
|
||||
generate_query(S, N, Counting1).
|
||||
|
@ -1,50 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
cp ~/bin/yap ~/bin/town_fgbp
|
||||
YAP=~/bin/town_fgbp
|
||||
|
||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
||||
OUT_FILE_NAME=fb_bp.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
|
||||
function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
||||
$extra_flag1, $extra_flag2,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
format(S, '$3: ~15+ ',[]),
|
||||
close(S).
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||
#run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||
#run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||
#run_solver town_50000 $1 town_50000 $3 $4 $5
|
||||
#run_solver town_100000 $1 town_100000 $3 $4 $5
|
||||
#run_solver town_500000 $1 town_500000 $3 $4 $5
|
||||
#run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
||||
}
|
||||
|
||||
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
||||
|
17
packages/CLPBN/clpbn/bp/benchmarks/city/fove_tests.sh
Executable file
17
packages/CLPBN/clpbn/bp/benchmarks/city/fove_tests.sh
Executable file
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
source city.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="fove"
|
||||
|
||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||
|
||||
LOG_FILE=$SOLVER.log
|
||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||
|
||||
rm -f $LOG_FILE
|
||||
rm -f ignore.$LOG_FILEE
|
||||
|
||||
run_all_graphs "fove "
|
||||
|
@ -1,50 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
cp ~/bin/yap ~/bin/town_gibbs
|
||||
YAP=~/bin/town_gibbs
|
||||
|
||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
||||
OUT_FILE_NAME=gibbs.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
|
||||
function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
||||
$extra_flag1, $extra_flag2,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
format(S, '$3: ~15+ ',[]),
|
||||
close(S).
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
||||
}
|
||||
|
||||
run_all_graphs gibbs "gibbs "
|
||||
|
17
packages/CLPBN/clpbn/bp/benchmarks/city/hve_tests.sh
Executable file
17
packages/CLPBN/clpbn/bp/benchmarks/city/hve_tests.sh
Executable file
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
source city.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="hve"
|
||||
|
||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||
|
||||
LOG_FILE=$SOLVER.log
|
||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||
|
||||
rm -f $LOG_FILE
|
||||
rm -f ignore.$LOG_FILE
|
||||
|
||||
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
||||
|
@ -1,50 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
cp ~/bin/yap ~/bin/town_jt
|
||||
YAP=~/bin/town_jt
|
||||
|
||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
||||
OUT_FILE_NAME=jt.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
|
||||
function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
||||
$extra_flag1, $extra_flag2,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
format(S, '$3: ~15+ ',[]),
|
||||
close(S).
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
||||
}
|
||||
|
||||
run_all_graphs jt "jt "
|
||||
|
@ -1,65 +0,0 @@
|
||||
|
||||
conservative_city(City, Cons) :-
|
||||
cons_table(City, ConsDist),
|
||||
{ Cons = conservative_city(City) with p([y,n], ConsDist) }.
|
||||
|
||||
|
||||
gender(X, Gender) :-
|
||||
gender_table(X, GenderDist),
|
||||
{ Gender = gender(X) with p([m,f], GenderDist) }.
|
||||
|
||||
|
||||
hair_color(X, Color) :-
|
||||
lives(X, City),
|
||||
conservative_city(City, Cons),
|
||||
hair_color_table(X,ColorTable),
|
||||
{ Color = hair_color(X) with
|
||||
p([t,f], ColorTable,[Cons]) }.
|
||||
|
||||
|
||||
car_color(X, Color) :-
|
||||
hair_color(X, HColor),
|
||||
car_color_table(X,CColorTable),
|
||||
{ Color = car_color(X) with
|
||||
p([t,f], CColorTable,[HColor]) }.
|
||||
|
||||
|
||||
height(X, Height) :-
|
||||
gender(X, Gender),
|
||||
height_table(X,HeightTable),
|
||||
{ Height = height(X) with
|
||||
p([t,f], HeightTable,[Gender]) }.
|
||||
|
||||
|
||||
shoe_size(X, Shoesize) :-
|
||||
height(X, Height),
|
||||
shoe_size_table(X,ShoesizeTable),
|
||||
{ Shoesize = shoe_size(X) with
|
||||
p([t,f], ShoesizeTable,[Height]) }.
|
||||
|
||||
|
||||
guilty(X, Guilt) :-
|
||||
guilty_table(X, GuiltDist),
|
||||
{ Guilt = guilty(X) with p([y,n], GuiltDist) }.
|
||||
|
||||
|
||||
descn(X, Descn) :-
|
||||
car_color(X, Car),
|
||||
hair_color(X, Hair),
|
||||
height(X, Height),
|
||||
guilty(X, Guilt),
|
||||
descn_table(X, DescTable),
|
||||
{ Descn = descn(X) with
|
||||
p([t,f], DescTable,[Car,Hair,Height,Guilt]) }.
|
||||
|
||||
|
||||
witness(City, Witness) :-
|
||||
descn(joe, DescnJ),
|
||||
descn(p2, Descn2),
|
||||
wit_table(WitTable),
|
||||
{ Witness = witness(City) with
|
||||
p([t,f], WitTable,[DescnJ, Descn2]) }.
|
||||
|
||||
|
||||
:- ensure_loaded(tables).
|
||||
|
@ -1,46 +0,0 @@
|
||||
|
||||
cons_table(amsterdam, [0.2, 0.8]) :- !.
|
||||
cons_table(_, [0.8, 0.2]).
|
||||
|
||||
|
||||
gender_table(_, [0.55, 0.44]).
|
||||
|
||||
|
||||
hair_color_table(_,
|
||||
/* conservative_city */
|
||||
/* y n */
|
||||
[ 0.05, 0.1,
|
||||
0.95, 0.9 ]).
|
||||
|
||||
|
||||
car_color_table(_,
|
||||
/* t f */
|
||||
[ 0.9, 0.2,
|
||||
0.1, 0.8 ]).
|
||||
|
||||
|
||||
height_table(_,
|
||||
/* m f */
|
||||
[ 0.6, 0.4,
|
||||
0.4, 0.6 ]).
|
||||
|
||||
|
||||
shoe_size_table(_,
|
||||
/* t f */
|
||||
[ 0.9, 0.1,
|
||||
0.1, 0.9 ]).
|
||||
|
||||
|
||||
guilty_table(_, [0.23, 0.77]).
|
||||
|
||||
|
||||
descn_table(_,
|
||||
/* color, hair, height, guilt */
|
||||
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
|
||||
[ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.33, 0.91, 0.95, 0.92,
|
||||
0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.61, 0.28, 0.77, 0.09, 0.05, 0.08]).
|
||||
|
||||
|
||||
wit_table([0.2, 0.45, 0.24, 0.34,
|
||||
0.8, 0.55, 0.76, 0.66]).
|
||||
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user